{-# LANGUAGE TupleSections #-} module Elaborator where import Data.List (elemIndex, lookup) import qualified Data.Set as S import Expr (Expr) import qualified Expr as E import IR (IRDef (..), IRExpr, IRProgram, IRSectionDef (..)) import qualified IR as I import Relude.Extra.Lens type Binders = [Text] data SectionContext = SectionContext { sectionDefs :: [(Text, [(Text, IRExpr)])] -- name and list of variables and their types it depends on , sectionVars :: [(Text, IRExpr)] -- variables and their types } type ElabMonad = State SectionContext lookupDefInCtxt :: Text -> SectionContext -> Maybe [(Text, IRExpr)] lookupDefInCtxt def (SectionContext defs _) = lookup def defs -- looks up a definition in the context and gives a list of the variables and -- their types that it depends on lookupDef :: Text -> ElabMonad (Maybe [(Text, IRExpr)]) lookupDef def = lookupDefInCtxt def <$> get lookupVarInCtxt :: Text -> SectionContext -> Maybe IRExpr lookupVarInCtxt var = lookup var . sectionVars -- looks up a variable in the context and returns its type lookupVar :: Text -> ElabMonad (Maybe IRExpr) lookupVar var = lookupVarInCtxt var <$> get sectionDefsL :: Lens' SectionContext [(Text, [(Text, IRExpr)])] sectionDefsL = lens sectionDefs setter where setter ctxt newDefs = ctxt{sectionDefs = newDefs} sectionVarsL :: Lens' SectionContext [(Text, IRExpr)] sectionVarsL = lens sectionVars setter where setter ctxt newVars = ctxt{sectionVars = newVars} saveStateSection :: ElabMonad a -> ElabMonad a saveStateSection action = do (SectionContext _ oldVars) <- get res <- action (SectionContext newDefs _) <- get put (SectionContext (mapMaybe (usesFrom oldVars) newDefs) oldVars) pure res where usesFrom :: [(Text, IRExpr)] -> (Text, [(Text, IRExpr)]) -> Maybe (Text, [(Text, IRExpr)]) usesFrom vars (name, uses) = let newUses = filter (`elem` vars) uses in if null newUses then Nothing else Just (name, newUses) saveState :: ElabMonad a -> ElabMonad a saveState action = get >>= (action <*) . put elabSection :: Text -> [IRSectionDef] -> ElabMonad [IRDef] elabSection _name contents = saveStateSection $ concat <$> forM contents elabDef elabProgram :: IRProgram -> [IRDef] elabProgram prog = evalState (elabSection "" prog) (SectionContext [] []) pushVariable :: Text -> IRExpr -> ElabMonad () pushVariable name ty = do newTy <- traverseBody ty modify $ over sectionVarsL ((name, newTy) :) pushDefinition :: Text -> [(Text, IRExpr)] -> SectionContext -> SectionContext pushDefinition name defVars (SectionContext defs vars) = SectionContext ((name, defVars) : defs) vars removeName :: Text -> ElabMonad () removeName name = do modify $ over sectionDefsL (filter ((/= name) . fst)) modify $ over sectionVarsL (filter ((/= name) . fst)) extendVars :: Set (Text, IRExpr) -> ElabMonad (Set (Text, IRExpr)) extendVars vars = do vars' <- foldr S.union S.empty <$> traverse (usedVars . snd) (S.toList vars) if vars' `S.isSubsetOf` vars then pure vars else extendVars (vars `S.union` vars') organize :: Set (Text, IRExpr) -> ElabMonad [(Text, IRExpr)] organize found = do vars <- gets sectionVars pure $ reverse [var | var <- vars, var `S.member` found] -- find all the section variables used in an expression usedVars :: IRExpr -> ElabMonad (Set (Text, IRExpr)) usedVars (I.Var name) = do varDeps <- maybe S.empty (S.singleton . (name,)) <$> lookupVar name defDeps <- maybe S.empty S.fromList <$> lookupDef name pure $ varDeps `S.union` defDeps usedVars (I.PureVar pvname) = usedVars (I.Var pvname) usedVars I.Star = pure S.empty usedVars (I.Level _) = pure S.empty usedVars (I.App m n) = S.union <$> usedVars m <*> usedVars n usedVars (I.Abs name ty body) = saveState $ do ty' <- usedVars ty removeName name S.union ty' <$> usedVars body usedVars (I.Pi name ty body) = saveState $ do ty' <- usedVars ty removeName name S.union ty' <$> usedVars body usedVars (I.Let name ascr value body) = saveState $ do ty' <- usedVars value ascr' <- traverse usedVars ascr removeName name S.union (ty' `S.union` (ascr' ?: S.empty)) <$> usedVars body usedVars (I.Prod m n) = S.union <$> usedVars m <*> usedVars n usedVars (I.Pair m n) = S.union <$> usedVars m <*> usedVars n usedVars (I.Pi1 x) = usedVars x usedVars (I.Pi2 x) = usedVars x -- traverse the body of a definition, adding the necessary section arguments to -- any definitions made within the section traverseBody :: IRExpr -> ElabMonad IRExpr traverseBody (I.Var name) = do mdeps <- lookupDef name case mdeps of Nothing -> pure $ I.Var name Just deps -> pure $ foldl' (\acc newVar -> I.App acc (I.Var $ fst newVar)) (I.Var name) deps traverseBody (I.PureVar pvname) = pure $ I.PureVar pvname traverseBody I.Star = pure I.Star traverseBody (I.Level k) = pure $ I.Level k traverseBody (I.App m n) = I.App <$> traverseBody m <*> traverseBody n traverseBody (I.Abs name ty body) = saveState $ do ty' <- traverseBody ty removeName name I.Abs name ty' <$> traverseBody body traverseBody (I.Pi name ty body) = saveState $ do ty' <- traverseBody ty removeName name I.Pi name ty' <$> traverseBody body traverseBody (I.Let name ascr value body) = saveState $ do ascr' <- traverse traverseBody ascr value' <- traverseBody value removeName name I.Let name ascr' value' <$> traverseBody body traverseBody (I.Prod m n) = I.Prod <$> traverseBody m <*> traverseBody n traverseBody (I.Pair m n) = I.Pair <$> traverseBody m <*> traverseBody n traverseBody (I.Pi1 x) = I.Pi1 <$> traverseBody x traverseBody (I.Pi2 x) = I.Pi2 <$> traverseBody x mkPi :: (Text, IRExpr) -> IRExpr -> IRExpr mkPi (param, typ) = I.Pi param typ mkAbs :: (Text, IRExpr) -> IRExpr -> IRExpr mkAbs (param, typ) = I.Abs param typ generalizeType :: IRExpr -> [(Text, IRExpr)] -> IRExpr generalizeType = foldr mkPi generalizeVal :: IRExpr -> [(Text, IRExpr)] -> IRExpr generalizeVal = foldr mkAbs elabDef :: IRSectionDef -> ElabMonad [IRDef] elabDef (IRDef (Def name ty body)) = do tyVars <- fromMaybe S.empty <$> traverse usedVars ty bodyVars <- usedVars body totalVars <- extendVars (tyVars `S.union` bodyVars) >>= organize newBody <- traverseBody body newType <- traverse traverseBody ty modify $ pushDefinition name totalVars pure [Def name (flip generalizeType totalVars <$> newType) (generalizeVal newBody totalVars)] elabDef (IRDef (Axiom name ty)) = do vars <- usedVars ty >>= extendVars >>= organize modify $ pushDefinition name vars pure [Axiom name (generalizeType ty vars)] elabDef (Section name contents) = saveStateSection $ elabSection name contents elabDef (Variable name ty) = [] <$ pushVariable name ty saveBinders :: State Binders a -> State Binders a saveBinders action = do binders <- get res <- action put binders pure res elaborate :: IRExpr -> Expr elaborate ir = evalState (elaborate' ir) [] where elaborate' :: IRExpr -> State Binders Expr elaborate' (I.Var n) = do binders <- get pure $ E.Var n . fromIntegral <$> elemIndex n binders ?: E.Free n elaborate' (I.PureVar pvname) = elaborate' $ I.Var pvname elaborate' I.Star = pure E.Star elaborate' (I.Level level) = pure $ E.Level level elaborate' (I.App m n) = E.App <$> elaborate' m <*> elaborate' n elaborate' (I.Abs x t b) = saveBinders $ do t' <- elaborate' t modify (x :) E.Abs x t' <$> elaborate' b elaborate' (I.Pi x t b) = saveBinders $ do t' <- elaborate' t modify (x :) E.Pi x t' <$> elaborate' b elaborate' (I.Let name Nothing val body) = saveBinders $ do val' <- elaborate' val modify (name :) E.Let name Nothing val' <$> elaborate' body elaborate' (I.Let name (Just ty) val body) = saveBinders $ do val' <- elaborate' val ty' <- elaborate' ty modify (name :) E.Let name (Just ty') val' <$> elaborate' body elaborate' (I.Prod m n) = E.Prod <$> elaborate' m <*> elaborate' n elaborate' (I.Pair m n) = E.Pair <$> elaborate' m <*> elaborate' n elaborate' (I.Pi1 x) = E.Pi1 <$> elaborate' x elaborate' (I.Pi2 x) = E.Pi2 <$> elaborate' x