Skip to content

Commit

Permalink
Add explicit interface methods
Browse files Browse the repository at this point in the history
It is often useful to define type -> value mappings, and the standard
way to do that is through interfaces (type classes). However, because
all methods so far had implicit type parameters, any attempt to
associate e.g. an integer with a type was difficult: without explicit
type annotations it often ended up being ambiguous.

This patch adds a new keyword `explicit` that makes it possible to
require that all types parameterizing a type class are to be taken as
explicit arguments of a given method. This lets us remove the awkward
`TypeVehicle` abstraction from prelude and in the future should make it
possible to define associated types. For example, a `Manifold` interface
could have a method declared as `explicit TangentSpace : Type`, which
would later make it possible to mention e.g. `TangentSpace MyDataType`
in types.
  • Loading branch information
apaszke committed Oct 1, 2021
1 parent b80d02d commit a4a2a08
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 29 deletions.
18 changes: 5 additions & 13 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -676,38 +676,30 @@ def castPtr (ptr: Ptr a) : Ptr b =
(MkPtr rawPtr) = ptr
MkPtr rawPtr

-- Is there a better way to select the right instance for `storageSize`??
data TypeVehicle a = MkTypeVehicle
def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle

interface Storable a
store : Ptr a -> a -> {IO} Unit
load : Ptr a -> {IO} a
storageSize_ : TypeVehicle a -> Int

def storageSize (a:Type) -> (d:Storable a) ?=> : Int =
tv : TypeVehicle a = MkTypeVehicle
storageSize_ tv
explicit storageSize : Int

instance Storable Word8
store = \(MkPtr ptr) x. %ptrStore ptr x
load = \(MkPtr ptr) . %ptrLoad ptr
storageSize_ = const 1
storageSize = 1

instance Storable Int32
store = \(MkPtr ptr) x. %ptrStore (internalCast %Int32Ptr ptr) x
load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr)
storageSize_ = const 4
storageSize = 4

instance Storable Float32
store = \(MkPtr ptr) x. %ptrStore (internalCast %Float32Ptr ptr) x
load = \(MkPtr ptr) . %ptrLoad (internalCast %Float32Ptr ptr)
storageSize_ = const 4
storageSize = 4

instance Storable (Ptr a)
store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x
load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr)
storageSize_ = const 8 -- TODO: something more portable?
storageSize = 8 -- TODO: something more portable?

-- TODO: Storable instances for other types

Expand Down
13 changes: 7 additions & 6 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,16 @@ emitSuperclass dataDef idx = do
getter <- makeSuperclassGetter dataDef idx
emitBinding $ SuperclassName dataDef idx getter

emitMethodType :: MonadBuilder m => ClassDefName -> Int -> m Name
emitMethodType classDef idx = do
getter <- makeMethodGetter classDef idx
emitMethodType :: MonadBuilder m => MethodKind -> ClassDefName -> Int -> m Name
emitMethodType kind classDef idx = do
getter <- makeMethodGetter kind classDef idx
emitBinding $ MethodName classDef idx getter

makeMethodGetter :: MonadBuilder m => ClassDefName -> Int -> m Atom
makeMethodGetter classDefName methodIdx = do
makeMethodGetter :: MonadBuilder m => MethodKind -> ClassDefName -> Int -> m Atom
makeMethodGetter kind classDefName methodIdx = do
ClassDef def@(_, DataDef _ paramBs _) _ <- getClassDef classDefName
buildImplicitNaryLam paramBs \params -> do
let arrow = case kind of ImplicitMethod -> ImplicitArrow; ExplicitMethod -> PureArrow
buildNaryLam arrow paramBs \params -> do
buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do
return $ getProjection [methodIdx] $ getProjection [1, 0] dict

Expand Down
8 changes: 4 additions & 4 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ inferUDecl (UInterface paramBs superclasses methodTys className methodNames) = d
paramBs superclasses methodTys
className' <- withNameHint className $ emitClassDef classDef
mapM_ (emitSuperclass className') [0..(length superclasses - 1)]
methodNames' <- forM (enumerate (toList methodNames)) \(i, name) ->
withNameHint name $ emitMethodType className' i
methodNames' <- forM (enumerate $ zip (toList methodNames) (fmap fst methodTys)) \(i, (name, k)) -> do
withNameHint name $ emitMethodType k className' i
return $ className @> Rename className'
<> newEnv methodNames (map Rename methodNames')
inferUDecl (UInstance argBinders ~(UInternalVar className) params methods maybeName) = do
Expand All @@ -391,11 +391,11 @@ inferDataDef (UDataDef (tyConName, paramBs) dataCons) =
return $ DataDef tyConName paramBs' dataCons'

inferInterfaceDataDef :: SourceName -> [SourceName] -> Nest UAnnBinder
-> [UType] -> [UType] -> UInferM ClassDef
-> [UType] -> [UMethodType] -> UInferM ClassDef
inferInterfaceDataDef className methodNames paramBs superclasses methods = do
dictDef <- withNestedBinders paramBs \paramBs' -> do
superclasses' <- mapM checkUType superclasses
methods' <- mapM checkUType methods
methods' <- mapM checkUType $ snd <$> methods
let dictContents = PairTy (ProdTy superclasses') (ProdTy methods')
return $ DataDef className paramBs'
[DataConDef ("Mk"<>className) (Nest (Ignore dictContents) Empty)]
Expand Down
8 changes: 6 additions & 2 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -686,15 +686,19 @@ instance Pretty UDecl where
"data" <+> p bTyCon <+> p bParams
<+> "where" <> nest 2 (hardline <> prettyLines (zip (toList bDataCons) dataCons))
pretty (UInterface params superclasses methodTys interfaceName methodNames) =
let methods = [UAnnBinder b ty | (b, ty) <- zip (toList methodNames) methodTys]
let methods = [p k <+> p (UAnnBinder b ty) | (b, (k, ty)) <- zip (toList methodNames) methodTys]
in "interface" <+> p params <+> p superclasses <+> p interfaceName
<> hardline <> prettyLines methods
<> hardline <> foldMap (<>hardline) methods
pretty (UInstance bs className params methods Nothing) =
"instance" <+> p bs <+> p className <+> p params <+> hardline <> prettyLines methods
pretty (UInstance bs className params methods (Just v)) =
"named-instance" <+> p v <+> ":" <+> p bs <+> p className <+> p params
<> hardline <> prettyLines methods

instance Pretty MethodKind where
pretty ImplicitMethod = ""
pretty ExplicitMethod = "explicit"

instance Pretty UPatAnnArrow where
pretty (UPatAnnArrow b arr) = p b <> ":" <> p arr

Expand Down
8 changes: 6 additions & 2 deletions src/lib/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,10 @@ interfaceDef = do
superclasses <- superclassConstraints
(tyConName, tyConParams) <- tyConDef
(methodNames, methodTys) <- unzip <$> onePerLine do
k <- (keyWord ExplicitKW $> ExplicitMethod) <|> pure ImplicitMethod
v <- anyName
ty <- annot uType
return (fromString v, ty)
return (fromString v, (k, ty))
let methodNames' = toNest methodNames
let tyConParams' = tyConParams
return $ UInterface tyConParams' superclasses methodTys (fromString tyConName) methodNames'
Expand Down Expand Up @@ -895,6 +896,7 @@ data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW
| ReadKW | WriteKW | StateKW | DataKW | InterfaceKW
| InstanceKW | WhereKW | IfKW | ThenKW | ElseKW | DoKW
| ExceptKW | IOKW | ViewKW | ImportKW | NamedInstanceKW
| ExplicitKW

upperName :: Lexer SourceName
upperName = label "upper-case name" $ lexeme $
Expand Down Expand Up @@ -943,11 +945,13 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar
DoKW -> "do"
ViewKW -> "view"
ImportKW -> "import"
ExplicitKW -> "explicit"

keyWordStrs :: [String]
keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam",
"Read", "Write", "Accum", "Except", "IO", "data", "interface",
"instance", "named-instance", "where", "if", "then", "else", "do", "view", "import"]
"instance", "named-instance", "where", "if", "then", "else",
"do", "view", "import", "explicit"]

fieldLabel :: Lexer Label
fieldLabel = label "field label" $ lexeme $
Expand Down
3 changes: 3 additions & 0 deletions src/lib/SourceRename.hs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ instance SourceRenamableB UDecl where
instanceName' <- mapM sourceRenameB instanceName
return $ UInstance conditions' className' params' methodDefs' instanceName'

instance SourceRenamableE MethodKind where
sourceRenameE = return

sourceRenameUBinderNest :: Renamer m => (Name -> SourceNameDef)
-> Nest UBinder -> WithEnv RenameEnv m (Nest UBinder)
sourceRenameUBinderNest _ Empty = return Empty
Expand Down
10 changes: 8 additions & 2 deletions src/lib/Syntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ module Syntax (
monMapSingle, monMapLookup, Direction (..), Limit (..),
SourceName, SourceMap (..), UExpr, UExpr' (..), UType, UPatAnn (..),
UAnnBinder (..), UVar (..), UBinder (..), UMethodDef (..),
UMethodTypeDef, UPatAnnArrow (..), UVars,
MethodKind (..), UMethodType, UMethodTypeDef, UPatAnnArrow (..), UVars,
UPat, UPat' (..), SourceUModule (..), SourceNameDef (..), sourceNameDefName,
UModule (..), UDecl (..), UDataDef (..), UArrow, arrowEff,
UEffect, UEffectRow, UEffArrow,
Expand Down Expand Up @@ -274,6 +274,9 @@ data UDataDef = UDataDef
[(SourceName, Nest UAnnBinder)] -- data constructor types
deriving (Show, Generic)

data MethodKind = ExplicitMethod | ImplicitMethod
deriving (Show, Generic)
type UMethodType = (MethodKind, UType) -- NOTE: Pairs are functors in second component
data UDecl =
ULet LetAnn UPatAnn UExpr
| UDataDefDecl
Expand All @@ -283,7 +286,7 @@ data UDecl =
| UInterface
(Nest UAnnBinder) -- parameter binders
[UType] -- superclasses
[UType] -- method types
[UMethodType] -- method types
UBinder -- class name
(Nest UBinder) -- method names
| UInstance
Expand Down Expand Up @@ -827,6 +830,9 @@ instance HasUVars UDecl where
freeUVars (UInstance bs className params methods _) =
freeUVars $ Abs bs ((className, params), methods)

instance HasUVars MethodKind where
freeUVars _ = mempty

instance (BindsUVars b1, HasUVars b1, HasUVars b2) => HasUVars (NestPair b1 b2) where
freeUVars (NestPair b1 b2) =
freeUVars b1 <> (freeUVars b2 `envDiff` boundUVars b1)
Expand Down

0 comments on commit a4a2a08

Please sign in to comment.