diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py index 868dc26c6cb9..8dddede2d9cc 100644 --- a/python/torch_mlir/extras/fx_decomp_util.py +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -49,6 +49,7 @@ torch.ops.aten.nan_to_num.default, torch.ops.aten.unbind, torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten.diag, ] diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 7c7198ef6f61..699d57cb2b0d 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -12,6 +12,7 @@ import torch.nn as nn import numpy as np +from torch_mlir.extras.fx_decomp_util import get_decomposition_table from torch_mlir.extras.fx_importer import FxImporter from torch_mlir.extras.fx_importer import SparsityMeta from torch_mlir import ir @@ -106,6 +107,9 @@ def sparse_export( # Build the regular FX traced graph with only dense arguments # (the current version would crash otherwise, see issue above). prog = torch.export.export(f, dargs, kwargs) + decomposition_table = get_decomposition_table() + if decomposition_table: + prog = prog.run_decompositions(decomposition_table) # Annotate sparse arguments in the graph and apply some very # basic propagation rules for sparsity. specs = prog.graph_signature.input_specs @@ -120,7 +124,6 @@ def sparse_export( node.meta["sparsity"] = sparse_metadata(args[k]) k = k + 1 elif node.op == "call_function": - # TODO: use upstream _opname implementation when available opname = node.target._schema.name.split("::")[1] # Zero preserving elt-wise unary op. if opname in {"abs", "neg", "relu", "sin"}: @@ -131,7 +134,7 @@ def sparse_export( torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 ) # TODO: Uncomment this to hack sparsity into the network. - # elif opname == "_to_dense": + # elif opname == "_to_dense" or opname == "to_dense": # # hack (assumes we never really want the to_dense for now) # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) elif opname == "select" and node.args[0].meta.get("sparsity", None): @@ -176,8 +179,8 @@ def sparse_jit(f, *args, **kwargs): compiled = backend.compile(module) invoker = backend.load(compiled) xargs = [] - # Prepare the buffer parameters (assume all dense). - # TODO: filters out scalar arguments, anything else? + # Prepare all the named buffer parameters (assume all dense). + # All scalar arguments are filtered out since they appear inline. params = dict(f.named_buffers(remove_duplicate=True)) params_flat, params_spec = torch.utils._pytree.tree_flatten(params) for p in params_flat: @@ -339,6 +342,7 @@ def forward(self, x, v): @run # +# CHECK-LABEL: test_sparse_SpMM # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, @@ -440,7 +444,7 @@ def forward(self, x): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) - # TODO: make this work + # TODO: make this work in MLIR # res3 = sparse_jit(net, batch_input) print("torch.sparse") print(res1) @@ -657,7 +661,14 @@ def forward(self, X): # CHECK: [0.1321, 0.2724, 0.2105, 0.3851], # CHECK: [0.2478, 0.3439, 0.1898, 0.2185], # CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) +# +# TODO: first row looks suspect... +# # CHECK: torch.mlir +# CHECK: {{\[}}[0. 0. 0. 0. ] +# CHECK: [0.13205223 0.27236593 0.21051763 0.38506418] +# CHECK: [0.24781987 0.34391665 0.18976606 0.2184974 ] +# CHECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}} # def test_sparse_feature_scaling(): class Scale(nn.Module): @@ -678,11 +689,11 @@ def forward(self, F): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(f) - # TODO: make this work - # res2 = sparse_jit(net, f) + res2 = sparse_jit(net, f) print("torch.sparse") print(res1) print("torch.mlir") + print(res2) @run