You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

232 lines
9.4 KiB

{-|
Module: Simple.TC
Description: Simple type checker.
A type checker for a simple subset of the language.
TODO: major refactor, split code accros modules
-}
{-# LANGUAGE LambdaCase, TupleSections #-}
module Simple.TC (runChecks, errorStr) where
import Control.Monad.State hiding (guard)
import Control.Monad.Except hiding (guard)
import qualified Data.Map as M
import qualified Data.Set as S
import Simple.TC.Types
import Simple.TC.TypeOps
import Simple.AST
import Types
import Misc
-- | Error out on false.
guard :: TypeError -> Bool -> Check ()
guard _ True = pure ()
guard e False = throwError e
-- | Error out on false and return given type on true.
guardt :: Type -> TypeError -> Bool -> Check Type
guardt t _ True = pure t
guardt _ e False = throwError e
-- | Convert a list of identifiers and expressions to the type
-- associated with each identifier.
iesTotis :: [(Identifier, Expr)] -> Check [(Type, Identifier)]
iesTotis ies = traverse (\(i,e) -> (,i) <$> inferExpr e) ies
-- | Create function to third arg
tsToFunc :: PN -> [Type] -> Type -> Type
tsToFunc _ [] b = b
tsToFunc p (t:ts) b = FuncType p t (tsToFunc p ts b)
-- | Returns the constructors of a type, currently not used. might be used for
-- totality checking of patterns in the future.
consOf :: Type -> Check [(Type, Identifier)]
consOf t = pure . filter (constructs t . fst) . map swap . M.toList =<< getEnv
where
-- | checks if second arg constructs first arg
constructs :: Type -> Type -> Bool
constructs t1 t2 | t1 == t2 = True
constructs t1 (FuncType _ _ t2) = constructs t1 t2
constructs _ _ = False
-- | Check that expression has certain type. Error on fail, return type on success.
checkExpr :: Expr -> Type -> Check Type
checkExpr expr t = checkType t >> inferExpr expr >>=
\t2 -> guardt t (TypeMismatch (pos expr) t t2) (t == t2)
-- | Infer the type of an expression
inferExpr :: Expr -> Check Type
inferExpr (ExpExpr t e) = checkExpr1 e t >> pure t
inferExpr (ImpExpr e) = inferExpr1 e
-- | Check that expression1 has certain type. Error on fail, return type on success.
checkExpr1 :: Expr1 -> Type -> Check Type
checkExpr1 expr t = checkType t >> inferExpr1 expr >>=
\t2 -> guardt t (TypeMismatch (pos expr) t t2) (t == t2)
-- | Infer the type of a function application
inferFunc :: PN -> Expr -> [Expr] -> Check Type
inferFunc p _ [] = throwError (ArityMismatch p)
inferFunc p e es = inferExpr e >>= \t -> go p t es
where
go :: PN -> Type -> [Expr] -> Check Type
go _ t [] = pure t
go p t (e:es) =
case t of
FuncType _ t1 t2 -> checkExpr e t1 >> go p t2 es
_ -> throwError (ArityMismatch p)
-- | Infer the type of an expression1
inferExpr1 :: Expr1 -> Check Type
inferExpr1 = \case
Apply p e es -> inferFunc p e es
Case p e pes -> inferExpr e >>= \t -> inferCase p t pes
Let _ ies e -> iesTotis ies >>= \tis -> withBindings tis (inferExpr e)
Inst p i ies -> getRecEnv >>= \renv ->
case M.lookup i renv of
Just rcons -> guard (IncompleteInstance p) (M.size rcons == length ies)
>> mapM_ (\(i',e) -> case M.lookup i' rcons of
Just t -> checkExpr e t
Nothing -> throwError (InvalidRecordField p i')) ies
Nothing -> throwError (UndefinedType p i)
>> pure (UsrType p i)
Term t -> inferTerm t
-- | Infer the type of a term
inferTerm :: Term -> Check Type
inferTerm = \case
TLit l -> inferLiteral l
TLambda p tis e -> checkTypes tis >> tsToFunc p (fst <$> tis) <$> withBindings tis (inferExpr e)
TLambdaCase p t pes -> FuncType p <$> checkType t <*> inferCase p t pes
TVar p i -> getEnv >>= \env -> case M.lookup i env of
Just t -> pure t
Nothing -> throwError (UnboundVar p i)
-- | "Infer" the type of a literal
inferLiteral :: Literal -> Check Type
inferLiteral = \case
LInt p _ -> pure (IntType p)
LChr p _ -> pure (ChrType p)
LStr p _ -> pure (StrType p)
-- | Infer the type of a case statement. Unfortunately does not perform any totality checks.
inferCase :: PN -> Type -> [(Pattern, Expr)] -> Check Type
inferCase p _ [] = throwError (NoCase p)
inferCase _ t ((pt,e):[]) = checkPattern pt t >> inferExpr e
inferCase p t ((pt,e):pes) = checkPattern pt t >> inferCase p t pes
>>= \t -> patternBindings t pt >>= \bs -> withBindings bs (checkExpr e t)
-- | Return the bound type variables in a pattern.
patternBindings :: Type -> Pattern -> Check [(Type, Identifier)]
patternBindings target = \case
Wild _ -> pure []
PLit _ -> pure []
PVar _ i -> getEnv >>= \env -> case M.lookup i env of
Just _ -> pure []
Nothing -> pure [(target, i)]
PApp p _ [] -> throwError (UnknownPattern p)
PApp p i ps -> getEnv >>= \env -> case M.lookup i env of
Just t -> go p t ps
Nothing -> throwError (UnknownPattern p)
where
go :: PN -> Type -> [Pattern] -> Check [(Type, Identifier)]
go _ _ [] = pure []
go p t (pt:ps) =
case t of
FuncType _ t1 t2 -> case pt of
PVar _ i -> ((t1, i):) <$> go p t2 ps
PApp p i ps -> getEnv >>= \env -> case M.lookup i env of
Just t -> go p t ps
Nothing -> throwError (UnknownPattern p)
_ -> go p t2 ps
_ -> throwError (ArityMismatch p)
-- | Check that a pattern matches on a specified type.
checkPattern :: Pattern -> Type -> Check Type
checkPattern pt t = checkType t >> inferPattern t pt >>=
\t2 -> guardt t (TypeMismatch (pos pt) t t2) (t == t2)
-- | Infer which type a pattern matches on. Note that the type we are matching on must be
-- provided in order to have meaningful type variables and wild cards.
inferPattern :: Type -> Pattern -> Check Type
inferPattern target = \case
Wild _ -> pure target
PVar _ i -> getEnv >>= \env -> case M.lookup i env of
Just t -> pure t
Nothing -> pure target -- variable
PLit l -> inferLiteral l
PApp p _ [] -> throwError (UnknownPattern p)
PApp p i ps -> getEnv >>= \env -> case M.lookup i env of
Just t -> go p t ps
Nothing -> throwError (UnknownPattern p)
where
go :: PN -> Type -> [Pattern] -> Check Type
go _ t [] = pure t
go p t (pt:ps) =
case t of
FuncType _ t1 t2 -> checkPattern pt t1 >> go p t2 ps
_ -> throwError (ArityMismatch p)
-- | Insert bind a top level binding. This does not check for type correctness of expressions, but rather
-- only serves to prepare for checking our top levels.
insertTLBinding :: TopLevel -> Check ()
insertTLBinding (Def p t i _ ) = insertUniqueBindings p [(t,i)]
insertTLBinding (Dat p i tis) = insertUniqueTypes p [i] >> insertUniqueBindings p tis
insertTLBinding (Rec p i tis) = insertUniqueTypes p [i] >> insertUniqueBindings p (constructor : deconstructors)
>> insertRecCons p i tis
where
constructor :: (Type, Identifier)
constructor = (tsToFunc p (fst <$> tis) curType, i)
deconstructors :: [(Type, Identifier)]
deconstructors = go curType tis
curType :: Type
curType = UsrType p i
go :: Type -> [(Type, Identifier)] -> [(Type, Identifier)]
go t ts = (\(t',i) -> (go1 t t',i)) <$> ts
go1 :: Type -> Type -> Type
go1 t1 t2 = FuncType (pos t2) t1 t2
-- | Checks a top level. Assumes you've already bound all top levels.
checkTopLevel :: TopLevel -> Check ()
checkTopLevel (Def _ _ _ []) = throwError Urk -- this should not be able to be (use non-empty lists you fool)
checkTopLevel (Def p _ i pes) = getEnv >>= \env ->
case M.lookup i env of
Just t -> mapM_ (checkCases p t) pes
Nothing -> throwError Urk -- should be in env at this point
-- TODO verify types
checkTopLevel (Dat _ _ tis) = checkTypes tis
checkTopLevel (Rec _ _ tis) = checkTypes tis
-- | Reduces a given type by matching patterns on it, removing arrows.
checkCases :: PN -> Type -> ([Pattern], Expr) -> Check Type
checkCases _ t ([],e) = checkExpr e t
checkCases p (FuncType _ t1 t2) (pe:ps,e) = checkPattern pe t1 >> patternBindings t1 pe >>=
\b -> withBindings b (checkCases p t2 (ps,e))
checkCases p _ _ = throwError (ArityMismatch p)
-- | Check that any user types are defined
checkType :: Type -> Check Type
checkType (FuncType p t1 t2) = FuncType p <$> checkType t1 <*> checkType t2
checkType t@(UsrType p i) = getTypeEnv >>= guardt t (UndefinedType p i) . S.member i
checkType t = pure t
-- | Check a bunch of types
checkTypes :: [(Type, a)] -> Check ()
checkTypes = mapM_ (checkType . fst)
-- | Check an AST
checkAST :: [TopLevel] -> Check ()
checkAST tls = mapM_ insertTLBinding tls >> mapM_ checkTopLevel tls
-- | Check an AST in the IO monad
runChecks :: [TopLevel] -> IO (Either TypeError CheckEnv)
runChecks tls = (fmap . fmap) snd (runExceptT (runStateT (runCheck (checkAST tls)) initialState))