basics of definitions!!!!

This commit is contained in:
William Ball 2024-11-17 01:57:53 -08:00
parent f5e79c3225
commit c1ccd50644
11 changed files with 159 additions and 127 deletions

View file

@ -1,8 +1,6 @@
module Main where
import Check
import qualified Data.Text.IO as T
import Expr
import Parser
import Repl
import System.Environment
@ -23,8 +21,4 @@ handleFile fileName =
input <- T.hGetContents fileH
case pAll input of
Left err -> putStrLn err
Right expr -> case findType [] expr of
Left err -> print err
Right ty -> do
putStrLn $ "expr:\t" ++ prettyS expr
putStrLn $ "type:\t" ++ prettyS ty
Right () -> putStrLn "success!"

View file

@ -2,6 +2,7 @@ module Repl (repl) where
import Check
import qualified Data.Text as T
import qualified Data.Map as M
import Expr
import Parser
import System.Console.Haskeline
@ -23,12 +24,7 @@ parseCommand (Just input) = Just (Input input)
handleInput :: ReplState -> String -> InputT IO ()
handleInput state input = case pAll (T.pack input) of
Left err -> outputStrLn err
Right expr -> case findType [] expr of
Left err -> outputStrLn $ show err
Right ty ->
if debugMode state
then printDebugInfo expr ty
else outputStrLn $ prettyS ty
Right () -> pure ()
printDebugInfo :: Expr -> Expr -> InputT IO ()
printDebugInfo expr ty = do

View file

@ -24,13 +24,15 @@ build-type: Simple
extra-doc-files: CHANGELOG.md
, README.md
-- Extra source files to be distributed with the package, such as examples, or a tutorial module.
-- extra-source-files:
common warnings
ghc-options: -Wall
library dependent-lambda-lib
import: warnings
exposed-modules: Check
Parser
Expr
Eval
hs-source-dirs: lib
build-depends: base ^>=4.19.1.0
@ -38,13 +40,11 @@ library dependent-lambda-lib
, text
, parser-combinators
, mtl
, containers
default-language: Haskell2010
default-extensions: OverloadedStrings
, GADTs
common warnings
ghc-options: -Wall
executable dependent-lambda
import: warnings
main-is: Main.hs
@ -53,6 +53,7 @@ executable dependent-lambda
build-depends: base ^>=4.19.1.0
, dependent-lambda-lib
, text
, containers
, haskeline
, directory
, filepath
@ -62,6 +63,7 @@ executable dependent-lambda
, GADTs
test-suite tests
import: warnings
type: exitcode-stdio-1.0
main-is: Tests.hs
other-modules: ExprTests
@ -70,6 +72,7 @@ test-suite tests
build-depends: base ^>=4.19.1.0
, HUnit
, text
, containers
, dependent-lambda-lib
hs-source-dirs: tests
default-language: Haskell2010

View file

@ -1,16 +1,19 @@
module Check (TypeCheckError (..), CheckResult (..), findType) where
module Check (TypeCheckError (..), CheckResult, checkType) where
import Control.Monad (unless)
import Control.Monad.Except (MonadError (throwError))
import Data.List (intercalate, (!?))
import Control.Monad.Reader
import Data.List ((!?))
import qualified Data.Map as M
import Data.Text (Text)
import qualified Data.Text as T
import Control.Monad (unless)
import Eval
import Expr
type Context = [Expr]
data TypeCheckError = SquareUntyped | UnboundVariable Text | NotASort Expr Expr | ExpectedPiType Expr Expr | NotEquivalent Expr Expr Expr deriving (Eq)
data TypeCheckError = SquareUntyped | UnboundVariable Text | NotASort Expr Expr | ExpectedPiType Expr Expr | NotEquivalent Expr Expr Expr deriving (Eq, Ord)
instance Show TypeCheckError where
show SquareUntyped = "□ does not have a type"
@ -21,25 +24,24 @@ instance Show TypeCheckError where
type CheckResult = Either TypeCheckError
matchPi :: Expr -> Expr -> CheckResult (Expr, Expr)
matchPi _ (Pi _ a b) = Right (a, b)
matchPi m e = Left $ ExpectedPiType m e
matchPi :: Expr -> Expr -> ReaderT Env CheckResult (Expr, Expr)
matchPi _ (Pi _ a b) = pure (a, b)
matchPi m e = throwError $ ExpectedPiType m e
showContext :: Context -> String
showContext g = "[" ++ intercalate ", " (map show g) ++ "]"
findType :: Context -> Expr -> CheckResult Expr
findType _ Star = Right Square
findType _ Square = Left SquareUntyped
findType :: Context -> Expr -> ReaderT Env CheckResult Expr
findType _ Star = pure Square
findType _ Square = throwError SquareUntyped
findType g (Var n x) = do
t <- maybe (Left $ UnboundVariable x) Right $ g !? fromInteger n
t <- maybe (throwError $ UnboundVariable x) pure $ g !? fromInteger n
s <- findType g t
unless (isSort s) $ throwError $ NotASort t s
pure t
findType g (Free n) = asks (M.lookup n) >>= maybe (throwError $ UnboundVariable n) (findType g)
findType g e@(App m n) = do
(a, b) <- findType g m >>= matchPi m
a' <- findType g n
unless (betaEquiv a a') $ throwError $ NotEquivalent a a' e
equiv <- asks $ runReader (betaEquiv a a')
unless equiv $ throwError $ NotEquivalent a a' e
pure $ subst 0 n b
findType g (Abs x a m) = do
s1 <- findType g a
@ -54,3 +56,6 @@ findType g (Pi _ a b) = do
s2 <- findType (incIndices a : map incIndices g) b
unless (isSort s2) $ throwError $ NotASort b s2
pure s2
checkType :: Env -> Context -> Expr -> CheckResult Expr
checkType env g t = runReaderT (findType g t) env

39
lib/Eval.hs Normal file
View file

@ -0,0 +1,39 @@
module Eval where
import Control.Monad.Reader
import qualified Data.Map as M
import Data.Maybe
import Data.Text (Text)
import Expr
type Env = M.Map Text Expr
-- substitute s for k *AND* decrement indices; only use after reducing a redex.
subst :: Integer -> Expr -> Expr -> Expr
subst k s (Var n x)
| k == n = s
| n > k = Var (n - 1) x
| otherwise = Var n x
subst _ _ (Free s) = Free s
subst _ _ Star = Star
subst _ _ Square = Square
subst k s (App m n) = App (subst k s m) (subst k s n)
subst k s (Abs x m n) = Abs x (subst k s m) (subst (k + 1) (incIndices s) n)
subst k s (Pi x m n) = Pi x (subst k s m) (subst (k + 1) (incIndices s) n)
whnf :: Expr -> Expr
whnf (App (Abs _ _ v) n) = whnf $ subst 0 n v
whnf e = e
betaEquiv :: Expr -> Expr -> Reader Env Bool
betaEquiv e1 e2
| e1 == e2 = pure True
| otherwise = case (whnf e1, whnf e2) of
(Var k1 _, Var k2 _) -> pure $ k1 == k2
(Free n, Free m) -> pure $ n == m
(Free n, e) -> fromMaybe False <$> (asks (M.lookup n) >>= traverse (`betaEquiv` e))
(e, Free n) -> fromMaybe False <$> (asks (M.lookup n) >>= traverse (`betaEquiv` e))
(Star, Star) -> pure True
(Abs _ t1 v1, Abs _ t2 v2) -> (&&) <$> betaEquiv t1 t2 <*> betaEquiv v1 v2 -- i want idiom brackets
(Pi _ t1 v1, Pi _ t2 v2) -> (&&) <$> betaEquiv t1 t2 <*> betaEquiv v1 v2
_ -> pure False -- remaining cases impossible or false

View file

@ -6,12 +6,13 @@ import qualified Data.Text as T
data Expr where
Var :: Integer -> Text -> Expr
Free :: Text -> Expr
Star :: Expr
Square :: Expr
App :: Expr -> Expr -> Expr
Abs :: Text -> Expr -> Expr -> Expr
Pi :: Text -> Expr -> Expr -> Expr
deriving (Show)
deriving (Show, Ord)
instance Eq Expr where
(Var n _) == (Var m _) = n == m
@ -24,12 +25,32 @@ instance Eq Expr where
occursFree :: Integer -> Expr -> Bool
occursFree n (Var k _) = n == k
occursFree _ (Free _) = False
occursFree _ Star = False
occursFree _ Square = False
occursFree n (App a b) = on (||) (occursFree n) a b
occursFree n (Abs _ a b) = occursFree n a || occursFree (n + 1) b
occursFree n (Pi _ a b) = occursFree n a || occursFree (n + 1) b
isSort :: Expr -> Bool
isSort Star = True
isSort Square = True
isSort _ = False
shiftIndices :: Integer -> Integer -> Expr -> Expr
shiftIndices d c (Var k x)
| k >= c = Var (k + d) x
| otherwise = Var k x
shiftIndices _ _ (Free s) = Free s
shiftIndices _ _ Star = Star
shiftIndices _ _ Square = Square
shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n)
shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n)
shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n)
incIndices :: Expr -> Expr
incIndices = shiftIndices 1 0
{- --------------------- PRETTY PRINTING ----------------------------- -}
parenthesize :: Text -> Text
@ -62,6 +83,7 @@ showParamGroup (ids, ty) = parenthesize $ T.unwords ids <> " : " <> pretty ty
helper :: Integer -> Expr -> Text
helper _ (Var _ s) = s
helper _ (Free s) = s
helper _ Star = "*"
helper _ Square = ""
helper k (App e1 e2) = if k > 3 then parenthesize res else res
@ -86,63 +108,3 @@ pretty = helper 0
prettyS :: Expr -> String
prettyS = T.unpack . pretty
{- --------------- ACTUAL MATH STUFF ---------------- -}
isSort :: Expr -> Bool
isSort Star = True
isSort Square = True
isSort _ = False
shiftIndices :: Integer -> Integer -> Expr -> Expr
shiftIndices d c (Var k x)
| k >= c = Var (k + d) x
| otherwise = Var k x
shiftIndices _ _ Star = Star
shiftIndices _ _ Square = Square
shiftIndices d c (App m n) = App (shiftIndices d c m) (shiftIndices d c n)
shiftIndices d c (Abs x m n) = Abs x (shiftIndices d c m) (shiftIndices d (c + 1) n)
shiftIndices d c (Pi x m n) = Pi x (shiftIndices d c m) (shiftIndices d (c + 1) n)
incIndices :: Expr -> Expr
incIndices = shiftIndices 1 0
-- substitute s for k *AND* decrement indices; only use after reducing a redex.
subst :: Integer -> Expr -> Expr -> Expr
subst k s (Var n x)
| k == n = s
| n > k = Var (n - 1) x
| otherwise = Var n x
subst _ _ Star = Star
subst _ _ Square = Square
subst k s (App m n) = App (subst k s m) (subst k s n)
subst k s (Abs x m n) = Abs x (subst k s m) (subst (k + 1) (incIndices s) n)
subst k s (Pi x m n) = Pi x (subst k s m) (subst (k + 1) (incIndices s) n)
betaReduce :: Expr -> Expr
betaReduce (Var k s) = Var k s
betaReduce Star = Star
betaReduce Square = Square
betaReduce (App (Abs _ _ v) n) = subst 0 n v
betaReduce (App m n) = App (betaReduce m) (betaReduce n)
betaReduce (Abs x t v) = Abs x (betaReduce t) (betaReduce v)
betaReduce (Pi x t v) = Pi x (betaReduce t) (betaReduce v)
betaNF :: Expr -> Expr
betaNF e = if e == e' then e else betaNF e'
where
e' = betaReduce e
whnf :: Expr -> Expr
whnf (App (Abs _ _ v) n) = whnf $ subst 0 n v
whnf e = e
betaEquiv :: Expr -> Expr -> Bool
betaEquiv e1 e2
| e1 == e2 = True
| otherwise = case (whnf e1, whnf e2) of
(Var k1 _, Var k2 _) -> k1 == k2
(Star, Star) -> True
(Abs _ t1 v1, Abs _ t2 v2) -> betaEquiv t1 t2 && betaEquiv v1 v2
(Pi _ t1 v1, Pi _ t2 v2) -> betaEquiv t1 t2 && betaEquiv v1 v2
_ -> False -- remaining cases impossible or false

View file

@ -1,5 +1,8 @@
{-# LANGUAGE NamedFieldPuns #-}
module Parser (pAll) where
import Check
import Control.Monad
import Control.Monad.State.Strict
import Data.Bifunctor (first)
@ -7,22 +10,30 @@ import Data.Functor.Identity
import Data.List (elemIndex)
import Data.List.NonEmpty (NonEmpty ((:|)))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
import Data.Text (Text)
import qualified Data.Text as T
import Eval
import Expr (Expr (..), incIndices)
import Text.Megaparsec hiding (State)
import Text.Megaparsec.Char
import qualified Text.Megaparsec.Char.Lexer as L
type InnerState = [Text]
data InnerState = IS {_binds :: [Text], _defs :: Env}
data CustomErrors = UnboundVariable Text [Text] deriving (Eq, Ord, Show)
newtype TypeError = TE TypeCheckError
deriving (Eq, Ord, Show)
instance ShowErrorComponent CustomErrors where
showErrorComponent (UnboundVariable var bound) =
"Unbound variable: " ++ T.unpack var ++ ". Did you mean one of: " ++ T.unpack (T.unwords bound) ++ "?"
type Parser = ParsecT TypeError Text (State InnerState)
type Parser = ParsecT CustomErrors Text (State InnerState)
instance ShowErrorComponent TypeError where
showErrorComponent (TE e) = show e
bindsToIS :: ([Text] -> [Text]) -> InnerState -> InnerState
bindsToIS f x@(IS{_binds}) = x{_binds = f _binds}
defsToIS :: (Env -> Env) -> InnerState -> InnerState
defsToIS f x@(IS{_defs}) = x{_defs = f _defs}
skipSpace :: Parser ()
skipSpace =
@ -38,15 +49,15 @@ pIdentifier :: Parser Text
pIdentifier = label "identifier" $ lexeme $ do
firstChar <- letterChar <|> char '_'
rest <- many $ alphaNumChar <|> char '_'
return $ T.pack (firstChar : rest) -- Still need T.pack here as we're building from chars
return $ T.pack (firstChar : rest)
pVar :: Parser Expr
pVar = label "variable" $ lexeme $ do
var <- pIdentifier
binders <- get
case elemIndex var binders of
Just i -> return $ Var (fromIntegral i) var
Nothing -> customFailure $ UnboundVariable var binders
binders <- _binds <$> get
pure $ case elemIndex var binders of
Just i -> Var (fromIntegral i) var
Nothing -> Free var
defChoice :: NE.NonEmpty Text -> Parser ()
defChoice options = lexeme $ label (T.unpack $ NE.head options) $ void $ choice $ NE.map chunk options
@ -56,7 +67,7 @@ pParamGroup = lexeme $ label "parameter group" $ between (char '(') (char ')') $
idents <- some pIdentifier
_ <- defChoice $ ":" :| []
ty <- pExpr
modify (flip (foldl $ flip (:)) idents)
modify $ bindsToIS $ flip (foldl $ flip (:)) idents
pure $ zip idents (iterate incIndices ty)
pParams :: Parser [(Text, Expr)]
@ -68,7 +79,7 @@ pLAbs = lexeme $ label "λ-abstraction" $ do
params <- pParams
_ <- defChoice $ "." :| ["=>", ""]
body <- pExpr
modify (drop $ length params)
modify $ bindsToIS $ drop $ length params
pure $ foldr (uncurry Abs) body params
pPAbs :: Parser Expr
@ -77,7 +88,7 @@ pPAbs = lexeme $ label "Π-abstraction" $ do
params <- pParams
_ <- defChoice $ "." :| [","]
body <- pExpr
modify (drop $ length params)
modify $ bindsToIS $ drop $ length params
pure $ foldr (uncurry Pi) body params
pArrow :: Parser Expr
@ -112,5 +123,20 @@ pAppTerm = lexeme $ pLAbs <|> pPAbs <|> pApp
pExpr :: Parser Expr
pExpr = lexeme $ try pArrow <|> pAppTerm
pAll :: Text -> Either String Expr
pAll input = first errorBundlePretty $ fst $ runIdentity $ runStateT (runParserT pExpr "" input) []
pDef :: Parser ()
pDef = lexeme $ label "definition" $ do
ident <- pIdentifier
_ <- defChoice $ ":=" :| []
value <- pExpr
_ <- defChoice $ ";" :| []
foo <- get
let ty = checkType (_defs foo) [] value
case ty of
Left err -> customFailure $ TE err
Right _ -> modify $ defsToIS $ M.insert ident value
pProgram :: Parser ()
pProgram = void $ many pDef
pAll :: Text -> Either String ()
pAll input = first errorBundlePretty $ fst $ runIdentity $ runStateT (runParserT pProgram "" input) $ IS{_binds = [], _defs = M.empty}

2
test.pg Normal file
View file

@ -0,0 +1,2 @@
id := fun (A : *) (x : A) . x ;
foo := fun (A B : *) (f : A -> B) (x : A) . id (A -> B) f (id A x) ;

View file

@ -1,11 +1,12 @@
module CheckTests (tests) where
import Check
import qualified Data.Map as M
import Expr (Expr (..))
import Test.HUnit
sort :: Test
sort = TestCase $ assertEqual "*" (Right Square) (findType [] Star)
sort = TestCase $ assertEqual "*" (Right Square) (checkType M.empty [] Star)
stlc :: Test
stlc =
@ -13,7 +14,12 @@ stlc =
assertEqual
"fun (x : A) (y : B) . x"
(Right $ Pi "" (Var 0 "A") (Pi "" (Var 2 "B") (Var 2 "A")))
(findType [Star, Star] $ Abs "x" (Var 0 "A") (Abs "y" (Var 2 "B") (Var 1 "x")))
(checkType M.empty [Star, Star] $ Abs "x" (Var 0 "A") (Abs "y" (Var 2 "B") (Var 1 "x")))
freeVar :: Test
freeVar =
TestCase $
assertEqual "{x = *} , [] |- x : □" (Right Square) (checkType (M.singleton "x" Star) [] (Free "x"))
polyIdent :: Test
polyIdent =
@ -21,7 +27,7 @@ polyIdent =
assertEqual
"fun (A : *) (x : A) . x"
(Right $ Pi "A" Star (Pi "" (Var 0 "A") (Var 1 "A")))
(findType [] (Abs "A" Star (Abs "x" (Var 0 "A") (Var 0 "x"))))
(checkType M.empty [] (Abs "A" Star (Abs "x" (Var 0 "A") (Var 0 "x"))))
typeCons :: Test
typeCons =
@ -29,7 +35,7 @@ typeCons =
assertEqual
"fun (A B : *) . A -> B"
(Right $ Pi "" Star (Pi "" Star Star))
(findType [] (Abs "A" Star (Abs "B" Star (Pi "" (Var 1 "A") (Var 1 "B")))))
(checkType M.empty [] (Abs "A" Star (Abs "B" Star (Pi "" (Var 1 "A") (Var 1 "B")))))
useTypeCons :: Test
useTypeCons =
@ -37,7 +43,7 @@ useTypeCons =
assertEqual
"fun (C : * -> *) (A : *) (x : C A) . x"
(Right $ Pi "C" (Pi "" Star Star) (Pi "A" Star (Pi "" (App (Var 1 "C") (Var 0 "A")) (App (Var 2 "C") (Var 1 "A")))))
(findType [] $ Abs "C" (Pi "" Star Star) (Abs "A" Star (Abs "x" (App (Var 1 "C") (Var 0 "A")) (Var 0 "x"))))
(checkType M.empty [] $ Abs "C" (Pi "" Star Star) (Abs "A" Star (Abs "x" (App (Var 1 "C") (Var 0 "A")) (Var 0 "x"))))
dependent :: Test
dependent =
@ -45,7 +51,7 @@ dependent =
assertEqual
"fun (S : *) (x : S) . S -> S"
(Right $ Pi "S" Star (Pi "" (Var 0 "S") Star))
(findType [] $ Abs "S" Star (Abs "x" (Var 0 "S") (Pi "" (Var 1 "S") (Var 2 "S"))))
(checkType M.empty [] $ Abs "S" Star (Abs "x" (Var 0 "S") (Pi "" (Var 1 "S") (Var 2 "S"))))
useDependent :: Test
useDependent =
@ -53,7 +59,7 @@ useDependent =
assertEqual
"fun (S : *) (P : S -> *) (x : S) . P x"
(Right $ Pi "S" Star (Pi "" (Pi "" (Var 0 "S") Star) (Pi "" (Var 1 "S") Star)))
(findType [] $ Abs "S" Star (Abs "P" (Pi "" (Var 0 "S") Star) (Abs "x" (Var 1 "S") (App (Var 1 "P") (Var 0 "x")))))
(checkType M.empty [] $ Abs "S" Star (Abs "P" (Pi "" (Var 0 "S") Star) (Abs "x" (Var 1 "S") (App (Var 1 "P") (Var 0 "x")))))
big :: Test
big =
@ -61,15 +67,15 @@ big =
assertEqual
"fun (S : *) (P Q : S -> *) (H : forall (x : S), P x -> Q x) (G : forall (x : S), P x) (x : S) . H x (G x)"
(Right $ Pi "S" Star (Pi "P" (Pi "" (Var 0 "S") Star) (Pi "Q" (Pi "" (Var 1 "S") Star) (Pi "" (Pi "x" (Var 2 "S") (Pi "" (App (Var 2 "P") (Var 0 "x")) (App (Var 2 "Q") (Var 1 "x")))) (Pi "" (Pi "x" (Var 3 "S") (App (Var 3 "P") (Var 0 "x"))) (Pi "x" (Var 4 "S") (App (Var 3 "Q") (Var 0 "x"))))))))
(findType [] $ Abs "S" Star (Abs "P" (Pi "" (Var 0 "S") Star) (Abs "Q" (Pi "" (Var 1 "S") Star) (Abs "H" (Pi "x" (Var 2 "S") (Pi "" (App (Var 2 "P") (Var 0 "x")) (App (Var 2 "Q") (Var 1 "x")))) (Abs "G" (Pi "x" (Var 3 "S") (App (Var 3 "P") (Var 0 "x"))) (Abs "x" (Var 4 "S") (App (App (Var 2 "H") (Var 0 "x")) (App (Var 1 "G") (Var 0 "x")))))))))
(checkType M.empty [] $ Abs "S" Star (Abs "P" (Pi "" (Var 0 "S") Star) (Abs "Q" (Pi "" (Var 1 "S") Star) (Abs "H" (Pi "x" (Var 2 "S") (Pi "" (App (Var 2 "P") (Var 0 "x")) (App (Var 2 "Q") (Var 1 "x")))) (Abs "G" (Pi "x" (Var 3 "S") (App (Var 3 "P") (Var 0 "x"))) (Abs "x" (Var 4 "S") (App (App (Var 2 "H") (Var 0 "x")) (App (Var 1 "G") (Var 0 "x")))))))))
tests :: Test
tests =
TestList
[ TestLabel "sort" sort
, TestLabel "λ→" stlc
, TestLabel "λ→" $ TestList [stlc, freeVar]
, TestLabel "λ2" polyIdent
, TestLabel "λω" (TestList [typeCons, useTypeCons])
, TestLabel "λP2" (TestList [dependent, useDependent])
, TestLabel "λω" $ TestList [typeCons, useTypeCons]
, TestLabel "λP2" $ TestList [dependent, useDependent]
, TestLabel "λC" big
]

View file

@ -1,6 +1,7 @@
module ExprTests (tests) where
import Expr
import Eval
import Test.HUnit
inner :: Expr
@ -34,13 +35,13 @@ substE1 =
after
(subst 0 (Var 2 "B") inner)
betaNFe1 :: Test
betaNFe1 =
whnfE1 :: Test
whnfE1 =
TestCase $
assertEqual
"e1 B"
after
(betaNF $ App e1 $ Var 2 "B")
(whnf $ App e1 $ Var 2 "B")
tests :: Test
tests =
@ -48,5 +49,5 @@ tests =
[ TestLabel "fFree" fFree
, TestLabel "incE1" incE1
, TestLabel "substE1" substE1
, TestLabel "betaNFe1" betaNFe1
, TestLabel "whnfE1" whnfE1
]

View file

@ -1,7 +1,5 @@
module ParserTests (tests) where
import Data.Text (Text)
import qualified Data.Text as T
import Expr (Expr (..))
import Parser (pAll)
import Test.HUnit