From 80fb0e8760daab5e4de3e059b63ec50d9722d2a0 Mon Sep 17 00:00:00 2001 From: William Ball Date: Mon, 11 Nov 2024 23:38:10 -0800 Subject: [PATCH] findType passing every test I've thrown at it! --- app/Check.hs | 15 +++--------- app/Expr.hs | 64 ++++++++++++++++++++++++++++----------------------- app/Main.hs | 2 +- app/Parser.hs | 4 ++-- 4 files changed, 41 insertions(+), 44 deletions(-) diff --git a/app/Check.hs b/app/Check.hs index a70135c..5bc578e 100644 --- a/app/Check.hs +++ b/app/Check.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE BangPatterns #-} module Check where import Control.Monad.Except @@ -6,7 +5,6 @@ import Data.List (intercalate, (!?)) import Control.Monad (unless) import Expr -import Debug.Trace type Context = [Expr] @@ -22,34 +20,27 @@ matchPi e = Left $ ExpectedFunctionType e showContext :: Context -> String showContext g = "[" ++ intercalate ", " (map show g) ++ "]" --- TODO: Debug these problem cases --- λ (S : *) (P : S -> *) (H : forall (x : S), P x) (y : S) => P y findType :: Context -> Expr -> CheckResult Expr -findType _ Star = trace "star" $ Right Square -findType _ Square = trace "square" $ Left SquareUntyped +findType _ Star = Right Square +findType _ Square = Left SquareUntyped findType g (Var n _) = do - !_ <- trace ("var:\t" ++ showContext g ++ "\n n:\t" ++ show n) (Right Star) t <- maybe (Left UnboundVariable) Right $ g !? fromInteger n s <- findType g t unless (isSort s) $ throwError $ NotASort s 32 pure t findType g (App m n) = do - !_ <- trace ("app:\t" ++ showContext g ++ "\n m:\t" ++ show m ++ "\n n: \t" ++ show n) (Right Star) (a, b) <- findType g m >>= matchPi a' <- findType g n unless (betaEquiv a a') $ throwError $ NotEquivalent a a' - pure $ subst n b + pure $ subst 0 n b findType g (Abs x a m) = do - !_ <- trace ("abs:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n m:\t" ++ show m) (Right Star) s1 <- findType g a - !_ <- trace ("back in abs:\t" ++ showContext g ++ "\n\t\t" ++ show a ++ " : " ++ show s1) (Right Star) unless (isSort s1) $ throwError $ NotASort s1 43 b <- findType (incIndices a : map incIndices g) m s2 <- findType g (Pi x a b) unless (isSort s2) $ throwError $ NotASort s2 46 pure $ if occursFree 0 b then Pi x a b else Pi "" a b findType g (Pi _ a b) = do - !_ <- trace ("pi:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n b:\t" ++ show b) (Right Star) s1 <- findType g a unless (isSort s1) $ throwError $ NotASort s1 51 s2 <- findType (incIndices a : map incIndices g) b diff --git a/app/Expr.hs b/app/Expr.hs index 0845d17..cb4bbff 100644 --- a/app/Expr.hs +++ b/app/Expr.hs @@ -11,7 +11,16 @@ data Expr where App :: Expr -> Expr -> Expr Abs :: String -> Expr -> Expr -> Expr Pi :: String -> Expr -> Expr -> Expr - deriving (Show, Eq) + deriving (Show) + +instance Eq Expr where + (Var n _) == (Var m _) = n == m + Star == Star = True + Square == Square = True + (App e1 e2) == (App f1 f2) = e1 == f1 && e2 == f2 + (Abs _ t1 b1) == (Abs _ t2 b2) = t1 == t2 && b1 == b2 + (Pi _ t1 b1) == (Pi _ t2 b2) = t1 == t2 && b1 == b2 + _ == _ = False infixl 4 <.> @@ -55,7 +64,7 @@ groupParams = foldr addParam [] addParam :: (String, Expr) -> [([String], Expr)] -> [([String], Expr)] addParam (x, t) [] = [([x], t)] addParam (x, t) l@((xs, s) : rest) - | t == s = (x : xs, t) : rest + | incIndices t == s = (x : xs, t) : rest | otherwise = ([x], t) : l showParamGroup :: ([String], Expr) -> String @@ -69,8 +78,8 @@ helper k (App e1 e2) = if k > 3 then parenthesize res else res where res = helper 3 e1 ++ " " ++ helper 4 e2 helper k (Pi "" t1 t2) = if k > 2 then parenthesize res else res - where - res = helper 3 t1 ++ " -> " ++ helper 2 t2 + where + res = helper 3 t1 ++ " -> " ++ helper 2 t2 helper k e@(Pi{}) = if k > 2 then parenthesize res else res where (params, body) = collectPis e @@ -92,38 +101,35 @@ isSort Star = True isSort Square = True isSort _ = False +shiftIndices :: Integer -> Integer -> Expr -> Expr +shiftIndices d c (Var k x) + | k >= c = Var (k + d) x + | otherwise = Var k x +shiftIndices _ _ Star = Star +shiftIndices _ _ Square = Square +shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n) +shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n) +shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n) + incIndices :: Expr -> Expr -incIndices (Var n x) = Var (n + 1) x -incIndices Star = Star -incIndices Square = Square -incIndices (App m n) = App (incIndices m) (incIndices n) -incIndices (Abs x m n) = Abs x (incIndices m) (incIndices n) -incIndices (Pi x m n) = Pi x (incIndices m) (incIndices n) +incIndices = shiftIndices 1 0 --- substitute s for 0 *AND* decrement indices; only use after reducing a redex. -subst :: Expr -> Expr -> Expr -subst s (Var 0 _) = s -subst _ (Var n s) = Var (n - 1) s -subst _ Star = Star -subst _ Square = Square -subst s (App m n) = App (subst s m) (subst s n) -subst s (Abs x m n) = Abs x (subst s m) (subst s n) -subst s (Pi x m n) = Pi x (subst s m) (subst s n) - -substnd :: Expr -> Expr -> Expr -substnd s (Var 0 _) = s -substnd _ (Var n s) = Var (n - 1) s -substnd _ Star = Star -substnd _ Square = Square -substnd s (App m n) = App (substnd s m) (substnd s n) -substnd s (Abs x m n) = Abs x (substnd s m) (substnd s n) -substnd s (Pi x m n) = Pi x (substnd s m) (substnd s n) +-- substitute s for k *AND* decrement indices; only use after reducing a redex. +subst :: Integer -> Expr -> Expr -> Expr +subst k s (Var n x) + | k == n = s + | otherwise = Var (n - 1) x +subst _ _ Star = Star +subst _ _ Square = Square +subst k s (App m n) = App (subst k s m) (subst k s n) +subst k s (Abs x m n) = Abs x (subst k s m) (subst (k + 1) (incIndices s) n) +subst k s (Pi x m n) = Pi x (subst k s m) (subst (k + 1) (incIndices s) n) betaReduce :: Expr -> Expr betaReduce (Var k s) = Var k s betaReduce Star = Star betaReduce Square = Square -betaReduce (App (Abs _ _ v) n) = subst n v +betaReduce (App (Abs _ _ v) n) = subst 0 n v betaReduce (App m n) = App (betaReduce m) (betaReduce n) betaReduce (Abs x t v) = Abs x (betaReduce t) (betaReduce v) betaReduce (Pi x t v) = Pi x (betaReduce t) (betaReduce v) diff --git a/app/Main.hs b/app/Main.hs index c16c480..cea0ec7 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -12,7 +12,7 @@ main = do input <- getLine case pAll input of Left err -> putStrLn err - Right expr -> print expr >> case findType [] expr of + Right expr -> case findType [] expr of Right ty -> putStrLn $ pretty expr ++ " : " ++ pretty ty Left err -> print err main diff --git a/app/Parser.hs b/app/Parser.hs index e79eaa1..d269865 100644 --- a/app/Parser.hs +++ b/app/Parser.hs @@ -7,7 +7,7 @@ import Data.Functor.Identity import Data.List (elemIndex) import Data.List.NonEmpty (NonEmpty ((:|))) import qualified Data.List.NonEmpty as NE -import Expr (Expr (..), (.->), incIndices) +import Expr (Expr (..), incIndices, (.->)) import Text.Megaparsec hiding (State) import Text.Megaparsec.Char import qualified Text.Megaparsec.Char.Lexer as L @@ -56,7 +56,7 @@ pParamGroup = lexeme $ label "parameter group" $ between (char '(') (char ')') $ idents <- some pIdentifier _ <- defChoice $ ":" :| [] ty <- pExpr - modify (idents ++) + modify (flip (foldl $ flip (:)) idents) pure $ zip idents (iterate incIndices ty) pParams :: Parser [(String, Expr)]