diff --git a/lib/prelude.dx b/lib/prelude.dx index af4f7155c..fa05fe1da 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -676,38 +676,30 @@ def castPtr (ptr: Ptr a) : Ptr b = (MkPtr rawPtr) = ptr MkPtr rawPtr --- Is there a better way to select the right instance for `storageSize`?? -data TypeVehicle a = MkTypeVehicle -def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle - interface Storable a store : Ptr a -> a -> {IO} Unit load : Ptr a -> {IO} a - storageSize_ : TypeVehicle a -> Int - -def storageSize (a:Type) -> (d:Storable a) ?=> : Int = - tv : TypeVehicle a = MkTypeVehicle - storageSize_ tv + storageSize a : Int instance Storable Word8 store = \(MkPtr ptr) x. %ptrStore ptr x load = \(MkPtr ptr) . %ptrLoad ptr - storageSize_ = const 1 + storageSize = 1 instance Storable Int32 store = \(MkPtr ptr) x. %ptrStore (internalCast %Int32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr) - storageSize_ = const 4 + storageSize = 4 instance Storable Float32 store = \(MkPtr ptr) x. %ptrStore (internalCast %Float32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internalCast %Float32Ptr ptr) - storageSize_ = const 4 + storageSize = 4 instance Storable (Ptr a) store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) - storageSize_ = const 8 -- TODO: something more portable? + storageSize = 8 -- TODO: something more portable? -- TODO: Storable instances for other types diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 3524057c2..e33a5c35f 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -417,7 +417,7 @@ tangentFunAsLambda m = do DerivWrt activeVars effs <- ask let hs = map (Bind . (:>TyKind) . effectRegion) effs liftM (PairVal ans) $ lift $ do - buildNestedLam PureArrow hs \hVals -> do + buildNaryLam PureArrow (toNest hs) \hVals -> do let hVarNames = map (\(Var (v:>_)) -> v) hVals -- TODO: handle exception effect too let effs' = zipWith (\(RWSEffect rws _) v -> RWSEffect rws v) effs hVarNames @@ -425,7 +425,7 @@ tangentFunAsLambda m = do let regionMap = newEnv (map ((:>()) . effectRegion) effs) hVals -- TODO: Only bind tangents for free variables? let activeVarBinders = map (Bind . fmap (tangentRefRegion regionMap)) $ envAsVars activeVars - buildNestedLam PureArrow activeVarBinders \activeVarArgs -> + buildNaryLam PureArrow (toNest activeVarBinders) \activeVarArgs -> buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) \_ -> runReaderT tanFun $ TangentEnv (newEnv (envNames activeVars) activeVarArgs) hVarNames diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 56df2e4cf..aa0b7585d 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -18,7 +18,7 @@ module Builder (emit, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildPi, app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, - fpow, flog, fLitLike, recGetHead, buildImplicitNaryLam, buildNaryLam, + fpow, flog, fLitLike, recGetHead, buildNaryLam, select, substBuilder, substBuilderR, emitUnpack, getUnpacked, fromPair, getFst, getSnd, getProj, getProjRef, naryApp, appReduce, appTryReduce, buildAbs, buildAAbs, buildAAbsAux, @@ -48,6 +48,7 @@ import Control.Monad.Reader import Control.Monad.Writer hiding (Alt) import Control.Monad.Identity import Control.Monad.State.Strict +import Data.Functor ((<&>)) import Data.Foldable (toList) import Data.List (elemIndex) import Data.Maybe (fromJust) @@ -261,34 +262,33 @@ emitSuperclass dataDef idx = do getter <- makeSuperclassGetter dataDef idx emitBinding $ SuperclassName dataDef idx getter -emitMethodType :: MonadBuilder m => ClassDefName -> Int -> m Name -emitMethodType classDef idx = do - getter <- makeMethodGetter classDef idx +emitMethodType :: MonadBuilder m => [Bool] -> ClassDefName -> Int -> m Name +emitMethodType explicit classDef idx = do + getter <- makeMethodGetter explicit classDef idx emitBinding $ MethodName classDef idx getter -makeMethodGetter :: MonadBuilder m => ClassDefName -> Int -> m Atom -makeMethodGetter classDefName methodIdx = do +makeMethodGetter :: MonadBuilder m => [Bool] -> ClassDefName -> Int -> m Atom +makeMethodGetter explicit classDefName methodIdx = do ClassDef def@(_, DataDef _ paramBs _) _ <- getClassDef classDefName - buildImplicitNaryLam paramBs \params -> do + let arrows = explicit <&> \case True -> PureArrow; False -> ImplicitArrow + let bs = toNest $ zip (toList paramBs) arrows + buildNestedLam bs \params -> do buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do return $ getProjection [methodIdx] $ getProjection [1, 0] dict makeSuperclassGetter :: MonadBuilder m => DataDefName -> Int -> m Atom makeSuperclassGetter classDefName methodIdx = do ClassDef def@(_, DataDef _ paramBs _) _ <- getClassDef classDefName - buildImplicitNaryLam paramBs \params -> do + buildNaryLam ImplicitArrow paramBs \params -> do buildLam (Bind ("d":> TypeCon def params)) PureArrow \dict -> do return $ getProjection [methodIdx] $ getProjection [0, 0] dict -buildImplicitNaryLam :: MonadBuilder m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom -buildImplicitNaryLam bs body = buildNaryLam ImplicitArrow bs body - buildNaryLam :: MonadBuilder m => Arrow -> (Nest Binder) -> ([Atom] -> m Atom) -> m Atom buildNaryLam _ Empty body = body [] buildNaryLam arr (Nest b bs) body = buildLam b arr \x -> do bs' <- substBuilder (b@>SubstVal x) bs - buildImplicitNaryLam bs' \xs -> body $ x:xs + buildNaryLam arr bs' \xs -> body $ x:xs recGetHead :: Label -> Atom -> Atom recGetHead l x = do @@ -544,10 +544,10 @@ buildNestedFor specs body = go specs [] go [] indices = body $ reverse indices go ((d,b):t) indices = buildFor d b $ \i -> go t (i:indices) -buildNestedLam :: MonadBuilder m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom -buildNestedLam _ [] f = f [] -buildNestedLam arr (b:bs) f = - buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs) +buildNestedLam :: MonadBuilder m => Nest (Binder, Arrow) -> ([Atom] -> m Atom) -> m Atom +buildNestedLam Empty f = f [] +buildNestedLam (Nest (b, arr) t) f = + buildLam b arr \x -> buildNestedLam t \xs -> f (x:xs) tabGet :: MonadBuilder m => Atom -> Atom -> m Atom tabGet tab idx = emit $ App tab idx diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 3f9bc8b14..e4724e80c 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -367,8 +367,19 @@ inferUDecl (UInterface paramBs superclasses methodTys className methodNames) = d paramBs superclasses methodTys className' <- withNameHint className $ emitClassDef classDef mapM_ (emitSuperclass className') [0..(length superclasses - 1)] - methodNames' <- forM (enumerate (toList methodNames)) \(i, name) -> - withNameHint name $ emitMethodType className' i + methodNames' <- forM (enumerate $ zip (toList methodNames) methodTys) \(i, (name, ty)) -> do + paramVs <- forM (toList paramBs) \(UAnnBinder v _) -> case v of + UBind n -> return $ UInternalVar n + _ -> throw CompilerErr "Unexpected interface binder. Please open a bug report!" + explicits <- case uMethodExplicitBs ty of + [] -> return $ replicate (length paramBs) False + e | e == paramVs -> return $ replicate (length paramBs) True + e -> case unexpected of + [] -> throw CompilerErr "Permuted or incomplete explicit type binders are not supported yet." + (h:_) -> throw TypeErr $ "Explicit type binder `" ++ pprint h ++ "` in method " ++ + pprint name ++ " is not a type parameter of its interface" + where unexpected = filter (not . (`elem` paramVs)) e + withNameHint name $ emitMethodType explicits className' i return $ className @> Rename className' <> newEnv methodNames (map Rename methodNames') inferUDecl (UInstance argBinders ~(UInternalVar className) params methods maybeName) = do @@ -391,11 +402,11 @@ inferDataDef (UDataDef (tyConName, paramBs) dataCons) = return $ DataDef tyConName paramBs' dataCons' inferInterfaceDataDef :: SourceName -> [SourceName] -> Nest UAnnBinder - -> [UType] -> [UType] -> UInferM ClassDef + -> [UType] -> [UMethodType] -> UInferM ClassDef inferInterfaceDataDef className methodNames paramBs superclasses methods = do dictDef <- withNestedBinders paramBs \paramBs' -> do superclasses' <- mapM checkUType superclasses - methods' <- mapM checkUType methods + methods' <- mapM checkUType $ uMethodType <$> methods let dictContents = PairTy (ProdTy superclasses') (ProdTy methods') return $ DataDef className paramBs' [DataConDef ("Mk"<>className) (Nest (Ignore dictContents) Empty)] diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index f939740db..01f8aa48a 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -684,11 +684,13 @@ instance Pretty UDecl where align $ p ann <+> p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) pretty (UDataDefDecl (UDataDef bParams dataCons) bTyCon bDataCons) = "data" <+> p bTyCon <+> p bParams - <+> "where" <> nest 2 (hardline <> prettyLines (zip (toList bDataCons) dataCons)) + <+> "where" <> nest 2 (hardline <> prettyLines (zip (toList bDataCons) dataCons)) pretty (UInterface params superclasses methodTys interfaceName methodNames) = - let methods = [UAnnBinder b ty | (b, ty) <- zip (toList methodNames) methodTys] - in "interface" <+> p params <+> p superclasses <+> p interfaceName - <> hardline <> prettyLines methods + "interface" <+> p params <+> p superclasses <+> p interfaceName + <> hardline <> foldMap (<>hardline) methods + where + methods = [hsep (p <$> e) <+> p (UAnnBinder b ty) | + (b, UMethodType e ty) <- zip (toList methodNames) methodTys] pretty (UInstance bs className params methods Nothing) = "instance" <+> p bs <+> p className <+> p params <+> hardline <> prettyLines methods pretty (UInstance bs className params methods (Just v)) = diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index 409fa803e..f3988173d 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -140,7 +140,7 @@ buildParallelBlock ablock@(ABlock decls result) = do unflattenIndexBundle :: MonadBuilder m => [Var] -> Atom -> m Atom unflattenIndexBundle [] arr = return arr unflattenIndexBundle [_] arr = return arr -unflattenIndexBundle ivs arr = buildNestedLam TabArrow (fmap Bind ivs) $ app arr . fst . mkBundle +unflattenIndexBundle ivs arr = buildNaryLam TabArrow (toNest $ fmap Bind ivs) $ app arr . fst . mkBundle type Loop = Abs Binder Block data NestDecision = Emit | Split (Nest Decl) (Binder, Loop) (Nest Decl) diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index bb71e20d6..0537a15d1 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -239,8 +239,9 @@ interfaceDef = do (tyConName, tyConParams) <- tyConDef (methodNames, methodTys) <- unzip <$> onePerLine do v <- anyName + explicit <- many anyName ty <- annot uType - return (fromString v, ty) + return (fromString v, UMethodType (USourceVar <$> explicit) ty) let methodNames' = toNest methodNames let tyConParams' = tyConParams return $ UInterface tyConParams' superclasses methodTys (fromString tyConName) methodNames' @@ -947,7 +948,8 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar keyWordStrs :: [String] keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam", "Read", "Write", "Accum", "Except", "IO", "data", "interface", - "instance", "named-instance", "where", "if", "then", "else", "do", "view", "import"] + "instance", "named-instance", "where", "if", "then", "else", + "do", "view", "import"] fieldLabel :: Lexer Label fieldLabel = label "field label" $ lexeme $ diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index bc7aaa95a..7b704dafa 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -188,6 +188,10 @@ instance SourceRenamableB UDecl where instanceName' <- mapM sourceRenameB instanceName return $ UInstance conditions' className' params' methodDefs' instanceName' +instance SourceRenamableE UMethodType where + sourceRenameE (UMethodType expl ty) = + UMethodType <$> traverse sourceRenameE expl <*> sourceRenameE ty + sourceRenameUBinderNest :: Renamer m => (Name -> SourceNameDef) -> Nest UBinder -> WithEnv RenameEnv m (Nest UBinder) sourceRenameUBinderNest _ Empty = return Empty diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index f4bebba43..26003aaf0 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -42,7 +42,7 @@ module Syntax ( monMapSingle, monMapLookup, Direction (..), Limit (..), SourceName, SourceMap (..), UExpr, UExpr' (..), UType, UPatAnn (..), UAnnBinder (..), UVar (..), UBinder (..), UMethodDef (..), - UMethodTypeDef, UPatAnnArrow (..), UVars, + UMethodType (..), UPatAnnArrow (..), UVars, UPat, UPat' (..), SourceUModule (..), SourceNameDef (..), sourceNameDefName, UModule (..), UDecl (..), UDataDef (..), UArrow, arrowEff, UEffect, UEffectRow, UEffArrow, @@ -283,7 +283,7 @@ data UDecl = | UInterface (Nest UAnnBinder) -- parameter binders [UType] -- superclasses - [UType] -- method types + [UMethodType] -- method types UBinder -- class name (Nest UBinder) -- method names | UInstance @@ -299,7 +299,8 @@ type UEffect = EffectP UVar type UEffectRow = EffectRowP UVar type UEffArrow = ArrowP UEffectRow -type UMethodTypeDef = (UBinder, UType) +data UMethodType = UMethodType { uMethodExplicitBs :: [UVar], uMethodType :: UType } + deriving (Show, Generic) data UMethodDef = UMethodDef UVar UExpr deriving (Show, Generic) data UPatAnn = UPatAnn UPat (Maybe UType) deriving (Show, Generic) @@ -823,7 +824,7 @@ instance HasUVars UDecl where freeUVars (UDataDefDecl dataDef bTyCon bDataCons) = freeUVars dataDef <> freeUVars (Abs bTyCon bDataCons) freeUVars (UInterface paramBs superclasses methods _ _) = - freeUVars $ Abs paramBs (superclasses, methods) + freeUVars $ Abs paramBs (superclasses, uMethodType <$> methods) freeUVars (UInstance bs className params methods _) = freeUVars $ Abs bs ((className, params), methods) diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index 2cc7ede26..641791998 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -880,3 +880,51 @@ def f6 (x:Int) : Int = f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ x -- This will compile extremely slowly if non-inlining is broken :p f6 0 > 100000 + +interface AssociatedInt a + value a : Int + +instance AssociatedInt Int + value = 2 + +instance AssociatedInt Float + value = 4 + +value Int +> 2 +value Float +> 4 + +interface AssociatedWithTwo a c + value2 a c : Int + +instance AssociatedWithTwo Int Float + value2 = 8 + +value2 Int Float +> 8 +value2 Float Int +> Type error:Couldn't synthesize a class dictionary for: (AssociatedWithTwo Float32 Int32) +> +> value2 Float Int +> ^^^^^^^^^^^^^^^^ + +-- TODO: This is a really bad error message +interface BadAssociatedName a + value2 c : Int +> Error: variable not in scope: c + +-- Technically the two tests below are not incorrect, but we don't implement them yet. +interface AssociatedSubsetOfParams a c + badValue1 a : Int +> Compiler bug! +> Please report this at github.com/google-research/dex-lang/issues +> +> Permuted or incomplete explicit type binders are not supported yet. + +interface AssociatedPermutedParams a c + badValue2 c a : Int +> Compiler bug! +> Please report this at github.com/google-research/dex-lang/issues +> +> Permuted or incomplete explicit type binders are not supported yet.