Skip to content

Commit

Permalink
Use a record environment in inference
Browse files Browse the repository at this point in the history
This is a non-functional change that should make it easier to thread
the names of builtin interfaces through the inference process.
  • Loading branch information
apaszke committed Sep 21, 2021
1 parent 43b30f9 commit 1ad4fe7
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ import PPrint
import Cat
import Util

type UInferM = ReaderT SubstEnv (ReaderT SrcCtx ((BuilderT (SolverT (Either Err)))))
data UInferEnv = UInferEnv
{ inferSubst :: SubstEnv
, srcCtx :: SrcCtx
}
type UInferM = ReaderT UInferEnv (BuilderT (SolverT (Either Err)))

type SigmaType = Type -- may start with an implicit lambda
type RhoType = Type -- doesn't start with an implicit lambda
Expand All @@ -53,7 +57,7 @@ inferModule :: Bindings -> UModule -> Except Module
inferModule scope (UModule uDecl sourceMap) = do
(evaluated, ((bindings, synthCandidates), decls)) <- runUInferM mempty scope mempty do
substEnv <- inferUDecl uDecl
extendR substEnv $ substBuilderR $ EvaluatedModule mempty mempty sourceMap
substBuilder substEnv $ EvaluatedModule mempty mempty sourceMap
let evaluated' = addSynthCandidates evaluated synthCandidates
if requiresEvaluation
then return $ Module Typed decls evaluated'
Expand Down Expand Up @@ -86,7 +90,7 @@ runUInferM :: (HasVars a, Subst a, Pretty a)
=> SubstEnv -> Scope -> SynthCandidates
-> UInferM a -> Except (a, ((Scope, SynthCandidates), Nest Decl))
runUInferM env scope scs m = runSolverT do
runBuilderT scope scs $ runReaderT (runReaderT m env) Nothing
runBuilderT scope scs $ runReaderT m $ UInferEnv env Nothing

checkSigma :: UExpr -> (Type -> RequiredTy Type) -> SigmaType -> UInferM Atom
checkSigma expr reqCon sTy = case sTy of
Expand Down Expand Up @@ -198,7 +202,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do
matchRequirement piTy
UDecl decl body -> do
env <- inferUDecl decl
extendR env $ checkOrInferRho body reqTy
extInferSubst env $ checkOrInferRho body reqTy
UCase scrut alts -> do
scrut' <- inferRho scrut
let scrutTy = getType scrut'
Expand Down Expand Up @@ -333,7 +337,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do
lookupUVar :: UVar -> UInferM Atom
lookupUVar (USourceVar _) = error "Shouldn't have source names left"
lookupUVar (UInternalVar v) = do
substEnv <- ask
substEnv <- getInferSubst
scope <- getScope
case envLookup substEnv v of
Nothing -> return $ fromJust $ nameToAtom scope v
Expand Down Expand Up @@ -403,7 +407,7 @@ withNestedBinders Empty cont = cont Empty
withNestedBinders (Nest (UAnnBinder b ty) rest) cont = do
ty' <- checkUType ty
withFreshName b ty' \x@(Var v) ->
extendR (b@>SubstVal x) $
extInferSubst (b@>SubstVal x) $
withNestedBinders rest \rest' ->
cont $ Nest (Bind v) rest'

Expand Down Expand Up @@ -442,7 +446,7 @@ checkInstance (Nest (UPatAnnArrow (UPatAnn p ann) arrow) rest) className params
buildLam (Bind $ patNameHint p :> argTy) (fromUArrow arrow) \(Var v) ->
checkLeaks [v] $ withBindPat p v $ checkInstance rest className params methods
checkInstance Empty className params methods = do
substEnv <- ask
substEnv <- getInferSubst
className' <- case envLookup substEnv className of
Nothing -> return className
Just (Rename className') -> return className'
Expand Down Expand Up @@ -480,7 +484,7 @@ checkUEffRow (EffectRow effs t) = do
lookupVarName :: Type -> UVar -> UInferM Name
lookupVarName ty ~(UInternalVar v) = do
-- TODO: more graceful errors on error
SubstVal (Var (v':>ty')) <- asks (!v)
SubstVal (Var (v':>ty')) <- (!v) <$> getInferSubst
constrainEq ty ty'
return v'

Expand Down Expand Up @@ -510,7 +514,7 @@ checkCaseAlt reqTy scrutineeTy (UAlt pat body) = do

lookupDataCon :: Name -> UInferM (NamedDataDef, Int)
lookupDataCon conName = do
substEnv <- ask
substEnv <- getInferSubst
conName' <- case envLookup substEnv conName of
Nothing -> return conName
Just (Rename v) -> return v
Expand Down Expand Up @@ -561,7 +565,7 @@ withBindPats pats body = foldr (uncurry withBindPat) body pats
withBindPat :: UPat -> Var -> UInferM a -> UInferM a
withBindPat pat var m = do
env <- bindPat pat $ Var var
extendR env m
extInferSubst env m

bindPat :: UPat -> Atom -> UInferM SubstEnv
bindPat (WithSrc pos pat) val = addSrcContext pos $ case pat of
Expand Down Expand Up @@ -740,13 +744,16 @@ openEffectRow (EffectRow effs Nothing) = extendEffRow effs <$> freshEff
openEffectRow effRow = return effRow

addSrcContext' :: SrcCtx -> UInferM a -> UInferM a
addSrcContext' pos m = do
env <- ask
addSrcContext pos $ lift $
local (const pos) $ runReaderT m env
addSrcContext' pos = addSrcContext pos . local (\e -> e { srcCtx = pos })

getSrcCtx :: UInferM SrcCtx
getSrcCtx = lift ask
getSrcCtx = asks srcCtx

getInferSubst :: UInferM SubstEnv
getInferSubst = asks inferSubst

extInferSubst :: SubstEnv -> UInferM a -> UInferM a
extInferSubst ext = local (\e -> e { inferSubst = inferSubst e <> ext })

-- === typeclass dictionary synthesizer ===

Expand Down

0 comments on commit 1ad4fe7

Please sign in to comment.