perga/lib/Elaborator.hs

208 lines
7.8 KiB
Haskell
Raw Normal View History

2024-12-06 00:40:24 -08:00
{-# LANGUAGE TupleSections #-}
2024-11-30 22:36:27 -08:00
module Elaborator where
2024-12-06 00:40:24 -08:00
import Data.List (elemIndex, lookup)
2024-12-06 13:36:14 -08:00
import qualified Data.Set as S
2024-11-30 22:36:27 -08:00
import Expr (Expr)
import qualified Expr as E
2024-12-06 13:36:14 -08:00
import IR (IRDef (..), IRExpr, IRProgram, IRSectionDef (..))
2024-11-30 22:36:27 -08:00
import qualified IR as I
2024-12-05 11:19:23 -08:00
import Relude.Extra.Lens
2024-11-30 22:36:27 -08:00
type Binders = [Text]
2024-12-05 11:19:23 -08:00
data SectionContext = SectionContext
2024-12-06 13:36:14 -08:00
{ sectionDefs :: [(Text, [(Text, IRExpr)])] -- name and list of variables and their types it depends on
2024-12-05 11:19:23 -08:00
, sectionVars :: [(Text, IRExpr)] -- variables and their types
}
2024-12-01 15:28:57 -08:00
2024-12-06 00:40:24 -08:00
type ElabMonad = State SectionContext
lookupDefInCtxt :: Text -> SectionContext -> Maybe [(Text, IRExpr)]
2024-12-06 13:36:14 -08:00
lookupDefInCtxt def (SectionContext defs _) = lookup def defs
2024-12-06 00:40:24 -08:00
-- looks up a definition in the context and gives a list of the variables and
-- their types that it depends on
lookupDef :: Text -> ElabMonad (Maybe [(Text, IRExpr)])
lookupDef def = lookupDefInCtxt def <$> get
lookupVarInCtxt :: Text -> SectionContext -> Maybe IRExpr
lookupVarInCtxt var = lookup var . sectionVars
-- looks up a variable in the context and returns its type
lookupVar :: Text -> ElabMonad (Maybe IRExpr)
lookupVar var = lookupVarInCtxt var <$> get
2024-12-05 11:19:23 -08:00
2024-12-06 13:36:14 -08:00
sectionDefsL :: Lens' SectionContext [(Text, [(Text, IRExpr)])]
2024-12-05 11:19:23 -08:00
sectionDefsL = lens sectionDefs setter
where
setter ctxt newDefs = ctxt{sectionDefs = newDefs}
sectionVarsL :: Lens' SectionContext [(Text, IRExpr)]
sectionVarsL = lens sectionVars setter
2024-12-01 15:28:57 -08:00
where
2024-12-05 11:19:23 -08:00
setter ctxt newVars = ctxt{sectionVars = newVars}
saveState :: ElabMonad a -> ElabMonad a
saveState action = get >>= (action <*) . put
2024-12-06 13:36:14 -08:00
debugIRExpr :: IRExpr -> String
debugIRExpr = E.prettyS . elaborate
debugIRDef :: IRDef -> String
debugIRDef (Def name (Just ty) body) = "def " ++ toString name ++ " : " ++ debugIRExpr ty ++ " := " ++ debugIRExpr body ++ ";"
debugIRDef (Def name Nothing body) = "def " ++ toString name ++ " := " ++ debugIRExpr body ++ ";"
debugIRDef (Axiom name typ) = "axiom " ++ toString name ++ " : " ++ debugIRExpr typ ++ ";"
debugIRSectionDef :: IRSectionDef -> String
debugIRSectionDef (Variable name typ) = "variable " ++ toString name ++ " : " ++ debugIRExpr typ ++ ";"
debugIRSectionDef (Section name _) = "section " ++ toString name ++ ";"
debugIRSectionDef (IRDef def) = debugIRDef def
elabSection :: Text -> [IRSectionDef] -> ElabMonad [IRDef]
2024-12-06 00:40:24 -08:00
elabSection _name contents = saveState $ concat <$> forM contents elabDef
2024-12-05 11:19:23 -08:00
2024-12-06 13:36:14 -08:00
elabProgram :: IRProgram -> [IRDef]
2024-12-06 00:40:24 -08:00
elabProgram prog = evalState (elabSection "" prog) (SectionContext [] [])
2024-12-05 11:19:23 -08:00
2024-12-06 00:40:24 -08:00
pushVariable :: Text -> IRExpr -> SectionContext -> SectionContext
pushVariable name ty (SectionContext defs vars) = SectionContext defs ((name, ty) : vars)
2024-12-06 13:36:14 -08:00
pushDefinition :: Text -> [(Text, IRExpr)] -> SectionContext -> SectionContext
2024-12-06 00:40:24 -08:00
pushDefinition name defVars (SectionContext defs vars) = SectionContext ((name, defVars) : defs) vars
2024-12-05 11:19:23 -08:00
2024-12-06 00:40:24 -08:00
removeName :: Text -> ElabMonad ()
2024-12-06 13:36:14 -08:00
removeName name = do
2024-12-06 00:40:24 -08:00
modify $ over sectionDefsL (filter ((/= name) . fst))
modify $ over sectionVarsL (filter ((/= name) . fst))
2024-12-05 11:19:23 -08:00
2024-12-06 13:36:14 -08:00
extendVars :: Set (Text, IRExpr) -> ElabMonad (Set (Text, IRExpr))
extendVars vars = do
vars' <- foldr S.union S.empty <$> traverse (usedVars . snd) (S.toList vars)
if vars' `S.isSubsetOf` vars
then pure vars
else extendVars (vars `S.union` vars')
organize :: Set (Text, IRExpr) -> ElabMonad [(Text, IRExpr)]
organize found = do
vars <- gets sectionVars
pure $ reverse [var | var <- vars, var `S.member` found]
2024-12-06 00:40:24 -08:00
-- find all the section variables used in an expression
2024-12-06 13:36:14 -08:00
usedVars :: IRExpr -> ElabMonad (Set (Text, IRExpr))
usedVars (I.Var name) = do
varDeps <- maybe S.empty (S.singleton . (name,)) <$> lookupVar name
defDeps <- maybe S.empty S.fromList <$> lookupDef name
pure $ varDeps `S.union` defDeps
usedVars I.Star = pure S.empty
usedVars (I.Level _) = pure S.empty
usedVars (I.App m n) = S.union <$> usedVars m <*> usedVars n
2024-12-06 00:40:24 -08:00
usedVars (I.Abs name ty ascr body) = saveState $ do
ty' <- usedVars ty
ascr' <- traverse usedVars ascr
removeName name
2024-12-06 13:36:14 -08:00
S.union (ty' `S.union` (ascr' ?: S.empty)) <$> usedVars body
2024-12-06 00:40:24 -08:00
usedVars (I.Pi name ty ascr body) = saveState $ do
ty' <- usedVars ty
ascr' <- traverse usedVars ascr
removeName name
2024-12-06 13:36:14 -08:00
S.union (ty' `S.union` (ascr' ?: S.empty)) <$> usedVars body
2024-12-06 00:40:24 -08:00
usedVars (I.Let name ascr value body) = saveState $ do
ty' <- usedVars value
ascr' <- traverse usedVars ascr
removeName name
2024-12-06 13:36:14 -08:00
S.union (ty' `S.union` (ascr' ?: S.empty)) <$> usedVars body
2024-12-06 00:40:24 -08:00
-- traverse the body of a definition, adding the necessary section arguments to
-- any definitions made within the section
traverseBody :: IRExpr -> ElabMonad IRExpr
traverseBody (I.Var name) = do
2024-12-06 13:36:14 -08:00
mdeps <- lookupDef name
case mdeps of
2024-12-06 00:40:24 -08:00
Nothing -> pure $ I.Var name
2024-12-06 13:36:14 -08:00
Just deps -> pure $ foldl' (\acc newVar -> I.App acc (I.Var $ fst newVar)) (I.Var name) deps
2024-12-06 00:40:24 -08:00
traverseBody I.Star = pure I.Star
2024-12-06 13:36:14 -08:00
traverseBody (I.Level k) = pure $ I.Level k
2024-12-06 00:40:24 -08:00
traverseBody (I.App m n) = I.App <$> traverseBody m <*> traverseBody n
traverseBody (I.Abs name ty ascr body) = saveState $ do
ty' <- traverseBody ty
ascr' <- traverse traverseBody ascr
removeName name
I.Abs name ty' ascr' <$> traverseBody body
traverseBody (I.Pi name ty ascr body) = saveState $ do
ty' <- traverseBody ty
ascr' <- traverse traverseBody ascr
removeName name
I.Pi name ty' ascr' <$> traverseBody body
traverseBody (I.Let name ascr value body) = saveState $ do
ascr' <- traverse traverseBody ascr
2024-12-06 13:36:14 -08:00
value' <- traverseBody value
2024-12-06 00:40:24 -08:00
removeName name
I.Let name ascr' value' <$> traverseBody body
mkPi :: (Text, IRExpr) -> IRExpr -> IRExpr
mkPi (param, typ) = I.Pi param typ Nothing
mkAbs :: (Text, IRExpr) -> IRExpr -> IRExpr
mkAbs (param, typ) = I.Abs param typ Nothing
generalizeType :: IRExpr -> [(Text, IRExpr)] -> IRExpr
generalizeType = foldr mkPi
generalizeVal :: IRExpr -> [(Text, IRExpr)] -> IRExpr
generalizeVal = foldr mkAbs
2024-12-06 13:36:14 -08:00
elabDef :: IRSectionDef -> ElabMonad [IRDef]
elabDef (IRDef (Def name ty body)) = do
tyVars <- fromMaybe S.empty <$> traverse usedVars ty
2024-12-06 00:40:24 -08:00
bodyVars <- usedVars body
2024-12-06 13:36:14 -08:00
totalVars <- extendVars (tyVars `S.union` bodyVars) >>= organize
newBody <- traverseBody body
newType <- traverse traverseBody ty
modify $ pushDefinition name totalVars
pure [Def name (flip generalizeType totalVars <$> newType) (generalizeVal newBody totalVars)]
elabDef (IRDef (Axiom name ty)) = do
vars <- usedVars ty >>= extendVars >>= organize
modify $ pushDefinition name vars
2024-12-06 00:40:24 -08:00
pure [Axiom name (generalizeType ty vars)]
elabDef (Section name contents) = saveState $ elabSection name contents
elabDef (Variable name ty) = [] <$ modify' (pushVariable name ty)
2024-12-05 11:19:23 -08:00
2024-12-06 13:36:14 -08:00
saveBinders :: State Binders a -> State Binders a
saveBinders action = do
binders <- get
res <- action
put binders
pure res
elaborate :: IRExpr -> Expr
elaborate ir = evalState (elaborate' ir) []
where
elaborate' :: IRExpr -> State Binders Expr
elaborate' (I.Var n) = do
binders <- get
pure $ E.Var n . fromIntegral <$> elemIndex n binders ?: E.Free n
elaborate' I.Star = pure E.Star
elaborate' (I.Level level) = pure $ E.Level level
elaborate' (I.App m n) = E.App <$> elaborate' m <*> elaborate' n
elaborate' (I.Abs x t a b) = saveBinders $ do
t' <- elaborate' t
a' <- traverse elaborate' a
modify (x :)
E.Abs x t' a' <$> elaborate' b
elaborate' (I.Pi x t a b) = saveBinders $ do
t' <- elaborate' t
a' <- traverse elaborate' a
modify (x :)
E.Pi x t' a' <$> elaborate' b
elaborate' (I.Let name Nothing val body) = saveBinders $ do
val' <- elaborate' val
modify (name :)
E.Let name Nothing val' <$> elaborate' body
elaborate' (I.Let name (Just ty) val body) = saveBinders $ do
val' <- elaborate' val
ty' <- elaborate' ty
modify (name :)
E.Let name (Just ty') val' <$> elaborate' body