106 lines
2.9 KiB
Haskell
106 lines
2.9 KiB
Haskell
{-# LANGUAGE LambdaCase, TupleSections #-}
|
|
module Toplevel (check) where
|
|
|
|
import Control.Monad.Except hiding (guard)
|
|
import Control.Monad.RWS hiding (guard)
|
|
|
|
import Data.Map (Map)
|
|
import Data.Set (Set)
|
|
import qualified Data.Map as M
|
|
import qualified Data.Set as S
|
|
|
|
import Data.Maybe (mapMaybe)
|
|
|
|
import TC
|
|
import TC.Helpers
|
|
import Solve
|
|
import Type
|
|
import Misc
|
|
import PostProcess
|
|
|
|
import qualified Hm.Abs as H
|
|
|
|
import Prelude hiding (map)
|
|
|
|
check :: [H.Def] -> Either TypeError [TL]
|
|
check defs = case runInfer M.empty (traverseInfer defs) of
|
|
Left err -> throwError err
|
|
Right (tls,_,cs) -> case runSolve (emptySubst, cs) of
|
|
Left err -> throwError err
|
|
Right sub -> case runExcept (detectMutualRecursion tls) of
|
|
Left err -> throwError err
|
|
Right () -> pure tls
|
|
|
|
|
|
traverseInfer :: [H.Def] -> Infer [TL]
|
|
traverseInfer defs = do
|
|
tls <- preprocess defs
|
|
|
|
env <- accumulateEnv tls
|
|
|
|
local (const env) $ do
|
|
mapM checkVar tls
|
|
|
|
checkVar :: TL -> Infer TL
|
|
checkVar t@TypeDef{} = pure t
|
|
checkVar v@(VarDef _ _ t exp) = do
|
|
t1 <- instantiate t
|
|
t2 <- infer exp
|
|
uniR t2 t1
|
|
pure v
|
|
|
|
preprocess :: [H.Def] -> Infer [TL]
|
|
preprocess = (flip runProcess) S.empty . postprocess
|
|
|
|
accumulateEnv :: [TL] -> Infer TypeEnv
|
|
accumulateEnv [] = ask
|
|
accumulateEnv (t:ts) = case t of
|
|
-- make sure none of the bindings already exist and go ahead
|
|
TypeDef p i [] env -> do
|
|
alreadyDef <- M.keysSet . M.intersection env <$> ask
|
|
mapM_ (throwError . AlreadyDefined p) alreadyDef
|
|
|
|
local (M.union env) (accumulateEnv ts)
|
|
|
|
VarDef p i t exp -> do
|
|
env <- ask
|
|
guard (throwError (AlreadyDefined p i)) (not (M.member i env))
|
|
local (M.insert i t) (accumulateEnv ts)
|
|
|
|
|
|
_ -> throwError Oop
|
|
|
|
detectMutualRecursion :: [TL] -> Except TypeError ()
|
|
detectMutualRecursion = detectLoops . createGraph
|
|
|
|
detectLoops :: RefGraph -> Except TypeError ()
|
|
detectLoops g = case M.minViewWithKey g of -- idk a better way to just get a key
|
|
Nothing -> pure ()
|
|
Just ((i,_),_) -> detectLoop S.empty g i >>= detectLoops . foldr M.delete g -- remove all visited nodes
|
|
|
|
detectLoop :: Set Id -> RefGraph -> Id -> Except TypeError (Set Id)
|
|
detectLoop v g i =
|
|
if S.member i v
|
|
then throwError (MutuallyRecursive v) -- already visited, loop
|
|
else case M.lookup i g of
|
|
Nothing -> pure v -- this has already been removed, meaning this branch is safe
|
|
Just is -> if is == S.empty
|
|
then pure (S.insert i v)
|
|
else S.unions <$> detectLoop (S.insert i v) g <~> S.toList is
|
|
|
|
createGraph :: [TL] -> RefGraph
|
|
createGraph = M.fromList . mapMaybe createNode
|
|
|
|
createNode :: TL -> Maybe RefNode
|
|
createNode TypeDef{} = Nothing
|
|
createNode (VarDef _ i _ e) = Just (i, refs e)
|
|
where
|
|
refs :: Exp -> Set Id
|
|
refs = \case
|
|
Let _ [] e2 -> refs e2
|
|
Let p ((i,e1):ies) e2 -> refs e1 <> refs (Let p ies e2) \- i
|
|
|
|
Abs _ is e2 -> refs e2 \\ S.fromList is
|
|
App _ e1 e2s -> refs e1 <> foldMap refs e2s
|
|
Var _ i -> S.singleton i
|