From 2ef2dcd8d81b8f9976a62c05b461cab5e116ef56 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 1 Oct 2021 11:52:39 +0000 Subject: [PATCH] Add explicit interface methods It is often useful to define type -> value mappings, and the standard way to do that is through interfaces (type classes). However, because all methods so far had implicit type parameters, any attempt to associate e.g. an integer with a type was difficult: without explicit type annotations it often ended up being ambiguous. This patch allows specifying interface parameters between each method name and the colon that begins its type annotations, with the mentioned parameters becoming explicit type parameters of the generated methods. This lets us remove the awkward `TypeVehicle` abstraction from prelude and in the future should make it possible to define associated types. For example, a `Manifold a` interface could have a method declared as `TangentSpace a : Type`, which would later make it possible to mention `TangentSpace MyDataType` in types. --- lib/prelude.dx | 18 +++++----------- src/lib/Autodiff.hs | 4 ++-- src/lib/Builder.hs | 32 +++++++++++++-------------- src/lib/Inference.hs | 19 ++++++++++++---- src/lib/PPrint.hs | 10 +++++---- src/lib/Parallelize.hs | 2 +- src/lib/Parser.hs | 6 ++++-- src/lib/SourceRename.hs | 4 ++++ src/lib/Syntax.hs | 9 ++++---- tests/eval-tests.dx | 48 +++++++++++++++++++++++++++++++++++++++++ 10 files changed, 106 insertions(+), 46 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index af4f7155c..fa05fe1da 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -676,38 +676,30 @@ def castPtr (ptr: Ptr a) : Ptr b = (MkPtr rawPtr) = ptr MkPtr rawPtr --- Is there a better way to select the right instance for `storageSize`?? -data TypeVehicle a = MkTypeVehicle -def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle - interface Storable a store : Ptr a -> a -> {IO} Unit load : Ptr a -> {IO} a - storageSize_ : TypeVehicle a -> Int - -def storageSize (a:Type) -> (d:Storable a) ?=> : Int = - tv : TypeVehicle a = MkTypeVehicle - storageSize_ tv + storageSize a : Int instance Storable Word8 store = \(MkPtr ptr) x. %ptrStore ptr x load = \(MkPtr ptr) . %ptrLoad ptr - storageSize_ = const 1 + storageSize = 1 instance Storable Int32 store = \(MkPtr ptr) x. %ptrStore (internalCast %Int32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr) - storageSize_ = const 4 + storageSize = 4 instance Storable Float32 store = \(MkPtr ptr) x. %ptrStore (internalCast %Float32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internalCast %Float32Ptr ptr) - storageSize_ = const 4 + storageSize = 4 instance Storable (Ptr a) store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) - storageSize_ = const 8 -- TODO: something more portable? + storageSize = 8 -- TODO: something more portable? -- TODO: Storable instances for other types diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 3524057c2..e33a5c35f 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -417,7 +417,7 @@ tangentFunAsLambda m = do DerivWrt activeVars effs <- ask let hs = map (Bind . (:>TyKind) . effectRegion) effs liftM (PairVal ans) $ lift $ do - buildNestedLam PureArrow hs \hVals -> do + buildNaryLam PureArrow (toNest hs) \hVals -> do let hVarNames = map (\(Var (v:>_)) -> v) hVals -- TODO: handle exception effect too let effs' = zipWith (\(RWSEffect rws _) v -> RWSEffect rws v) effs hVarNames @@ -425,7 +425,7 @@ tangentFunAsLambda m = do let regionMap = newEnv (map ((:>()) . effectRegion) effs) hVals -- TODO: Only bind tangents for free variables? let activeVarBinders = map (Bind . fmap (tangentRefRegion regionMap)) $ envAsVars activeVars - buildNestedLam PureArrow activeVarBinders \activeVarArgs -> + buildNaryLam PureArrow (toNest activeVarBinders) \activeVarArgs -> buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) \_ -> runReaderT tanFun $ TangentEnv (newEnv (envNames activeVars) activeVarArgs) hVarNames diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 56df2e4cf..aa0b7585d 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -18,7 +18,7 @@ module Builder (emit, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildPi, app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, - fpow, flog, fLitLike, recGetHead, buildImplicitNaryLam, buildNaryLam, + fpow, flog, fLitLike, recGetHead, buildNaryLam, select, substBuilder, substBuilderR, emitUnpack, getUnpacked, fromPair, getFst, getSnd, getProj, getProjRef, naryApp, appReduce, appTryReduce, buildAbs, buildAAbs, buildAAbsAux, @@ -48,6 +48,7 @@ import Control.Monad.Reader import Control.Monad.Writer hiding (Alt) import Control.Monad.Identity import Control.Monad.State.Strict +import Data.Functor ((<&>)) import Data.Foldable (toList) import Data.List (elemIndex) import Data.Maybe (fromJust) @@ -261,34 +262,33 @@ emitSuperclass dataDef idx = do getter <- makeSuperclassGetter dataDef idx emitBinding $ SuperclassName dataDef idx getter -emitMethodType :: MonadBuilder m => ClassDefName -> Int -> m Name -emitMethodType classDef idx = do - getter <- makeMethodGetter classDef idx +emitMethodType :: MonadBuilder m => [Bool] -> ClassDefName -> Int -> m Name +emitMethodType explicit classDef idx = do + getter <- makeMethodGetter explicit classDef idx emitBinding $ MethodName classDef idx getter -makeMethodGetter :: MonadBuilder m => ClassDefName -> Int -> m Atom -makeMethodGetter classDefName methodIdx = do +makeMethodGetter :: MonadBuilder m => [Bool] -> ClassDefName -> Int -> m Atom +makeMethodGetter explicit classDefName methodIdx = do ClassDef def@(_, DataDef _ paramBs _) _ <- getClassDef classDefName - buildImplicitNaryLam paramBs \params -> do + let arrows = explicit <&> \case True -> PureArrow; False -> ImplicitArrow + let bs = toNest $ zip (toList paramBs) arrows + buildNestedLam bs \params -> do buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do return $ getProjection [methodIdx] $ getProjection [1, 0] dict makeSuperclassGetter :: MonadBuilder m => DataDefName -> Int -> m Atom makeSuperclassGetter classDefName methodIdx = do ClassDef def@(_, DataDef _ paramBs _) _ <- getClassDef classDefName - buildImplicitNaryLam paramBs \params -> do + buildNaryLam ImplicitArrow paramBs \params -> do buildLam (Bind ("d":> TypeCon def params)) PureArrow \dict -> do return $ getProjection [methodIdx] $ getProjection [0, 0] dict -buildImplicitNaryLam :: MonadBuilder m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom -buildImplicitNaryLam bs body = buildNaryLam ImplicitArrow bs body - buildNaryLam :: MonadBuilder m => Arrow -> (Nest Binder) -> ([Atom] -> m Atom) -> m Atom buildNaryLam _ Empty body = body [] buildNaryLam arr (Nest b bs) body = buildLam b arr \x -> do bs' <- substBuilder (b@>SubstVal x) bs - buildImplicitNaryLam bs' \xs -> body $ x:xs + buildNaryLam arr bs' \xs -> body $ x:xs recGetHead :: Label -> Atom -> Atom recGetHead l x = do @@ -544,10 +544,10 @@ buildNestedFor specs body = go specs [] go [] indices = body $ reverse indices go ((d,b):t) indices = buildFor d b $ \i -> go t (i:indices) -buildNestedLam :: MonadBuilder m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom -buildNestedLam _ [] f = f [] -buildNestedLam arr (b:bs) f = - buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs) +buildNestedLam :: MonadBuilder m => Nest (Binder, Arrow) -> ([Atom] -> m Atom) -> m Atom +buildNestedLam Empty f = f [] +buildNestedLam (Nest (b, arr) t) f = + buildLam b arr \x -> buildNestedLam t \xs -> f (x:xs) tabGet :: MonadBuilder m => Atom -> Atom -> m Atom tabGet tab idx = emit $ App tab idx diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 3f9bc8b14..e4724e80c 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -367,8 +367,19 @@ inferUDecl (UInterface paramBs superclasses methodTys className methodNames) = d paramBs superclasses methodTys className' <- withNameHint className $ emitClassDef classDef mapM_ (emitSuperclass className') [0..(length superclasses - 1)] - methodNames' <- forM (enumerate (toList methodNames)) \(i, name) -> - withNameHint name $ emitMethodType className' i + methodNames' <- forM (enumerate $ zip (toList methodNames) methodTys) \(i, (name, ty)) -> do + paramVs <- forM (toList paramBs) \(UAnnBinder v _) -> case v of + UBind n -> return $ UInternalVar n + _ -> throw CompilerErr "Unexpected interface binder. Please open a bug report!" + explicits <- case uMethodExplicitBs ty of + [] -> return $ replicate (length paramBs) False + e | e == paramVs -> return $ replicate (length paramBs) True + e -> case unexpected of + [] -> throw CompilerErr "Permuted or incomplete explicit type binders are not supported yet." + (h:_) -> throw TypeErr $ "Explicit type binder `" ++ pprint h ++ "` in method " ++ + pprint name ++ " is not a type parameter of its interface" + where unexpected = filter (not . (`elem` paramVs)) e + withNameHint name $ emitMethodType explicits className' i return $ className @> Rename className' <> newEnv methodNames (map Rename methodNames') inferUDecl (UInstance argBinders ~(UInternalVar className) params methods maybeName) = do @@ -391,11 +402,11 @@ inferDataDef (UDataDef (tyConName, paramBs) dataCons) = return $ DataDef tyConName paramBs' dataCons' inferInterfaceDataDef :: SourceName -> [SourceName] -> Nest UAnnBinder - -> [UType] -> [UType] -> UInferM ClassDef + -> [UType] -> [UMethodType] -> UInferM ClassDef inferInterfaceDataDef className methodNames paramBs superclasses methods = do dictDef <- withNestedBinders paramBs \paramBs' -> do superclasses' <- mapM checkUType superclasses - methods' <- mapM checkUType methods + methods' <- mapM checkUType $ uMethodType <$> methods let dictContents = PairTy (ProdTy superclasses') (ProdTy methods') return $ DataDef className paramBs' [DataConDef ("Mk"<>className) (Nest (Ignore dictContents) Empty)] diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index f939740db..01f8aa48a 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -684,11 +684,13 @@ instance Pretty UDecl where align $ p ann <+> p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) pretty (UDataDefDecl (UDataDef bParams dataCons) bTyCon bDataCons) = "data" <+> p bTyCon <+> p bParams - <+> "where" <> nest 2 (hardline <> prettyLines (zip (toList bDataCons) dataCons)) + <+> "where" <> nest 2 (hardline <> prettyLines (zip (toList bDataCons) dataCons)) pretty (UInterface params superclasses methodTys interfaceName methodNames) = - let methods = [UAnnBinder b ty | (b, ty) <- zip (toList methodNames) methodTys] - in "interface" <+> p params <+> p superclasses <+> p interfaceName - <> hardline <> prettyLines methods + "interface" <+> p params <+> p superclasses <+> p interfaceName + <> hardline <> foldMap (<>hardline) methods + where + methods = [hsep (p <$> e) <+> p (UAnnBinder b ty) | + (b, UMethodType e ty) <- zip (toList methodNames) methodTys] pretty (UInstance bs className params methods Nothing) = "instance" <+> p bs <+> p className <+> p params <+> hardline <> prettyLines methods pretty (UInstance bs className params methods (Just v)) = diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index 409fa803e..f3988173d 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -140,7 +140,7 @@ buildParallelBlock ablock@(ABlock decls result) = do unflattenIndexBundle :: MonadBuilder m => [Var] -> Atom -> m Atom unflattenIndexBundle [] arr = return arr unflattenIndexBundle [_] arr = return arr -unflattenIndexBundle ivs arr = buildNestedLam TabArrow (fmap Bind ivs) $ app arr . fst . mkBundle +unflattenIndexBundle ivs arr = buildNaryLam TabArrow (toNest $ fmap Bind ivs) $ app arr . fst . mkBundle type Loop = Abs Binder Block data NestDecision = Emit | Split (Nest Decl) (Binder, Loop) (Nest Decl) diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index bb71e20d6..0537a15d1 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -239,8 +239,9 @@ interfaceDef = do (tyConName, tyConParams) <- tyConDef (methodNames, methodTys) <- unzip <$> onePerLine do v <- anyName + explicit <- many anyName ty <- annot uType - return (fromString v, ty) + return (fromString v, UMethodType (USourceVar <$> explicit) ty) let methodNames' = toNest methodNames let tyConParams' = tyConParams return $ UInterface tyConParams' superclasses methodTys (fromString tyConName) methodNames' @@ -947,7 +948,8 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar keyWordStrs :: [String] keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam", "Read", "Write", "Accum", "Except", "IO", "data", "interface", - "instance", "named-instance", "where", "if", "then", "else", "do", "view", "import"] + "instance", "named-instance", "where", "if", "then", "else", + "do", "view", "import"] fieldLabel :: Lexer Label fieldLabel = label "field label" $ lexeme $ diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index bc7aaa95a..7b704dafa 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -188,6 +188,10 @@ instance SourceRenamableB UDecl where instanceName' <- mapM sourceRenameB instanceName return $ UInstance conditions' className' params' methodDefs' instanceName' +instance SourceRenamableE UMethodType where + sourceRenameE (UMethodType expl ty) = + UMethodType <$> traverse sourceRenameE expl <*> sourceRenameE ty + sourceRenameUBinderNest :: Renamer m => (Name -> SourceNameDef) -> Nest UBinder -> WithEnv RenameEnv m (Nest UBinder) sourceRenameUBinderNest _ Empty = return Empty diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index f4bebba43..26003aaf0 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -42,7 +42,7 @@ module Syntax ( monMapSingle, monMapLookup, Direction (..), Limit (..), SourceName, SourceMap (..), UExpr, UExpr' (..), UType, UPatAnn (..), UAnnBinder (..), UVar (..), UBinder (..), UMethodDef (..), - UMethodTypeDef, UPatAnnArrow (..), UVars, + UMethodType (..), UPatAnnArrow (..), UVars, UPat, UPat' (..), SourceUModule (..), SourceNameDef (..), sourceNameDefName, UModule (..), UDecl (..), UDataDef (..), UArrow, arrowEff, UEffect, UEffectRow, UEffArrow, @@ -283,7 +283,7 @@ data UDecl = | UInterface (Nest UAnnBinder) -- parameter binders [UType] -- superclasses - [UType] -- method types + [UMethodType] -- method types UBinder -- class name (Nest UBinder) -- method names | UInstance @@ -299,7 +299,8 @@ type UEffect = EffectP UVar type UEffectRow = EffectRowP UVar type UEffArrow = ArrowP UEffectRow -type UMethodTypeDef = (UBinder, UType) +data UMethodType = UMethodType { uMethodExplicitBs :: [UVar], uMethodType :: UType } + deriving (Show, Generic) data UMethodDef = UMethodDef UVar UExpr deriving (Show, Generic) data UPatAnn = UPatAnn UPat (Maybe UType) deriving (Show, Generic) @@ -823,7 +824,7 @@ instance HasUVars UDecl where freeUVars (UDataDefDecl dataDef bTyCon bDataCons) = freeUVars dataDef <> freeUVars (Abs bTyCon bDataCons) freeUVars (UInterface paramBs superclasses methods _ _) = - freeUVars $ Abs paramBs (superclasses, methods) + freeUVars $ Abs paramBs (superclasses, uMethodType <$> methods) freeUVars (UInstance bs className params methods _) = freeUVars $ Abs bs ((className, params), methods) diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index 2cc7ede26..641791998 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -880,3 +880,51 @@ def f6 (x:Int) : Int = f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ x -- This will compile extremely slowly if non-inlining is broken :p f6 0 > 100000 + +interface AssociatedInt a + value a : Int + +instance AssociatedInt Int + value = 2 + +instance AssociatedInt Float + value = 4 + +value Int +> 2 +value Float +> 4 + +interface AssociatedWithTwo a c + value2 a c : Int + +instance AssociatedWithTwo Int Float + value2 = 8 + +value2 Int Float +> 8 +value2 Float Int +> Type error:Couldn't synthesize a class dictionary for: (AssociatedWithTwo Float32 Int32) +> +> value2 Float Int +> ^^^^^^^^^^^^^^^^ + +-- TODO: This is a really bad error message +interface BadAssociatedName a + value2 c : Int +> Error: variable not in scope: c + +-- Technically the two tests below are not incorrect, but we don't implement them yet. +interface AssociatedSubsetOfParams a c + badValue1 a : Int +> Compiler bug! +> Please report this at github.com/google-research/dex-lang/issues +> +> Permuted or incomplete explicit type binders are not supported yet. + +interface AssociatedPermutedParams a c + badValue2 c a : Int +> Compiler bug! +> Please report this at github.com/google-research/dex-lang/issues +> +> Permuted or incomplete explicit type binders are not supported yet.