{-# LANGUAGE LambdaCase, TupleSections #-} module Toplevel (check) where import Control.Monad.Except hiding (guard) import Control.Monad.RWS hiding (guard) import Data.Map (Map) import Data.Set (Set) import qualified Data.Map as M import qualified Data.Set as S import Data.Maybe (mapMaybe) import TC import TC.Helpers import Solve import Type import Misc import PostProcess import qualified Hm.Abs as H import Prelude hiding (map) check :: [H.Def] -> Either TypeError [TL] check defs = case runInfer M.empty (traverseInfer defs) of Left err -> throwError err Right (tls,_,cs) -> case runSolve (emptySubst, cs) of Left err -> throwError err Right sub -> case runExcept (detectMutualRecursion tls) of Left err -> throwError err Right () -> pure tls traverseInfer :: [H.Def] -> Infer [TL] traverseInfer defs = do tls <- preprocess defs env <- accumulateEnv tls local (const env) $ do mapM checkVar tls checkVar :: TL -> Infer TL checkVar t@TypeDef{} = pure t checkVar v@(VarDef _ _ t exp) = do t1 <- instantiate t t2 <- infer exp uniR t2 t1 pure v preprocess :: [H.Def] -> Infer [TL] preprocess = (flip runProcess) S.empty . postprocess accumulateEnv :: [TL] -> Infer TypeEnv accumulateEnv [] = ask accumulateEnv (t:ts) = case t of -- make sure none of the bindings already exist and go ahead TypeDef p i [] env -> do alreadyDef <- M.keysSet . M.intersection env <$> ask mapM_ (throwError . AlreadyDefined p) alreadyDef local (M.union env) (accumulateEnv ts) VarDef p i t exp -> do env <- ask guard (throwError (AlreadyDefined p i)) (not (M.member i env)) local (M.insert i t) (accumulateEnv ts) _ -> throwError Oop detectMutualRecursion :: [TL] -> Except TypeError () detectMutualRecursion = detectLoops . createGraph detectLoops :: RefGraph -> Except TypeError () detectLoops g = case M.minViewWithKey g of -- idk a better way to just get a key Nothing -> pure () Just ((i,_),_) -> detectLoop S.empty g i >>= detectLoops . foldr M.delete g -- remove all visited nodes detectLoop :: Set Id -> RefGraph -> Id -> Except TypeError (Set Id) detectLoop v g i = if S.member i v then throwError (MutuallyRecursive v) -- already visited, loop else case M.lookup i g of Nothing -> pure v -- this has already been removed, meaning this branch is safe Just is -> if is == S.empty then pure (S.insert i v) else S.unions <$> detectLoop (S.insert i v) g <~> S.toList is createGraph :: [TL] -> RefGraph createGraph = M.fromList . mapMaybe createNode createNode :: TL -> Maybe RefNode createNode TypeDef{} = Nothing createNode (VarDef _ i _ e) = Just (i, refs e) where refs :: Exp -> Set Id refs = \case Let _ [] e2 -> refs e2 Let p ((i,e1):ies) e2 -> refs e1 <> refs (Let p ies e2) \- i Abs _ is e2 -> refs e2 \\ S.fromList is App _ e1 e2s -> refs e1 <> foldMap refs e2s Var _ i -> S.singleton i