diff --git a/app/Main.hs b/app/Main.hs index 21e56b1..1f04bd9 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -16,6 +16,6 @@ main = getArgs >>= \case [] -> exitSuccess f:_ -> readFile f >>= \t -> case runAlex' parse f t >>= sanityCheck f t >>= convert f t of Right r -> runChecks r >>= \case - Right (_, env) -> print env - Left err -> print err + Right env -> print env + Left err -> putStrLn (errorStr f t err) Left e -> putStrLn e diff --git a/readme.simple.txt b/readme.simple.txt index 3d519e4..32890a7 100644 --- a/readme.simple.txt +++ b/readme.simple.txt @@ -1,22 +1,28 @@ -- simple test -(data Bool ([Bool] True) - ([Bool] False)) +(data 𝔹 ([𝔹] 𝕋) + ([𝔹] 𝔽)) -[Bool -> Bool] -(def (((not True) False) - ((not False) True))) +[𝔹 → 𝔹] +(def (((not 𝕋) 𝔽) + ((not 𝔽) 𝕋))) -[Integer -> Bool] -(def (((is1 1) True) - ((is1 _) False))) +[ℤ → 𝔹] +(def (((is1 1) 𝕋) + ((is1 _) 𝔽))) -(data IntList ([IntList] IEmpty) - ([Integer -> IntList -> IntList] ICons)) +(data ℤList ([ℤList] IEmpty) + ([ℤ → ℤList → ℤList] ICons)) -(data BoolList ([BoolList] BEmpty) - ([Bool -> BoolList -> BoolList] BCons)) +(data 𝔹List ([𝔹List] BEmpty) + ([𝔹 → 𝔹List → 𝔹List] BCons)) -[IntList -> BoolList] +[ℤList → 𝔹List] (def (((is1l IEmpty) BEmpty) ((is1l (ICons x xs)) (BCons (is1 x) (is1l xs))))) + +[ℤList → 𝔹List] +(def is1l2 is1l) + +[𝔹 → ℤList → 𝔹List] +(def is1l3 (lambda [𝔹] x is1l)) diff --git a/sexprml.cabal b/sexprml.cabal index cd54ef2..320fc18 100644 --- a/sexprml.cabal +++ b/sexprml.cabal @@ -30,6 +30,8 @@ library , Simple.AST , Simple.Convert , Simple.TC + , Simple.TC.Types + , Simple.TC.TypeOps build-tool-depends: alex:alex >= 3.0, happy:happy >= 1.19.5 build-depends: base ^>=4.14.1.0 , array diff --git a/src/Lexer/Lexer.x b/src/Lexer/Lexer.x index 21a942f..e454149 100644 --- a/src/Lexer/Lexer.x +++ b/src/Lexer/Lexer.x @@ -46,10 +46,10 @@ $ident = $printable # $special # $white @string = \" (@rawchar) * \" @Npl = $digit+pl -@typetype = "Type" -- macro cause I might want to change it -@stringtype = 𝕊 | "String" -@chartype = ℂ | "Char" @inttype = ℤ | "Integer" +@typetype = "Type" +@stringtype = "String" +@chartype = "Char" @lambdacase = λcase | "lambdacase" diff --git a/src/Misc.hs b/src/Misc.hs index b588a62..bcc27f3 100644 --- a/src/Misc.hs +++ b/src/Misc.hs @@ -9,3 +9,7 @@ module Misc where infixl 6 <~> (<~>) :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b) (<~>) = traverse + +-- | Swap the two elements of a pair +swap :: (a,b) -> (b,a) +swap (a,b) = (b,a) diff --git a/src/Parser/Types.hs b/src/Parser/Types.hs index 7b7612f..2e51980 100644 --- a/src/Parser/Types.hs +++ b/src/Parser/Types.hs @@ -119,7 +119,7 @@ instance Positioned Arg1 where pos (EArg e) = pos e pos (Wild p) = p --- | Body +-- | Body, explicitly or implicitly typed data Expr = TExp TYSG Expr1 | EExp Expr1 @@ -129,6 +129,7 @@ instance Positioned Expr where pos (TExp _ e) = pos e pos (EExp e) = pos e +-- | Body data Expr1 = Apply PN Expr [Arg] | ECase PN Expr [(Pattern, Expr)] diff --git a/src/Simple/AST.hs b/src/Simple/AST.hs index c5b2901..3eac013 100644 --- a/src/Simple/AST.hs +++ b/src/Simple/AST.hs @@ -8,6 +8,7 @@ module Simple.AST where import Types +-- | A type in the language data Type = FuncType PN Type Type | ChrType PN @@ -16,6 +17,16 @@ data Type | UsrType PN Identifier deriving Show +-- | Prints a pretty type signature for a given type. +pretty :: Type -> String +pretty (FuncType _ t1@(FuncType _ _ _) t2) = "(" <> pretty t1 <> ") → " <> pretty t2 +pretty (FuncType _ t1 t2) = pretty t1 <> " → " <> pretty t2 +pretty (ChrType _) = "Char" +pretty (StrType _) = "String" +pretty (IntType _) = "ℤ" +pretty (UsrType _ i) = unId i + +-- | Only compares type information, ignores position information. instance Eq Type where FuncType _ t11 t12 == FuncType _ t21 t22 = t11 == t21 && t12 == t22 ChrType _ == ChrType _ = True @@ -31,6 +42,7 @@ instance Positioned Type where pos (IntType p) = p pos (UsrType p _) = p +-- | A top level declaration. data TopLevel = Def PN Type Identifier [([Pattern], Expr)] -- ^ Variable and function definition | Dat PN Identifier [(Type, Identifier)] -- ^ Data definition @@ -42,6 +54,7 @@ instance Positioned TopLevel where pos (Dat p _ _) = p pos (Rec p _ _) = p +-- | A pattern to match on. data Pattern = Wild PN | PVar PN Identifier @@ -55,6 +68,7 @@ instance Positioned Pattern where pos (PLit l) = pos l pos (PApp p _ _) = p +-- | A literal value of a builtin type. data Literal = LInt PN Integer | LChr PN Char @@ -66,6 +80,7 @@ instance Positioned Literal where pos (LChr p _) = p pos (LStr p _) = p +-- | A term data Term = TLit Literal | TLambda PN [(Type, Identifier)] Expr @@ -79,6 +94,7 @@ instance Positioned Term where pos (TLit l) = pos l pos (TVar p _) = p +-- | An expression which is either explicitly or implicitly typed data Expr = ExpExpr Type Expr1 | ImpExpr Expr1 @@ -88,6 +104,7 @@ instance Positioned Expr where pos (ExpExpr _ e) = pos e pos (ImpExpr e) = pos e +-- | An expression data Expr1 = Apply PN Expr [Expr] | Case PN Expr [(Pattern, Expr)] diff --git a/src/Simple/Convert.hs b/src/Simple/Convert.hs index af70846..e66ee78 100644 --- a/src/Simple/Convert.hs +++ b/src/Simple/Convert.hs @@ -15,13 +15,16 @@ import Misc import qualified Parser.Types as P import qualified Simple.AST as S +-- | Convert from the original parse tree to the simpler subset convert :: FilePath -> String -> [P.TL] -> Either String [S.TopLevel] convert fp tx = traverse convertTL where - -- | create an error message at a provided position + -- | Create an error message at a provided position lerror :: PN -> String -> Either String a lerror p = Left . errorMessage p fp tx + -- All the following are functions to convert different types from the P module to the S module + convertTL :: P.TL -> Either String S.TopLevel convertTL (P.ExDef p t i pes) = S.Def p <$> convertTYSG t <*> pure i <*> convertPE2 <~> pes convertTL (P.DtDef p i (_:_) _) = lerror p ("Type arguments in definition of type '" <> unId i <> "'") diff --git a/src/Simple/TC.hs b/src/Simple/TC.hs index ac1ab44..cb02750 100644 --- a/src/Simple/TC.hs +++ b/src/Simple/TC.hs @@ -6,139 +6,67 @@ A type checker for a simple subset of the language. TODO: major refactor, split code accros modules -} -{-# LANGUAGE GeneralizedNewtypeDeriving, LambdaCase, TupleSections, TemplateHaskell #-} -module Simple.TC (runChecks) where - -import Simple.AST -import Types +{-# LANGUAGE LambdaCase, TupleSections #-} +module Simple.TC (runChecks, errorStr) where import Control.Monad.State hiding (guard) import Control.Monad.Except hiding (guard) -import Data.Map (Map) -import Data.Set (Set) import qualified Data.Map as M import qualified Data.Set as S -import Lens.Micro -import Lens.Micro.TH - -data TypeError - = Urk - | TypeMismatch PN Type Type -- ^ expected, got - | ArityMismatch PN - | UnboundVar PN Identifier - | NoCase PN - | UnknownPattern PN - | UndefinedType PN Identifier - | InvalidRecordField PN Identifier - | IncompleteInstance PN Identifier - deriving (Show) - -type Env a = Map Identifier a - -data CheckEnv = CheckEnv { _defs :: Env Type - , _types :: Set Identifier - , _rec :: Env (Env Type) - } deriving Show - -makeLenses ''CheckEnv - -initialState :: CheckEnv -initialState = CheckEnv M.empty S.empty M.empty - -newtype Check a = Check { runCheck :: StateT CheckEnv (ExceptT TypeError IO) a } - deriving (Functor, Applicative, Monad, MonadError TypeError, MonadState CheckEnv) - -swap :: (a,b) -> (b,a) -swap (a,b) = (b,a) - -local :: (CheckEnv -> CheckEnv) -> Check a -> Check a -local f ca = get >>= \s -> put (f s) >> ca >>= \r -> put s >> pure r - -modState :: (CheckEnv -> CheckEnv) -> Check () -modState f = get >>= put . f - -getEnv :: Check (Env Type) -getEnv = (^. defs) <$> get - -localEnv :: (Env Type -> Env Type) -> Check a -> Check a -localEnv f = local (\env -> env & defs %~ f) - -modifyEnv :: (Env Type -> Env Type) -> Check () -modifyEnv f = modState (\env -> env & defs %~ f) - -modifyTypeEnv :: (Set Identifier -> Set Identifier) -> Check () -modifyTypeEnv f = modState (\env -> env & types %~ f) - -modifyRecEnv :: (Env (Env Type) -> Env (Env Type)) -> Check () -modifyRecEnv f = modState (\env -> env & rec %~ f) - -getTypeEnv :: Check (Set Identifier) -getTypeEnv = (^. types) <$> get - -getRecEnv :: Check (Env (Env Type)) -getRecEnv = (^. rec) <$> get - -localTypeEnv :: (Set Identifier -> Set Identifier) -> Check a -> Check a -localTypeEnv f = local (\env -> env & types %~ f) +import Simple.TC.Types +import Simple.TC.TypeOps +import Simple.AST +import Types +import Misc +-- | Error out on false. guard :: TypeError -> Bool -> Check () guard _ True = pure () guard e False = throwError e +-- | Error out on false and return given type on true. guardt :: Type -> TypeError -> Bool -> Check Type guardt t _ True = pure t guardt _ e False = throwError e -withBindings :: [(Type, Identifier)] -> Check a -> Check a -withBindings = localEnv . (flip insertMapBindings) - -insertBindings :: [(Type, Identifier)] -> Check () -insertBindings = modifyEnv . (flip insertMapBindings) - -insertTypes :: [Identifier] -> Check () -insertTypes = modifyTypeEnv . (flip insertSetTypes) - -insertMapBindings :: Env a -> [(a, Identifier)] -> Env a -insertMapBindings = foldr (uncurry (flip M.insert)) - -insertSetTypes :: Ord a => Set a -> [a] -> Set a -insertSetTypes = foldr S.insert - -insertRecCons :: Identifier -> [(Type, Identifier)] -> Check () -insertRecCons i = modifyRecEnv . M.insert i . M.fromList . map swap - +-- | Convert a list of identifiers and expressions to the type +-- associated with each identifier. iesTotis :: [(Identifier, Expr)] -> Check [(Type, Identifier)] -iesTotis ies = getEnv >>= \env -> traverse (\(i,e) -> (,i) <$> inferExpr e) ies +iesTotis ies = traverse (\(i,e) -> (,i) <$> inferExpr e) ies --- ^ Create function to third arg +-- | Create function to third arg tsToFunc :: PN -> [Type] -> Type -> Type tsToFunc _ [] b = b tsToFunc p (t:ts) b = FuncType p t (tsToFunc p ts b) --- ^ Constructors of.. maybe I don't need this lol +-- | Returns the constructors of a type, currently not used. might be used for +-- totality checking of patterns in the future. consOf :: Type -> Check [(Type, Identifier)] consOf t = pure . filter (constructs t . fst) . map swap . M.toList =<< getEnv where - -- ^ checks if second arg constructs first arg + -- | checks if second arg constructs first arg constructs :: Type -> Type -> Bool - constructs t1 t2 | t1 == t2 = True - constructs t1 (FuncType _ t2 t3) = constructs t1 t3 - constructs _ _ = False + constructs t1 t2 | t1 == t2 = True + constructs t1 (FuncType _ _ t2) = constructs t1 t2 + constructs _ _ = False +-- | Check that expression has certain type. Error on fail, return type on success. checkExpr :: Expr -> Type -> Check Type checkExpr expr t = checkType t >> inferExpr expr >>= \t2 -> guardt t (TypeMismatch (pos expr) t t2) (t == t2) +-- | Infer the type of an expression inferExpr :: Expr -> Check Type -inferExpr = \case - ExpExpr t e -> checkExpr1 e t >> pure t - ImpExpr e -> inferExpr1 e +inferExpr (ExpExpr t e) = checkExpr1 e t >> pure t +inferExpr (ImpExpr e) = inferExpr1 e +-- | Check that expression1 has certain type. Error on fail, return type on success. checkExpr1 :: Expr1 -> Type -> Check Type checkExpr1 expr t = checkType t >> inferExpr1 expr >>= \t2 -> guardt t (TypeMismatch (pos expr) t t2) (t == t2) +-- | Infer the type of a function application inferFunc :: PN -> Expr -> [Expr] -> Check Type inferFunc p _ [] = throwError (ArityMismatch p) inferFunc p e es = inferExpr e >>= \t -> go p t es @@ -150,26 +78,24 @@ inferFunc p e es = inferExpr e >>= \t -> go p t es FuncType _ t1 t2 -> checkExpr e t1 >> go p t2 es _ -> throwError (ArityMismatch p) +-- | Infer the type of an expression1 inferExpr1 :: Expr1 -> Check Type inferExpr1 = \case - Apply p e es -> inferFunc p e es - Case p e pes -> inferExpr e >>= \t -> inferCase p t pes - Let _ ies e -> iesTotis ies >>= \tis -> withBindings tis (inferExpr e) - Inst p i ies -> getRecEnv >>= \renv -> - case M.lookup i renv of - Just rcons -> guard (IncompleteInstance p i) (M.size rcons == length ies) - >> mapM_ (\(i',e) -> case M.lookup i' rcons of - Just t -> checkExpr e t - Nothing -> throwError (InvalidRecordField p i')) ies - Nothing -> throwError (UndefinedType p i) - >> pure (UsrType p i) - - Term t -> inferTerm t - -checkTerm :: Term -> Type -> Check Type -checkTerm term t = checkType t >> inferTerm term >>= - \t2 -> guardt t (TypeMismatch (pos term) t t2) (t == t2) - + Apply p e es -> inferFunc p e es + Case p e pes -> inferExpr e >>= \t -> inferCase p t pes + Let _ ies e -> iesTotis ies >>= \tis -> withBindings tis (inferExpr e) + Inst p i ies -> getRecEnv >>= \renv -> + case M.lookup i renv of + Just rcons -> guard (IncompleteInstance p) (M.size rcons == length ies) + >> mapM_ (\(i',e) -> case M.lookup i' rcons of + Just t -> checkExpr e t + Nothing -> throwError (InvalidRecordField p i')) ies + Nothing -> throwError (UndefinedType p i) + >> pure (UsrType p i) + + Term t -> inferTerm t + +-- | Infer the type of a term inferTerm :: Term -> Check Type inferTerm = \case TLit l -> inferLiteral l @@ -179,23 +105,21 @@ inferTerm = \case Just t -> pure t Nothing -> throwError (UnboundVar p i) -checkLiteral :: Literal -> Type -> Check Type -checkLiteral lit t = checkType t >> inferLiteral lit >>= - \t2 -> guardt t (TypeMismatch (pos lit) t t2) (t == t2) - +-- | "Infer" the type of a literal inferLiteral :: Literal -> Check Type inferLiteral = \case LInt p _ -> pure (IntType p) LChr p _ -> pure (ChrType p) LStr p _ -> pure (StrType p) --- no totality checking :/ +-- | Infer the type of a case statement. Unfortunately does not perform any totality checks. inferCase :: PN -> Type -> [(Pattern, Expr)] -> Check Type inferCase p _ [] = throwError (NoCase p) -inferCase p t ((pt,e):[]) = checkPattern pt t >> inferExpr e +inferCase _ t ((pt,e):[]) = checkPattern pt t >> inferExpr e inferCase p t ((pt,e):pes) = checkPattern pt t >> inferCase p t pes >>= \t -> patternBindings t pt >>= \bs -> withBindings bs (checkExpr e t) +-- | Return the bound type variables in a pattern. patternBindings :: Type -> Pattern -> Check [(Type, Identifier)] patternBindings target = \case Wild _ -> pure [] @@ -220,14 +144,17 @@ patternBindings target = \case _ -> go p t2 ps _ -> throwError (ArityMismatch p) +-- | Check that a pattern matches on a specified type. checkPattern :: Pattern -> Type -> Check Type checkPattern pt t = checkType t >> inferPattern t pt >>= \t2 -> guardt t (TypeMismatch (pos pt) t t2) (t == t2) +-- | Infer which type a pattern matches on. Note that the type we are matching on must be +-- provided in order to have meaningful type variables and wild cards. inferPattern :: Type -> Pattern -> Check Type inferPattern target = \case - Wild p -> pure target - PVar p i -> getEnv >>= \env -> case M.lookup i env of + Wild _ -> pure target + PVar _ i -> getEnv >>= \env -> case M.lookup i env of Just t -> pure t Nothing -> pure target -- variable PLit l -> inferLiteral l @@ -244,12 +171,13 @@ inferPattern target = \case _ -> throwError (ArityMismatch p) --- This goes on the users given declerations, if they've been wrong it'll error out later +-- | Insert bind a top level binding. This does not check for type correctness of expressions, but rather +-- only serves to prepare for checking our top levels. insertTLBinding :: TopLevel -> Check () -insertTLBinding (Def _ t i _ ) = insertBindings [(t,i)] -insertTLBinding (Dat _ i tis) = insertTypes [i] >> insertBindings tis -insertTLBinding (Rec p i tis) = insertTypes [i] >> insertBindings (constructor : deconstructors) - >> insertRecCons i tis +insertTLBinding (Def p t i _ ) = insertUniqueBindings p [(t,i)] +insertTLBinding (Dat p i tis) = insertUniqueTypes p [i] >> insertUniqueBindings p tis +insertTLBinding (Rec p i tis) = insertUniqueTypes p [i] >> insertUniqueBindings p (constructor : deconstructors) + >> insertRecCons p i tis where constructor :: (Type, Identifier) constructor = (tsToFunc p (fst <$> tis) curType, i) @@ -266,7 +194,7 @@ insertTLBinding (Rec p i tis) = insertTypes [i] >> insertBindings (constructor go1 :: Type -> Type -> Type go1 t1 t2 = FuncType (pos t2) t1 t2 --- Assumes you've already bound all top levels +-- | Checks a top level. Assumes you've already bound all top levels. checkTopLevel :: TopLevel -> Check () checkTopLevel (Def _ _ _ []) = throwError Urk -- this should not be able to be (use non-empty lists you fool) checkTopLevel (Def p _ i pes) = getEnv >>= \env -> @@ -274,30 +202,31 @@ checkTopLevel (Def p _ i pes) = getEnv >>= \env -> Just t -> mapM_ (checkCases p t) pes Nothing -> throwError Urk -- should be in env at this point -checkTopLevel (Dat p i tis) = checkTypes tis -checkTopLevel (Rec p i tis) = checkTypes tis +-- TODO verify types +checkTopLevel (Dat _ _ tis) = checkTypes tis +checkTopLevel (Rec _ _ tis) = checkTypes tis --- if we have no patterns, simply check type, otherwise match pattern with function type and recurse +-- | Reduces a given type by matching patterns on it, removing arrows. checkCases :: PN -> Type -> ([Pattern], Expr) -> Check Type checkCases _ t ([],e) = checkExpr e t -checkCases p (FuncType _ t1 t2) (pe:ps,e) = checkPattern pe t1 >> - patternBindings t1 pe >>= \b -> withBindings b (checkCases p t2 (ps,e)) +checkCases p (FuncType _ t1 t2) (pe:ps,e) = checkPattern pe t1 >> patternBindings t1 pe >>= + \b -> withBindings b (checkCases p t2 (ps,e)) checkCases p _ _ = throwError (ArityMismatch p) --- checkPattern :: Pattern -> Type -> Check Type --- checkPattern pt t = checkType t >> inferPattern t pt >>= - --- check that shit is in scope +-- | Check that any user types are defined checkType :: Type -> Check Type checkType (FuncType p t1 t2) = FuncType p <$> checkType t1 <*> checkType t2 checkType t@(UsrType p i) = getTypeEnv >>= guardt t (UndefinedType p i) . S.member i checkType t = pure t +-- | Check a bunch of types checkTypes :: [(Type, a)] -> Check () checkTypes = mapM_ (checkType . fst) +-- | Check an AST checkAST :: [TopLevel] -> Check () checkAST tls = mapM_ insertTLBinding tls >> mapM_ checkTopLevel tls -runChecks :: [TopLevel] -> IO (Either TypeError ((), CheckEnv)) -runChecks tls = runExceptT (runStateT (runCheck (checkAST tls)) initialState) +-- | Check an AST in the IO monad +runChecks :: [TopLevel] -> IO (Either TypeError CheckEnv) +runChecks tls = (fmap . fmap) snd (runExceptT (runStateT (runCheck (checkAST tls)) initialState)) diff --git a/src/Simple/TC/TypeOps.hs b/src/Simple/TC/TypeOps.hs new file mode 100644 index 0000000..8a9e8a5 --- /dev/null +++ b/src/Simple/TC/TypeOps.hs @@ -0,0 +1,94 @@ +{-| +Module: Simple.TC.TypeOps +Description: Defines operations used on typechecker types + +This module defines a lot of useful operation over our typechecker monad +-} +module Simple.TC.TypeOps where + +import Control.Monad.State hiding (guard) +import Control.Monad.Except hiding (guard) +import Data.Set (Set) +import qualified Data.Map as M +import qualified Data.Set as S + +import Lens.Micro + +import Simple.TC.Types +import Simple.AST +import Types +import Misc + +-- | Change state only for some computation +local :: (CheckEnv -> CheckEnv) -> Check a -> Check a +local f ca = get >>= \s -> put (f s) >> ca >>= \r -> put s >> pure r + +-- | Change the state somehow +modState :: (CheckEnv -> CheckEnv) -> Check () +modState f = get >>= put . f + +-- | Get the value level enviornment +getEnv :: Check (Env Type) +getEnv = (^. defs) <$> get + +-- | Change value environment only for some computation. +localEnv :: (Env Type -> Env Type) -> Check a -> Check a +localEnv f = local (\env -> env & defs %~ f) + +-- | Modify the value environment +modifyEnv :: (Env Type -> Env Type) -> Check () +modifyEnv f = modState (\env -> env & defs %~ f) + +-- | Modify the type environment +modifyTypeEnv :: (Set Identifier -> Set Identifier) -> Check () +modifyTypeEnv f = modState (\env -> env & types %~ f) + +-- | Modify the record environment +modifyRecEnv :: (Env (Env Type) -> Env (Env Type)) -> Check () +modifyRecEnv f = modState (\env -> env & rec %~ f) + +-- | Return the type environment +getTypeEnv :: Check (Set Identifier) +getTypeEnv = (^. types) <$> get + +-- | Return the record environment +getRecEnv :: Check (Env (Env Type)) +getRecEnv = (^. rec) <$> get + +-- | Change type environment only for some computation. +localTypeEnv :: (Set Identifier -> Set Identifier) -> Check a -> Check a +localTypeEnv f = local (\env -> env & types %~ f) + +-- | Used to bind local variables. +withBindings :: [(Type, Identifier)] -> Check a -> Check a +withBindings = localEnv . (flip insertMapBindings) + +-- | Used to permanently bind variables. +insertBindings :: [(Type, Identifier)] -> Check () +insertBindings = modifyEnv . (flip insertMapBindings) + +-- | Bind some variables and make sure the names were not taken +insertUniqueBindings :: PN -> [(Type, Identifier)] -> Check () +insertUniqueBindings p ais = getEnv >>= \env -> foldM go env ais >>= modifyEnv . const + where + go :: Env Type -> (Type, Identifier) -> Check (Env Type) + go e (a, i) | M.member i e = throwError (AlreadyBound p i) + | otherwise = pure (M.insert i a e) + +-- | Bind some type names and make sure the names were not taken +insertUniqueTypes :: PN -> [Identifier] -> Check () +insertUniqueTypes p is = getTypeEnv >>= \env -> foldM go env is >>= modifyTypeEnv . const + where + go :: Set Identifier -> Identifier -> Check (Set Identifier) + go s i | S.member i s = throwError (TypeAlreadyBound p i) + | otherwise = pure (S.insert i s) + +-- | Insert some stuff into an environment +insertMapBindings :: Env a -> [(a, Identifier)] -> Env a +insertMapBindings = foldr (uncurry (flip M.insert)) + +-- | Insert record constructors into the environment +insertRecCons :: PN -> Identifier -> [(Type, Identifier)] -> Check () +insertRecCons p i tis = getRecEnv >>= \renv -> if M.member i renv + then throwError (TypeAlreadyBound p i) + else (modifyRecEnv . M.insert i . M.fromList . map swap) tis diff --git a/src/Simple/TC/Types.hs b/src/Simple/TC/Types.hs new file mode 100644 index 0000000..6e98788 --- /dev/null +++ b/src/Simple/TC/Types.hs @@ -0,0 +1,71 @@ +{-| +Module: Simple.TC.Types +Description: Types for the simple typechecker. + +Types for the simple typechecker. +-} +{-# LANGUAGE GeneralizedNewtypeDeriving, TemplateHaskell, LambdaCase #-} +module Simple.TC.Types where + +import Simple.AST +import Types +import Error + +import Control.Monad.State +import Control.Monad.Except + +import Data.Map (Map) +import Data.Set (Set) +import qualified Data.Map as M +import qualified Data.Set as S + +import Lens.Micro.TH + +-- | The different kinds of errors which can occur during checking +data TypeError + = Urk + | TypeMismatch PN Type Type -- ^ expected, got + | ArityMismatch PN + | UnboundVar PN Identifier + | AlreadyBound PN Identifier + | TypeAlreadyBound PN Identifier + | NoCase PN + | UnknownPattern PN + | UndefinedType PN Identifier + | InvalidRecordField PN Identifier + | IncompleteInstance PN + deriving (Show) + +-- | Generates a fancy error string from a TypeError +errorStr :: FilePath -> String -> TypeError -> String +errorStr fp tx = \case + Urk -> "urk" + TypeMismatch p e g -> errorMessage p fp tx ("Type mismatch, expected '" <> pretty e <> "' got '" <> pretty g <> "'") + ArityMismatch p -> errorMessage p fp tx "Arity mismatch" + UnboundVar p i -> errorMessage p fp tx ("Unbound variable '" <> unId i <> "'") + AlreadyBound p i -> errorMessage p fp tx ("Identifier '" <> unId i <> "' already bound") + TypeAlreadyBound p i -> errorMessage p fp tx ("A type of name'" <> unId i <> "' already exists") + NoCase p -> errorMessage p fp tx "No case" + UnknownPattern p -> errorMessage p fp tx "Unknown pattern" + UndefinedType p i -> errorMessage p fp tx ("Undefined type '" <> unId i <> "'") + InvalidRecordField p i -> errorMessage p fp tx ("Invalid record field '" <> unId i <> "'") + IncompleteInstance p -> errorMessage p fp tx "Incomplete instance" + +-- | An enviornment simply maps an identifier to something +type Env a = Map Identifier a + +-- | The complete enviornment of the checker +data CheckEnv = CheckEnv { _defs :: Env Type + , _types :: Set Identifier + , _rec :: Env (Env Type) + } deriving Show + +makeLenses ''CheckEnv + +-- | The initial enviornment of the checker +initialState :: CheckEnv +initialState = CheckEnv M.empty S.empty M.empty + +-- | The monad which checking is performed in +newtype Check a = Check { runCheck :: StateT CheckEnv (ExceptT TypeError IO) a } + deriving (Functor, Applicative, Monad, MonadError TypeError, MonadState CheckEnv)