{-# LANGUAGE LambdaCase #-} module Solve where import Control.Monad.Reader import Control.Monad.Except import qualified Data.Map as M import qualified Data.Set as S import Type unify :: CT -> MonoT -> MonoT -> Solve Unifier unify _ t1 t2 | t1 == t2 = pure emptyUnifier unify d t1@(l1 `TArr` r1) t2@(l2 `TArr` r2) = do (s1,c1) <- unify d l1 l2 (s2,c2) <- unify d (apply s1 r1) (apply s1 r2) case d of Unify -> pure (s1 <&> s2, c1 ++ c2) UnifyRight -> if M.intersection s1 s2 == M.empty then pure (s1 <&> s2, c1 ++ c2) else throwError (UnificationRight t1 t2) unify _ (TVar i) t = bind i t unify Unify t (TVar i) = bind i t unify _ t1 t2 = throwError (UnificationFailure t1 t2) bind :: Id -> MonoT -> Solve Unifier bind i1 (TVar i2) | i1 == i2 = pure emptyUnifier bind i t | S.member i (free t) = throwError (InfiniteType i t) | otherwise = pure (M.singleton i t, []) solver :: Solve Subst solver = ask >>= \case (subst,[]) -> pure subst (s0, (t1, t2, d) : cs) -> do (s1, c1) <- unify d t1 t2 local (const (s1 <&> s0, c1 ++ apply s1 cs)) solver runSolve :: Unifier -> Either TypeError Subst runSolve = runExcept . runReaderT (getSolve solver)