hm/src/Toplevel.hs

106 lines
2.9 KiB
Haskell

{-# LANGUAGE LambdaCase, TupleSections #-}
module Toplevel (check) where
import Control.Monad.Except hiding (guard)
import Control.Monad.RWS hiding (guard)
import Data.Map (Map)
import Data.Set (Set)
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Maybe (mapMaybe)
import TC
import TC.Helpers
import Solve
import Type
import Misc
import PostProcess
import qualified Hm.Abs as H
import Prelude hiding (map)
check :: [H.Def] -> Either TypeError [TL]
check defs = case runInfer M.empty (traverseInfer defs) of
Left err -> throwError err
Right (tls,_,cs) -> case runSolve (emptySubst, cs) of
Left err -> throwError err
Right sub -> case runExcept (detectMutualRecursion tls) of
Left err -> throwError err
Right () -> pure tls
traverseInfer :: [H.Def] -> Infer [TL]
traverseInfer defs = do
tls <- preprocess defs
env <- accumulateEnv tls
local (const env) $ do
mapM checkVar tls
checkVar :: TL -> Infer TL
checkVar t@TypeDef{} = pure t
checkVar v@(VarDef _ _ t exp) = do
t1 <- instantiate t
t2 <- infer exp
uniR t2 t1
pure v
preprocess :: [H.Def] -> Infer [TL]
preprocess = (flip runProcess) S.empty . postprocess
accumulateEnv :: [TL] -> Infer TypeEnv
accumulateEnv [] = ask
accumulateEnv (t:ts) = case t of
-- make sure none of the bindings already exist and go ahead
TypeDef p i [] env -> do
alreadyDef <- M.keysSet . M.intersection env <$> ask
mapM_ (throwError . AlreadyDefined p) alreadyDef
local (M.union env) (accumulateEnv ts)
VarDef p i t exp -> do
env <- ask
guard (throwError (AlreadyDefined p i)) (not (M.member i env))
local (M.insert i t) (accumulateEnv ts)
_ -> throwError Oop
detectMutualRecursion :: [TL] -> Except TypeError ()
detectMutualRecursion = detectLoops . createGraph
detectLoops :: RefGraph -> Except TypeError ()
detectLoops g = case M.minViewWithKey g of -- idk a better way to just get a key
Nothing -> pure ()
Just ((i,_),_) -> detectLoop S.empty g i >>= detectLoops . foldr M.delete g -- remove all visited nodes
detectLoop :: Set Id -> RefGraph -> Id -> Except TypeError (Set Id)
detectLoop v g i =
if S.member i v
then throwError (MutuallyRecursive v) -- already visited, loop
else case M.lookup i g of
Nothing -> pure v -- this has already been removed, meaning this branch is safe
Just is -> if is == S.empty
then pure (S.insert i v)
else S.unions <$> detectLoop (S.insert i v) g <~> S.toList is
createGraph :: [TL] -> RefGraph
createGraph = M.fromList . mapMaybe createNode
createNode :: TL -> Maybe RefNode
createNode TypeDef{} = Nothing
createNode (VarDef _ i _ e) = Just (i, refs e)
where
refs :: Exp -> Set Id
refs = \case
Let _ [] e2 -> refs e2
Let p ((i,e1):ies) e2 -> refs e1 <> refs (Let p ies e2) \- i
Abs _ is e2 -> refs e2 \\ S.fromList is
App _ e1 e2s -> refs e1 <> foldMap refs e2s
Var _ i -> S.singleton i