findType passing every test I've thrown at it!

This commit is contained in:
William Ball 2024-11-11 23:38:10 -08:00
parent 39cab7fd3d
commit 80fb0e8760
4 changed files with 41 additions and 44 deletions

View file

@ -1,4 +1,3 @@
{-# LANGUAGE BangPatterns #-}
module Check where module Check where
import Control.Monad.Except import Control.Monad.Except
@ -6,7 +5,6 @@ import Data.List (intercalate, (!?))
import Control.Monad (unless) import Control.Monad (unless)
import Expr import Expr
import Debug.Trace
type Context = [Expr] type Context = [Expr]
@ -22,34 +20,27 @@ matchPi e = Left $ ExpectedFunctionType e
showContext :: Context -> String showContext :: Context -> String
showContext g = "[" ++ intercalate ", " (map show g) ++ "]" showContext g = "[" ++ intercalate ", " (map show g) ++ "]"
-- TODO: Debug these problem cases
-- λ (S : *) (P : S -> *) (H : forall (x : S), P x) (y : S) => P y
findType :: Context -> Expr -> CheckResult Expr findType :: Context -> Expr -> CheckResult Expr
findType _ Star = trace "star" $ Right Square findType _ Star = Right Square
findType _ Square = trace "square" $ Left SquareUntyped findType _ Square = Left SquareUntyped
findType g (Var n _) = do findType g (Var n _) = do
!_ <- trace ("var:\t" ++ showContext g ++ "\n n:\t" ++ show n) (Right Star)
t <- maybe (Left UnboundVariable) Right $ g !? fromInteger n t <- maybe (Left UnboundVariable) Right $ g !? fromInteger n
s <- findType g t s <- findType g t
unless (isSort s) $ throwError $ NotASort s 32 unless (isSort s) $ throwError $ NotASort s 32
pure t pure t
findType g (App m n) = do findType g (App m n) = do
!_ <- trace ("app:\t" ++ showContext g ++ "\n m:\t" ++ show m ++ "\n n: \t" ++ show n) (Right Star)
(a, b) <- findType g m >>= matchPi (a, b) <- findType g m >>= matchPi
a' <- findType g n a' <- findType g n
unless (betaEquiv a a') $ throwError $ NotEquivalent a a' unless (betaEquiv a a') $ throwError $ NotEquivalent a a'
pure $ subst n b pure $ subst 0 n b
findType g (Abs x a m) = do findType g (Abs x a m) = do
!_ <- trace ("abs:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n m:\t" ++ show m) (Right Star)
s1 <- findType g a s1 <- findType g a
!_ <- trace ("back in abs:\t" ++ showContext g ++ "\n\t\t" ++ show a ++ " : " ++ show s1) (Right Star)
unless (isSort s1) $ throwError $ NotASort s1 43 unless (isSort s1) $ throwError $ NotASort s1 43
b <- findType (incIndices a : map incIndices g) m b <- findType (incIndices a : map incIndices g) m
s2 <- findType g (Pi x a b) s2 <- findType g (Pi x a b)
unless (isSort s2) $ throwError $ NotASort s2 46 unless (isSort s2) $ throwError $ NotASort s2 46
pure $ if occursFree 0 b then Pi x a b else Pi "" a b pure $ if occursFree 0 b then Pi x a b else Pi "" a b
findType g (Pi _ a b) = do findType g (Pi _ a b) = do
!_ <- trace ("pi:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n b:\t" ++ show b) (Right Star)
s1 <- findType g a s1 <- findType g a
unless (isSort s1) $ throwError $ NotASort s1 51 unless (isSort s1) $ throwError $ NotASort s1 51
s2 <- findType (incIndices a : map incIndices g) b s2 <- findType (incIndices a : map incIndices g) b

View file

@ -11,7 +11,16 @@ data Expr where
App :: Expr -> Expr -> Expr App :: Expr -> Expr -> Expr
Abs :: String -> Expr -> Expr -> Expr Abs :: String -> Expr -> Expr -> Expr
Pi :: String -> Expr -> Expr -> Expr Pi :: String -> Expr -> Expr -> Expr
deriving (Show, Eq) deriving (Show)
instance Eq Expr where
(Var n _) == (Var m _) = n == m
Star == Star = True
Square == Square = True
(App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2
(Abs _ t1 b1) == (Abs _ t2 b2) = t1 == t2 && b1 == b2
(Pi _ t1 b1) == (Pi _ t2 b2) = t1 == t2 && b1 == b2
_ == _ = False
infixl 4 <.> infixl 4 <.>
@ -55,7 +64,7 @@ groupParams = foldr addParam []
addParam :: (String, Expr) -> [([String], Expr)] -> [([String], Expr)] addParam :: (String, Expr) -> [([String], Expr)] -> [([String], Expr)]
addParam (x, t) [] = [([x], t)] addParam (x, t) [] = [([x], t)]
addParam (x, t) l@((xs, s) : rest) addParam (x, t) l@((xs, s) : rest)
| t == s = (x : xs, t) : rest | incIndices t == s = (x : xs, t) : rest
| otherwise = ([x], t) : l | otherwise = ([x], t) : l
showParamGroup :: ([String], Expr) -> String showParamGroup :: ([String], Expr) -> String
@ -92,38 +101,35 @@ isSort Star = True
isSort Square = True isSort Square = True
isSort _ = False isSort _ = False
shiftIndices :: Integer -> Integer -> Expr -> Expr
shiftIndices d c (Var k x)
| k >= c = Var (k + d) x
| otherwise = Var k x
shiftIndices _ _ Star = Star
shiftIndices _ _ Square = Square
shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n)
shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n)
shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n)
incIndices :: Expr -> Expr incIndices :: Expr -> Expr
incIndices (Var n x) = Var (n + 1) x incIndices = shiftIndices 1 0
incIndices Star = Star
incIndices Square = Square
incIndices (App m n) = App (incIndices m) (incIndices n)
incIndices (Abs x m n) = Abs x (incIndices m) (incIndices n)
incIndices (Pi x m n) = Pi x (incIndices m) (incIndices n)
-- substitute s for 0 *AND* decrement indices; only use after reducing a redex. -- substitute s for k *AND* decrement indices; only use after reducing a redex.
subst :: Expr -> Expr -> Expr subst :: Integer -> Expr -> Expr -> Expr
subst s (Var 0 _) = s subst k s (Var n x)
subst _ (Var n s) = Var (n - 1) s | k == n = s
subst _ Star = Star | otherwise = Var (n - 1) x
subst _ Square = Square subst _ _ Star = Star
subst s (App m n) = App (subst s m) (subst s n) subst _ _ Square = Square
subst s (Abs x m n) = Abs x (subst s m) (subst s n) subst k s (App m n) = App (subst k s m) (subst k s n)
subst s (Pi x m n) = Pi x (subst s m) (subst s n) subst k s (Abs x m n) = Abs x (subst k s m) (subst (k + 1) (incIndices s) n)
subst k s (Pi x m n) = Pi x (subst k s m) (subst (k + 1) (incIndices s) n)
substnd :: Expr -> Expr -> Expr
substnd s (Var 0 _) = s
substnd _ (Var n s) = Var (n - 1) s
substnd _ Star = Star
substnd _ Square = Square
substnd s (App m n) = App (substnd s m) (substnd s n)
substnd s (Abs x m n) = Abs x (substnd s m) (substnd s n)
substnd s (Pi x m n) = Pi x (substnd s m) (substnd s n)
betaReduce :: Expr -> Expr betaReduce :: Expr -> Expr
betaReduce (Var k s) = Var k s betaReduce (Var k s) = Var k s
betaReduce Star = Star betaReduce Star = Star
betaReduce Square = Square betaReduce Square = Square
betaReduce (App (Abs _ _ v) n) = subst n v betaReduce (App (Abs _ _ v) n) = subst 0 n v
betaReduce (App m n) = App (betaReduce m) (betaReduce n) betaReduce (App m n) = App (betaReduce m) (betaReduce n)
betaReduce (Abs x t v) = Abs x (betaReduce t) (betaReduce v) betaReduce (Abs x t v) = Abs x (betaReduce t) (betaReduce v)
betaReduce (Pi x t v) = Pi x (betaReduce t) (betaReduce v) betaReduce (Pi x t v) = Pi x (betaReduce t) (betaReduce v)

View file

@ -12,7 +12,7 @@ main = do
input <- getLine input <- getLine
case pAll input of case pAll input of
Left err -> putStrLn err Left err -> putStrLn err
Right expr -> print expr >> case findType [] expr of Right expr -> case findType [] expr of
Right ty -> putStrLn $ pretty expr ++ " : " ++ pretty ty Right ty -> putStrLn $ pretty expr ++ " : " ++ pretty ty
Left err -> print err Left err -> print err
main main

View file

@ -7,7 +7,7 @@ import Data.Functor.Identity
import Data.List (elemIndex) import Data.List (elemIndex)
import Data.List.NonEmpty (NonEmpty ((:|))) import Data.List.NonEmpty (NonEmpty ((:|)))
import qualified Data.List.NonEmpty as NE import qualified Data.List.NonEmpty as NE
import Expr (Expr (..), (.->), incIndices) import Expr (Expr (..), incIndices, (.->))
import Text.Megaparsec hiding (State) import Text.Megaparsec hiding (State)
import Text.Megaparsec.Char import Text.Megaparsec.Char
import qualified Text.Megaparsec.Char.Lexer as L import qualified Text.Megaparsec.Char.Lexer as L
@ -56,7 +56,7 @@ pParamGroup = lexeme $ label "parameter group" $ between (char '(') (char ')') $
idents <- some pIdentifier idents <- some pIdentifier
_ <- defChoice $ ":" :| [] _ <- defChoice $ ":" :| []
ty <- pExpr ty <- pExpr
modify (idents ++) modify (flip (foldl $ flip (:)) idents)
pure $ zip idents (iterate incIndices ty) pure $ zip idents (iterate incIndices ty)
pParams :: Parser [(String, Expr)] pParams :: Parser [(String, Expr)]