Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/torch_mlir/extras/fx_decomp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
torch.ops.aten.nan_to_num.default,
torch.ops.aten.unbind,
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten.diag,
Comment thread
aartbik marked this conversation as resolved.
]


Expand Down
25 changes: 18 additions & 7 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn as nn
import numpy as np

from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir.extras.fx_importer import SparsityMeta
from torch_mlir import ir
Expand Down Expand Up @@ -106,6 +107,9 @@ def sparse_export(
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
# Annotate sparse arguments in the graph and apply some very
# basic propagation rules for sparsity.
specs = prog.graph_signature.input_specs
Expand All @@ -120,7 +124,6 @@ def sparse_export(
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
# TODO: use upstream _opname implementation when available
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
if opname in {"abs", "neg", "relu", "sin"}:
Expand All @@ -131,7 +134,7 @@ def sparse_export(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
# TODO: Uncomment this to hack sparsity into the network.
# elif opname == "_to_dense":
# elif opname == "_to_dense" or opname == "to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "select" and node.args[0].meta.get("sparsity", None):
Expand Down Expand Up @@ -176,8 +179,8 @@ def sparse_jit(f, *args, **kwargs):
compiled = backend.compile(module)
invoker = backend.load(compiled)
xargs = []
# Prepare the buffer parameters (assume all dense).
# TODO: filters out scalar arguments, anything else?
# Prepare all the named buffer parameters (assume all dense).
# All scalar arguments are filtered out since they appear inline.
params = dict(f.named_buffers(remove_duplicate=True))
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
for p in params_flat:
Expand Down Expand Up @@ -339,6 +342,7 @@ def forward(self, x, v):

@run
#
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
Expand Down Expand Up @@ -440,7 +444,7 @@ def forward(self, x):
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
# TODO: make this work
# TODO: make this work in MLIR
# res3 = sparse_jit(net, batch_input)
print("torch.sparse")
print(res1)
Expand Down Expand Up @@ -657,7 +661,14 @@ def forward(self, X):
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
# CHECK: [0.2478, 0.3439, 0.1898, 0.2185],
# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
#
# TODO: first row looks suspect...
#
# CHECK: torch.mlir
# CHECK: {{\[}}[0. 0. 0. 0. ]
# CHECK: [0.13205223 0.27236593 0.21051763 0.38506418]
# CHECK: [0.24781987 0.34391665 0.18976606 0.2184974 ]
# CHECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}}
#
def test_sparse_feature_scaling():
class Scale(nn.Module):
Expand All @@ -678,11 +689,11 @@ def forward(self, F):

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(f)
# TODO: make this work
# res2 = sparse_jit(net, f)
res2 = sparse_jit(net, f)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)


@run
Expand Down