{-# LANGUAGE GADTs #-} module Expr where import Data.Function (on) data Expr where Var :: Integer -> String -> Expr Star :: Expr Square :: Expr App :: Expr -> Expr -> Expr Abs :: String -> Expr -> Expr -> Expr Pi :: String -> Expr -> Expr -> Expr deriving (Show, Eq) infixl 4 <.> (<.>) :: Expr -> Expr -> Expr (<.>) = App infixr 2 .-> (.->) :: Expr -> Expr -> Expr a .-> b = Pi "" a (incIndices b) occursFree :: Integer -> Expr -> Bool occursFree n (Var k _) = n == k occursFree _ Star = False occursFree _ Square = False occursFree n (App a b) = on (||) (occursFree n) a b occursFree n (Abs _ a b) = occursFree n a || occursFree (n + 1) b occursFree n (Pi _ a b) = occursFree n a || occursFree (n + 1) b {- --------------------- PRETTY PRINTING ----------------------------- -} parenthesize :: String -> String parenthesize s = "(" ++ s ++ ")" collectLambdas :: Expr -> ([(String, Expr)], Expr) collectLambdas (Abs x ty body) = ((x, ty) : params, final) where (params, final) = collectLambdas body collectLambdas e = ([], e) collectPis :: Expr -> ([(String, Expr)], Expr) collectPis p@(Pi "" _ _) = ([], p) collectPis (Pi x ty body) = ((x, ty) : params, final) where (params, final) = collectPis body collectPis e = ([], e) groupParams :: [(String, Expr)] -> [([String], Expr)] groupParams = foldr addParam [] where addParam :: (String, Expr) -> [([String], Expr)] -> [([String], Expr)] addParam (x, t) [] = [([x], t)] addParam (x, t) l@((xs, s) : rest) | t == s = (x : xs, t) : rest | otherwise = ([x], t) : l showParamGroup :: ([String], Expr) -> String showParamGroup (ids, ty) = parenthesize $ unwords ids ++ " : " ++ pretty ty helper :: Integer -> Expr -> String helper _ (Var _ s) = s helper _ Star = "*" helper _ Square = "□" helper k (App e1 e2) = if k > 3 then parenthesize res else res where res = helper 3 e1 ++ " " ++ helper 4 e2 helper k (Pi "" t1 t2) = if k > 2 then parenthesize res else res where res = helper 3 t1 ++ " -> " ++ helper 2 t2 helper k e@(Pi{}) = if k > 2 then parenthesize res else res where (params, body) = collectPis e grouped = showParamGroup <$> groupParams params res = "∏ " ++ unwords grouped ++ " . " ++ pretty body helper k e@(Abs{}) = if k >= 1 then parenthesize res else res where (params, body) = collectLambdas e grouped = showParamGroup <$> groupParams params res = "λ " ++ unwords grouped ++ " . " ++ pretty body pretty :: Expr -> String pretty = helper 0 {- --------------- ACTUAL MATH STUFF ---------------- -} isSort :: Expr -> Bool isSort Star = True isSort Square = True isSort _ = False incIndices :: Expr -> Expr incIndices (Var n x) = Var (n + 1) x incIndices Star = Star incIndices Square = Square incIndices (App m n) = App (incIndices m) (incIndices n) incIndices (Abs x m n) = Abs x (incIndices m) (incIndices n) incIndices (Pi x m n) = Pi x (incIndices m) (incIndices n) -- substitute s for 0 *AND* decrement indices; only use after reducing a redex. subst :: Expr -> Expr -> Expr subst s (Var 0 _) = s subst _ (Var n s) = Var (n - 1) s subst _ Star = Star subst _ Square = Square subst s (App m n) = App (subst s m) (subst s n) subst s (Abs x m n) = Abs x (subst s m) (subst s n) subst s (Pi x m n) = Pi x (subst s m) (subst s n) substnd :: Expr -> Expr -> Expr substnd s (Var 0 _) = s substnd _ (Var n s) = Var (n - 1) s substnd _ Star = Star substnd _ Square = Square substnd s (App m n) = App (substnd s m) (substnd s n) substnd s (Abs x m n) = Abs x (substnd s m) (substnd s n) substnd s (Pi x m n) = Pi x (substnd s m) (substnd s n) betaReduce :: Expr -> Expr betaReduce (Var k s) = Var k s betaReduce Star = Star betaReduce Square = Square betaReduce (App (Abs _ _ v) n) = subst n v betaReduce (App m n) = App (betaReduce m) (betaReduce n) betaReduce (Abs x t v) = Abs x (betaReduce t) (betaReduce v) betaReduce (Pi x t v) = Pi x (betaReduce t) (betaReduce v) betaNF :: Expr -> Expr betaNF e = if e == e' then e else betaNF e' where e' = betaReduce e betaEquiv :: Expr -> Expr -> Bool betaEquiv = on (==) betaNF