perga/lib/Expr.hs

235 lines
8.2 KiB
Haskell
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

module Expr where
import qualified Data.List.NonEmpty as NE
import Prettyprinter
import Prettyprinter.Render.Text
import Prelude hiding (group)
data Expr where
Var :: Text -> Integer -> Expr
Free :: Text -> Expr
Axiom :: Text -> Expr
Star :: Expr
Level :: Integer -> Expr
App :: Expr -> Expr -> Expr
Abs :: Text -> Expr -> Expr -> Expr
Pi :: Text -> Expr -> Expr -> Expr
Let :: Text -> Maybe Expr -> Expr -> Expr -> Expr
Sigma :: Text -> Expr -> Expr -> Expr
Pair :: Expr -> Expr -> Expr
Pi1 :: Expr -> Expr
Pi2 :: Expr -> Expr
deriving (Show, Ord)
instance Pretty Expr where
pretty = prettyExpr 0 . cleanNames . dedupNames
instance Eq Expr where
(Var _ n) == (Var _ m) = n == m
(Free s) == (Free t) = s == t
(Axiom s) == (Axiom t) = s == t
Star == Star = True
(Level i) == (Level j) = i == j
(App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2
(Abs _ t1 b1) == (Abs _ t2 b2) = t1 == t2 && b1 == b2
(Pi _ t1 b1) == (Pi _ t2 b2) = t1 == t2 && b1 == b2
(Let _ _ v1 b1) == (Let _ _ v2 b2) = v1 == v2 && b1 == b2
(Sigma _ x1 y1) == (Sigma _ x2 y2) = x1 == x2 && y1 == y2
(Pair x1 y1) == (Pair x2 y2) = x1 == x2 && y1 == y2
(Pi1 x) == (Pi1 y) = x == y
(Pi2 x) == (Pi2 y) = x == y
_ == _ = False
occursFree :: Integer -> Expr -> Bool
occursFree n (Var _ k) = n == k
occursFree _ (Free _) = False
occursFree _ (Axiom _) = False
occursFree _ Star = False
occursFree _ (Level _) = False
occursFree n (App a b) = on (||) (occursFree n) a b
occursFree n (Abs _ a b) = occursFree n a || occursFree (n + 1) b
occursFree n (Pi _ a b) = occursFree n a || occursFree (n + 1) b
occursFree n (Let _ _ v b) = occursFree n v || occursFree (n + 1) b
occursFree n (Sigma _ x y) = occursFree n x || occursFree n y
occursFree n (Pair x y) = occursFree n x || occursFree n y
occursFree n (Pi1 x) = occursFree n x
occursFree n (Pi2 x) = occursFree n x
shiftIndices :: Integer -> Integer -> Expr -> Expr
shiftIndices d c (Var x k)
| k >= c = Var x (k + d)
| otherwise = Var x k
shiftIndices _ _ (Free s) = Free s
shiftIndices _ _ (Axiom s) = Axiom s
shiftIndices _ _ Star = Star
shiftIndices _ _ (Level i) = Level i
shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n)
shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n)
shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n)
shiftIndices d c (Let x t v b) = Let x t (shiftIndices d c v) (shiftIndices d (c + 1) b)
shiftIndices d c (Sigma x m n) = Sigma x (shiftIndices d c m) (shiftIndices d c n)
shiftIndices d c (Pair m n) = Pair (shiftIndices d c m) (shiftIndices d c n)
shiftIndices d c (Pi1 x) = Pi1 (shiftIndices d c x)
shiftIndices d c (Pi2 x) = Pi2 (shiftIndices d c x)
incIndices :: Expr -> Expr
incIndices = shiftIndices 1 0
{- --------------------- PRETTY PRINTING ----------------------------- -}
dedupNames :: Expr -> Expr
dedupNames = go []
where
varName :: [Text] -> Text -> Int -> Text
varName bs x k =
if x == ""
then x
else case bs !!? k of
Nothing -> x
Just x' ->
let count = (length $ filter (== x') $ drop (k + 1) bs)
in if count > 0
then x <> printLevel count
else x
go :: [Text] -> Expr -> Expr
go bs (Var x k) = Var (varName bs x (fromIntegral k)) k
go bs (App m n) = App (go bs m) (go bs n)
go bs (Abs x ty b) = Abs (varName (x : bs) x 0) (go bs ty) (go (x : bs) b)
go bs (Pi x ty b) = Pi (varName (x : bs) x 0) (go bs ty) (go (x : bs) b)
go bs (Let x ascr val b) = Let (varName (x : bs) x 0) (go bs <$> ascr) (go bs val) (go (x : bs) b)
go _ e = e
data Param = Param Text Expr
data ParamGroup = ParamGroup [Text] Expr
data Binding = Binding Text [ParamGroup] Expr
instance Pretty Param where
pretty (Param x ty) = group $ parens $ pretty x <+> ":" <+> pretty ty
instance Pretty ParamGroup where
pretty (ParamGroup vars ty) = group $ parens $ align (sep (map pretty vars)) <+> ":" <+> pretty ty
instance Pretty Binding where
pretty (Binding var [] body) = group $ parens $ pretty var <+> ":=" <+> pretty body
pretty (Binding var params body) = group $ parens $ pretty var <+> align (sep (map pretty params)) <+> ":=" <+> pretty body
collectLambdas :: Expr -> ([Param], Expr)
collectLambdas (Abs x ty body) = (Param x ty : params, final)
where
(params, final) = collectLambdas body
collectLambdas e = ([], e)
collectLets :: Expr -> ([Binding], Expr)
collectLets (Let x _ val body) = (Binding x params' val' : bindings, final)
where
(bindings, final) = collectLets body
(params, val') = collectLambdas val
params' = groupParams params
collectLets e = ([], e)
collectPis :: Expr -> ([Param], Expr)
collectPis p@(Pi "" _ _) = ([], p)
collectPis (Pi x ty body) = (Param x ty : params, final)
where
(params, final) = collectPis body
collectPis e = ([], e)
collectArrows :: Expr -> NonEmpty Expr
collectArrows (Pi "" t1 t2) = t1 :| toList rest
where
rest = collectArrows t2
collectArrows e = pure e
collectApps :: Expr -> NonEmpty Expr
collectApps (App e1 e2) = e2 :| toList rest
where
rest = collectApps e1
collectApps e = pure e
cleanNames :: Expr -> Expr
cleanNames (App m n) = App (cleanNames m) (cleanNames n)
cleanNames (Abs x ty body) = Abs x (cleanNames ty) (cleanNames body)
cleanNames (Pi x ty body)
| occursFree 0 body = Pi x (cleanNames ty) (cleanNames body)
| otherwise = Pi "" ty (cleanNames body)
cleanNames e = e
groupParams :: [Param] -> [ParamGroup]
groupParams = foldr addParam []
where
addParam :: Param -> [ParamGroup] -> [ParamGroup]
addParam (Param x t) [] = [ParamGroup [x] t]
addParam (Param x t) l@(ParamGroup xs s : rest)
| incIndices t == s = ParamGroup (x : xs) t : rest
| otherwise = ParamGroup [x] t : l
printLevel :: (IsString s, Semigroup s, Integral i) => i -> s
printLevel k
| k == 0 = ""
| k == 1 = ""
| k == 2 = ""
| k == 3 = ""
| k == 4 = ""
| k == 5 = ""
| k == 6 = ""
| k == 7 = ""
| k == 8 = ""
| k == 9 = ""
| k < 0 = printLevel k
| otherwise = printLevel (k `div` 10) <> printLevel (k `mod` 10)
prettyExpr :: Integer -> Expr -> Doc ann
prettyExpr k expr = case expr of
Var s _ -> pretty s
Free s -> pretty s
Axiom s -> pretty s
Star -> ""
Level i
| i == 0 -> ""
| otherwise -> "" <> printLevel i
App{}
| k > 3 -> parens application
| otherwise -> application
where
(top :| rest) = NE.reverse $ collectApps expr
application = group $ hang 4 $ prettyExpr 3 top <> line <> align (sep $ map (prettyExpr 4) rest)
Pi "" _ _
| k > 2 -> parens piType
| otherwise -> piType
where
(top :| rest) = collectArrows expr
piType = group $ hang 4 $ prettyExpr 3 top <+> align (sep $ map (("->" <+>) . prettyExpr 2) rest)
Pi{} -> group $ hang 4 $ "" <+> align (sep grouped) <> "," <> line <> align (prettyExpr 0 body)
where
(params, body) = collectPis expr
grouped = pretty <$> groupParams params
Abs{} ->
if k >= 1
then parens lambdaForm
else lambdaForm
where
(params, body) = collectLambdas expr
grouped = pretty <$> groupParams params
lambdaForm = group $ hang 4 $ "λ" <+> align (sep grouped) <+> "=>" <> line <> align (prettyExpr 0 body)
Let{} ->
group $
vsep
[ "let" <+> align bindings
, "in" <+> align (prettyExpr 0 body)
, "end"
]
where
(binds, body) = collectLets expr
bindings = sep $ map pretty binds
(Sigma "" x y) -> parens $ parens (pretty x) <+> "×" <+> parens (pretty y)
(Sigma x t m) -> parens $ "Σ" <+> pretty x <+> ":" <+> pretty t <> "," <+> pretty m
(Pair x y) -> parens $ pretty x <> "," <+> pretty y
(Pi1 x) -> parens $ "π₁" <+> parens (pretty x)
(Pi2 x) -> parens $ "π₂" <+> parens (pretty x)
prettyT :: Expr -> Text
prettyT = renderStrict . layoutSmart defaultLayoutOptions . pretty
prettyS :: Expr -> String
prettyS = toString . prettyT