module Expr where import qualified Data.List.NonEmpty as NE import Prettyprinter import Prettyprinter.Render.Text import Prelude hiding (group) data Expr where Var :: Text -> Integer -> Expr Free :: Text -> Expr Axiom :: Text -> Expr Star :: Expr Level :: Integer -> Expr App :: Expr -> Expr -> Expr Abs :: Text -> Expr -> Expr -> Expr Pi :: Text -> Expr -> Expr -> Expr Let :: Text -> Maybe Expr -> Expr -> Expr -> Expr Prod :: Expr -> Expr -> Expr Pair :: Expr -> Expr -> Expr Pi1 :: Expr -> Expr Pi2 :: Expr -> Expr deriving (Show, Ord) instance Pretty Expr where pretty = prettyExpr 0 . cleanNames . dedupNames instance Eq Expr where (Var _ n) == (Var _ m) = n == m (Free s) == (Free t) = s == t (Axiom s) == (Axiom t) = s == t Star == Star = True (Level i) == (Level j) = i == j (App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2 (Abs _ t1 b1) == (Abs _ t2 b2) = t1 == t2 && b1 == b2 (Pi _ t1 b1) == (Pi _ t2 b2) = t1 == t2 && b1 == b2 (Let _ _ v1 b1) == (Let _ _ v2 b2) = v1 == v2 && b1 == b2 (Prod x1 y1) == (Prod x2 y2) = x1 == x2 && y1 == y2 (Pair x1 y1) == (Pair x2 y2) = x1 == x2 && y1 == y2 (Pi1 x) == (Pi1 y) = x == y (Pi2 x) == (Pi2 y) = x == y _ == _ = False occursFree :: Integer -> Expr -> Bool occursFree n (Var _ k) = n == k occursFree _ (Free _) = False occursFree _ (Axiom _) = False occursFree _ Star = False occursFree _ (Level _) = False occursFree n (App a b) = on (||) (occursFree n) a b occursFree n (Abs _ a b) = occursFree n a || occursFree (n + 1) b occursFree n (Pi _ a b) = occursFree n a || occursFree (n + 1) b occursFree n (Let _ _ v b) = occursFree n v || occursFree (n + 1) b occursFree n (Prod x y) = occursFree n x || occursFree n y occursFree n (Pair x y) = occursFree n x || occursFree n y occursFree n (Pi1 x) = occursFree n x occursFree n (Pi2 x) = occursFree n x shiftIndices :: Integer -> Integer -> Expr -> Expr shiftIndices d c (Var x k) | k >= c = Var x (k + d) | otherwise = Var x k shiftIndices _ _ (Free s) = Free s shiftIndices _ _ (Axiom s) = Axiom s shiftIndices _ _ Star = Star shiftIndices _ _ (Level i) = Level i shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n) shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n) shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n) shiftIndices d c (Let x t v b) = Let x t (shiftIndices d c v) (shiftIndices d (c + 1) b) shiftIndices d c (Prod m n) = Prod (shiftIndices d c m) (shiftIndices d c n) shiftIndices d c (Pair m n) = Pair (shiftIndices d c m) (shiftIndices d c n) shiftIndices d c (Pi1 x) = Pi1 (shiftIndices d c x) shiftIndices d c (Pi2 x) = Pi2 (shiftIndices d c x) incIndices :: Expr -> Expr incIndices = shiftIndices 1 0 {- --------------------- PRETTY PRINTING ----------------------------- -} dedupNames :: Expr -> Expr dedupNames = go [] where varName :: [Text] -> Text -> Int -> Text varName bs x k = if x == "" then x else case bs !!? k of Nothing -> x Just x' -> let count = (length $ filter (== x') $ drop (k + 1) bs) in if count > 0 then x <> printLevel count else x go :: [Text] -> Expr -> Expr go bs (Var x k) = Var (varName bs x (fromIntegral k)) k go bs (App m n) = App (go bs m) (go bs n) go bs (Abs x ty b) = Abs (varName (x : bs) x 0) (go bs ty) (go (x : bs) b) go bs (Pi x ty b) = Pi (varName (x : bs) x 0) (go bs ty) (go (x : bs) b) go bs (Let x ascr val b) = Let (varName (x : bs) x 0) (go bs <$> ascr) (go bs val) (go (x : bs) b) go _ e = e data Param = Param Text Expr data ParamGroup = ParamGroup [Text] Expr data Binding = Binding Text [ParamGroup] Expr instance Pretty Param where pretty (Param x ty) = group $ parens $ pretty x <+> ":" <+> pretty ty instance Pretty ParamGroup where pretty (ParamGroup vars ty) = group $ parens $ align (sep (map pretty vars)) <+> ":" <+> pretty ty instance Pretty Binding where pretty (Binding var [] body) = group $ parens $ pretty var <+> ":=" <+> pretty body pretty (Binding var params body) = group $ parens $ pretty var <+> align (sep (map pretty params)) <+> ":=" <+> pretty body collectLambdas :: Expr -> ([Param], Expr) collectLambdas (Abs x ty body) = (Param x ty : params, final) where (params, final) = collectLambdas body collectLambdas e = ([], e) collectLets :: Expr -> ([Binding], Expr) collectLets (Let x _ val body) = (Binding x params' val' : bindings, final) where (bindings, final) = collectLets body (params, val') = collectLambdas val params' = groupParams params collectLets e = ([], e) collectPis :: Expr -> ([Param], Expr) collectPis p@(Pi "" _ _) = ([], p) collectPis (Pi x ty body) = (Param x ty : params, final) where (params, final) = collectPis body collectPis e = ([], e) collectArrows :: Expr -> NonEmpty Expr collectArrows (Pi "" t1 t2) = t1 :| toList rest where rest = collectArrows t2 collectArrows e = pure e collectApps :: Expr -> NonEmpty Expr collectApps (App e1 e2) = e2 :| toList rest where rest = collectApps e1 collectApps e = pure e cleanNames :: Expr -> Expr cleanNames (App m n) = App (cleanNames m) (cleanNames n) cleanNames (Abs x ty body) = Abs x (cleanNames ty) (cleanNames body) cleanNames (Pi x ty body) | occursFree 0 body = Pi x (cleanNames ty) (cleanNames body) | otherwise = Pi "" ty (cleanNames body) cleanNames e = e groupParams :: [Param] -> [ParamGroup] groupParams = foldr addParam [] where addParam :: Param -> [ParamGroup] -> [ParamGroup] addParam (Param x t) [] = [ParamGroup [x] t] addParam (Param x t) l@(ParamGroup xs s : rest) | incIndices t == s = ParamGroup (x : xs) t : rest | otherwise = ParamGroup [x] t : l printLevel :: (IsString s, Semigroup s, Integral i) => i -> s printLevel k | k == 0 = "₀" | k == 1 = "₁" | k == 2 = "₂" | k == 3 = "₃" | k == 4 = "₄" | k == 5 = "₅" | k == 6 = "₆" | k == 7 = "₇" | k == 8 = "₈" | k == 9 = "₉" | k < 0 = printLevel k | otherwise = printLevel (k `div` 10) <> printLevel (k `mod` 10) prettyExpr :: Integer -> Expr -> Doc ann prettyExpr k expr = case expr of Var s _ -> pretty s Free s -> pretty s Axiom s -> pretty s Star -> "★" Level i | i == 0 -> "□" | otherwise -> "□" <> printLevel i App{} | k > 3 -> parens application | otherwise -> application where (top :| rest) = NE.reverse $ collectApps expr application = group $ hang 4 $ prettyExpr 3 top <> line <> align (sep $ map (prettyExpr 4) rest) Pi "" _ _ | k > 2 -> parens piType | otherwise -> piType where (top :| rest) = collectArrows expr piType = group $ hang 4 $ prettyExpr 3 top <+> align (sep $ map (("->" <+>) . prettyExpr 2) rest) Pi{} -> group $ hang 4 $ "∏" <+> align (sep grouped) <> "," <> line <> align (prettyExpr 0 body) where (params, body) = collectPis expr grouped = pretty <$> groupParams params Abs{} -> if k >= 1 then parens lambdaForm else lambdaForm where (params, body) = collectLambdas expr grouped = pretty <$> groupParams params lambdaForm = group $ hang 4 $ "λ" <+> align (sep grouped) <+> "=>" <> line <> align (prettyExpr 0 body) Let{} -> group $ vsep [ "let" <+> align bindings , "in" <+> align (prettyExpr 0 body) , "end" ] where (binds, body) = collectLets expr bindings = sep $ map pretty binds (Prod x y) -> parens $ parens (pretty x) <+> "×" <+> parens (pretty y) (Pair x y) -> parens $ pretty x <> "," <+> pretty y (Pi1 x) -> parens $ "π₁" <+> parens (pretty x) (Pi2 x) -> parens $ "π₂" <+> parens (pretty x) prettyT :: Expr -> Text prettyT = renderStrict . layoutSmart defaultLayoutOptions . pretty prettyS :: Expr -> String prettyS = toString . prettyT