fixed a sneaky parser bug

This commit is contained in:
William Ball 2024-11-11 20:08:21 -08:00
parent 5ce06d1012
commit 39cab7fd3d
3 changed files with 14 additions and 14 deletions

View file

@ -10,7 +10,7 @@ import Debug.Trace
type Context = [Expr] type Context = [Expr]
data TypeCheckError = Err | SquareUntyped | UnboundVariable | NotASort Expr | ExpectedFunctionType Expr | NotEquivalent Expr Expr data TypeCheckError = Err | SquareUntyped | UnboundVariable | NotASort Expr Int | ExpectedFunctionType Expr | NotEquivalent Expr Expr
deriving (Show) deriving (Show)
type CheckResult = Either TypeCheckError type CheckResult = Either TypeCheckError
@ -22,14 +22,16 @@ matchPi e = Left $ ExpectedFunctionType e
showContext :: Context -> String showContext :: Context -> String
showContext g = "[" ++ intercalate ", " (map show g) ++ "]" showContext g = "[" ++ intercalate ", " (map show g) ++ "]"
-- TODO: Debug these problem cases
-- λ (S : *) (P : S -> *) (H : forall (x : S), P x) (y : S) => P y
findType :: Context -> Expr -> CheckResult Expr findType :: Context -> Expr -> CheckResult Expr
findType _ Star = Right Square findType _ Star = trace "star" $ Right Square
findType _ Square = Left SquareUntyped findType _ Square = trace "square" $ Left SquareUntyped
findType g (Var n _) = do findType g (Var n _) = do
!_ <- trace ("var:\t" ++ showContext g ++ "\n n:\t" ++ show n) (Right Star) !_ <- trace ("var:\t" ++ showContext g ++ "\n n:\t" ++ show n) (Right Star)
t <- maybe (Left UnboundVariable) Right $ g !? fromInteger n t <- maybe (Left UnboundVariable) Right $ g !? fromInteger n
s <- findType g t s <- findType g t
unless (isSort s) $ throwError $ NotASort s unless (isSort s) $ throwError $ NotASort s 32
pure t pure t
findType g (App m n) = do findType g (App m n) = do
!_ <- trace ("app:\t" ++ showContext g ++ "\n m:\t" ++ show m ++ "\n n: \t" ++ show n) (Right Star) !_ <- trace ("app:\t" ++ showContext g ++ "\n m:\t" ++ show m ++ "\n n: \t" ++ show n) (Right Star)
@ -40,15 +42,16 @@ findType g (App m n) = do
findType g (Abs x a m) = do findType g (Abs x a m) = do
!_ <- trace ("abs:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n m:\t" ++ show m) (Right Star) !_ <- trace ("abs:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n m:\t" ++ show m) (Right Star)
s1 <- findType g a s1 <- findType g a
unless (isSort s1) $ throwError $ NotASort s1 !_ <- trace ("back in abs:\t" ++ showContext g ++ "\n\t\t" ++ show a ++ " : " ++ show s1) (Right Star)
unless (isSort s1) $ throwError $ NotASort s1 43
b <- findType (incIndices a : map incIndices g) m b <- findType (incIndices a : map incIndices g) m
s2 <- findType g (Pi x a b) s2 <- findType g (Pi x a b)
unless (isSort s2) $ throwError $ NotASort s2 unless (isSort s2) $ throwError $ NotASort s2 46
pure $ if occursFree 0 b then Pi x a b else Pi "" a b pure $ if occursFree 0 b then Pi x a b else Pi "" a b
findType g (Pi _ a b) = do findType g (Pi _ a b) = do
!_ <- trace ("pi:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n b:\t" ++ show b) (Right Star) !_ <- trace ("pi:\t" ++ showContext g ++ "\n a:\t" ++ show a ++ "\n b:\t" ++ show b) (Right Star)
s1 <- findType g a s1 <- findType g a
unless (isSort s1) $ throwError $ NotASort s1 unless (isSort s1) $ throwError $ NotASort s1 51
s2 <- findType (incIndices a : map incIndices g) b s2 <- findType (incIndices a : map incIndices g) b
unless (isSort s2) $ throwError $ NotASort s2 unless (isSort s2) $ throwError $ NotASort s2 53
pure s2 pure s2

View file

@ -12,8 +12,7 @@ main = do
input <- getLine input <- getLine
case pAll input of case pAll input of
Left err -> putStrLn err Left err -> putStrLn err
-- Right expr -> putStrLn (pretty expr) Right expr -> print expr >> case findType [] expr of
Right expr -> case findType [] expr of
Right ty -> putStrLn $ pretty expr ++ " : " ++ pretty ty Right ty -> putStrLn $ pretty expr ++ " : " ++ pretty ty
Left err -> print err Left err -> print err
main main

View file

@ -1,5 +1,3 @@
{-# LANGUAGE TupleSections #-}
module Parser where module Parser where
import Control.Monad import Control.Monad
@ -9,7 +7,7 @@ import Data.Functor.Identity
import Data.List (elemIndex) import Data.List (elemIndex)
import Data.List.NonEmpty (NonEmpty ((:|))) import Data.List.NonEmpty (NonEmpty ((:|)))
import qualified Data.List.NonEmpty as NE import qualified Data.List.NonEmpty as NE
import Expr (Expr (..), (.->)) import Expr (Expr (..), (.->), incIndices)
import Text.Megaparsec hiding (State) import Text.Megaparsec hiding (State)
import Text.Megaparsec.Char import Text.Megaparsec.Char
import qualified Text.Megaparsec.Char.Lexer as L import qualified Text.Megaparsec.Char.Lexer as L
@ -59,7 +57,7 @@ pParamGroup = lexeme $ label "parameter group" $ between (char '(') (char ')') $
_ <- defChoice $ ":" :| [] _ <- defChoice $ ":" :| []
ty <- pExpr ty <- pExpr
modify (idents ++) modify (idents ++)
pure $ (,ty) <$> idents pure $ zip idents (iterate incIndices ty)
pParams :: Parser [(String, Expr)] pParams :: Parser [(String, Expr)]
pParams = concat <$> some pParamGroup pParams = concat <$> some pParamGroup