move to constrain generation -> solving model. TODO: move code between modules, clean up

This commit is contained in:
Rachel Lambda Samuelsson 2022-01-28 13:46:42 +01:00
parent 413f7d3a21
commit 687b65cd4e
8 changed files with 184 additions and 166 deletions

View File

@ -9,8 +9,11 @@ import Data.Text (Text)
import qualified Data.Text.IO as T import qualified Data.Text.IO as T
import qualified Data.Set as S import qualified Data.Set as S
import qualified Data.Map as M
import TC (initialState, runCheck, infer, generalize) import Type (initialState, emptySubst, apply)
import TC (runInfer, infer, generalize)
import Solve (runSolve)
import PostProcess (expToExp, runProcess) import PostProcess (expToExp, runProcess)
import Pretty import Pretty
@ -26,12 +29,17 @@ inferType s = case pExp ts of
putStrLn "\nParse Successful!" putStrLn "\nParse Successful!"
putStrLn (printTree tree) putStrLn (printTree tree)
let action = runProcess (expToExp tree) S.empty >>= infer >>= generalize . snd let action = runProcess (expToExp tree) S.empty >>= infer
let result = fst (runCheck initialState action) let result = runInfer M.empty action
case result of case result of
Left err -> print err Left err -> print err
Right res -> T.putStrLn (pretty res) Right (t,_,c) -> case runSolve (emptySubst, c) of
Left err -> print err
Right subst -> case runInfer M.empty (generalize (apply subst t)) of
Left err -> print err
Right (t,_,_) -> T.putStrLn (pretty t)
where where
ts = init (resolveLayout True (myLexer s)) ts = init (resolveLayout True (myLexer s))
showPosToken ((l,c),t) = concat [ show l, ":", show c, "\t", show t ] showPosToken ((l,c),t) = concat [ show l, ":", show c, "\t", show t ]

View File

@ -32,6 +32,7 @@ library
, Misc , Misc
, PostProcess , PostProcess
, Pretty , Pretty
, Solve
other-modules: Hm.ErrM other-modules: Hm.ErrM
build-tool-depends: alex:alex >= 3.0, happy:happy >= 1.19.5 build-tool-depends: alex:alex >= 3.0, happy:happy >= 1.19.5

View File

@ -7,6 +7,8 @@ import qualified Data.Map as M
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Prelude hiding (map)
lookupDefault :: Ord k => a -> k -> Map k a -> a lookupDefault :: Ord k => a -> k -> Map k a -> a
lookupDefault d k m = fromMaybe d (M.lookup k m) lookupDefault d k m = fromMaybe d (M.lookup k m)
@ -21,3 +23,6 @@ infix 5 <~>
infixr 9 .: infixr 9 .:
(.:) :: (c -> d) -> (a -> b -> c) -> a -> b -> d (.:) :: (c -> d) -> (a -> b -> c) -> a -> b -> d
(.:) = (.) . (.) (.:) = (.) . (.)
map :: Functor f => (a -> b) -> f a -> f b
map = fmap

View File

@ -19,12 +19,12 @@ import TC
import Prelude hiding (map) import Prelude hiding (map)
-- Type env for parsing type signatures -- Type env for parsing type signatures
type Process = StateT (Set Id) Check type Process = StateT (Set Id) Infer
insertType :: Id -> Process () insertType :: Id -> Process ()
insertType i = get >>= put . S.insert i insertType i = get >>= put . S.insert i
runProcess :: Process a -> Set Id -> Check a runProcess :: Process a -> Set Id -> Infer a
runProcess = (fst <$>) .: runStateT runProcess = (fst <$>) .: runStateT
postprocess :: [H.Def] -> Process [TL] postprocess :: [H.Def] -> Process [TL]
@ -38,10 +38,8 @@ addDef = \case
-- add type before typesig to id params -- add type before typesig to id params
defToTL :: H.Def -> Process TL defToTL :: H.Def -> Process TL
defToTL (H.VarDef p i t e) = VarDef <$> lift (setPos p) <*> pure i <*> typeSigToPolyT t <*> expToExp e defToTL (H.VarDef p i t e) = VarDef p i <$> typeSigToPolyT t <*> expToExp e
defToTL (H.TypeDef p t ds) = do defToTL (H.TypeDef p t ds) = do
_ <- lift (setPos p)
(i,_) <- typeSigToIdParams t (i,_) <- typeSigToIdParams t
let (Id s) = i let (Id s) = i
@ -72,7 +70,7 @@ defToTL (H.TypeDef p t ds) = do
typeSigToIdParams :: H.TypeSig -> Process (Id, [Id]) typeSigToIdParams :: H.TypeSig -> Process (Id, [Id])
typeSigToIdParams = lift . setPos >=> \case typeSigToIdParams = \case
H.TypeFun{} -> throwError InvalidTypeDecl H.TypeFun{} -> throwError InvalidTypeDecl
H.TypeApp{} -> throwError (Unimplemented "Type parameters") H.TypeApp{} -> throwError (Unimplemented "Type parameters")
H.TypeVar _ i -> pure (i, []) H.TypeVar _ i -> pure (i, [])

View File

@ -8,7 +8,6 @@ import qualified Data.Map as M
import Type import Type
import Data.List (sort) import Data.List (sort)
import TC (initialState, apply, free)
class Pretty a where class Pretty a where
pretty :: a -> Text pretty :: a -> Text
@ -39,7 +38,7 @@ instance Normalize MonoT where
normalize t = apply (goS t) t normalize t = apply (goS t) t
go :: MonoT -> [(Id, Id)] go :: MonoT -> [(Id, Id)]
go t = zip (sort (S.toList (free t))) (variables initialState) go t = zip (sort (S.toList (free t))) initialState
goS :: MonoT -> Subst goS :: MonoT -> Subst
goS = M.fromList . map (\(x,y) -> (x, TVar y)) . go goS = M.fromList . map (\(x,y) -> (x, TVar y)) . go

37
src/Solve.hs Normal file
View File

@ -0,0 +1,37 @@
{-# LANGUAGE LambdaCase #-}
module Solve where
import Control.Monad.Reader
import Control.Monad.Except
import qualified Data.Map as M
import qualified Data.Set as S
import Type
unify :: MonoT -> MonoT -> Solve Unifier
unify t1 t2 | t1 == t2 = pure emptyUnifier
unify (l1 `TArr` r1) (l2 `TArr` r2) = do
(s1,c1) <- unify l1 l2
(s2,c2) <- unify (apply s1 r1) (apply s1 r2)
pure (s1 <&> s2, c1 ++ c2)
unify (TVar i) t = bind i t
unify t (TVar i) = bind i t
unify t1 t2 = throwError (UnificationFailure t1 t2)
bind :: Id -> MonoT -> Solve Unifier
bind i1 (TVar i2) | i1 == i2 = pure emptyUnifier
bind i t | S.member i (free t) = throwError (InfiniteType i t)
| otherwise = pure (M.singleton i t, [])
solver :: Solve Subst
solver = ask >>= \case
(subst,[]) -> pure subst
(s0, (t1, t2) : cs) -> do
(s1, c1) <- unify t1 t2
local (const (s1 <&> s0, c1 ++ apply s1 cs)) solver
runSolve :: Unifier -> Either TypeError Subst
runSolve = runExcept . runReaderT (getSolve solver)

181
src/TC.hs
View File

@ -2,9 +2,9 @@
{-# LANGUAGE TupleSections, FlexibleInstances #-} {-# LANGUAGE TupleSections, FlexibleInstances #-}
module TC where module TC where
import Control.Monad.Reader hiding (guard) import Control.Monad.Identity hiding (guard)
import Control.Monad.State hiding (guard)
import Control.Monad.Except hiding (guard) import Control.Monad.Except hiding (guard)
import Control.Monad.RWS hiding (guard)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as S import qualified Data.Set as S
@ -17,129 +17,40 @@ import Misc
import Prelude hiding (map) import Prelude hiding (map)
map :: Functor f => (a -> b) -> f a -> f b runInfer :: TypeEnv -> Infer a -> Either TypeError (a, [Id], [Constraint])
map = fmap runInfer r = runIdentity . runExceptT . (\i -> runRWST i r initialState) . getInfer
runCheck :: CheckState -> Check a -> (Either TypeError a, CheckState) localEnv :: Id -> PolyT -> Infer a -> Infer a
runCheck s = (flip runState) s . runExceptT . getCheck localEnv i t = local (M.insert i t)
-- I'm still not quite sure how replicateM works, but in this instance it is
-- used to generate a list of strings "a", "b" ... "z", "aa", "ab" ... so on
--
-- Does it make sense to start with an empty state?
initialState :: CheckState
initialState = CS ([1..] >>= map (Id . T.pack) . flip replicateM ['a'..'z']) Nothing M.empty
getVars :: Check [Id]
getVars = variables <$> get
setVars :: [Id] -> Check ()
setVars ids = get >>= \s -> put (CS ids (lastPos s) (typeEnv s))
getEnv :: Check TypeEnv
getEnv = typeEnv <$> get
setEnv :: TypeEnv -> Check ()
setEnv env = get >>= \s -> put (CS (variables s) (lastPos s) env)
addEnv :: Id -> PolyT -> Check ()
addEnv i p = getEnv >>= setEnv . M.insert i p
localEnv :: TypeEnv -> Check a -> Check a
localEnv e m = getEnv >>= \o -> setEnv e >> m >>= \r -> setEnv o >> pure r
localEnv' :: Check a -> Check a
localEnv' m = getEnv >>= \o -> m >>= \r -> setEnv o >> pure r
-- returns p again to allow chaining into lambdacase
setPos :: Positioned p => p -> Check p
setPos p = get >>= \s -> put (CS (variables s) (pos p) (typeEnv s)) >> pure p
guard :: Applicative f => f () -> Bool -> f () guard :: Applicative f => f () -> Bool -> f ()
guard _ True = pure () guard _ True = pure ()
guard f False = f guard f False = f
class Substitutable a where uni :: MonoT -> MonoT -> Infer ()
apply :: Subst -> a -> a -- ^ apply a substitution uni t1 t2 = tell [(t1, t2)]
free :: a -> Set Id -- ^ free type variables
instance Substitutable MonoT where fresh :: Infer MonoT
apply s = \case
TCon i -> TCon i
TVar i -> lookupDefault (TVar i) i s
(t1 `TArr` t2) -> apply s t1 `TArr` apply s t2
free = \case
TCon{} -> S.empty
TVar i -> S.singleton i
(t1 `TArr` t2) -> free t1 <> free t2
instance Substitutable PolyT where
apply s = \case
Forall as t -> Forall as (apply (foldr M.delete s as) t)
Mono t -> Mono (apply s t)
free = \case
Forall as t -> free t \\ as
Mono t -> free t
instance Substitutable TypeEnv where
apply s = map (apply s)
free = free . M.elems
instance Substitutable a => Substitutable [a] where
apply = map . apply
free = foldMap free
applyEnv :: Subst -> Check ()
applyEnv s = getEnv >>= setEnv . apply s
-- This substution, and that one
(<&>) :: Subst -> Subst -> Subst
(<&>) s1 s2 = map (apply s1) s2 <> s1
emptySubst :: Subst
emptySubst = M.empty
unify :: MonoT -> MonoT -> Check Subst
unify (l1 `TArr` r1) (l2 `TArr` r2) = do
s1 <- unify l1 l2
s2 <- unify (apply s1 r1) (apply s1 r2)
pure (s1 <&> s2)
unify (TVar i) t = bind i t
unify t (TVar i) = bind i t
unify (TCon i1) (TCon i2) | i1 == i2 = pure emptySubst
unify t1 t2 = throwError (UnificationFailure t1 t2)
bind :: Id -> MonoT -> Check Subst
bind i1 (TVar i2) | i1 == i2 = pure emptySubst
bind i t | S.member i (free t) = throwError (InfiniteType i t)
| otherwise = pure (M.singleton i t)
fresh :: Check MonoT
fresh = do fresh = do
(var:vars) <- getVars (var:vars) <- get
setVars vars put vars
pure (TVar var) pure (TVar var)
-- replace polymorphic type variables with monomorphic ones -- replace polymorphic type variables with monomorphic ones
instantiate :: PolyT -> Check MonoT instantiate :: PolyT -> Infer MonoT
instantiate (Mono t) = pure t instantiate (Mono t) = pure t
instantiate (Forall is t) = foldM freshInsert emptySubst is >>= pure . (flip apply) t instantiate (Forall is t) = foldM freshInsert emptySubst is >>= pure . (flip apply) t
where where
freshInsert :: Subst -> Id -> Check Subst freshInsert :: Subst -> Id -> Infer Subst
freshInsert s k = (\a -> M.insert k a s) <$> fresh freshInsert s k = (\a -> M.insert k a s) <$> fresh
generalize :: MonoT -> Check PolyT generalize :: MonoT -> Infer PolyT
generalize t = getEnv >>= \env -> pure (Forall (free t \\ free env) t) generalize t = ask >>= \env -> pure (Forall (free t \\ free env) t)
lookupType :: Id -> Check MonoT lookupType :: Pos -> Id -> Infer MonoT
lookupType i = getEnv >>= \env -> lookupType p i = ask >>= \env ->
case M.lookup i env of case M.lookup i env of
Nothing -> throwError (UnboundVariable i) Nothing -> throwError (UnboundVariable p i)
Just t -> instantiate t Just t -> instantiate t
constructs :: Id -> MonoT -> Bool constructs :: Id -> MonoT -> Bool
@ -147,52 +58,38 @@ constructs i (TArr _ t) = constructs i t
constructs i1 (TCon i2) = i1 == i2 constructs i1 (TCon i2) = i1 == i2
constructs _ _ = False constructs _ _ = False
infer :: Exp -> Check (Subst, MonoT) infer :: Exp -> Infer MonoT
infer = setPos >=> \case infer = \case
Var _ i -> (emptySubst,) <$> lookupType i Var p i -> lookupType p i
Let _ [] e -> infer e Let _ [] e -> infer e
Let p ((i,e1):ies) e2 -> do Let p ((i,e1):ies) e2 -> do
(s1, t1) <- infer e1 t1 <- generalize =<< infer e1
apply s1 <$> getEnv >>= \e -> localEnv e $ do localEnv i t1 (infer (Let p ies e2))
t1g <- generalize t1
addEnv i t1g
(s2, t2) <- infer (Let p ies e2)
pure (s2 <&> s1, t2)
Abs _ [] e -> infer e Abs _ [] e -> infer e
Abs p (i:is) e -> localEnv' $ do Abs p (i:is) e -> do
tv <- fresh tv <- fresh
addEnv i (Forall S.empty tv) t <- localEnv i (Forall S.empty tv) (infer (Abs p is e))
(s, t) <- infer (Abs p is e) pure (tv `TArr` t)
pure (s, apply s tv `TArr` t)
App _ e es -> go e (reverse es) App p e es -> go p e (reverse es)
where where
go :: Exp -> [Exp] -> Check (Subst, MonoT) go :: Pos -> Exp -> [Exp] -> Infer MonoT
go _ [] = throwError Oop go _ _ [] = throwError Oop
go e1 [e2] = localEnv' $ do go p e1 [e2] = do
t1 <- infer e1
(s1, t1) <- infer e1 t2 <- infer e2
applyEnv s1
(s2, t2) <- infer e2
tv <- fresh tv <- fresh
uni t1 (t2 `TArr` tv)
s3 <- unify (apply s2 t1) (t2 `TArr` tv) pure tv
pure (s3 <&> s2 <&> s1, apply s3 tv)
go e1 (e2:es) = localEnv' $ do
(s1, t1) <- go e1 es
applyEnv s1
(s2, t2) <- infer e2
go p e1 (e2:es) = do
t1 <- go p e1 es
t2 <- infer e2
tv <- fresh tv <- fresh
uni t1 (t2 `TArr` tv)
s3 <- unify (apply s2 t1) (t2 `TArr` tv) pure tv
pure (s3 <&> s2 <&> s1, apply s3 tv)

View File

@ -5,17 +5,25 @@ module Type
, Id(..) , Id(..)
) where ) where
import Control.Monad.RWS
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Except import Control.Monad.Except
import Data.Map (Map) import Data.Map (Map)
import Data.Set (Set) import Data.Set (Set)
import Data.Text (Text) import Data.Text (Text)
import qualified Data.Set as S
import qualified Data.Map as M
import qualified Data.Text as T
import Hm.Abs (Id(..)) import Hm.Abs (Id(..))
import qualified Hm.Abs as H (TypeSig'(..), TypeSig(..)) import qualified Hm.Abs as H (TypeSig'(..), TypeSig)
import Misc
import Prelude hiding (map)
data PolyT data PolyT
= Forall (Set Id) MonoT -- ^ ∀ σ₁ σ₂ … σₙ. τ = Forall (Set Id) MonoT -- ^ ∀ σ₁ σ₂ … σₙ. τ
@ -27,7 +35,7 @@ data MonoT
= TArr MonoT MonoT -- ^ function = TArr MonoT MonoT -- ^ function
| TVar Id -- ^ variable | TVar Id -- ^ variable
| TCon Id -- ^ constant | TCon Id -- ^ constant
deriving Show deriving (Eq, Show)
type Pos = Maybe (Int, Int) type Pos = Maybe (Int, Int)
@ -69,11 +77,48 @@ instance Positioned H.TypeSig where
H.TypeApp p _ _ -> p H.TypeApp p _ _ -> p
H.TypeVar p _ -> p H.TypeVar p _ -> p
class Substitutable a where
apply :: Subst -> a -> a -- ^ apply a substitution
free :: a -> Set Id -- ^ free type variables
instance Substitutable MonoT where
apply s = \case
TCon i -> TCon i
TVar i -> lookupDefault (TVar i) i s
(t1 `TArr` t2) -> apply s t1 `TArr` apply s t2
free = \case
TCon{} -> S.empty
TVar i -> S.singleton i
(t1 `TArr` t2) -> free t1 <> free t2
instance Substitutable PolyT where
apply s = \case
Forall as t -> Forall as (apply (foldr M.delete s as) t)
Mono t -> Mono (apply s t)
free = \case
Forall as t -> free t \\ as
Mono t -> free t
instance Substitutable TypeEnv where
apply s = map (apply s)
free = free . M.elems
instance Substitutable a => Substitutable [a] where
apply = map . apply
free = foldMap free
instance (Substitutable a, Substitutable b) => Substitutable (a, b) where
apply s (a, b) = (apply s a, apply s b)
free (a, b) = free a <> free b
data TypeError data TypeError
= Oop -- ^ compiler error (oops) = Oop -- ^ compiler error (oops)
| UnificationFailure MonoT MonoT | UnificationFailure MonoT MonoT
| InfiniteType Id MonoT | InfiniteType Id MonoT
| UnboundVariable Id | UnboundVariable Pos Id
| Unimplemented Text | Unimplemented Text
| InvalidTypeDecl | InvalidTypeDecl
| InvalidConstructor | InvalidConstructor
@ -84,13 +129,41 @@ type TypeEnv = Map Id PolyT
type Subst = Map Id MonoT type Subst = Map Id MonoT
data CheckState = CS { variables :: [Id] emptySubst :: Subst
, lastPos :: Pos emptySubst = M.empty
, typeEnv :: TypeEnv
} deriving Show
newtype Check a = Check { getCheck :: ExceptT TypeError (State CheckState) a } -- This substution, and that one
deriving (Functor, Applicative, Monad, MonadError TypeError, MonadState CheckState) (<&>) :: Subst -> Subst -> Subst
(<&>) s1 s2 = map (apply s1) s2 <> s1
instance MonadFail Check where type Constraint = (MonoT, MonoT)
type CheckState = [Id]
initialState :: [Id]
initialState = [1..] >>= map (Id . T.pack) . flip replicateM ['a'..'z']
newtype Infer a = Infer { getInfer :: RWST TypeEnv [Constraint] CheckState (Except TypeError) a }
deriving ( Functor, Applicative, Monad
, MonadError TypeError
, MonadState CheckState
, MonadReader TypeEnv
, MonadWriter [Constraint]
)
instance MonadFail Infer where
fail _ = throwError Oop
type Unifier = (Subst, [Constraint])
emptyUnifier :: Unifier
emptyUnifier = (emptySubst, [])
newtype Solve a = Solve { getSolve :: ReaderT Unifier (Except TypeError) a}
deriving ( Functor, Applicative, Monad
, MonadError TypeError
, MonadReader Unifier
)
instance MonadFail Solve where
fail _ = throwError Oop fail _ = throwError Oop