added detection for mutual recursion

This commit is contained in:
Rachel Lambda Samuelsson 2022-01-29 14:01:25 +01:00
parent cf55e12391
commit b8336ed81d
4 changed files with 57 additions and 3 deletions

View File

@ -12,10 +12,14 @@ import Prelude hiding (map)
lookupDefault :: Ord k => a -> k -> Map k a -> a lookupDefault :: Ord k => a -> k -> Map k a -> a
lookupDefault d k m = fromMaybe d (M.lookup k m) lookupDefault d k m = fromMaybe d (M.lookup k m)
infixl \\ infixl 3 \\
(\\) :: Ord a => Set a -> Set a -> Set a (\\) :: Ord a => Set a -> Set a -> Set a
(\\) = S.difference (\\) = S.difference
infixl 3 \-
(\-) :: Ord a => Set a -> a -> Set a
(\-) = flip S.delete
infix 5 <~> infix 5 <~>
(<~>) :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b) (<~>) :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b)
(<~>) = traverse (<~>) = traverse

View File

@ -1,27 +1,35 @@
{-# LANGUAGE LambdaCase, TupleSections #-}
module Toplevel (check) where module Toplevel (check) where
import Control.Monad.Except hiding (guard) import Control.Monad.Except hiding (guard)
import Control.Monad.RWS hiding (guard) import Control.Monad.RWS hiding (guard)
import Data.Map (Map) import Data.Map (Map)
import Data.Set (Set)
import qualified Data.Map as M import qualified Data.Map as M
import qualified Data.Set as S import qualified Data.Set as S
import Data.Maybe (mapMaybe)
import TC import TC
import TC.Helpers import TC.Helpers
import Solve import Solve
import Type import Type
import Misc
import PostProcess import PostProcess
import qualified Hm.Abs as H import qualified Hm.Abs as H
import Prelude hiding (map)
check :: [H.Def] -> Either TypeError [TL] check :: [H.Def] -> Either TypeError [TL]
check defs = case runInfer M.empty (traverseInfer defs) of check defs = case runInfer M.empty (traverseInfer defs) of
Left err -> throwError err Left err -> throwError err
Right (tls,_,cs) -> case runSolve (emptySubst, cs) of Right (tls,_,cs) -> case runSolve (emptySubst, cs) of
Left err -> throwError err Left err -> throwError err
Right sub -> pure tls Right sub -> case runExcept (detectMutualRecursion tls) of
Left err -> throwError err
Right () -> pure tls
traverseInfer :: [H.Def] -> Infer [TL] traverseInfer :: [H.Def] -> Infer [TL]
@ -61,3 +69,37 @@ accumulateEnv (t:ts) = case t of
_ -> throwError Oop _ -> throwError Oop
detectMutualRecursion :: [TL] -> Except TypeError ()
detectMutualRecursion = detectLoops . createGraph
detectLoops :: RefGraph -> Except TypeError ()
detectLoops g = case M.minViewWithKey g of -- idk a better way to just get a key
Nothing -> pure ()
Just ((i,_),_) -> detectLoop S.empty g i >>= detectLoops . foldr M.delete g -- remove all visited nodes
detectLoop :: Set Id -> RefGraph -> Id -> Except TypeError (Set Id)
detectLoop v g i =
if S.member i v
then throwError (MutuallyRecursive v) -- already visited, loop
else case M.lookup i g of
Nothing -> pure v -- this has already been removed, meaning this branch is safe
Just is -> if is == S.empty
then pure v -- leaf \/ Set not Traversable
else S.unions <$> detectLoop (S.insert i v) g <~> S.toList is
createGraph :: [TL] -> RefGraph
createGraph = M.fromList . mapMaybe createNode
createNode :: TL -> Maybe RefNode
createNode TypeDef{} = Nothing
createNode (VarDef _ i _ e) = Just (i, refs e)
where
refs :: Exp -> Set Id
refs = \case
Let _ [] e2 -> refs e2
Let p ((i,e1):ies) e2 -> refs e1 <> refs (Let p ies e2) \- i
Abs _ is e2 -> refs e2 \\ S.fromList is
App _ e1 e2s -> refs e1 <> foldMap refs e2s
Var _ i -> S.singleton i

View File

@ -124,6 +124,7 @@ data TypeError
| InvalidConstructor | InvalidConstructor
| ArityMismatch | ArityMismatch
| AlreadyDefined Pos Id | AlreadyDefined Pos Id
| MutuallyRecursive (Set Id)
deriving Show deriving Show
type TypeEnv = Map Id PolyT type TypeEnv = Map Id PolyT
@ -173,3 +174,6 @@ newtype Solve a = Solve { getSolve :: ReaderT Unifier (Except TypeError) a}
instance MonadFail Solve where instance MonadFail Solve where
fail _ = throwError Oop fail _ = throwError Oop
type RefNode = (Id, Set Id)
type RefGraph = Map Id (Set Id)

View File

@ -51,3 +51,7 @@ eval : Expr → Nat
-- here's some other functions -- here's some other functions
isEven : Nat → Bool isEven : Nat → Bool
:= rec[Nat] true not := rec[Nat] true not
type bot
absurd : bot -> A := rec[bot]