You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
232 lines
9.4 KiB
232 lines
9.4 KiB
{-| |
|
Module: Simple.TC |
|
Description: Simple type checker. |
|
|
|
A type checker for a simple subset of the language. |
|
|
|
TODO: major refactor, split code accros modules |
|
-} |
|
{-# LANGUAGE LambdaCase, TupleSections #-} |
|
module Simple.TC (runChecks, errorStr) where |
|
|
|
import Control.Monad.State hiding (guard) |
|
import Control.Monad.Except hiding (guard) |
|
import qualified Data.Map as M |
|
import qualified Data.Set as S |
|
|
|
import Simple.TC.Types |
|
import Simple.TC.TypeOps |
|
import Simple.AST |
|
import Types |
|
import Misc |
|
|
|
-- | Error out on false. |
|
guard :: TypeError -> Bool -> Check () |
|
guard _ True = pure () |
|
guard e False = throwError e |
|
|
|
-- | Error out on false and return given type on true. |
|
guardt :: Type -> TypeError -> Bool -> Check Type |
|
guardt t _ True = pure t |
|
guardt _ e False = throwError e |
|
|
|
-- | Convert a list of identifiers and expressions to the type |
|
-- associated with each identifier. |
|
iesTotis :: [(Identifier, Expr)] -> Check [(Type, Identifier)] |
|
iesTotis ies = traverse (\(i,e) -> (,i) <$> inferExpr e) ies |
|
|
|
-- | Create function to third arg |
|
tsToFunc :: PN -> [Type] -> Type -> Type |
|
tsToFunc _ [] b = b |
|
tsToFunc p (t:ts) b = FuncType p t (tsToFunc p ts b) |
|
|
|
-- | Returns the constructors of a type, currently not used. might be used for |
|
-- totality checking of patterns in the future. |
|
consOf :: Type -> Check [(Type, Identifier)] |
|
consOf t = pure . filter (constructs t . fst) . map swap . M.toList =<< getEnv |
|
where |
|
-- | checks if second arg constructs first arg |
|
constructs :: Type -> Type -> Bool |
|
constructs t1 t2 | t1 == t2 = True |
|
constructs t1 (FuncType _ _ t2) = constructs t1 t2 |
|
constructs _ _ = False |
|
|
|
-- | Check that expression has certain type. Error on fail, return type on success. |
|
checkExpr :: Expr -> Type -> Check Type |
|
checkExpr expr t = checkType t >> inferExpr expr >>= |
|
\t2 -> guardt t (TypeMismatch (pos expr) t t2) (t == t2) |
|
|
|
-- | Infer the type of an expression |
|
inferExpr :: Expr -> Check Type |
|
inferExpr (ExpExpr t e) = checkExpr1 e t >> pure t |
|
inferExpr (ImpExpr e) = inferExpr1 e |
|
|
|
-- | Check that expression1 has certain type. Error on fail, return type on success. |
|
checkExpr1 :: Expr1 -> Type -> Check Type |
|
checkExpr1 expr t = checkType t >> inferExpr1 expr >>= |
|
\t2 -> guardt t (TypeMismatch (pos expr) t t2) (t == t2) |
|
|
|
-- | Infer the type of a function application |
|
inferFunc :: PN -> Expr -> [Expr] -> Check Type |
|
inferFunc p _ [] = throwError (ArityMismatch p) |
|
inferFunc p e es = inferExpr e >>= \t -> go p t es |
|
where |
|
go :: PN -> Type -> [Expr] -> Check Type |
|
go _ t [] = pure t |
|
go p t (e:es) = |
|
case t of |
|
FuncType _ t1 t2 -> checkExpr e t1 >> go p t2 es |
|
_ -> throwError (ArityMismatch p) |
|
|
|
-- | Infer the type of an expression1 |
|
inferExpr1 :: Expr1 -> Check Type |
|
inferExpr1 = \case |
|
Apply p e es -> inferFunc p e es |
|
Case p e pes -> inferExpr e >>= \t -> inferCase p t pes |
|
Let _ ies e -> iesTotis ies >>= \tis -> withBindings tis (inferExpr e) |
|
Inst p i ies -> getRecEnv >>= \renv -> |
|
case M.lookup i renv of |
|
Just rcons -> guard (IncompleteInstance p) (M.size rcons == length ies) |
|
>> mapM_ (\(i',e) -> case M.lookup i' rcons of |
|
Just t -> checkExpr e t |
|
Nothing -> throwError (InvalidRecordField p i')) ies |
|
Nothing -> throwError (UndefinedType p i) |
|
>> pure (UsrType p i) |
|
|
|
Term t -> inferTerm t |
|
|
|
-- | Infer the type of a term |
|
inferTerm :: Term -> Check Type |
|
inferTerm = \case |
|
TLit l -> inferLiteral l |
|
TLambda p tis e -> checkTypes tis >> tsToFunc p (fst <$> tis) <$> withBindings tis (inferExpr e) |
|
TLambdaCase p t pes -> FuncType p <$> checkType t <*> inferCase p t pes |
|
TVar p i -> getEnv >>= \env -> case M.lookup i env of |
|
Just t -> pure t |
|
Nothing -> throwError (UnboundVar p i) |
|
|
|
-- | "Infer" the type of a literal |
|
inferLiteral :: Literal -> Check Type |
|
inferLiteral = \case |
|
LInt p _ -> pure (IntType p) |
|
LChr p _ -> pure (ChrType p) |
|
LStr p _ -> pure (StrType p) |
|
|
|
-- | Infer the type of a case statement. Unfortunately does not perform any totality checks. |
|
inferCase :: PN -> Type -> [(Pattern, Expr)] -> Check Type |
|
inferCase p _ [] = throwError (NoCase p) |
|
inferCase _ t ((pt,e):[]) = checkPattern pt t >> inferExpr e |
|
inferCase p t ((pt,e):pes) = checkPattern pt t >> inferCase p t pes |
|
>>= \t -> patternBindings t pt >>= \bs -> withBindings bs (checkExpr e t) |
|
|
|
-- | Return the bound type variables in a pattern. |
|
patternBindings :: Type -> Pattern -> Check [(Type, Identifier)] |
|
patternBindings target = \case |
|
Wild _ -> pure [] |
|
PLit _ -> pure [] |
|
PVar _ i -> getEnv >>= \env -> case M.lookup i env of |
|
Just _ -> pure [] |
|
Nothing -> pure [(target, i)] |
|
PApp p _ [] -> throwError (UnknownPattern p) |
|
PApp p i ps -> getEnv >>= \env -> case M.lookup i env of |
|
Just t -> go p t ps |
|
Nothing -> throwError (UnknownPattern p) |
|
where |
|
go :: PN -> Type -> [Pattern] -> Check [(Type, Identifier)] |
|
go _ _ [] = pure [] |
|
go p t (pt:ps) = |
|
case t of |
|
FuncType _ t1 t2 -> case pt of |
|
PVar _ i -> ((t1, i):) <$> go p t2 ps |
|
PApp p i ps -> getEnv >>= \env -> case M.lookup i env of |
|
Just t -> go p t ps |
|
Nothing -> throwError (UnknownPattern p) |
|
_ -> go p t2 ps |
|
_ -> throwError (ArityMismatch p) |
|
|
|
-- | Check that a pattern matches on a specified type. |
|
checkPattern :: Pattern -> Type -> Check Type |
|
checkPattern pt t = checkType t >> inferPattern t pt >>= |
|
\t2 -> guardt t (TypeMismatch (pos pt) t t2) (t == t2) |
|
|
|
-- | Infer which type a pattern matches on. Note that the type we are matching on must be |
|
-- provided in order to have meaningful type variables and wild cards. |
|
inferPattern :: Type -> Pattern -> Check Type |
|
inferPattern target = \case |
|
Wild _ -> pure target |
|
PVar _ i -> getEnv >>= \env -> case M.lookup i env of |
|
Just t -> pure t |
|
Nothing -> pure target -- variable |
|
PLit l -> inferLiteral l |
|
PApp p _ [] -> throwError (UnknownPattern p) |
|
PApp p i ps -> getEnv >>= \env -> case M.lookup i env of |
|
Just t -> go p t ps |
|
Nothing -> throwError (UnknownPattern p) |
|
where |
|
go :: PN -> Type -> [Pattern] -> Check Type |
|
go _ t [] = pure t |
|
go p t (pt:ps) = |
|
case t of |
|
FuncType _ t1 t2 -> checkPattern pt t1 >> go p t2 ps |
|
_ -> throwError (ArityMismatch p) |
|
|
|
|
|
-- | Insert bind a top level binding. This does not check for type correctness of expressions, but rather |
|
-- only serves to prepare for checking our top levels. |
|
insertTLBinding :: TopLevel -> Check () |
|
insertTLBinding (Def p t i _ ) = insertUniqueBindings p [(t,i)] |
|
insertTLBinding (Dat p i tis) = insertUniqueTypes p [i] >> insertUniqueBindings p tis |
|
insertTLBinding (Rec p i tis) = insertUniqueTypes p [i] >> insertUniqueBindings p (constructor : deconstructors) |
|
>> insertRecCons p i tis |
|
where |
|
constructor :: (Type, Identifier) |
|
constructor = (tsToFunc p (fst <$> tis) curType, i) |
|
|
|
deconstructors :: [(Type, Identifier)] |
|
deconstructors = go curType tis |
|
|
|
curType :: Type |
|
curType = UsrType p i |
|
|
|
go :: Type -> [(Type, Identifier)] -> [(Type, Identifier)] |
|
go t ts = (\(t',i) -> (go1 t t',i)) <$> ts |
|
|
|
go1 :: Type -> Type -> Type |
|
go1 t1 t2 = FuncType (pos t2) t1 t2 |
|
|
|
-- | Checks a top level. Assumes you've already bound all top levels. |
|
checkTopLevel :: TopLevel -> Check () |
|
checkTopLevel (Def _ _ _ []) = throwError Urk -- this should not be able to be (use non-empty lists you fool) |
|
checkTopLevel (Def p _ i pes) = getEnv >>= \env -> |
|
case M.lookup i env of |
|
Just t -> mapM_ (checkCases p t) pes |
|
Nothing -> throwError Urk -- should be in env at this point |
|
|
|
-- TODO verify types |
|
checkTopLevel (Dat _ _ tis) = checkTypes tis |
|
checkTopLevel (Rec _ _ tis) = checkTypes tis |
|
|
|
-- | Reduces a given type by matching patterns on it, removing arrows. |
|
checkCases :: PN -> Type -> ([Pattern], Expr) -> Check Type |
|
checkCases _ t ([],e) = checkExpr e t |
|
checkCases p (FuncType _ t1 t2) (pe:ps,e) = checkPattern pe t1 >> patternBindings t1 pe >>= |
|
\b -> withBindings b (checkCases p t2 (ps,e)) |
|
checkCases p _ _ = throwError (ArityMismatch p) |
|
|
|
-- | Check that any user types are defined |
|
checkType :: Type -> Check Type |
|
checkType (FuncType p t1 t2) = FuncType p <$> checkType t1 <*> checkType t2 |
|
checkType t@(UsrType p i) = getTypeEnv >>= guardt t (UndefinedType p i) . S.member i |
|
checkType t = pure t |
|
|
|
-- | Check a bunch of types |
|
checkTypes :: [(Type, a)] -> Check () |
|
checkTypes = mapM_ (checkType . fst) |
|
|
|
-- | Check an AST |
|
checkAST :: [TopLevel] -> Check () |
|
checkAST tls = mapM_ insertTLBinding tls >> mapM_ checkTopLevel tls |
|
|
|
-- | Check an AST in the IO monad |
|
runChecks :: [TopLevel] -> IO (Either TypeError CheckEnv) |
|
runChecks tls = (fmap . fmap) snd (runExceptT (runStateT (runCheck (checkAST tls)) initialState))
|
|
|