43 lines
1.2 KiB
Haskell
43 lines
1.2 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
module Solve where
|
|
|
|
import Control.Monad.Reader
|
|
import Control.Monad.Except
|
|
|
|
import qualified Data.Map as M
|
|
import qualified Data.Set as S
|
|
|
|
import Type
|
|
|
|
unify :: CT -> MonoT -> MonoT -> Solve Unifier
|
|
unify _ t1 t2 | t1 == t2 = pure emptyUnifier
|
|
unify d t1@(l1 `TArr` r1) t2@(l2 `TArr` r2) = do
|
|
(s1,c1) <- unify d l1 l2
|
|
(s2,c2) <- unify d (apply s1 r1) (apply s1 r2)
|
|
case d of
|
|
Unify -> pure (s1 <&> s2, c1 ++ c2)
|
|
UnifyRight -> if M.intersection s1 s2 == M.empty
|
|
then pure (s1 <&> s2, c1 ++ c2)
|
|
else throwError (UnificationRight t1 t2)
|
|
|
|
unify _ (TVar i) t = bind i t
|
|
unify Unify t (TVar i) = bind i t
|
|
|
|
unify _ t1 t2 = throwError (UnificationFailure t1 t2)
|
|
|
|
|
|
bind :: Id -> MonoT -> Solve Unifier
|
|
bind i1 (TVar i2) | i1 == i2 = pure emptyUnifier
|
|
bind i t | S.member i (free t) = throwError (InfiniteType i t)
|
|
| otherwise = pure (M.singleton i t, [])
|
|
|
|
solver :: Solve Subst
|
|
solver = ask >>= \case
|
|
(subst,[]) -> pure subst
|
|
(s0, (t1, t2, d) : cs) -> do
|
|
(s1, c1) <- unify d t1 t2
|
|
local (const (s1 <&> s0, c1 ++ apply s1 cs)) solver
|
|
|
|
runSolve :: Unifier -> Either TypeError Subst
|
|
runSolve = runExcept . runReaderT (getSolve solver)
|