85 lines
2.3 KiB
Haskell
85 lines
2.3 KiB
Haskell
{-# LANGUAGE LambdaCase, TypeSynonymInstances #-}
|
|
{-# LANGUAGE TupleSections, FlexibleInstances #-}
|
|
module TC (runInfer, infer) where
|
|
|
|
import Control.Monad.Identity hiding (guard)
|
|
import Control.Monad.Except hiding (guard)
|
|
import Control.Monad.RWS hiding (guard)
|
|
|
|
import Data.Set (Set)
|
|
import qualified Data.Set as S
|
|
import qualified Data.Map as M
|
|
|
|
import qualified Data.Text as T
|
|
|
|
import TC.Helpers
|
|
import Type
|
|
import Misc
|
|
import Solve
|
|
|
|
import Prelude hiding (map)
|
|
|
|
runInfer :: TypeEnv -> Infer a -> Either TypeError (a, [Id], [Constraint])
|
|
runInfer = runInfer' initialState
|
|
|
|
runInfer' :: [Id] -> TypeEnv -> Infer a -> Either TypeError (a, [Id], [Constraint])
|
|
runInfer' s r = runIdentity . runExceptT . (\i -> runRWST i r s) . getInfer
|
|
|
|
localEnv :: Id -> PolyT -> Infer a -> Infer a
|
|
localEnv i t = local (M.insert i t)
|
|
|
|
solveFor :: Infer a -> (Subst -> a -> Infer b) -> Infer b
|
|
solveFor m f = do
|
|
env <- ask
|
|
state <- get
|
|
case runInfer' state env m of
|
|
Left err -> throwError err
|
|
Right (a,st,cs) -> case runSolve (emptySubst, cs) of
|
|
Left err -> throwError err
|
|
Right sub -> put st >> f sub a
|
|
|
|
uni :: MonoT -> MonoT -> Infer ()
|
|
uni t1 t2 = tell [(t1, t2, Unify)]
|
|
|
|
lookupType :: Pos -> Id -> Infer MonoT
|
|
lookupType p i = ask >>= \env ->
|
|
case M.lookup i env of
|
|
Nothing -> throwError (UnboundVariable p i)
|
|
Just t -> instantiate t
|
|
|
|
infer :: Exp -> Infer MonoT
|
|
infer = \case
|
|
|
|
Var p i -> lookupType p i
|
|
|
|
Let _ [] e -> infer e
|
|
Let p ((i,e1):ies) e2 -> solveFor (infer e1) $ \su mt -> local (apply su) $ do
|
|
pt <- generalize (apply su mt)
|
|
localEnv i pt (infer (Let p ies e2)) -- should (apply su ies) be used?
|
|
|
|
Abs _ [] e -> infer e
|
|
Abs p (i:is) e -> do
|
|
tv <- fresh
|
|
t <- localEnv i (Forall S.empty tv) (infer (Abs p is e))
|
|
pure (tv `TArr` t)
|
|
|
|
App p e es -> go p e (reverse es)
|
|
where
|
|
go :: Pos -> Exp -> [Exp] -> Infer MonoT
|
|
go _ _ [] = throwError Oop
|
|
go p e1 [e2] = do
|
|
t1 <- infer e1
|
|
t2 <- infer e2
|
|
tv <- fresh
|
|
uni t1 (t2 `TArr` tv)
|
|
|
|
pure tv
|
|
|
|
go p e1 (e2:es) = do
|
|
t1 <- go p e1 es
|
|
t2 <- infer e2
|
|
tv <- fresh
|
|
uni t1 (t2 `TArr` tv)
|
|
|
|
pure tv
|