From 687b65cd4e966b4c1e3efc3c476604271e91981c Mon Sep 17 00:00:00 2001 From: depsterr Date: Fri, 28 Jan 2022 13:46:42 +0100 Subject: [PATCH] move to constrain generation -> solving model. TODO: move code between modules, clean up --- app/Main.hs | 16 +++- hm.cabal | 1 + src/Misc.hs | 5 ++ src/PostProcess.hs | 10 +-- src/Pretty.hs | 3 +- src/Solve.hs | 37 +++++++++ src/TC.hs | 183 ++++++++++----------------------------------- src/Type.hs | 95 ++++++++++++++++++++--- 8 files changed, 184 insertions(+), 166 deletions(-) create mode 100644 src/Solve.hs diff --git a/app/Main.hs b/app/Main.hs index 1926cab..c8d8c89 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -9,8 +9,11 @@ import Data.Text (Text) import qualified Data.Text.IO as T import qualified Data.Set as S +import qualified Data.Map as M -import TC (initialState, runCheck, infer, generalize) +import Type (initialState, emptySubst, apply) +import TC (runInfer, infer, generalize) +import Solve (runSolve) import PostProcess (expToExp, runProcess) import Pretty @@ -26,12 +29,17 @@ inferType s = case pExp ts of putStrLn "\nParse Successful!" putStrLn (printTree tree) - let action = runProcess (expToExp tree) S.empty >>= infer >>= generalize . snd - let result = fst (runCheck initialState action) + let action = runProcess (expToExp tree) S.empty >>= infer + let result = runInfer M.empty action case result of Left err -> print err - Right res -> T.putStrLn (pretty res) + Right (t,_,c) -> case runSolve (emptySubst, c) of + Left err -> print err + Right subst -> case runInfer M.empty (generalize (apply subst t)) of + Left err -> print err + Right (t,_,_) -> T.putStrLn (pretty t) + where ts = init (resolveLayout True (myLexer s)) showPosToken ((l,c),t) = concat [ show l, ":", show c, "\t", show t ] diff --git a/hm.cabal b/hm.cabal index 3a24acd..ca8ecd6 100644 --- a/hm.cabal +++ b/hm.cabal @@ -32,6 +32,7 @@ library , Misc , PostProcess , Pretty + , Solve other-modules: Hm.ErrM build-tool-depends: alex:alex >= 3.0, happy:happy >= 1.19.5 diff --git a/src/Misc.hs b/src/Misc.hs index 8098016..160928f 100644 --- a/src/Misc.hs +++ b/src/Misc.hs @@ -7,6 +7,8 @@ import qualified Data.Map as M import Data.Maybe (fromMaybe) +import Prelude hiding (map) + lookupDefault :: Ord k => a -> k -> Map k a -> a lookupDefault d k m = fromMaybe d (M.lookup k m) @@ -21,3 +23,6 @@ infix 5 <~> infixr 9 .: (.:) :: (c -> d) -> (a -> b -> c) -> a -> b -> d (.:) = (.) . (.) + +map :: Functor f => (a -> b) -> f a -> f b +map = fmap diff --git a/src/PostProcess.hs b/src/PostProcess.hs index 15fbc13..496f8ce 100644 --- a/src/PostProcess.hs +++ b/src/PostProcess.hs @@ -19,12 +19,12 @@ import TC import Prelude hiding (map) -- Type env for parsing type signatures -type Process = StateT (Set Id) Check +type Process = StateT (Set Id) Infer insertType :: Id -> Process () insertType i = get >>= put . S.insert i -runProcess :: Process a -> Set Id -> Check a +runProcess :: Process a -> Set Id -> Infer a runProcess = (fst <$>) .: runStateT postprocess :: [H.Def] -> Process [TL] @@ -38,10 +38,8 @@ addDef = \case -- add type before typesig to id params defToTL :: H.Def -> Process TL -defToTL (H.VarDef p i t e) = VarDef <$> lift (setPos p) <*> pure i <*> typeSigToPolyT t <*> expToExp e +defToTL (H.VarDef p i t e) = VarDef p i <$> typeSigToPolyT t <*> expToExp e defToTL (H.TypeDef p t ds) = do - _ <- lift (setPos p) - (i,_) <- typeSigToIdParams t let (Id s) = i @@ -72,7 +70,7 @@ defToTL (H.TypeDef p t ds) = do typeSigToIdParams :: H.TypeSig -> Process (Id, [Id]) -typeSigToIdParams = lift . setPos >=> \case +typeSigToIdParams = \case H.TypeFun{} -> throwError InvalidTypeDecl H.TypeApp{} -> throwError (Unimplemented "Type parameters") H.TypeVar _ i -> pure (i, []) diff --git a/src/Pretty.hs b/src/Pretty.hs index a92e083..7890363 100644 --- a/src/Pretty.hs +++ b/src/Pretty.hs @@ -8,7 +8,6 @@ import qualified Data.Map as M import Type import Data.List (sort) -import TC (initialState, apply, free) class Pretty a where pretty :: a -> Text @@ -39,7 +38,7 @@ instance Normalize MonoT where normalize t = apply (goS t) t go :: MonoT -> [(Id, Id)] -go t = zip (sort (S.toList (free t))) (variables initialState) +go t = zip (sort (S.toList (free t))) initialState goS :: MonoT -> Subst goS = M.fromList . map (\(x,y) -> (x, TVar y)) . go diff --git a/src/Solve.hs b/src/Solve.hs new file mode 100644 index 0000000..9d49f41 --- /dev/null +++ b/src/Solve.hs @@ -0,0 +1,37 @@ +{-# LANGUAGE LambdaCase #-} +module Solve where + +import Control.Monad.Reader +import Control.Monad.Except + +import qualified Data.Map as M +import qualified Data.Set as S + +import Type + +unify :: MonoT -> MonoT -> Solve Unifier +unify t1 t2 | t1 == t2 = pure emptyUnifier +unify (l1 `TArr` r1) (l2 `TArr` r2) = do + (s1,c1) <- unify l1 l2 + (s2,c2) <- unify (apply s1 r1) (apply s1 r2) + pure (s1 <&> s2, c1 ++ c2) + +unify (TVar i) t = bind i t +unify t (TVar i) = bind i t + +unify t1 t2 = throwError (UnificationFailure t1 t2) + +bind :: Id -> MonoT -> Solve Unifier +bind i1 (TVar i2) | i1 == i2 = pure emptyUnifier +bind i t | S.member i (free t) = throwError (InfiniteType i t) + | otherwise = pure (M.singleton i t, []) + +solver :: Solve Subst +solver = ask >>= \case + (subst,[]) -> pure subst + (s0, (t1, t2) : cs) -> do + (s1, c1) <- unify t1 t2 + local (const (s1 <&> s0, c1 ++ apply s1 cs)) solver + +runSolve :: Unifier -> Either TypeError Subst +runSolve = runExcept . runReaderT (getSolve solver) diff --git a/src/TC.hs b/src/TC.hs index 6eaafd0..62d1fb5 100644 --- a/src/TC.hs +++ b/src/TC.hs @@ -2,9 +2,9 @@ {-# LANGUAGE TupleSections, FlexibleInstances #-} module TC where -import Control.Monad.Reader hiding (guard) -import Control.Monad.State hiding (guard) -import Control.Monad.Except hiding (guard) +import Control.Monad.Identity hiding (guard) +import Control.Monad.Except hiding (guard) +import Control.Monad.RWS hiding (guard) import Data.Set (Set) import qualified Data.Set as S @@ -17,129 +17,40 @@ import Misc import Prelude hiding (map) -map :: Functor f => (a -> b) -> f a -> f b -map = fmap +runInfer :: TypeEnv -> Infer a -> Either TypeError (a, [Id], [Constraint]) +runInfer r = runIdentity . runExceptT . (\i -> runRWST i r initialState) . getInfer -runCheck :: CheckState -> Check a -> (Either TypeError a, CheckState) -runCheck s = (flip runState) s . runExceptT . getCheck - --- I'm still not quite sure how replicateM works, but in this instance it is --- used to generate a list of strings "a", "b" ... "z", "aa", "ab" ... so on --- --- Does it make sense to start with an empty state? -initialState :: CheckState -initialState = CS ([1..] >>= map (Id . T.pack) . flip replicateM ['a'..'z']) Nothing M.empty - -getVars :: Check [Id] -getVars = variables <$> get - -setVars :: [Id] -> Check () -setVars ids = get >>= \s -> put (CS ids (lastPos s) (typeEnv s)) - -getEnv :: Check TypeEnv -getEnv = typeEnv <$> get - -setEnv :: TypeEnv -> Check () -setEnv env = get >>= \s -> put (CS (variables s) (lastPos s) env) - -addEnv :: Id -> PolyT -> Check () -addEnv i p = getEnv >>= setEnv . M.insert i p - -localEnv :: TypeEnv -> Check a -> Check a -localEnv e m = getEnv >>= \o -> setEnv e >> m >>= \r -> setEnv o >> pure r - -localEnv' :: Check a -> Check a -localEnv' m = getEnv >>= \o -> m >>= \r -> setEnv o >> pure r - --- returns p again to allow chaining into lambdacase -setPos :: Positioned p => p -> Check p -setPos p = get >>= \s -> put (CS (variables s) (pos p) (typeEnv s)) >> pure p +localEnv :: Id -> PolyT -> Infer a -> Infer a +localEnv i t = local (M.insert i t) guard :: Applicative f => f () -> Bool -> f () guard _ True = pure () guard f False = f -class Substitutable a where - apply :: Subst -> a -> a -- ^ apply a substitution - free :: a -> Set Id -- ^ free type variables +uni :: MonoT -> MonoT -> Infer () +uni t1 t2 = tell [(t1, t2)] -instance Substitutable MonoT where - apply s = \case - TCon i -> TCon i - TVar i -> lookupDefault (TVar i) i s - (t1 `TArr` t2) -> apply s t1 `TArr` apply s t2 - - free = \case - TCon{} -> S.empty - TVar i -> S.singleton i - (t1 `TArr` t2) -> free t1 <> free t2 - -instance Substitutable PolyT where - apply s = \case - Forall as t -> Forall as (apply (foldr M.delete s as) t) - Mono t -> Mono (apply s t) - - free = \case - Forall as t -> free t \\ as - Mono t -> free t - -instance Substitutable TypeEnv where - apply s = map (apply s) - free = free . M.elems - -instance Substitutable a => Substitutable [a] where - apply = map . apply - free = foldMap free - -applyEnv :: Subst -> Check () -applyEnv s = getEnv >>= setEnv . apply s - --- This substution, and that one -(<&>) :: Subst -> Subst -> Subst -(<&>) s1 s2 = map (apply s1) s2 <> s1 - -emptySubst :: Subst -emptySubst = M.empty - -unify :: MonoT -> MonoT -> Check Subst -unify (l1 `TArr` r1) (l2 `TArr` r2) = do - s1 <- unify l1 l2 - s2 <- unify (apply s1 r1) (apply s1 r2) - pure (s1 <&> s2) - -unify (TVar i) t = bind i t -unify t (TVar i) = bind i t - -unify (TCon i1) (TCon i2) | i1 == i2 = pure emptySubst - -unify t1 t2 = throwError (UnificationFailure t1 t2) - -bind :: Id -> MonoT -> Check Subst -bind i1 (TVar i2) | i1 == i2 = pure emptySubst -bind i t | S.member i (free t) = throwError (InfiniteType i t) - | otherwise = pure (M.singleton i t) - -fresh :: Check MonoT +fresh :: Infer MonoT fresh = do - (var:vars) <- getVars - setVars vars + (var:vars) <- get + put vars pure (TVar var) -- replace polymorphic type variables with monomorphic ones -instantiate :: PolyT -> Check MonoT +instantiate :: PolyT -> Infer MonoT instantiate (Mono t) = pure t instantiate (Forall is t) = foldM freshInsert emptySubst is >>= pure . (flip apply) t where - freshInsert :: Subst -> Id -> Check Subst + freshInsert :: Subst -> Id -> Infer Subst freshInsert s k = (\a -> M.insert k a s) <$> fresh -generalize :: MonoT -> Check PolyT -generalize t = getEnv >>= \env -> pure (Forall (free t \\ free env) t) +generalize :: MonoT -> Infer PolyT +generalize t = ask >>= \env -> pure (Forall (free t \\ free env) t) -lookupType :: Id -> Check MonoT -lookupType i = getEnv >>= \env -> +lookupType :: Pos -> Id -> Infer MonoT +lookupType p i = ask >>= \env -> case M.lookup i env of - Nothing -> throwError (UnboundVariable i) + Nothing -> throwError (UnboundVariable p i) Just t -> instantiate t constructs :: Id -> MonoT -> Bool @@ -147,52 +58,38 @@ constructs i (TArr _ t) = constructs i t constructs i1 (TCon i2) = i1 == i2 constructs _ _ = False -infer :: Exp -> Check (Subst, MonoT) -infer = setPos >=> \case +infer :: Exp -> Infer MonoT +infer = \case - Var _ i -> (emptySubst,) <$> lookupType i + Var p i -> lookupType p i Let _ [] e -> infer e Let p ((i,e1):ies) e2 -> do - (s1, t1) <- infer e1 - apply s1 <$> getEnv >>= \e -> localEnv e $ do - t1g <- generalize t1 - addEnv i t1g - (s2, t2) <- infer (Let p ies e2) - pure (s2 <&> s1, t2) + t1 <- generalize =<< infer e1 + localEnv i t1 (infer (Let p ies e2)) Abs _ [] e -> infer e - Abs p (i:is) e -> localEnv' $ do + Abs p (i:is) e -> do tv <- fresh - addEnv i (Forall S.empty tv) - (s, t) <- infer (Abs p is e) - pure (s, apply s tv `TArr` t) + t <- localEnv i (Forall S.empty tv) (infer (Abs p is e)) + pure (tv `TArr` t) - App _ e es -> go e (reverse es) + App p e es -> go p e (reverse es) where - go :: Exp -> [Exp] -> Check (Subst, MonoT) - go _ [] = throwError Oop - go e1 [e2] = localEnv' $ do - - (s1, t1) <- infer e1 - applyEnv s1 - - (s2, t2) <- infer e2 - + go :: Pos -> Exp -> [Exp] -> Infer MonoT + go _ _ [] = throwError Oop + go p e1 [e2] = do + t1 <- infer e1 + t2 <- infer e2 tv <- fresh + uni t1 (t2 `TArr` tv) - s3 <- unify (apply s2 t1) (t2 `TArr` tv) - pure (s3 <&> s2 <&> s1, apply s3 tv) - - go e1 (e2:es) = localEnv' $ do - - (s1, t1) <- go e1 es - applyEnv s1 - - (s2, t2) <- infer e2 + pure tv + go p e1 (e2:es) = do + t1 <- go p e1 es + t2 <- infer e2 tv <- fresh + uni t1 (t2 `TArr` tv) - s3 <- unify (apply s2 t1) (t2 `TArr` tv) - - pure (s3 <&> s2 <&> s1, apply s3 tv) + pure tv diff --git a/src/Type.hs b/src/Type.hs index 3ef2ff6..8208ba8 100644 --- a/src/Type.hs +++ b/src/Type.hs @@ -5,17 +5,25 @@ module Type , Id(..) ) where +import Control.Monad.RWS import Control.Monad.Reader -import Control.Monad.State import Control.Monad.Except import Data.Map (Map) import Data.Set (Set) import Data.Text (Text) +import qualified Data.Set as S +import qualified Data.Map as M +import qualified Data.Text as T + import Hm.Abs (Id(..)) -import qualified Hm.Abs as H (TypeSig'(..), TypeSig(..)) +import qualified Hm.Abs as H (TypeSig'(..), TypeSig) + +import Misc + +import Prelude hiding (map) data PolyT = Forall (Set Id) MonoT -- ^ ∀ σ₁ σ₂ … σₙ. τ @@ -27,7 +35,7 @@ data MonoT = TArr MonoT MonoT -- ^ function | TVar Id -- ^ variable | TCon Id -- ^ constant - deriving Show + deriving (Eq, Show) type Pos = Maybe (Int, Int) @@ -69,11 +77,48 @@ instance Positioned H.TypeSig where H.TypeApp p _ _ -> p H.TypeVar p _ -> p +class Substitutable a where + apply :: Subst -> a -> a -- ^ apply a substitution + free :: a -> Set Id -- ^ free type variables + +instance Substitutable MonoT where + apply s = \case + TCon i -> TCon i + TVar i -> lookupDefault (TVar i) i s + (t1 `TArr` t2) -> apply s t1 `TArr` apply s t2 + + free = \case + TCon{} -> S.empty + TVar i -> S.singleton i + (t1 `TArr` t2) -> free t1 <> free t2 + +instance Substitutable PolyT where + apply s = \case + Forall as t -> Forall as (apply (foldr M.delete s as) t) + Mono t -> Mono (apply s t) + + free = \case + Forall as t -> free t \\ as + Mono t -> free t + +instance Substitutable TypeEnv where + apply s = map (apply s) + free = free . M.elems + +instance Substitutable a => Substitutable [a] where + apply = map . apply + free = foldMap free + +instance (Substitutable a, Substitutable b) => Substitutable (a, b) where + apply s (a, b) = (apply s a, apply s b) + free (a, b) = free a <> free b + + data TypeError = Oop -- ^ compiler error (oops) | UnificationFailure MonoT MonoT | InfiniteType Id MonoT - | UnboundVariable Id + | UnboundVariable Pos Id | Unimplemented Text | InvalidTypeDecl | InvalidConstructor @@ -84,13 +129,41 @@ type TypeEnv = Map Id PolyT type Subst = Map Id MonoT -data CheckState = CS { variables :: [Id] - , lastPos :: Pos - , typeEnv :: TypeEnv - } deriving Show +emptySubst :: Subst +emptySubst = M.empty -newtype Check a = Check { getCheck :: ExceptT TypeError (State CheckState) a } - deriving (Functor, Applicative, Monad, MonadError TypeError, MonadState CheckState) +-- This substution, and that one +(<&>) :: Subst -> Subst -> Subst +(<&>) s1 s2 = map (apply s1) s2 <> s1 -instance MonadFail Check where +type Constraint = (MonoT, MonoT) + +type CheckState = [Id] + +initialState :: [Id] +initialState = [1..] >>= map (Id . T.pack) . flip replicateM ['a'..'z'] + +newtype Infer a = Infer { getInfer :: RWST TypeEnv [Constraint] CheckState (Except TypeError) a } + deriving ( Functor, Applicative, Monad + , MonadError TypeError + , MonadState CheckState + , MonadReader TypeEnv + , MonadWriter [Constraint] + ) + +instance MonadFail Infer where + fail _ = throwError Oop + +type Unifier = (Subst, [Constraint]) + +emptyUnifier :: Unifier +emptyUnifier = (emptySubst, []) + +newtype Solve a = Solve { getSolve :: ReaderT Unifier (Except TypeError) a} + deriving ( Functor, Applicative, Monad + , MonadError TypeError + , MonadReader Unifier + ) + +instance MonadFail Solve where fail _ = throwError Oop