-
Notifications
You must be signed in to change notification settings - Fork 256
/
Copy pathInfer.hs
284 lines (230 loc) · 7.8 KB
/
Infer.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Infer (
Constraint,
TypeError(..),
Subst(..),
inferTop,
constraintsExpr
) where
import Env
import Type
import Syntax
import Control.Monad.Except
import Control.Monad.State
import Control.Monad.RWS
import Control.Monad.Identity
import Data.List (nub)
import qualified Data.Map as Map
import qualified Data.Set as Set
-------------------------------------------------------------------------------
-- Classes
-------------------------------------------------------------------------------
-- | Inference monad
type Infer a = (RWST
Env -- Typing environment
[Constraint] -- Generated constraints
InferState -- Inference state
(Except -- Inference errors
TypeError)
a) -- Result
-- | Inference state
data InferState = InferState { count :: Int }
-- | Initial inference state
initInfer :: InferState
initInfer = InferState { count = 0 }
type Constraint = (Type, Type)
type Unifier = (Subst, [Constraint])
-- | Constraint solver monad
type Solve a = ExceptT TypeError Identity a
newtype Subst = Subst (Map.Map TVar Type)
deriving (Eq, Ord, Show, Monoid)
class Substitutable a where
apply :: Subst -> a -> a
ftv :: a -> Set.Set TVar
instance Substitutable Type where
apply _ (TCon a) = TCon a
apply (Subst s) t@(TVar a) = Map.findWithDefault t a s
apply s (t1 `TArr` t2) = apply s t1 `TArr` apply s t2
ftv TCon{} = Set.empty
ftv (TVar a) = Set.singleton a
ftv (t1 `TArr` t2) = ftv t1 `Set.union` ftv t2
instance Substitutable Scheme where
apply (Subst s) (Forall as t) = Forall as $ apply s' t
where s' = Subst $ foldr Map.delete s as
ftv (Forall as t) = ftv t `Set.difference` Set.fromList as
instance Substitutable Constraint where
apply s (t1, t2) = (apply s t1, apply s t2)
ftv (t1, t2) = ftv t1 `Set.union` ftv t2
instance Substitutable a => Substitutable [a] where
apply = map . apply
ftv = foldr (Set.union . ftv) Set.empty
instance Substitutable Env where
apply s (TypeEnv env) = TypeEnv $ Map.map (apply s) env
ftv (TypeEnv env) = ftv $ Map.elems env
data TypeError
= UnificationFail Type Type
| InfiniteType TVar Type
| UnboundVariable String
| Ambigious [Constraint]
| UnificationMismatch [Type] [Type]
-------------------------------------------------------------------------------
-- Inference
-------------------------------------------------------------------------------
-- | Run the inference monad
runInfer :: Env -> Infer Type -> Either TypeError (Type, [Constraint])
runInfer env m = runExcept $ evalRWST m env initInfer
-- | Solve for the toplevel type of an expression in a given environment
inferExpr :: Env -> Expr -> Either TypeError Scheme
inferExpr env ex = do
(ty, cs) <- runInfer env (infer ex)
subst <- runSolve cs
return (closeOver (apply subst ty))
-- | Return the internal constraints used in solving for the type of an expression
constraintsExpr :: Env -> Expr -> Either TypeError ([Constraint], Subst, Type, Scheme)
constraintsExpr env ex = do
(ty, cs) <- runInfer env (infer ex)
subst <- runSolve cs
let sc = closeOver $ apply subst ty
return (cs, subst, ty, sc)
-- | Canonicalize and return the polymorphic toplevel type.
closeOver :: Type -> Scheme
closeOver = normalize . generalize Env.empty
-- | Unify two types
uni :: Type -> Type -> Infer ()
uni t1 t2 = tell [(t1, t2)]
-- | Extend type environment
inEnv :: (Name, Scheme) -> Infer a -> Infer a
inEnv (x, sc) m = do
let scope e = (remove e x) `extend` (x, sc)
local scope m
-- | Lookup type in the environment
lookupEnv :: Name -> Infer Type
lookupEnv x = do
(TypeEnv env) <- ask
case Map.lookup x env of
Nothing -> throwError $ UnboundVariable x
Just s -> do t <- instantiate s
return t
letters :: [String]
letters = [1..] >>= flip replicateM ['a'..'z']
fresh :: Infer Type
fresh = do
s <- get
put s{count = count s + 1}
return $ TVar $ TV (letters !! count s)
instantiate :: Scheme -> Infer Type
instantiate (Forall as t) = do
as' <- mapM (\_ -> fresh) as
let s = Subst $ Map.fromList $ zip as as'
return $ apply s t
generalize :: Env -> Type -> Scheme
generalize env t = Forall as t
where as = Set.toList $ ftv t `Set.difference` ftv env
ops :: Map.Map Binop Type
ops = Map.fromList [
(Add, (typeInt `TArr` (typeInt `TArr` typeInt)))
, (Mul, (typeInt `TArr` (typeInt `TArr` typeInt)))
, (Sub, (typeInt `TArr` (typeInt `TArr` typeInt)))
, (Eql, (typeInt `TArr` (typeInt `TArr` typeBool)))
]
infer :: Expr -> Infer Type
infer expr = case expr of
Lit (LInt _) -> return $ typeInt
Lit (LBool _) -> return $ typeBool
Var x -> lookupEnv x
Lam x e -> do
tv <- fresh
t <- inEnv (x, Forall [] tv) (infer e)
return (tv `TArr` t)
App e1 e2 -> do
t1 <- infer e1
t2 <- infer e2
tv <- fresh
uni t1 (t2 `TArr` tv)
return tv
Let x e1 e2 -> do
env <- ask
t1 <- infer e1
let sc = generalize env t1
t2 <- inEnv (x, sc) (infer e2)
return t2
Fix e1 -> do
t1 <- infer e1
tv <- fresh
uni (tv `TArr` tv) t1
return tv
Op op e1 e2 -> do
t1 <- infer e1
t2 <- infer e2
tv <- fresh
let u1 = t1 `TArr` (t2 `TArr` tv)
u2 = ops Map.! op
uni u1 u2
return tv
If cond tr fl -> do
t1 <- infer cond
t2 <- infer tr
t3 <- infer fl
uni t1 typeBool
uni t2 t3
return t2
inferTop :: Env -> [(String, Expr)] -> Either TypeError Env
inferTop env [] = Right env
inferTop env ((name, ex):xs) = do
ty <- (inferExpr env ex)
inferTop (extend env (name, ty)) xs
normalize :: Scheme -> Scheme
normalize (Forall _ body) = Forall (map snd ord) (normtype body)
where
ord = zip (nub $ fv body) (map TV letters)
fv (TVar a) = [a]
fv (TArr a b) = fv a ++ fv b
fv (TCon _) = []
normtype (TArr a b) = TArr (normtype a) (normtype b)
normtype (TCon a) = TCon a
normtype (TVar a) =
case Prelude.lookup a ord of
Just x -> TVar x
Nothing -> error "type variable not in signature"
-------------------------------------------------------------------------------
-- Constraint Solver
-------------------------------------------------------------------------------
-- | The empty substitution
emptySubst :: Subst
emptySubst = mempty
-- | Compose substitutions
compose :: Subst -> Subst -> Subst
(Subst s1) `compose` (Subst s2) = Subst $ Map.map (apply (Subst s1)) s2 `Map.union` s1
-- | Run the constraint solver
runSolve :: [Constraint] -> Either TypeError Subst
runSolve cs = runIdentity $ runExceptT $ solver st
where st = (emptySubst, cs)
unifyMany :: [Type] -> [Type] -> Solve Subst
unifyMany [] [] = return emptySubst
unifyMany (t1 : ts1) (t2 : ts2) =
do su1 <- unifies t1 t2
su2 <- unifyMany (apply su1 ts1) (apply su1 ts2)
return (su2 `compose` su1)
unifyMany t1 t2 = throwError $ UnificationMismatch t1 t2
unifies :: Type -> Type -> Solve Subst
unifies t1 t2 | t1 == t2 = return emptySubst
unifies (TVar v) t = v `bind` t
unifies t (TVar v) = v `bind` t
unifies (TArr t1 t2) (TArr t3 t4) = unifyMany [t1, t2] [t3, t4]
unifies t1 t2 = throwError $ UnificationFail t1 t2
-- Unification solver
solver :: Unifier -> Solve Subst
solver (su, cs) =
case cs of
[] -> return su
((t1, t2): cs0) -> do
su1 <- unifies t1 t2
solver (su1 `compose` su, (apply su1 cs0))
bind :: TVar -> Type -> Solve Subst
bind a t | t == TVar a = return emptySubst
| occursCheck a t = throwError $ InfiniteType a t
| otherwise = return $ (Subst $ Map.singleton a t)
occursCheck :: Substitutable a => TVar -> a -> Bool
occursCheck a t = a `Set.member` ftv t