added detection for mutual recursion

This commit is contained in:
Rachel Lambda Samuelsson 2022-01-29 14:01:25 +01:00
parent cf55e12391
commit b8336ed81d
4 changed files with 57 additions and 3 deletions

View File

@ -12,10 +12,14 @@ import Prelude hiding (map)
lookupDefault :: Ord k => a -> k -> Map k a -> a
lookupDefault d k m = fromMaybe d (M.lookup k m)
infixl \\
infixl 3 \\
(\\) :: Ord a => Set a -> Set a -> Set a
(\\) = S.difference
infixl 3 \-
(\-) :: Ord a => Set a -> a -> Set a
(\-) = flip S.delete
infix 5 <~>
(<~>) :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b)
(<~>) = traverse

View File

@ -1,27 +1,35 @@
{-# LANGUAGE LambdaCase, TupleSections #-}
module Toplevel (check) where
import Control.Monad.Except hiding (guard)
import Control.Monad.RWS hiding (guard)
import Data.Map (Map)
import Data.Set (Set)
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Maybe (mapMaybe)
import TC
import TC.Helpers
import Solve
import Type
import Misc
import PostProcess
import qualified Hm.Abs as H
import Prelude hiding (map)
check :: [H.Def] -> Either TypeError [TL]
check defs = case runInfer M.empty (traverseInfer defs) of
Left err -> throwError err
Right (tls,_,cs) -> case runSolve (emptySubst, cs) of
Left err -> throwError err
Right sub -> pure tls
Right sub -> case runExcept (detectMutualRecursion tls) of
Left err -> throwError err
Right () -> pure tls
traverseInfer :: [H.Def] -> Infer [TL]
@ -61,3 +69,37 @@ accumulateEnv (t:ts) = case t of
_ -> throwError Oop
detectMutualRecursion :: [TL] -> Except TypeError ()
detectMutualRecursion = detectLoops . createGraph
detectLoops :: RefGraph -> Except TypeError ()
detectLoops g = case M.minViewWithKey g of -- idk a better way to just get a key
Nothing -> pure ()
Just ((i,_),_) -> detectLoop S.empty g i >>= detectLoops . foldr M.delete g -- remove all visited nodes
detectLoop :: Set Id -> RefGraph -> Id -> Except TypeError (Set Id)
detectLoop v g i =
if S.member i v
then throwError (MutuallyRecursive v) -- already visited, loop
else case M.lookup i g of
Nothing -> pure v -- this has already been removed, meaning this branch is safe
Just is -> if is == S.empty
then pure v -- leaf \/ Set not Traversable
else S.unions <$> detectLoop (S.insert i v) g <~> S.toList is
createGraph :: [TL] -> RefGraph
createGraph = M.fromList . mapMaybe createNode
createNode :: TL -> Maybe RefNode
createNode TypeDef{} = Nothing
createNode (VarDef _ i _ e) = Just (i, refs e)
where
refs :: Exp -> Set Id
refs = \case
Let _ [] e2 -> refs e2
Let p ((i,e1):ies) e2 -> refs e1 <> refs (Let p ies e2) \- i
Abs _ is e2 -> refs e2 \\ S.fromList is
App _ e1 e2s -> refs e1 <> foldMap refs e2s
Var _ i -> S.singleton i

View File

@ -124,6 +124,7 @@ data TypeError
| InvalidConstructor
| ArityMismatch
| AlreadyDefined Pos Id
| MutuallyRecursive (Set Id)
deriving Show
type TypeEnv = Map Id PolyT
@ -173,3 +174,6 @@ newtype Solve a = Solve { getSolve :: ReaderT Unifier (Except TypeError) a}
instance MonadFail Solve where
fail _ = throwError Oop
type RefNode = (Id, Set Id)
type RefGraph = Map Id (Set Id)

View File

@ -51,3 +51,7 @@ eval : Expr → Nat
-- here's some other functions
isEven : Nat → Bool
:= rec[Nat] true not
type bot
absurd : bot -> A := rec[bot]