findType passing every test I've thrown at it!
This commit is contained in:
parent
39cab7fd3d
commit
80fb0e8760
4 changed files with 41 additions and 44 deletions
15
app/Check.hs
15
app/Check.hs
|
|
@ -1,4 +1,3 @@
|
|||
{-# LANGUAGE BangPatterns #-}
|
||||
module Check where
|
||||
|
||||
import Control.Monad.Except
|
||||
|
|
@ -6,7 +5,6 @@ import Data.List (intercalate, (!?))
|
|||
|
||||
import Control.Monad (unless)
|
||||
import Expr
|
||||
import Debug.Trace
|
||||
|
||||
type Context = [Expr]
|
||||
|
||||
|
|
@ -22,34 +20,27 @@ matchPi e = Left $ ExpectedFunctionType e
|
|||
showContext :: Context -> String
|
||||
showContext g = "[" ++ intercalate ", " (map show g) ++ "]"
|
||||
|
||||
-- TODO: Debug these problem cases
|
||||
-- λ (S : *) (P : S -> *) (H : forall (x : S), P x) (y : S) => P y
|
||||
findType :: Context -> Expr -> CheckResult Expr
|
||||
findType _ Star = trace "star" $ Right Square
|
||||
findType _ Square = trace "square" $ Left SquareUntyped
|
||||
findType _ Star = Right Square
|
||||
findType _ Square = Left SquareUntyped
|
||||
findType g (Var n _) = do
|
||||
!_ <- trace ("var:\t" ++ showContext g ++ "\n n:\t" ++ show n) (Right Star)
|
||||
t <- maybe (Left UnboundVariable) Right $ g !? fromInteger n
|
||||
s <- findType g t
|
||||
unless (isSort s) $ throwError $ NotASort s 32
|
||||
pure t
|
||||
findType g (App m n) = do
|
||||
!_ <- trace ("app:\t" ++ showContext g ++ "\n m:\t" ++ show m ++ "\n n: \t" ++ show n) (Right Star)
|
||||
(a, b) <- findType g m >>= matchPi
|
||||
a' <- findType g n
|
||||
unless (betaEquiv a a') $ throwError $ NotEquivalent a a'
|
||||
pure $ subst n b
|
||||
pure $ subst 0 n b
|
||||
findType g (Abs x a m) = do
|
||||
!_ <- trace ("abs:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n m:\t" ++ show m) (Right Star)
|
||||
s1 <- findType g a
|
||||
!_ <- trace ("back in abs:\t" ++ showContext g ++ "\n\t\t" ++ show a ++ " : " ++ show s1) (Right Star)
|
||||
unless (isSort s1) $ throwError $ NotASort s1 43
|
||||
b <- findType (incIndices a : map incIndices g) m
|
||||
s2 <- findType g (Pi x a b)
|
||||
unless (isSort s2) $ throwError $ NotASort s2 46
|
||||
pure $ if occursFree 0 b then Pi x a b else Pi "" a b
|
||||
findType g (Pi _ a b) = do
|
||||
!_ <- trace ("pi:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n b:\t" ++ show b) (Right Star)
|
||||
s1 <- findType g a
|
||||
unless (isSort s1) $ throwError $ NotASort s1 51
|
||||
s2 <- findType (incIndices a : map incIndices g) b
|
||||
|
|
|
|||
64
app/Expr.hs
64
app/Expr.hs
|
|
@ -11,7 +11,16 @@ data Expr where
|
|||
App :: Expr -> Expr -> Expr
|
||||
Abs :: String -> Expr -> Expr -> Expr
|
||||
Pi :: String -> Expr -> Expr -> Expr
|
||||
deriving (Show, Eq)
|
||||
deriving (Show)
|
||||
|
||||
instance Eq Expr where
|
||||
(Var n _) == (Var m _) = n == m
|
||||
Star == Star = True
|
||||
Square == Square = True
|
||||
(App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2
|
||||
(Abs _ t1 b1) == (Abs _ t2 b2) = t1 == t2 && b1 == b2
|
||||
(Pi _ t1 b1) == (Pi _ t2 b2) = t1 == t2 && b1 == b2
|
||||
_ == _ = False
|
||||
|
||||
infixl 4 <.>
|
||||
|
||||
|
|
@ -55,7 +64,7 @@ groupParams = foldr addParam []
|
|||
addParam :: (String, Expr) -> [([String], Expr)] -> [([String], Expr)]
|
||||
addParam (x, t) [] = [([x], t)]
|
||||
addParam (x, t) l@((xs, s) : rest)
|
||||
| t == s = (x : xs, t) : rest
|
||||
| incIndices t == s = (x : xs, t) : rest
|
||||
| otherwise = ([x], t) : l
|
||||
|
||||
showParamGroup :: ([String], Expr) -> String
|
||||
|
|
@ -69,8 +78,8 @@ helper k (App e1 e2) = if k > 3 then parenthesize res else res
|
|||
where
|
||||
res = helper 3 e1 ++ " " ++ helper 4 e2
|
||||
helper k (Pi "" t1 t2) = if k > 2 then parenthesize res else res
|
||||
where
|
||||
res = helper 3 t1 ++ " -> " ++ helper 2 t2
|
||||
where
|
||||
res = helper 3 t1 ++ " -> " ++ helper 2 t2
|
||||
helper k e@(Pi{}) = if k > 2 then parenthesize res else res
|
||||
where
|
||||
(params, body) = collectPis e
|
||||
|
|
@ -92,38 +101,35 @@ isSort Star = True
|
|||
isSort Square = True
|
||||
isSort _ = False
|
||||
|
||||
shiftIndices :: Integer -> Integer -> Expr -> Expr
|
||||
shiftIndices d c (Var k x)
|
||||
| k >= c = Var (k + d) x
|
||||
| otherwise = Var k x
|
||||
shiftIndices _ _ Star = Star
|
||||
shiftIndices _ _ Square = Square
|
||||
shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n)
|
||||
shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n)
|
||||
shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n)
|
||||
|
||||
incIndices :: Expr -> Expr
|
||||
incIndices (Var n x) = Var (n + 1) x
|
||||
incIndices Star = Star
|
||||
incIndices Square = Square
|
||||
incIndices (App m n) = App (incIndices m) (incIndices n)
|
||||
incIndices (Abs x m n) = Abs x (incIndices m) (incIndices n)
|
||||
incIndices (Pi x m n) = Pi x (incIndices m) (incIndices n)
|
||||
incIndices = shiftIndices 1 0
|
||||
|
||||
-- substitute s for 0 *AND* decrement indices; only use after reducing a redex.
|
||||
subst :: Expr -> Expr -> Expr
|
||||
subst s (Var 0 _) = s
|
||||
subst _ (Var n s) = Var (n - 1) s
|
||||
subst _ Star = Star
|
||||
subst _ Square = Square
|
||||
subst s (App m n) = App (subst s m) (subst s n)
|
||||
subst s (Abs x m n) = Abs x (subst s m) (subst s n)
|
||||
subst s (Pi x m n) = Pi x (subst s m) (subst s n)
|
||||
|
||||
substnd :: Expr -> Expr -> Expr
|
||||
substnd s (Var 0 _) = s
|
||||
substnd _ (Var n s) = Var (n - 1) s
|
||||
substnd _ Star = Star
|
||||
substnd _ Square = Square
|
||||
substnd s (App m n) = App (substnd s m) (substnd s n)
|
||||
substnd s (Abs x m n) = Abs x (substnd s m) (substnd s n)
|
||||
substnd s (Pi x m n) = Pi x (substnd s m) (substnd s n)
|
||||
-- substitute s for k *AND* decrement indices; only use after reducing a redex.
|
||||
subst :: Integer -> Expr -> Expr -> Expr
|
||||
subst k s (Var n x)
|
||||
| k == n = s
|
||||
| otherwise = Var (n - 1) x
|
||||
subst _ _ Star = Star
|
||||
subst _ _ Square = Square
|
||||
subst k s (App m n) = App (subst k s m) (subst k s n)
|
||||
subst k s (Abs x m n) = Abs x (subst k s m) (subst (k + 1) (incIndices s) n)
|
||||
subst k s (Pi x m n) = Pi x (subst k s m) (subst (k + 1) (incIndices s) n)
|
||||
|
||||
betaReduce :: Expr -> Expr
|
||||
betaReduce (Var k s) = Var k s
|
||||
betaReduce Star = Star
|
||||
betaReduce Square = Square
|
||||
betaReduce (App (Abs _ _ v) n) = subst n v
|
||||
betaReduce (App (Abs _ _ v) n) = subst 0 n v
|
||||
betaReduce (App m n) = App (betaReduce m) (betaReduce n)
|
||||
betaReduce (Abs x t v) = Abs x (betaReduce t) (betaReduce v)
|
||||
betaReduce (Pi x t v) = Pi x (betaReduce t) (betaReduce v)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ main = do
|
|||
input <- getLine
|
||||
case pAll input of
|
||||
Left err -> putStrLn err
|
||||
Right expr -> print expr >> case findType [] expr of
|
||||
Right expr -> case findType [] expr of
|
||||
Right ty -> putStrLn $ pretty expr ++ " : " ++ pretty ty
|
||||
Left err -> print err
|
||||
main
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import Data.Functor.Identity
|
|||
import Data.List (elemIndex)
|
||||
import Data.List.NonEmpty (NonEmpty ((:|)))
|
||||
import qualified Data.List.NonEmpty as NE
|
||||
import Expr (Expr (..), (.->), incIndices)
|
||||
import Expr (Expr (..), incIndices, (.->))
|
||||
import Text.Megaparsec hiding (State)
|
||||
import Text.Megaparsec.Char
|
||||
import qualified Text.Megaparsec.Char.Lexer as L
|
||||
|
|
@ -56,7 +56,7 @@ pParamGroup = lexeme $ label "parameter group" $ between (char '(') (char ')') $
|
|||
idents <- some pIdentifier
|
||||
_ <- defChoice $ ":" :| []
|
||||
ty <- pExpr
|
||||
modify (idents ++)
|
||||
modify (flip (foldl $ flip (:)) idents)
|
||||
pure $ zip idents (iterate incIndices ty)
|
||||
|
||||
pParams :: Parser [(String, Expr)]
|
||||
|
|
|
|||
Loading…
Reference in a new issue