added detection for mutual recursion
This commit is contained in:
parent
cf55e12391
commit
b8336ed81d
|
@ -12,10 +12,14 @@ import Prelude hiding (map)
|
||||||
lookupDefault :: Ord k => a -> k -> Map k a -> a
|
lookupDefault :: Ord k => a -> k -> Map k a -> a
|
||||||
lookupDefault d k m = fromMaybe d (M.lookup k m)
|
lookupDefault d k m = fromMaybe d (M.lookup k m)
|
||||||
|
|
||||||
infixl \\
|
infixl 3 \\
|
||||||
(\\) :: Ord a => Set a -> Set a -> Set a
|
(\\) :: Ord a => Set a -> Set a -> Set a
|
||||||
(\\) = S.difference
|
(\\) = S.difference
|
||||||
|
|
||||||
|
infixl 3 \-
|
||||||
|
(\-) :: Ord a => Set a -> a -> Set a
|
||||||
|
(\-) = flip S.delete
|
||||||
|
|
||||||
infix 5 <~>
|
infix 5 <~>
|
||||||
(<~>) :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b)
|
(<~>) :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b)
|
||||||
(<~>) = traverse
|
(<~>) = traverse
|
||||||
|
|
|
@ -1,27 +1,35 @@
|
||||||
|
{-# LANGUAGE LambdaCase, TupleSections #-}
|
||||||
module Toplevel (check) where
|
module Toplevel (check) where
|
||||||
|
|
||||||
import Control.Monad.Except hiding (guard)
|
import Control.Monad.Except hiding (guard)
|
||||||
import Control.Monad.RWS hiding (guard)
|
import Control.Monad.RWS hiding (guard)
|
||||||
|
|
||||||
import Data.Map (Map)
|
import Data.Map (Map)
|
||||||
|
import Data.Set (Set)
|
||||||
import qualified Data.Map as M
|
import qualified Data.Map as M
|
||||||
import qualified Data.Set as S
|
import qualified Data.Set as S
|
||||||
|
|
||||||
|
import Data.Maybe (mapMaybe)
|
||||||
|
|
||||||
import TC
|
import TC
|
||||||
import TC.Helpers
|
import TC.Helpers
|
||||||
import Solve
|
import Solve
|
||||||
import Type
|
import Type
|
||||||
|
import Misc
|
||||||
import PostProcess
|
import PostProcess
|
||||||
|
|
||||||
import qualified Hm.Abs as H
|
import qualified Hm.Abs as H
|
||||||
|
|
||||||
|
import Prelude hiding (map)
|
||||||
|
|
||||||
check :: [H.Def] -> Either TypeError [TL]
|
check :: [H.Def] -> Either TypeError [TL]
|
||||||
check defs = case runInfer M.empty (traverseInfer defs) of
|
check defs = case runInfer M.empty (traverseInfer defs) of
|
||||||
Left err -> throwError err
|
Left err -> throwError err
|
||||||
Right (tls,_,cs) -> case runSolve (emptySubst, cs) of
|
Right (tls,_,cs) -> case runSolve (emptySubst, cs) of
|
||||||
Left err -> throwError err
|
Left err -> throwError err
|
||||||
Right sub -> pure tls
|
Right sub -> case runExcept (detectMutualRecursion tls) of
|
||||||
|
Left err -> throwError err
|
||||||
|
Right () -> pure tls
|
||||||
|
|
||||||
|
|
||||||
traverseInfer :: [H.Def] -> Infer [TL]
|
traverseInfer :: [H.Def] -> Infer [TL]
|
||||||
|
@ -61,3 +69,37 @@ accumulateEnv (t:ts) = case t of
|
||||||
|
|
||||||
|
|
||||||
_ -> throwError Oop
|
_ -> throwError Oop
|
||||||
|
|
||||||
|
detectMutualRecursion :: [TL] -> Except TypeError ()
|
||||||
|
detectMutualRecursion = detectLoops . createGraph
|
||||||
|
|
||||||
|
detectLoops :: RefGraph -> Except TypeError ()
|
||||||
|
detectLoops g = case M.minViewWithKey g of -- idk a better way to just get a key
|
||||||
|
Nothing -> pure ()
|
||||||
|
Just ((i,_),_) -> detectLoop S.empty g i >>= detectLoops . foldr M.delete g -- remove all visited nodes
|
||||||
|
|
||||||
|
detectLoop :: Set Id -> RefGraph -> Id -> Except TypeError (Set Id)
|
||||||
|
detectLoop v g i =
|
||||||
|
if S.member i v
|
||||||
|
then throwError (MutuallyRecursive v) -- already visited, loop
|
||||||
|
else case M.lookup i g of
|
||||||
|
Nothing -> pure v -- this has already been removed, meaning this branch is safe
|
||||||
|
Just is -> if is == S.empty
|
||||||
|
then pure v -- leaf \/ Set not Traversable
|
||||||
|
else S.unions <$> detectLoop (S.insert i v) g <~> S.toList is
|
||||||
|
|
||||||
|
createGraph :: [TL] -> RefGraph
|
||||||
|
createGraph = M.fromList . mapMaybe createNode
|
||||||
|
|
||||||
|
createNode :: TL -> Maybe RefNode
|
||||||
|
createNode TypeDef{} = Nothing
|
||||||
|
createNode (VarDef _ i _ e) = Just (i, refs e)
|
||||||
|
where
|
||||||
|
refs :: Exp -> Set Id
|
||||||
|
refs = \case
|
||||||
|
Let _ [] e2 -> refs e2
|
||||||
|
Let p ((i,e1):ies) e2 -> refs e1 <> refs (Let p ies e2) \- i
|
||||||
|
|
||||||
|
Abs _ is e2 -> refs e2 \\ S.fromList is
|
||||||
|
App _ e1 e2s -> refs e1 <> foldMap refs e2s
|
||||||
|
Var _ i -> S.singleton i
|
||||||
|
|
|
@ -124,6 +124,7 @@ data TypeError
|
||||||
| InvalidConstructor
|
| InvalidConstructor
|
||||||
| ArityMismatch
|
| ArityMismatch
|
||||||
| AlreadyDefined Pos Id
|
| AlreadyDefined Pos Id
|
||||||
|
| MutuallyRecursive (Set Id)
|
||||||
deriving Show
|
deriving Show
|
||||||
|
|
||||||
type TypeEnv = Map Id PolyT
|
type TypeEnv = Map Id PolyT
|
||||||
|
@ -173,3 +174,6 @@ newtype Solve a = Solve { getSolve :: ReaderT Unifier (Except TypeError) a}
|
||||||
|
|
||||||
instance MonadFail Solve where
|
instance MonadFail Solve where
|
||||||
fail _ = throwError Oop
|
fail _ = throwError Oop
|
||||||
|
|
||||||
|
type RefNode = (Id, Set Id)
|
||||||
|
type RefGraph = Map Id (Set Id)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user