diff --git a/src/Misc.hs b/src/Misc.hs index 160928f..633eb14 100644 --- a/src/Misc.hs +++ b/src/Misc.hs @@ -12,10 +12,14 @@ import Prelude hiding (map) lookupDefault :: Ord k => a -> k -> Map k a -> a lookupDefault d k m = fromMaybe d (M.lookup k m) -infixl \\ +infixl 3 \\ (\\) :: Ord a => Set a -> Set a -> Set a (\\) = S.difference +infixl 3 \- +(\-) :: Ord a => Set a -> a -> Set a +(\-) = flip S.delete + infix 5 <~> (<~>) :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b) (<~>) = traverse diff --git a/src/Toplevel.hs b/src/Toplevel.hs index 23e2c33..981d5ce 100644 --- a/src/Toplevel.hs +++ b/src/Toplevel.hs @@ -1,27 +1,35 @@ +{-# LANGUAGE LambdaCase, TupleSections #-} module Toplevel (check) where import Control.Monad.Except hiding (guard) import Control.Monad.RWS hiding (guard) import Data.Map (Map) +import Data.Set (Set) import qualified Data.Map as M import qualified Data.Set as S +import Data.Maybe (mapMaybe) + import TC import TC.Helpers import Solve import Type - +import Misc import PostProcess import qualified Hm.Abs as H +import Prelude hiding (map) + check :: [H.Def] -> Either TypeError [TL] check defs = case runInfer M.empty (traverseInfer defs) of Left err -> throwError err Right (tls,_,cs) -> case runSolve (emptySubst, cs) of Left err -> throwError err - Right sub -> pure tls + Right sub -> case runExcept (detectMutualRecursion tls) of + Left err -> throwError err + Right () -> pure tls traverseInfer :: [H.Def] -> Infer [TL] @@ -61,3 +69,37 @@ accumulateEnv (t:ts) = case t of _ -> throwError Oop + +detectMutualRecursion :: [TL] -> Except TypeError () +detectMutualRecursion = detectLoops . createGraph + +detectLoops :: RefGraph -> Except TypeError () +detectLoops g = case M.minViewWithKey g of -- idk a better way to just get a key + Nothing -> pure () + Just ((i,_),_) -> detectLoop S.empty g i >>= detectLoops . foldr M.delete g -- remove all visited nodes + +detectLoop :: Set Id -> RefGraph -> Id -> Except TypeError (Set Id) +detectLoop v g i = + if S.member i v + then throwError (MutuallyRecursive v) -- already visited, loop + else case M.lookup i g of + Nothing -> pure v -- this has already been removed, meaning this branch is safe + Just is -> if is == S.empty + then pure v -- leaf \/ Set not Traversable + else S.unions <$> detectLoop (S.insert i v) g <~> S.toList is + +createGraph :: [TL] -> RefGraph +createGraph = M.fromList . mapMaybe createNode + +createNode :: TL -> Maybe RefNode +createNode TypeDef{} = Nothing +createNode (VarDef _ i _ e) = Just (i, refs e) + where + refs :: Exp -> Set Id + refs = \case + Let _ [] e2 -> refs e2 + Let p ((i,e1):ies) e2 -> refs e1 <> refs (Let p ies e2) \- i + + Abs _ is e2 -> refs e2 \\ S.fromList is + App _ e1 e2s -> refs e1 <> foldMap refs e2s + Var _ i -> S.singleton i diff --git a/src/Type.hs b/src/Type.hs index 117d011..129511d 100644 --- a/src/Type.hs +++ b/src/Type.hs @@ -124,6 +124,7 @@ data TypeError | InvalidConstructor | ArityMismatch | AlreadyDefined Pos Id + | MutuallyRecursive (Set Id) deriving Show type TypeEnv = Map Id PolyT @@ -173,3 +174,6 @@ newtype Solve a = Solve { getSolve :: ReaderT Unifier (Except TypeError) a} instance MonadFail Solve where fail _ = throwError Oop + +type RefNode = (Id, Set Id) +type RefGraph = Map Id (Set Id) diff --git a/test.hm b/test.hm index 4c1aff2..f44dd8a 100644 --- a/test.hm +++ b/test.hm @@ -51,3 +51,7 @@ eval : Expr → Nat -- here's some other functions isEven : Nat → Bool := rec[Nat] true not + +type bot + +absurd : bot -> A := rec[bot]