Skip to content

Commit

Permalink
Add explicit interface methods
Browse files Browse the repository at this point in the history
It is often useful to define type -> value mappings, and the standard
way to do that is through interfaces (type classes). However, because
all methods so far had implicit type parameters, any attempt to
associate e.g. an integer with a type was difficult: without explicit
type annotations it often ended up being ambiguous.

This patch allows specifying interface parameters between each method
name and the colon that begins its type annotations, with the mentioned
parameters becoming explicit type parameters of the generated methods.
This lets us remove the awkward `TypeVehicle` abstraction from prelude
and in the future should make it possible to define associated types.
For example, a `Manifold a` interface could have a method declared as
`TangentSpace a : Type`, which would later make it possible to mention
`TangentSpace MyDataType` in types.
  • Loading branch information
apaszke committed Oct 5, 2021
1 parent b80d02d commit 2ef2dcd
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 46 deletions.
18 changes: 5 additions & 13 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -676,38 +676,30 @@ def castPtr (ptr: Ptr a) : Ptr b =
(MkPtr rawPtr) = ptr
MkPtr rawPtr

-- Is there a better way to select the right instance for `storageSize`??
data TypeVehicle a = MkTypeVehicle
def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle

interface Storable a
store : Ptr a -> a -> {IO} Unit
load : Ptr a -> {IO} a
storageSize_ : TypeVehicle a -> Int

def storageSize (a:Type) -> (d:Storable a) ?=> : Int =
tv : TypeVehicle a = MkTypeVehicle
storageSize_ tv
storageSize a : Int

instance Storable Word8
store = \(MkPtr ptr) x. %ptrStore ptr x
load = \(MkPtr ptr) . %ptrLoad ptr
storageSize_ = const 1
storageSize = 1

instance Storable Int32
store = \(MkPtr ptr) x. %ptrStore (internalCast %Int32Ptr ptr) x
load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr)
storageSize_ = const 4
storageSize = 4

instance Storable Float32
store = \(MkPtr ptr) x. %ptrStore (internalCast %Float32Ptr ptr) x
load = \(MkPtr ptr) . %ptrLoad (internalCast %Float32Ptr ptr)
storageSize_ = const 4
storageSize = 4

instance Storable (Ptr a)
store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x
load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr)
storageSize_ = const 8 -- TODO: something more portable?
storageSize = 8 -- TODO: something more portable?

-- TODO: Storable instances for other types

Expand Down
4 changes: 2 additions & 2 deletions src/lib/Autodiff.hs
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,15 @@ tangentFunAsLambda m = do
DerivWrt activeVars effs <- ask
let hs = map (Bind . (:>TyKind) . effectRegion) effs
liftM (PairVal ans) $ lift $ do
buildNestedLam PureArrow hs \hVals -> do
buildNaryLam PureArrow (toNest hs) \hVals -> do
let hVarNames = map (\(Var (v:>_)) -> v) hVals
-- TODO: handle exception effect too
let effs' = zipWith (\(RWSEffect rws _) v -> RWSEffect rws v) effs hVarNames
-- want to use tangents here, not the original binders
let regionMap = newEnv (map ((:>()) . effectRegion) effs) hVals
-- TODO: Only bind tangents for free variables?
let activeVarBinders = map (Bind . fmap (tangentRefRegion regionMap)) $ envAsVars activeVars
buildNestedLam PureArrow activeVarBinders \activeVarArgs ->
buildNaryLam PureArrow (toNest activeVarBinders) \activeVarArgs ->
buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) \_ ->
runReaderT tanFun $ TangentEnv
(newEnv (envNames activeVars) activeVarArgs) hVarNames
Expand Down
32 changes: 16 additions & 16 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ module Builder (emit, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildPi,
app,
add, mul, sub, neg, div',
iadd, imul, isub, idiv, ilt, ieq,
fpow, flog, fLitLike, recGetHead, buildImplicitNaryLam, buildNaryLam,
fpow, flog, fLitLike, recGetHead, buildNaryLam,
select, substBuilder, substBuilderR, emitUnpack, getUnpacked,
fromPair, getFst, getSnd, getProj, getProjRef,
naryApp, appReduce, appTryReduce, buildAbs, buildAAbs, buildAAbsAux,
Expand Down Expand Up @@ -48,6 +48,7 @@ import Control.Monad.Reader
import Control.Monad.Writer hiding (Alt)
import Control.Monad.Identity
import Control.Monad.State.Strict
import Data.Functor ((<&>))
import Data.Foldable (toList)
import Data.List (elemIndex)
import Data.Maybe (fromJust)
Expand Down Expand Up @@ -261,34 +262,33 @@ emitSuperclass dataDef idx = do
getter <- makeSuperclassGetter dataDef idx
emitBinding $ SuperclassName dataDef idx getter

emitMethodType :: MonadBuilder m => ClassDefName -> Int -> m Name
emitMethodType classDef idx = do
getter <- makeMethodGetter classDef idx
emitMethodType :: MonadBuilder m => [Bool] -> ClassDefName -> Int -> m Name
emitMethodType explicit classDef idx = do
getter <- makeMethodGetter explicit classDef idx
emitBinding $ MethodName classDef idx getter

makeMethodGetter :: MonadBuilder m => ClassDefName -> Int -> m Atom
makeMethodGetter classDefName methodIdx = do
makeMethodGetter :: MonadBuilder m => [Bool] -> ClassDefName -> Int -> m Atom
makeMethodGetter explicit classDefName methodIdx = do
ClassDef def@(_, DataDef _ paramBs _) _ <- getClassDef classDefName
buildImplicitNaryLam paramBs \params -> do
let arrows = explicit <&> \case True -> PureArrow; False -> ImplicitArrow
let bs = toNest $ zip (toList paramBs) arrows
buildNestedLam bs \params -> do
buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do
return $ getProjection [methodIdx] $ getProjection [1, 0] dict

makeSuperclassGetter :: MonadBuilder m => DataDefName -> Int -> m Atom
makeSuperclassGetter classDefName methodIdx = do
ClassDef def@(_, DataDef _ paramBs _) _ <- getClassDef classDefName
buildImplicitNaryLam paramBs \params -> do
buildNaryLam ImplicitArrow paramBs \params -> do
buildLam (Bind ("d":> TypeCon def params)) PureArrow \dict -> do
return $ getProjection [methodIdx] $ getProjection [0, 0] dict

buildImplicitNaryLam :: MonadBuilder m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom
buildImplicitNaryLam bs body = buildNaryLam ImplicitArrow bs body

buildNaryLam :: MonadBuilder m => Arrow -> (Nest Binder) -> ([Atom] -> m Atom) -> m Atom
buildNaryLam _ Empty body = body []
buildNaryLam arr (Nest b bs) body =
buildLam b arr \x -> do
bs' <- substBuilder (b@>SubstVal x) bs
buildImplicitNaryLam bs' \xs -> body $ x:xs
buildNaryLam arr bs' \xs -> body $ x:xs

recGetHead :: Label -> Atom -> Atom
recGetHead l x = do
Expand Down Expand Up @@ -544,10 +544,10 @@ buildNestedFor specs body = go specs []
go [] indices = body $ reverse indices
go ((d,b):t) indices = buildFor d b $ \i -> go t (i:indices)

buildNestedLam :: MonadBuilder m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom
buildNestedLam _ [] f = f []
buildNestedLam arr (b:bs) f =
buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs)
buildNestedLam :: MonadBuilder m => Nest (Binder, Arrow) -> ([Atom] -> m Atom) -> m Atom
buildNestedLam Empty f = f []
buildNestedLam (Nest (b, arr) t) f =
buildLam b arr \x -> buildNestedLam t \xs -> f (x:xs)

tabGet :: MonadBuilder m => Atom -> Atom -> m Atom
tabGet tab idx = emit $ App tab idx
Expand Down
19 changes: 15 additions & 4 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,19 @@ inferUDecl (UInterface paramBs superclasses methodTys className methodNames) = d
paramBs superclasses methodTys
className' <- withNameHint className $ emitClassDef classDef
mapM_ (emitSuperclass className') [0..(length superclasses - 1)]
methodNames' <- forM (enumerate (toList methodNames)) \(i, name) ->
withNameHint name $ emitMethodType className' i
methodNames' <- forM (enumerate $ zip (toList methodNames) methodTys) \(i, (name, ty)) -> do
paramVs <- forM (toList paramBs) \(UAnnBinder v _) -> case v of
UBind n -> return $ UInternalVar n
_ -> throw CompilerErr "Unexpected interface binder. Please open a bug report!"
explicits <- case uMethodExplicitBs ty of
[] -> return $ replicate (length paramBs) False
e | e == paramVs -> return $ replicate (length paramBs) True
e -> case unexpected of
[] -> throw CompilerErr "Permuted or incomplete explicit type binders are not supported yet."
(h:_) -> throw TypeErr $ "Explicit type binder `" ++ pprint h ++ "` in method " ++
pprint name ++ " is not a type parameter of its interface"
where unexpected = filter (not . (`elem` paramVs)) e
withNameHint name $ emitMethodType explicits className' i
return $ className @> Rename className'
<> newEnv methodNames (map Rename methodNames')
inferUDecl (UInstance argBinders ~(UInternalVar className) params methods maybeName) = do
Expand All @@ -391,11 +402,11 @@ inferDataDef (UDataDef (tyConName, paramBs) dataCons) =
return $ DataDef tyConName paramBs' dataCons'

inferInterfaceDataDef :: SourceName -> [SourceName] -> Nest UAnnBinder
-> [UType] -> [UType] -> UInferM ClassDef
-> [UType] -> [UMethodType] -> UInferM ClassDef
inferInterfaceDataDef className methodNames paramBs superclasses methods = do
dictDef <- withNestedBinders paramBs \paramBs' -> do
superclasses' <- mapM checkUType superclasses
methods' <- mapM checkUType methods
methods' <- mapM checkUType $ uMethodType <$> methods
let dictContents = PairTy (ProdTy superclasses') (ProdTy methods')
return $ DataDef className paramBs'
[DataConDef ("Mk"<>className) (Nest (Ignore dictContents) Empty)]
Expand Down
10 changes: 6 additions & 4 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -684,11 +684,13 @@ instance Pretty UDecl where
align $ p ann <+> p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs)
pretty (UDataDefDecl (UDataDef bParams dataCons) bTyCon bDataCons) =
"data" <+> p bTyCon <+> p bParams
<+> "where" <> nest 2 (hardline <> prettyLines (zip (toList bDataCons) dataCons))
<+> "where" <> nest 2 (hardline <> prettyLines (zip (toList bDataCons) dataCons))
pretty (UInterface params superclasses methodTys interfaceName methodNames) =
let methods = [UAnnBinder b ty | (b, ty) <- zip (toList methodNames) methodTys]
in "interface" <+> p params <+> p superclasses <+> p interfaceName
<> hardline <> prettyLines methods
"interface" <+> p params <+> p superclasses <+> p interfaceName
<> hardline <> foldMap (<>hardline) methods
where
methods = [hsep (p <$> e) <+> p (UAnnBinder b ty) |
(b, UMethodType e ty) <- zip (toList methodNames) methodTys]
pretty (UInstance bs className params methods Nothing) =
"instance" <+> p bs <+> p className <+> p params <+> hardline <> prettyLines methods
pretty (UInstance bs className params methods (Just v)) =
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Parallelize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ buildParallelBlock ablock@(ABlock decls result) = do
unflattenIndexBundle :: MonadBuilder m => [Var] -> Atom -> m Atom
unflattenIndexBundle [] arr = return arr
unflattenIndexBundle [_] arr = return arr
unflattenIndexBundle ivs arr = buildNestedLam TabArrow (fmap Bind ivs) $ app arr . fst . mkBundle
unflattenIndexBundle ivs arr = buildNaryLam TabArrow (toNest $ fmap Bind ivs) $ app arr . fst . mkBundle

type Loop = Abs Binder Block
data NestDecision = Emit | Split (Nest Decl) (Binder, Loop) (Nest Decl)
Expand Down
6 changes: 4 additions & 2 deletions src/lib/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ interfaceDef = do
(tyConName, tyConParams) <- tyConDef
(methodNames, methodTys) <- unzip <$> onePerLine do
v <- anyName
explicit <- many anyName
ty <- annot uType
return (fromString v, ty)
return (fromString v, UMethodType (USourceVar <$> explicit) ty)
let methodNames' = toNest methodNames
let tyConParams' = tyConParams
return $ UInterface tyConParams' superclasses methodTys (fromString tyConName) methodNames'
Expand Down Expand Up @@ -947,7 +948,8 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar
keyWordStrs :: [String]
keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam",
"Read", "Write", "Accum", "Except", "IO", "data", "interface",
"instance", "named-instance", "where", "if", "then", "else", "do", "view", "import"]
"instance", "named-instance", "where", "if", "then", "else",
"do", "view", "import"]

fieldLabel :: Lexer Label
fieldLabel = label "field label" $ lexeme $
Expand Down
4 changes: 4 additions & 0 deletions src/lib/SourceRename.hs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ instance SourceRenamableB UDecl where
instanceName' <- mapM sourceRenameB instanceName
return $ UInstance conditions' className' params' methodDefs' instanceName'

instance SourceRenamableE UMethodType where
sourceRenameE (UMethodType expl ty) =
UMethodType <$> traverse sourceRenameE expl <*> sourceRenameE ty

sourceRenameUBinderNest :: Renamer m => (Name -> SourceNameDef)
-> Nest UBinder -> WithEnv RenameEnv m (Nest UBinder)
sourceRenameUBinderNest _ Empty = return Empty
Expand Down
9 changes: 5 additions & 4 deletions src/lib/Syntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ module Syntax (
monMapSingle, monMapLookup, Direction (..), Limit (..),
SourceName, SourceMap (..), UExpr, UExpr' (..), UType, UPatAnn (..),
UAnnBinder (..), UVar (..), UBinder (..), UMethodDef (..),
UMethodTypeDef, UPatAnnArrow (..), UVars,
UMethodType (..), UPatAnnArrow (..), UVars,
UPat, UPat' (..), SourceUModule (..), SourceNameDef (..), sourceNameDefName,
UModule (..), UDecl (..), UDataDef (..), UArrow, arrowEff,
UEffect, UEffectRow, UEffArrow,
Expand Down Expand Up @@ -283,7 +283,7 @@ data UDecl =
| UInterface
(Nest UAnnBinder) -- parameter binders
[UType] -- superclasses
[UType] -- method types
[UMethodType] -- method types
UBinder -- class name
(Nest UBinder) -- method names
| UInstance
Expand All @@ -299,7 +299,8 @@ type UEffect = EffectP UVar
type UEffectRow = EffectRowP UVar
type UEffArrow = ArrowP UEffectRow

type UMethodTypeDef = (UBinder, UType)
data UMethodType = UMethodType { uMethodExplicitBs :: [UVar], uMethodType :: UType }
deriving (Show, Generic)
data UMethodDef = UMethodDef UVar UExpr deriving (Show, Generic)

data UPatAnn = UPatAnn UPat (Maybe UType) deriving (Show, Generic)
Expand Down Expand Up @@ -823,7 +824,7 @@ instance HasUVars UDecl where
freeUVars (UDataDefDecl dataDef bTyCon bDataCons) =
freeUVars dataDef <> freeUVars (Abs bTyCon bDataCons)
freeUVars (UInterface paramBs superclasses methods _ _) =
freeUVars $ Abs paramBs (superclasses, methods)
freeUVars $ Abs paramBs (superclasses, uMethodType <$> methods)
freeUVars (UInstance bs className params methods _) =
freeUVars $ Abs bs ((className, params), methods)

Expand Down
48 changes: 48 additions & 0 deletions tests/eval-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,51 @@ def f6 (x:Int) : Int = f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ x
-- This will compile extremely slowly if non-inlining is broken
:p f6 0
> 100000

interface AssociatedInt a
value a : Int

instance AssociatedInt Int
value = 2

instance AssociatedInt Float
value = 4

value Int
> 2
value Float
> 4

interface AssociatedWithTwo a c
value2 a c : Int

instance AssociatedWithTwo Int Float
value2 = 8

value2 Int Float
> 8
value2 Float Int
> Type error:Couldn't synthesize a class dictionary for: (AssociatedWithTwo Float32 Int32)
>
> value2 Float Int
> ^^^^^^^^^^^^^^^^

-- TODO: This is a really bad error message
interface BadAssociatedName a
value2 c : Int
> Error: variable not in scope: c

-- Technically the two tests below are not incorrect, but we don't implement them yet.
interface AssociatedSubsetOfParams a c
badValue1 a : Int
> Compiler bug!
> Please report this at github.com/google-research/dex-lang/issues
>
> Permuted or incomplete explicit type binders are not supported yet.

interface AssociatedPermutedParams a c
badValue2 c a : Int
> Compiler bug!
> Please report this at github.com/google-research/dex-lang/issues
>
> Permuted or incomplete explicit type binders are not supported yet.

0 comments on commit 2ef2dcd

Please sign in to comment.