Skip to content

Commit acd57a3

Browse files
authored
Support fake_quantize_per_tensor_affine_cachemask (#3477)
Add a new op with shape/dtypes and decompose into `fake_quantize_per_tensor_affine` when the second result is unused. The xfail_set change is on ONNX because torch cannot export this op to ONNX.
1 parent 83bfb6f commit acd57a3

File tree

8 files changed

+129
-0
lines changed

8 files changed

+129
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4595,6 +4595,34 @@ def Torch_AtenFakeQuantizePerTensorAffineOp : Torch_Op<"aten.fake_quantize_per_t
45954595
}];
45964596
}
45974597

4598+
def Torch_AtenFakeQuantizePerTensorAffineCachemaskOp : Torch_Op<"aten.fake_quantize_per_tensor_affine_cachemask", [
4599+
AllowsTypeRefinement,
4600+
HasValueSemantics,
4601+
ReadOnly
4602+
]> {
4603+
let summary = "Generated op for `aten::fake_quantize_per_tensor_affine_cachemask : (Tensor, float, int, int, int) -> (Tensor, Tensor)`";
4604+
let arguments = (ins
4605+
AnyTorchTensorType:$self,
4606+
Torch_FloatType:$scale,
4607+
Torch_IntType:$zero_point,
4608+
Torch_IntType:$quant_min,
4609+
Torch_IntType:$quant_max
4610+
);
4611+
let results = (outs
4612+
AnyTorchOptionalTensorType:$output,
4613+
AnyTorchOptionalTensorType:$mask
4614+
);
4615+
let hasCustomAssemblyFormat = 1;
4616+
let extraClassDefinition = [{
4617+
ParseResult AtenFakeQuantizePerTensorAffineCachemaskOp::parse(OpAsmParser &parser, OperationState &result) {
4618+
return parseDefaultTorchOp(parser, result, 5, 2);
4619+
}
4620+
void AtenFakeQuantizePerTensorAffineCachemaskOp::print(OpAsmPrinter &printer) {
4621+
printDefaultTorchOp(printer, *this, 5, 2);
4622+
}
4623+
}];
4624+
}
4625+
45984626
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
45994627
AllowsTypeRefinement,
46004628
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6328,6 +6328,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
63286328
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
63296329
" return %0 : !torch.list<int>\n"
63306330
" }\n"
6331+
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine_cachemask\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
6332+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6333+
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6334+
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
6335+
" return %2 : !torch.tuple<list<int>, list<int>>\n"
6336+
" }\n"
63316337
" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
63326338
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
63336339
" return %0 : !torch.list<int>\n"
@@ -10189,6 +10195,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1018910195
" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
1019010196
" return %0 : !torch.list<int>\n"
1019110197
" }\n"
10198+
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine_cachemask\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
10199+
" %int11 = torch.constant.int 11\n"
10200+
" %int15 = torch.constant.int 15\n"
10201+
" %none = torch.constant.none\n"
10202+
" %str = torch.constant.str \"AssertionError: \"\n"
10203+
" %int1 = torch.constant.int 1\n"
10204+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
10205+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
10206+
" torch.prim.If %1 -> () {\n"
10207+
" torch.prim.If.yield\n"
10208+
" } else {\n"
10209+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10210+
" torch.prim.If.yield\n"
10211+
" }\n"
10212+
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
10213+
" torch.prim.If %2 -> () {\n"
10214+
" torch.prim.If.yield\n"
10215+
" } else {\n"
10216+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10217+
" torch.prim.If.yield\n"
10218+
" }\n"
10219+
" %3 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
10220+
" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
10221+
" return %4 : !torch.tuple<int, int>\n"
10222+
" }\n"
1019210223
" func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1019310224
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1019410225
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8146,6 +8146,31 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp
81468146
};
81478147
} // namespace
81488148

8149+
namespace {
8150+
// Decompose aten.fake_quantize_per_tensor_affine_cachemask
8151+
// into aten.fake_quantize_per_tensor_affine
8152+
// when the second result is unused.
8153+
class DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp
8154+
: public OpRewritePattern<AtenFakeQuantizePerTensorAffineCachemaskOp> {
8155+
public:
8156+
using OpRewritePattern<
8157+
AtenFakeQuantizePerTensorAffineCachemaskOp>::OpRewritePattern;
8158+
LogicalResult matchAndRewrite(AtenFakeQuantizePerTensorAffineCachemaskOp op,
8159+
PatternRewriter &rewriter) const override {
8160+
if (!op->getResult(1).use_empty())
8161+
return failure();
8162+
8163+
auto newOp = rewriter.create<AtenFakeQuantizePerTensorAffineOp>(
8164+
op.getLoc(), op->getResult(0).getType(), op.getSelf(), op.getScale(),
8165+
op.getZeroPoint(), op.getQuantMin(), op.getQuantMax());
8166+
8167+
rewriter.replaceAllUsesWith(op->getResult(0), newOp);
8168+
rewriter.eraseOp(op);
8169+
return success();
8170+
}
8171+
};
8172+
} // namespace
8173+
81498174
namespace {
81508175
class DecomposeComplexOpsPass
81518176
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -8375,6 +8400,8 @@ class DecomposeComplexOpsPass
83758400
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
83768401
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
83778402
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
8403+
addPatternIfTargetOpIsIllegal<
8404+
DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns);
83788405
// More specific conv ops
83798406
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTbcOp>(patterns);
83808407
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
460460
target.addIllegalOp<AtenRelu6Op>();
461461
target.addIllegalOp<AtenEluOp>();
462462
target.addIllegalOp<AtenFakeQuantizePerTensorAffineOp>();
463+
target.addIllegalOp<AtenFakeQuantizePerTensorAffineCachemaskOp>();
463464
target.addIllegalOp<AtenGluOp>();
464465
target.addIllegalOp<AtenSeluOp>();
465466
target.addIllegalOp<AtenHardswishOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@
396396
"ElementwiseRreluTrainStaticModule_basic",
397397
"ElementwiseToDtypeI64ToUI8Module_basic",
398398
"EqIntModule_basic",
399+
"FakeQuantizePerTensorAffineCachemaskModule_basic",
399400
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
400401
"FakeQuantizePerTensorAffineModule_basic",
401402
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
@@ -1055,6 +1056,7 @@
10551056
"EmptyStridedModule_basic",
10561057
"EqIntModule_basic",
10571058
"ExpandAsIntModule_basic",
1059+
"FakeQuantizePerTensorAffineCachemaskModule_basic",
10581060
"FakeQuantizePerTensorAffineModule_basic",
10591061
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
10601062
"Fill_TensorFloat64WithFloat32Static_basic",
@@ -2400,6 +2402,7 @@
24002402
"EmptyStridedSizeIntStrideModule_basic",
24012403
"EqIntModule_basic",
24022404
"ExponentialModule_basic",
2405+
"FakeQuantizePerTensorAffineCachemaskModule_basic",
24032406
"FloatImplicitModule_basic",
24042407
"GeFloatIntModule_basic",
24052408
"GeFloatModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim
118118
def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> List[int]:
119119
return upstream_shape_functions.unary(self)
120120

121+
def aten〇fake_quantize_per_tensor_affine_cachemask〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]:
122+
return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self))
123+
121124
def aten〇sin〡shape(self: List[int]) -> List[int]:
122125
return upstream_shape_functions.unary(self)
123126

@@ -2162,6 +2165,14 @@ def aten〇fake_quantize_per_tensor_affine〡dtype(self_rank_dtype: Tuple[int, i
21622165
assert self_dtype != torch.bfloat16
21632166
return self_dtype
21642167

2168+
# note: fake_quantize_per_tensor_affine doesn't support "meta" device, use "cpu" instead.
2169+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device="cpu", scale=0.1, zero_point=0, quant_min=0, quant_max=255, error_types={torch.complex128, torch.complex64, torch.bfloat16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool}))
2170+
def aten〇fake_quantize_per_tensor_affine_cachemask〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> Tuple[int, int]:
2171+
self_rank, self_dtype = self_rank_dtype
2172+
assert is_float_dtype(self_dtype)
2173+
assert self_dtype != torch.bfloat16
2174+
return (self_rank_dtype[1], torch.bool)
2175+
21652176
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
21662177
def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
21672178
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,9 @@ def emit_with_mutating_variants(key, **kwargs):
458458
emit(
459459
"aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)"
460460
)
461+
emit(
462+
"aten::fake_quantize_per_tensor_affine_cachemask : (Tensor, float, int, int, int) -> (Tensor, Tensor)"
463+
)
461464
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
462465
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
463466
emit("aten::mish : (Tensor) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,28 @@ def get_quantized_mlp():
181181
@register_test_case(module_factory=get_quantized_mlp)
182182
def QuantizedMLP_basic(module, tu: TestUtils):
183183
module.forward(get_quant_model_input())
184+
185+
186+
# ==============================================================================
187+
188+
189+
class FakeQuantizePerTensorAffineCachemaskModule(torch.nn.Module):
190+
def __init__(self):
191+
super().__init__()
192+
193+
@export
194+
@annotate_args(
195+
[
196+
None,
197+
([6, 4], torch.float32, True),
198+
]
199+
)
200+
def forward(self, a):
201+
return torch.ops.aten.fake_quantize_per_tensor_affine_cachemask(
202+
a, 2.0, 0, -128, 127
203+
)[0]
204+
205+
206+
@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineCachemaskModule())
207+
def FakeQuantizePerTensorAffineCachemaskModule_basic(module, tu: TestUtils):
208+
module.forward(tu.rand(6, 4))

0 commit comments

Comments
 (0)