Skip to content

Commit facbe5d

Browse files
authored
[Torch Dialect] support AtenArangeStartOutOp in ReduceOpVariants like… (#2563)
… AtenBernoulli_FloatOp It fixing case like: `%2110 = torch.aten.arange.start_out %int1, %int1517, %int1, %2109 : !torch.int, !torch.int, !torch.int, !torch.tensor -> !torch.tensor`. `aten.arange.start_out` doesn't have value semantics also, means`%2110` is an alias for %2109. So I decompose it to `aten.arange.start` + `torch.contents.overwrite`. The complex decomposition logic is target to handle cases like view and dtype cast which I add in e2e tests.
1 parent dad1f01 commit facbe5d

File tree

3 files changed

+110
-8
lines changed

3 files changed

+110
-8
lines changed

lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,26 +191,71 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
191191
// Reduce Ops without value semantics but the corresponding without trailing
192192
// underscore variant doesn't exist.
193193
namespace {
194+
195+
// int(ceil((end - start) / step))
196+
Value calculateArangeResultNumElements(PatternRewriter &rewriter, Location loc,
197+
Value start, Value end, Value step) {
198+
Value sub = rewriter.create<AtenSubOp>(
199+
loc, Torch::NumberType::get(rewriter.getContext()), end, start);
200+
Value div = rewriter.create<AtenDivOp>(loc, sub, step);
201+
return rewriter.create<AtenCeilFloatOp>(loc, div);
202+
}
203+
194204
class ReduceNonValueSemanticOps : public RewritePattern {
195205
public:
196206
ReduceNonValueSemanticOps(MLIRContext *context)
197207
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
198208
LogicalResult matchAndRewrite(Operation *op,
199209
PatternRewriter &rewriter) const override {
200210
Location loc = op->getLoc();
201-
Operation *newOp;
211+
MLIRContext *ctx = op->getContext();
202212
if (isa<AtenBernoulli_FloatOp>(op)) {
203-
newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
213+
Operation *newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
204214
loc, op->getResultTypes(), op->getOperands());
215+
auto tensor =
216+
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
217+
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
218+
rewriter.replaceOp(op, op->getOperand(0));
219+
return success();
220+
} else if (auto arangeOutOp = dyn_cast<AtenArangeStartOutOp>(op)) {
221+
Value start = arangeOutOp.getStart();
222+
Value end = arangeOutOp.getEnd();
223+
Value step = arangeOutOp.getStep();
224+
Value out = arangeOutOp.getOut();
225+
226+
// `overwrite.tensor.contents` cannot change the tensor shape,
227+
// so `out` tensor should have same num_elements with result tensor.
228+
// It means that we don't support code like:
229+
// `x = torch.randn(12)`
230+
// `y = torch.arange(13, out=x)`
231+
Value resultNumElements =
232+
calculateArangeResultNumElements(rewriter, loc, start, end, step);
233+
Value outNumElements = rewriter.create<AtenNumelOp>(loc, out);
234+
Value eqOrNot =
235+
rewriter.create<AtenEqIntOp>(loc, resultNumElements, outNumElements);
236+
rewriter.create<RuntimeAssertOp>(
237+
loc, eqOrNot,
238+
rewriter.getStringAttr("`out` tensor should have the same "
239+
"num_elements with result tenosr"));
240+
241+
auto dtype = rewriter.create<PrimDtypeOp>(loc, out);
242+
auto device = rewriter.create<PrimDeviceOp>(loc, out);
243+
auto shape = rewriter.create<AtenSizeOp>(
244+
loc, Torch::ListType::get(Torch::IntType::get(ctx)), out);
245+
auto none = rewriter.create<ConstantNoneOp>(loc);
246+
Value newArange = rewriter.create<AtenArangeStartStepOp>(
247+
loc, arangeOutOp.getResult().getType(), start, end, step, dtype,
248+
/*layout=*/none, device, /*pin_memory=*/none);
249+
Value reshape = rewriter.create<AtenReshapeOp>(
250+
loc, arangeOutOp.getResult().getType(), newArange, shape);
251+
252+
auto vtensor = rewriter.create<CopyToValueTensorOp>(loc, reshape);
253+
createOverwriteTensorContents(rewriter, loc, vtensor, out);
254+
rewriter.replaceOp(arangeOutOp, out);
255+
return success();
205256
} else {
206257
return failure();
207258
}
208-
209-
auto tensor =
210-
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
211-
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
212-
rewriter.replaceOp(op, op->getOperand(0));
213-
return success();
214259
}
215260
};
216261
} // namespace
@@ -309,6 +354,7 @@ struct ReduceOpVariantsPass
309354
ConversionTarget target(*context);
310355
target.addIllegalOp<NonValueTensorLiteralOp>();
311356
target.addIllegalOp<AtenBernoulli_FloatOp>();
357+
target.addIllegalOp<AtenArangeStartOutOp>();
312358
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
313359
Operation *op) {
314360
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@
302302

303303
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
304304
"ThresholdBackward2dMixedModule_basic",
305+
306+
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
307+
"ArangeStartOutViewModule_basic",
305308
}
306309

307310
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
@@ -1303,6 +1306,8 @@
13031306
"AtenEyeModuleFalsePinMemory_basic",
13041307
"AtenEyeModuleFloat2D_basic",
13051308
"MeanModule_basic",
1309+
"ArangeStartOutModule_basic",
1310+
"ArangeStartOutViewModule_basic",
13061311
}
13071312

13081313
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
@@ -1372,6 +1377,7 @@
13721377
"_ConvolutionDeprecated2DCudnnModule_basic",
13731378
"_ConvolutionDeprecated2DDeterministicModule_basic",
13741379
"AddIntModule_basic",
1380+
"ArangeStartOutViewModule_basic",
13751381
"AtenIntBoolOpModule_basic",
13761382
"BernoulliTensorModule_basic",
13771383
"BincountMinlengthModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,53 @@ def forward(self):
248248
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
249249
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
250250
module.forward()
251+
252+
# ==============================================================================
253+
254+
class ArangeStartOutModule(torch.nn.Module):
255+
def __init__(self):
256+
super().__init__()
257+
258+
@export
259+
@annotate_args([
260+
None,
261+
([12], torch.int64, True),
262+
])
263+
def forward(self, x):
264+
return torch.arange(start=0, end=12, out=x)
265+
266+
@register_test_case(module_factory=lambda: ArangeStartOutModule())
267+
def ArangeStartOutModule_basic(module, tu: TestUtils):
268+
module.forward(torch.zeros(12).to(torch.int64))
269+
270+
class ArangeStartOutViewModule(torch.nn.Module):
271+
def __init__(self):
272+
super().__init__()
273+
274+
@export
275+
@annotate_args([
276+
None,
277+
([3, 4], torch.int64, True),
278+
])
279+
def forward(self, x):
280+
return torch.arange(start=1, end=13, out=x)
281+
282+
@register_test_case(module_factory=lambda: ArangeStartOutViewModule())
283+
def ArangeStartOutViewModule_basic(module, tu: TestUtils):
284+
module.forward(torch.zeros(3, 4).to(torch.int64))
285+
286+
class ArangeStartOutDtypeModule(torch.nn.Module):
287+
def __init__(self):
288+
super().__init__()
289+
290+
@export
291+
@annotate_args([
292+
None,
293+
([12], torch.int64, True),
294+
])
295+
def forward(self, x):
296+
return torch.arange(start=1.1, end=13.1, out=x)
297+
298+
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
299+
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
300+
module.forward(torch.zeros(12).to(torch.int64))

0 commit comments

Comments
 (0)