hm/src/Solve.hs

43 lines
1.2 KiB
Haskell

{-# LANGUAGE LambdaCase #-}
module Solve where
import Control.Monad.Reader
import Control.Monad.Except
import qualified Data.Map as M
import qualified Data.Set as S
import Type
unify :: CT -> MonoT -> MonoT -> Solve Unifier
unify _ t1 t2 | t1 == t2 = pure emptyUnifier
unify d t1@(l1 `TArr` r1) t2@(l2 `TArr` r2) = do
(s1,c1) <- unify d l1 l2
(s2,c2) <- unify d (apply s1 r1) (apply s1 r2)
case d of
Unify -> pure (s1 <&> s2, c1 ++ c2)
UnifyRight -> if M.intersection s1 s2 == M.empty
then pure (s1 <&> s2, c1 ++ c2)
else throwError (UnificationRight t1 t2)
unify _ (TVar i) t = bind i t
unify Unify t (TVar i) = bind i t
unify _ t1 t2 = throwError (UnificationFailure t1 t2)
bind :: Id -> MonoT -> Solve Unifier
bind i1 (TVar i2) | i1 == i2 = pure emptyUnifier
bind i t | S.member i (free t) = throwError (InfiniteType i t)
| otherwise = pure (M.singleton i t, [])
solver :: Solve Subst
solver = ask >>= \case
(subst,[]) -> pure subst
(s0, (t1, t2, d) : cs) -> do
(s1, c1) <- unify d t1 t2
local (const (s1 <&> s0, c1 ++ apply s1 cs)) solver
runSolve :: Unifier -> Either TypeError Subst
runSolve = runExcept . runReaderT (getSolve solver)