perga/lib/Expr.hs
2024-12-13 22:45:37 -08:00

214 lines
7.1 KiB
Haskell

module Expr where
import qualified Data.List.NonEmpty as NE
import Prettyprinter
import Prettyprinter.Render.Text
import Prelude hiding (group)
data Expr where
Var :: Text -> Integer -> Expr
Free :: Text -> Expr
Axiom :: Text -> Expr
Star :: Expr
Level :: Integer -> Expr
App :: Expr -> Expr -> Expr
Abs :: Text -> Expr -> Expr -> Expr
Pi :: Text -> Expr -> Expr -> Expr
Let :: Text -> Maybe Expr -> Expr -> Expr -> Expr
deriving (Show, Ord)
instance Pretty Expr where
pretty = prettyExpr 0 . cleanNames . dedupNames
instance Eq Expr where
(Var _ n) == (Var _ m) = n == m
(Free s) == (Free t) = s == t
(Axiom s) == (Axiom t) = s == t
Star == Star = True
(Level i) == (Level j) = i == j
(App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2
(Abs _ t1 b1) == (Abs _ t2 b2) = t1 == t2 && b1 == b2
(Pi _ t1 b1) == (Pi _ t2 b2) = t1 == t2 && b1 == b2
(Let _ _ v1 b1) == (Let _ _ v2 b2) = v1 == v2 && b1 == b2
_ == _ = False
occursFree :: Integer -> Expr -> Bool
occursFree n (Var _ k) = n == k
occursFree _ (Free _) = False
occursFree _ (Axiom _) = False
occursFree _ Star = False
occursFree _ (Level _) = False
occursFree n (App a b) = on (||) (occursFree n) a b
occursFree n (Abs _ a b) = occursFree n a || occursFree (n + 1) b
occursFree n (Pi _ a b) = occursFree n a || occursFree (n + 1) b
occursFree n (Let _ _ v b) = occursFree n v || occursFree (n + 1) b
shiftIndices :: Integer -> Integer -> Expr -> Expr
shiftIndices d c (Var x k)
| k >= c = Var x (k + d)
| otherwise = Var x k
shiftIndices _ _ (Free s) = Free s
shiftIndices _ _ (Axiom s) = Axiom s
shiftIndices _ _ Star = Star
shiftIndices _ _ (Level i) = Level i
shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n)
shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n)
shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n)
shiftIndices d c (Let x t v b) = Let x t (shiftIndices d c v) (shiftIndices d (c + 1) b)
incIndices :: Expr -> Expr
incIndices = shiftIndices 1 0
{- --------------------- PRETTY PRINTING ----------------------------- -}
dedupNames :: Expr -> Expr
dedupNames = go []
where
varName :: [Text] -> Text -> Int -> Text
varName bs x k =
if x == ""
then x
else case bs !!? k of
Nothing -> x
Just x' ->
let count = (length $ filter (== x') $ drop (k + 1) bs)
in if count > 0
then x <> printLevel count
else x
go :: [Text] -> Expr -> Expr
go bs (Var x k) = Var (varName bs x (fromIntegral k)) k
go bs (App m n) = App (go bs m) (go bs n)
go bs (Abs x ty b) = Abs (varName (x : bs) x 0) (go bs ty) (go (x : bs) b)
go bs (Pi x ty b) = Pi (varName (x : bs) x 0) (go bs ty) (go (x : bs) b)
go bs (Let x ascr val b) = Let (varName (x : bs) x 0) (go bs <$> ascr) (go bs val) (go (x : bs) b)
go _ e = e
data Param = Param Text Expr
data ParamGroup = ParamGroup [Text] Expr
data Binding = Binding Text [ParamGroup] Expr
instance Pretty Param where
pretty (Param x ty) = group $ parens $ pretty x <+> ":" <+> pretty ty
instance Pretty ParamGroup where
pretty (ParamGroup vars ty) = group $ parens $ align (sep (map pretty vars)) <+> ":" <+> pretty ty
instance Pretty Binding where
pretty (Binding var [] body) = group $ parens $ pretty var <+> ":=" <+> pretty body
pretty (Binding var params body) = group $ parens $ pretty var <+> align (sep (map pretty params)) <+> ":=" <+> pretty body
collectLambdas :: Expr -> ([Param], Expr)
collectLambdas (Abs x ty body) = (Param x ty : params, final)
where
(params, final) = collectLambdas body
collectLambdas e = ([], e)
collectLets :: Expr -> ([Binding], Expr)
collectLets (Let x _ val body) = (Binding x params' val' : bindings, final)
where
(bindings, final) = collectLets body
(params, val') = collectLambdas val
params' = groupParams params
collectLets e = ([], e)
collectPis :: Expr -> ([Param], Expr)
collectPis p@(Pi "" _ _) = ([], p)
collectPis (Pi x ty body) = (Param x ty : params, final)
where
(params, final) = collectPis body
collectPis e = ([], e)
collectArrows :: Expr -> NonEmpty Expr
collectArrows (Pi "" t1 t2) = t1 :| toList rest
where
rest = collectArrows t2
collectArrows e = pure e
collectApps :: Expr -> NonEmpty Expr
collectApps (App e1 e2) = e2 :| toList rest
where
rest = collectApps e1
collectApps e = pure e
cleanNames :: Expr -> Expr
cleanNames (App m n) = App (cleanNames m) (cleanNames n)
cleanNames (Abs x ty body) = Abs x (cleanNames ty) (cleanNames body)
cleanNames (Pi x ty body)
| occursFree 0 body = Pi x (cleanNames ty) (cleanNames body)
| otherwise = Pi "" ty (cleanNames body)
cleanNames e = e
groupParams :: [Param] -> [ParamGroup]
groupParams = foldr addParam []
where
addParam :: Param -> [ParamGroup] -> [ParamGroup]
addParam (Param x t) [] = [ParamGroup [x] t]
addParam (Param x t) l@(ParamGroup xs s : rest)
| incIndices t == s = ParamGroup (x : xs) t : rest
| otherwise = ParamGroup [x] t : l
printLevel :: (IsString s, Semigroup s, Integral i) => i -> s
printLevel k
| k == 0 = ""
| k == 1 = ""
| k == 2 = ""
| k == 3 = ""
| k == 4 = ""
| k == 5 = ""
| k == 6 = ""
| k == 7 = ""
| k == 8 = ""
| k == 9 = ""
| k < 0 = printLevel k
| otherwise = printLevel (k `div` 10) <> printLevel (k `mod` 10)
prettyExpr :: Integer -> Expr -> Doc ann
prettyExpr k expr = case expr of
Var s _ -> pretty s
Free s -> pretty s
Axiom s -> pretty s
Star -> ""
Level i
| i == 0 -> ""
| otherwise -> "" <> printLevel i
App{}
| k > 3 -> parens application
| otherwise -> application
where
(top :| rest) = NE.reverse $ collectApps expr
application = group $ hang 4 $ prettyExpr 3 top <> line <> align (sep $ map (prettyExpr 4) rest)
Pi "" _ _
| k > 2 -> parens piType
| otherwise -> piType
where
(top :| rest) = collectArrows expr
piType = group $ hang 4 $ prettyExpr 3 top <+> align (sep $ map (("->" <+>) . prettyExpr 2) rest)
Pi{} -> group $ hang 4 $ "" <+> align (sep grouped) <> "," <> line <> align (prettyExpr 0 body)
where
(params, body) = collectPis expr
grouped = pretty <$> groupParams params
Abs{} ->
if k >= 1
then parens lambdaForm
else lambdaForm
where
(params, body) = collectLambdas expr
grouped = pretty <$> groupParams params
lambdaForm = group $ hang 4 $ "λ" <+> align (sep grouped) <+> "=>" <> line <> align (prettyExpr 0 body)
Let{} ->
group $
vsep
[ "let" <+> align bindings
, "in" <+> align (prettyExpr 0 body)
, "end"
]
where
(binds, body) = collectLets expr
bindings = sep $ map pretty binds
prettyT :: Expr -> Text
prettyT = renderStrict . layoutSmart defaultLayoutOptions . pretty
prettyS :: Expr -> String
prettyS = toString . prettyT