added detection for mutual recursion
This commit is contained in:
parent
cf55e12391
commit
b8336ed81d
|
@ -12,10 +12,14 @@ import Prelude hiding (map)
|
|||
lookupDefault :: Ord k => a -> k -> Map k a -> a
|
||||
lookupDefault d k m = fromMaybe d (M.lookup k m)
|
||||
|
||||
infixl \\
|
||||
infixl 3 \\
|
||||
(\\) :: Ord a => Set a -> Set a -> Set a
|
||||
(\\) = S.difference
|
||||
|
||||
infixl 3 \-
|
||||
(\-) :: Ord a => Set a -> a -> Set a
|
||||
(\-) = flip S.delete
|
||||
|
||||
infix 5 <~>
|
||||
(<~>) :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b)
|
||||
(<~>) = traverse
|
||||
|
|
|
@ -1,27 +1,35 @@
|
|||
{-# LANGUAGE LambdaCase, TupleSections #-}
|
||||
module Toplevel (check) where
|
||||
|
||||
import Control.Monad.Except hiding (guard)
|
||||
import Control.Monad.RWS hiding (guard)
|
||||
|
||||
import Data.Map (Map)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Map as M
|
||||
import qualified Data.Set as S
|
||||
|
||||
import Data.Maybe (mapMaybe)
|
||||
|
||||
import TC
|
||||
import TC.Helpers
|
||||
import Solve
|
||||
import Type
|
||||
|
||||
import Misc
|
||||
import PostProcess
|
||||
|
||||
import qualified Hm.Abs as H
|
||||
|
||||
import Prelude hiding (map)
|
||||
|
||||
check :: [H.Def] -> Either TypeError [TL]
|
||||
check defs = case runInfer M.empty (traverseInfer defs) of
|
||||
Left err -> throwError err
|
||||
Right (tls,_,cs) -> case runSolve (emptySubst, cs) of
|
||||
Left err -> throwError err
|
||||
Right sub -> pure tls
|
||||
Right sub -> case runExcept (detectMutualRecursion tls) of
|
||||
Left err -> throwError err
|
||||
Right () -> pure tls
|
||||
|
||||
|
||||
traverseInfer :: [H.Def] -> Infer [TL]
|
||||
|
@ -61,3 +69,37 @@ accumulateEnv (t:ts) = case t of
|
|||
|
||||
|
||||
_ -> throwError Oop
|
||||
|
||||
detectMutualRecursion :: [TL] -> Except TypeError ()
|
||||
detectMutualRecursion = detectLoops . createGraph
|
||||
|
||||
detectLoops :: RefGraph -> Except TypeError ()
|
||||
detectLoops g = case M.minViewWithKey g of -- idk a better way to just get a key
|
||||
Nothing -> pure ()
|
||||
Just ((i,_),_) -> detectLoop S.empty g i >>= detectLoops . foldr M.delete g -- remove all visited nodes
|
||||
|
||||
detectLoop :: Set Id -> RefGraph -> Id -> Except TypeError (Set Id)
|
||||
detectLoop v g i =
|
||||
if S.member i v
|
||||
then throwError (MutuallyRecursive v) -- already visited, loop
|
||||
else case M.lookup i g of
|
||||
Nothing -> pure v -- this has already been removed, meaning this branch is safe
|
||||
Just is -> if is == S.empty
|
||||
then pure v -- leaf \/ Set not Traversable
|
||||
else S.unions <$> detectLoop (S.insert i v) g <~> S.toList is
|
||||
|
||||
createGraph :: [TL] -> RefGraph
|
||||
createGraph = M.fromList . mapMaybe createNode
|
||||
|
||||
createNode :: TL -> Maybe RefNode
|
||||
createNode TypeDef{} = Nothing
|
||||
createNode (VarDef _ i _ e) = Just (i, refs e)
|
||||
where
|
||||
refs :: Exp -> Set Id
|
||||
refs = \case
|
||||
Let _ [] e2 -> refs e2
|
||||
Let p ((i,e1):ies) e2 -> refs e1 <> refs (Let p ies e2) \- i
|
||||
|
||||
Abs _ is e2 -> refs e2 \\ S.fromList is
|
||||
App _ e1 e2s -> refs e1 <> foldMap refs e2s
|
||||
Var _ i -> S.singleton i
|
||||
|
|
|
@ -124,6 +124,7 @@ data TypeError
|
|||
| InvalidConstructor
|
||||
| ArityMismatch
|
||||
| AlreadyDefined Pos Id
|
||||
| MutuallyRecursive (Set Id)
|
||||
deriving Show
|
||||
|
||||
type TypeEnv = Map Id PolyT
|
||||
|
@ -173,3 +174,6 @@ newtype Solve a = Solve { getSolve :: ReaderT Unifier (Except TypeError) a}
|
|||
|
||||
instance MonadFail Solve where
|
||||
fail _ = throwError Oop
|
||||
|
||||
type RefNode = (Id, Set Id)
|
||||
type RefGraph = Map Id (Set Id)
|
||||
|
|
Loading…
Reference in New Issue
Block a user