module Expr where import qualified Data.List.NonEmpty as NE import Prettyprinter import Prettyprinter.Render.Text import Prelude hiding (group) data Expr where Var :: Text -> Integer -> Expr Free :: Text -> Expr Axiom :: Text -> Expr Star :: Expr Level :: Integer -> Expr App :: Expr -> Expr -> Expr Abs :: Text -> Expr -> Maybe Expr -> Expr -> Expr Pi :: Text -> Expr -> Maybe Expr -> Expr -> Expr Let :: Text -> Maybe Expr -> Expr -> Expr -> Expr deriving (Show, Ord) instance Pretty Expr where pretty = prettyExpr 0 . cleanNames instance Eq Expr where (Var _ n) == (Var _ m) = n == m (Free s) == (Free t) = s == t (Axiom s) == (Axiom t) = s == t Star == Star = True (Level i) == (Level j) = i == j (App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2 (Abs _ _ t1 b1) == (Abs _ _ t2 b2) = t1 == t2 && b1 == b2 (Pi _ _ t1 b1) == (Pi _ _ t2 b2) = t1 == t2 && b1 == b2 (Let _ _ v1 b1) == (Let _ _ v2 b2) = v1 == v2 && b1 == b2 _ == _ = False occursFree :: Integer -> Expr -> Bool occursFree n (Var _ k) = n == k occursFree _ (Free _) = False occursFree _ (Axiom _) = False occursFree _ Star = False occursFree _ (Level _) = False occursFree n (App a b) = on (||) (occursFree n) a b occursFree n (Abs _ a _ b) = occursFree n a || occursFree (n + 1) b occursFree n (Pi _ a _ b) = occursFree n a || occursFree (n + 1) b occursFree n (Let _ _ v b) = occursFree n v || occursFree (n + 1) b shiftIndices :: Integer -> Integer -> Expr -> Expr shiftIndices d c (Var x k) | k >= c = Var x (k + d) | otherwise = Var x k shiftIndices _ _ (Free s) = Free s shiftIndices _ _ (Axiom s) = Axiom s shiftIndices _ _ Star = Star shiftIndices _ _ (Level i) = Level i shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n) shiftIndices d c (Abs x m a n) = Abs x (shiftIndices d c m) a (shiftIndices d (c + 1) n) shiftIndices d c (Pi x m a n) = Pi x (shiftIndices d c m) a (shiftIndices d (c + 1) n) shiftIndices d c (Let x t v b) = Let x t (shiftIndices d c v) (shiftIndices d (c + 1) b) incIndices :: Expr -> Expr incIndices = shiftIndices 1 0 {- --------------------- PRETTY PRINTING ----------------------------- -} data Param = Param Text Expr data ParamGroup = ParamGroup [Text] Expr data Binding = Binding Text [ParamGroup] Expr instance Pretty Param where pretty (Param x ty) = group $ parens $ pretty x <+> ":" <+> pretty ty instance Pretty ParamGroup where pretty (ParamGroup vars ty) = group $ parens $ align (sep (map pretty vars)) <+> ":" <+> pretty ty instance Pretty Binding where pretty (Binding var [] body) = group $ parens $ pretty var <+> ":=" <+> pretty body pretty (Binding var params body) = group $ parens $ pretty var <+> align (sep (map pretty params)) <+> ":=" <+> pretty body collectLambdas :: Expr -> ([Param], Expr) collectLambdas (Abs x ty _ body) = (Param x ty : params, final) where (params, final) = collectLambdas body collectLambdas e = ([], e) collectLets :: Expr -> ([Binding], Expr) collectLets (Let x _ val body) = (Binding x params' val' : bindings, final) where (bindings, final) = collectLets body (params, val') = collectLambdas val params' = groupParams params collectLets e = ([], e) collectPis :: Expr -> ([Param], Expr) collectPis p@(Pi "" _ _ _) = ([], p) collectPis (Pi x ty _ body) = (Param x ty : params, final) where (params, final) = collectPis body collectPis e = ([], e) collectApps :: Expr -> NonEmpty Expr collectApps (App e1 e2) = e2 :| toList rest where rest = collectApps e1 collectApps e = pure e cleanNames :: Expr -> Expr cleanNames (App m n) = App (cleanNames m) (cleanNames n) cleanNames (Abs x ty a body) = Abs x (cleanNames ty) a (cleanNames body) cleanNames (Pi x ty a body) | occursFree 0 body = Pi x (cleanNames ty) a (cleanNames body) | otherwise = Pi "" ty a (cleanNames body) cleanNames e = e groupParams :: [Param] -> [ParamGroup] groupParams = foldr addParam [] where addParam :: Param -> [ParamGroup] -> [ParamGroup] addParam (Param x t) [] = [ParamGroup [x] t] addParam (Param x t) l@(ParamGroup xs s : rest) | incIndices t == s = ParamGroup (x : xs) t : rest | otherwise = ParamGroup [x] t : l printLevel :: Integer -> Doc ann printLevel k | k == 0 = "₀" | k == 1 = "₁" | k == 2 = "₂" | k == 3 = "₃" | k == 4 = "₄" | k == 5 = "₅" | k == 6 = "₆" | k == 7 = "₇" | k == 8 = "₈" | k == 9 = "₉" | k < 0 = pretty k | otherwise = printLevel (k `div` 10) <> printLevel (k `mod` 10) prettyExpr :: Integer -> Expr -> Doc ann prettyExpr k expr = case expr of Var s _ -> pretty s Free s -> pretty s Axiom s -> pretty s Star -> "*" Level i | i == 0 -> "□" | otherwise -> "□" <> printLevel i App{} | k > 3 -> parens application | otherwise -> application where (top :| rest) = NE.reverse $ collectApps expr application = group $ hang 4 $ prettyExpr 3 top <> line <> align (sep $ map (prettyExpr 4) rest) Pi "" t1 _ t2 | k > 2 -> parens piType | otherwise -> piType where piType = group $ prettyExpr 3 t1 <+> "->" <+> align (prettyExpr 2 t2) Pi{} -> group $ hang 4 $ "∏" <+> align (sep grouped) <> "," <> line <> align (prettyExpr 0 body) where (params, body) = collectPis expr grouped = pretty <$> groupParams params Abs{} -> if k >= 1 then parens lambdaForm else lambdaForm where (params, body) = collectLambdas expr grouped = pretty <$> groupParams params lambdaForm = group $ hang 4 $ "λ" <+> align (sep grouped) <+> "=>" <> line <> align (prettyExpr 0 body) Let{} -> group $ vsep [ "let" <+> align bindings , "in" <+> align (prettyExpr 0 body) , "end" ] where (binds, body) = collectLets expr bindings = sep $ map pretty binds prettyT :: Expr -> Text prettyT = renderStrict . layoutSmart defaultLayoutOptions . pretty prettyS :: Expr -> String prettyS = toString . prettyT