module Expr where data Expr where Var :: Text -> Integer -> Expr Free :: Text -> Expr Level :: Integer -> Expr App :: Expr -> Expr -> Expr Abs :: Text -> Expr -> Expr -> Expr Pi :: Text -> Expr -> Expr -> Expr Let :: Text -> Maybe Expr -> Expr -> Expr -> Expr deriving (Show, Ord) instance Eq Expr where (Var _ n) == (Var _ m) = n == m (Free s) == (Free t) = s == t (Level i) == (Level j) = i == j (App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2 (Abs _ t1 b1) == (Abs _ t2 b2) = t1 == t2 && b1 == b2 (Pi _ t1 b1) == (Pi _ t2 b2) = t1 == t2 && b1 == b2 (Let _ _ v1 b1) == (Let _ _ v2 b2) = v1 == v2 && b1 == b2 _ == _ = False occursFree :: Integer -> Expr -> Bool occursFree n (Var _ k) = n == k occursFree _ (Free _) = False occursFree _ (Level _) = False occursFree n (App a b) = on (||) (occursFree n) a b occursFree n (Abs _ a b) = occursFree n a || occursFree (n + 1) b occursFree n (Pi _ a b) = occursFree n a || occursFree (n + 1) b occursFree n (Let _ _ v b) = occursFree n v || occursFree (n + 1) b shiftIndices :: Integer -> Integer -> Expr -> Expr shiftIndices d c (Var x k) | k >= c = Var x (k + d) | otherwise = Var x k shiftIndices _ _ (Free s) = Free s shiftIndices _ _ (Level i) = Level i shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n) shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n) shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n) shiftIndices d c (Let x t v b) = Let x t (shiftIndices d c v) (shiftIndices d (c + 1) b) incIndices :: Expr -> Expr incIndices = shiftIndices 1 0 {- --------------------- PRETTY PRINTING ----------------------------- -} parenthesize :: Text -> Text parenthesize s = "(" <> s <> ")" type Param = (Text, Expr) type ParamGroup = ([Text], Expr) type Binding = (Text, [ParamGroup], Expr) collectLambdas :: Expr -> ([Param], Expr) collectLambdas (Abs x ty body) = ((x, ty) : params, final) where (params, final) = collectLambdas body collectLambdas e = ([], e) collectLets :: Expr -> ([Binding], Expr) collectLets (Let x _ val body) = ((x, params', val') : bindings, final) where (bindings, final) = collectLets body (params, val') = collectLambdas val params' = groupParams params collectLets e = ([], e) collectPis :: Expr -> ([Param], Expr) collectPis p@(Pi "" _ _) = ([], p) collectPis (Pi x ty body) = ((x, ty) : params, final) where (params, final) = collectPis body collectPis e = ([], e) cleanNames :: Expr -> Expr cleanNames (App m n) = App (cleanNames m) (cleanNames n) cleanNames (Abs x ty body) = Abs x (cleanNames ty) (cleanNames body) cleanNames (Pi x ty body) | occursFree 0 body = Pi x (cleanNames ty) (cleanNames body) | otherwise = Pi "" ty (cleanNames body) cleanNames e = e groupParams :: [Param] -> [ParamGroup] groupParams = foldr addParam [] where addParam :: (Text, Expr) -> [([Text], Expr)] -> [([Text], Expr)] addParam (x, t) [] = [([x], t)] addParam (x, t) l@((xs, s) : rest) | incIndices t == s = (x : xs, t) : rest | otherwise = ([x], t) : l showParamGroup :: ParamGroup -> Text showParamGroup (ids, ty) = parenthesize $ unwords ids <> " : " <> pretty ty showBinding :: Binding -> Text showBinding (ident, params, val) = parenthesize $ ident <> " " <> unwords (map showParamGroup params) <> " := " <> pretty val helper :: Integer -> Expr -> Text helper _ (Var s _) = s helper _ (Free s) = s helper _ (Level i) | i == 0 = "*" | otherwise = "*" <> show i helper k (App e1 e2) = if k > 3 then parenthesize res else res where res = helper 3 e1 <> " " <> helper 4 e2 helper k (Pi "" t1 t2) = if k > 2 then parenthesize res else res where res = helper 3 t1 <> " -> " <> helper 2 t2 helper k e@(Pi{}) = if k > 2 then parenthesize res else res where (params, body) = collectPis e grouped = showParamGroup <$> groupParams params res = "∏ " <> unwords grouped <> " . " <> pretty body helper k e@(Abs{}) = if k >= 1 then parenthesize res else res where (params, body) = collectLambdas e grouped = showParamGroup <$> groupParams params res = "λ " <> unwords grouped <> " . " <> pretty body helper _ e@(Let{}) = res where (binds, body) = collectLets e res = "let " <> unwords (map showBinding binds) <> " in " <> pretty body <> " end" pretty :: Expr -> Text pretty = helper 0 . cleanNames prettyS :: Expr -> String prettyS = toString . pretty