234 lines
8.1 KiB
Haskell
234 lines
8.1 KiB
Haskell
module Expr where
|
||
|
||
import qualified Data.List.NonEmpty as NE
|
||
import Prettyprinter
|
||
import Prettyprinter.Render.Text
|
||
import Prelude hiding (group)
|
||
|
||
data Expr where
|
||
Var :: Text -> Integer -> Expr
|
||
Free :: Text -> Expr
|
||
Axiom :: Text -> Expr
|
||
Star :: Expr
|
||
Level :: Integer -> Expr
|
||
App :: Expr -> Expr -> Expr
|
||
Abs :: Text -> Expr -> Expr -> Expr
|
||
Pi :: Text -> Expr -> Expr -> Expr
|
||
Let :: Text -> Maybe Expr -> Expr -> Expr -> Expr
|
||
Prod :: Expr -> Expr -> Expr
|
||
Pair :: Expr -> Expr -> Expr
|
||
Pi1 :: Expr -> Expr
|
||
Pi2 :: Expr -> Expr
|
||
deriving (Show, Ord)
|
||
|
||
instance Pretty Expr where
|
||
pretty = prettyExpr 0 . cleanNames . dedupNames
|
||
|
||
instance Eq Expr where
|
||
(Var _ n) == (Var _ m) = n == m
|
||
(Free s) == (Free t) = s == t
|
||
(Axiom s) == (Axiom t) = s == t
|
||
Star == Star = True
|
||
(Level i) == (Level j) = i == j
|
||
(App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2
|
||
(Abs _ t1 b1) == (Abs _ t2 b2) = t1 == t2 && b1 == b2
|
||
(Pi _ t1 b1) == (Pi _ t2 b2) = t1 == t2 && b1 == b2
|
||
(Let _ _ v1 b1) == (Let _ _ v2 b2) = v1 == v2 && b1 == b2
|
||
(Prod x1 y1) == (Prod x2 y2) = x1 == x2 && y1 == y2
|
||
(Pair x1 y1) == (Pair x2 y2) = x1 == x2 && y1 == y2
|
||
(Pi1 x) == (Pi1 y) = x == y
|
||
(Pi2 x) == (Pi2 y) = x == y
|
||
_ == _ = False
|
||
|
||
occursFree :: Integer -> Expr -> Bool
|
||
occursFree n (Var _ k) = n == k
|
||
occursFree _ (Free _) = False
|
||
occursFree _ (Axiom _) = False
|
||
occursFree _ Star = False
|
||
occursFree _ (Level _) = False
|
||
occursFree n (App a b) = on (||) (occursFree n) a b
|
||
occursFree n (Abs _ a b) = occursFree n a || occursFree (n + 1) b
|
||
occursFree n (Pi _ a b) = occursFree n a || occursFree (n + 1) b
|
||
occursFree n (Let _ _ v b) = occursFree n v || occursFree (n + 1) b
|
||
occursFree n (Prod x y) = occursFree n x || occursFree n y
|
||
occursFree n (Pair x y) = occursFree n x || occursFree n y
|
||
occursFree n (Pi1 x) = occursFree n x
|
||
occursFree n (Pi2 x) = occursFree n x
|
||
|
||
shiftIndices :: Integer -> Integer -> Expr -> Expr
|
||
shiftIndices d c (Var x k)
|
||
| k >= c = Var x (k + d)
|
||
| otherwise = Var x k
|
||
shiftIndices _ _ (Free s) = Free s
|
||
shiftIndices _ _ (Axiom s) = Axiom s
|
||
shiftIndices _ _ Star = Star
|
||
shiftIndices _ _ (Level i) = Level i
|
||
shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n)
|
||
shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n)
|
||
shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n)
|
||
shiftIndices d c (Let x t v b) = Let x t (shiftIndices d c v) (shiftIndices d (c + 1) b)
|
||
shiftIndices d c (Prod m n) = Prod (shiftIndices d c m) (shiftIndices d c n)
|
||
shiftIndices d c (Pair m n) = Pair (shiftIndices d c m) (shiftIndices d c n)
|
||
shiftIndices d c (Pi1 x) = Pi1 (shiftIndices d c x)
|
||
shiftIndices d c (Pi2 x) = Pi2 (shiftIndices d c x)
|
||
|
||
incIndices :: Expr -> Expr
|
||
incIndices = shiftIndices 1 0
|
||
|
||
{- --------------------- PRETTY PRINTING ----------------------------- -}
|
||
|
||
dedupNames :: Expr -> Expr
|
||
dedupNames = go []
|
||
where
|
||
varName :: [Text] -> Text -> Int -> Text
|
||
varName bs x k =
|
||
if x == ""
|
||
then x
|
||
else case bs !!? k of
|
||
Nothing -> x
|
||
Just x' ->
|
||
let count = (length $ filter (== x') $ drop (k + 1) bs)
|
||
in if count > 0
|
||
then x <> printLevel count
|
||
else x
|
||
|
||
go :: [Text] -> Expr -> Expr
|
||
go bs (Var x k) = Var (varName bs x (fromIntegral k)) k
|
||
go bs (App m n) = App (go bs m) (go bs n)
|
||
go bs (Abs x ty b) = Abs (varName (x : bs) x 0) (go bs ty) (go (x : bs) b)
|
||
go bs (Pi x ty b) = Pi (varName (x : bs) x 0) (go bs ty) (go (x : bs) b)
|
||
go bs (Let x ascr val b) = Let (varName (x : bs) x 0) (go bs <$> ascr) (go bs val) (go (x : bs) b)
|
||
go _ e = e
|
||
|
||
data Param = Param Text Expr
|
||
data ParamGroup = ParamGroup [Text] Expr
|
||
data Binding = Binding Text [ParamGroup] Expr
|
||
|
||
instance Pretty Param where
|
||
pretty (Param x ty) = group $ parens $ pretty x <+> ":" <+> pretty ty
|
||
|
||
instance Pretty ParamGroup where
|
||
pretty (ParamGroup vars ty) = group $ parens $ align (sep (map pretty vars)) <+> ":" <+> pretty ty
|
||
|
||
instance Pretty Binding where
|
||
pretty (Binding var [] body) = group $ parens $ pretty var <+> ":=" <+> pretty body
|
||
pretty (Binding var params body) = group $ parens $ pretty var <+> align (sep (map pretty params)) <+> ":=" <+> pretty body
|
||
|
||
collectLambdas :: Expr -> ([Param], Expr)
|
||
collectLambdas (Abs x ty body) = (Param x ty : params, final)
|
||
where
|
||
(params, final) = collectLambdas body
|
||
collectLambdas e = ([], e)
|
||
|
||
collectLets :: Expr -> ([Binding], Expr)
|
||
collectLets (Let x _ val body) = (Binding x params' val' : bindings, final)
|
||
where
|
||
(bindings, final) = collectLets body
|
||
(params, val') = collectLambdas val
|
||
params' = groupParams params
|
||
collectLets e = ([], e)
|
||
|
||
collectPis :: Expr -> ([Param], Expr)
|
||
collectPis p@(Pi "" _ _) = ([], p)
|
||
collectPis (Pi x ty body) = (Param x ty : params, final)
|
||
where
|
||
(params, final) = collectPis body
|
||
collectPis e = ([], e)
|
||
|
||
collectArrows :: Expr -> NonEmpty Expr
|
||
collectArrows (Pi "" t1 t2) = t1 :| toList rest
|
||
where
|
||
rest = collectArrows t2
|
||
collectArrows e = pure e
|
||
|
||
collectApps :: Expr -> NonEmpty Expr
|
||
collectApps (App e1 e2) = e2 :| toList rest
|
||
where
|
||
rest = collectApps e1
|
||
collectApps e = pure e
|
||
|
||
cleanNames :: Expr -> Expr
|
||
cleanNames (App m n) = App (cleanNames m) (cleanNames n)
|
||
cleanNames (Abs x ty body) = Abs x (cleanNames ty) (cleanNames body)
|
||
cleanNames (Pi x ty body)
|
||
| occursFree 0 body = Pi x (cleanNames ty) (cleanNames body)
|
||
| otherwise = Pi "" ty (cleanNames body)
|
||
cleanNames e = e
|
||
|
||
groupParams :: [Param] -> [ParamGroup]
|
||
groupParams = foldr addParam []
|
||
where
|
||
addParam :: Param -> [ParamGroup] -> [ParamGroup]
|
||
addParam (Param x t) [] = [ParamGroup [x] t]
|
||
addParam (Param x t) l@(ParamGroup xs s : rest)
|
||
| incIndices t == s = ParamGroup (x : xs) t : rest
|
||
| otherwise = ParamGroup [x] t : l
|
||
|
||
printLevel :: (IsString s, Semigroup s, Integral i) => i -> s
|
||
printLevel k
|
||
| k == 0 = "₀"
|
||
| k == 1 = "₁"
|
||
| k == 2 = "₂"
|
||
| k == 3 = "₃"
|
||
| k == 4 = "₄"
|
||
| k == 5 = "₅"
|
||
| k == 6 = "₆"
|
||
| k == 7 = "₇"
|
||
| k == 8 = "₈"
|
||
| k == 9 = "₉"
|
||
| k < 0 = printLevel k
|
||
| otherwise = printLevel (k `div` 10) <> printLevel (k `mod` 10)
|
||
|
||
prettyExpr :: Integer -> Expr -> Doc ann
|
||
prettyExpr k expr = case expr of
|
||
Var s _ -> pretty s
|
||
Free s -> pretty s
|
||
Axiom s -> pretty s
|
||
Star -> "★"
|
||
Level i
|
||
| i == 0 -> "□"
|
||
| otherwise -> "□" <> printLevel i
|
||
App{}
|
||
| k > 3 -> parens application
|
||
| otherwise -> application
|
||
where
|
||
(top :| rest) = NE.reverse $ collectApps expr
|
||
application = group $ hang 4 $ prettyExpr 3 top <> line <> align (sep $ map (prettyExpr 4) rest)
|
||
Pi "" _ _
|
||
| k > 2 -> parens piType
|
||
| otherwise -> piType
|
||
where
|
||
(top :| rest) = collectArrows expr
|
||
piType = group $ hang 4 $ prettyExpr 3 top <+> align (sep $ map (("->" <+>) . prettyExpr 2) rest)
|
||
Pi{} -> group $ hang 4 $ "∏" <+> align (sep grouped) <> "," <> line <> align (prettyExpr 0 body)
|
||
where
|
||
(params, body) = collectPis expr
|
||
grouped = pretty <$> groupParams params
|
||
Abs{} ->
|
||
if k >= 1
|
||
then parens lambdaForm
|
||
else lambdaForm
|
||
where
|
||
(params, body) = collectLambdas expr
|
||
grouped = pretty <$> groupParams params
|
||
lambdaForm = group $ hang 4 $ "λ" <+> align (sep grouped) <+> "=>" <> line <> align (prettyExpr 0 body)
|
||
Let{} ->
|
||
group $
|
||
vsep
|
||
[ "let" <+> align bindings
|
||
, "in" <+> align (prettyExpr 0 body)
|
||
, "end"
|
||
]
|
||
where
|
||
(binds, body) = collectLets expr
|
||
bindings = sep $ map pretty binds
|
||
(Prod x y) -> parens $ parens (pretty x) <+> "×" <+> parens (pretty y)
|
||
(Pair x y) -> parens $ pretty x <> "," <+> pretty y
|
||
(Pi1 x) -> parens $ "π₁" <+> parens (pretty x)
|
||
(Pi2 x) -> parens $ "π₂" <+> parens (pretty x)
|
||
|
||
prettyT :: Expr -> Text
|
||
prettyT = renderStrict . layoutSmart defaultLayoutOptions . pretty
|
||
|
||
prettyS :: Expr -> String
|
||
prettyS = toString . prettyT
|