Skip to content

Commit d4a53f3

Browse files
committed
[mlir] call target materialization more in dialect conversion
During dialect conversion, target materialization is triggered to create cast-like operations when a type mismatch occurs between the value that replaces a rewritten operation and the type that another operations expects as operands processed by the type conversion. First, a dummy cast is inserted to make sure the pattern application can proceed. The decision to trigger the user-provided materialization hook is taken later based on the result of the dummy cast having uses. However, it only has uses if other patterns constructed new operations using the casted value as operand. If existing (legal) operations use the replaced value, they may have not been updated to use the casted value yet. The conversion infra would then delete the dummy cast first, and then would replace the uses with now-invalid (null in the bast case) value. When deciding whether to trigger cast materialization, check for liveness the uses not only of the casted value, but also of all the values that it replaces. This was discovered in the finalizing bufferize pass that cleans up mutually-cancelling casts without touching other operations. It is not impossible that there are other scenarios where the dialect converison infra could produce invalid operand uses because of dummy casts erased too eagerly. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D119937
1 parent dd4dde8 commit d4a53f3

File tree

4 files changed

+87
-0
lines changed

4 files changed

+87
-0
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2588,6 +2588,11 @@ static void computeNecessaryMaterializations(
25882588
return !necessaryMaterializations.count(matIt->second);
25892589
return rewriterImpl.isOpIgnored(user);
25902590
};
2591+
// This value may be replacing another value that has a live user.
2592+
for (Value inv : inverseMapping.lookup(value))
2593+
if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2594+
return true;
2595+
// Or have live users itself.
25912596
return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
25922597
};
25932598

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt -test-target-materialization-with-no-uses %s | FileCheck %s
2+
3+
// The conversion is set up as follows:
4+
// - type_changer ops are illegal;
5+
// - type_changer ops are replaced with their operands;
6+
// - i16 types are converted to i64 by the type conversion;
7+
// - the rest of the types are legal.
8+
// The first type_changer is replaced with its operand. For the pattern to
9+
// apply to the second type_changer, the conversion infra creates a dummy
10+
// cast operation to cast from the i32 to i64 because the original op takes an
11+
// (illegal) i16 that became i64. This dummy operation should be replaced by
12+
// the one produced by the target materialization hook. At the moment when the
13+
// materialization decision is taken, the i64 replacement of the first type
14+
// change (the result of the dummy cast) has no uses, but the value it replaces
15+
// does, so the infra must call the materialization rather than assume the
16+
// dummy cast to be dead.
17+
18+
// CHECK-LABEL: @foo
19+
func @foo() {
20+
%0 = "test.type_producer"() : () -> i32
21+
// CHECK: test.cast
22+
// CHECK-NOT: test.type_changer
23+
%1 = "test.type_changer"(%0) : (i32) -> i16
24+
%2 = "test.type_changer"(%1) : (i16) -> i64
25+
"test.type_consumer"(%2) : (i64) -> ()
26+
return
27+
}

mlir/test/lib/Dialect/Test/TestOps.td

+2
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,8 @@ def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
16031603
Results<(outs AnyType)>;
16041604
def TestTypeConsumerOp : TEST_Op<"type_consumer">,
16051605
Arguments<(ins AnyType)>;
1606+
def TestTypeChangerOp : TEST_Op<"type_changer">,
1607+
Arguments<(ins AnyType)>, Results<(outs AnyType)>;
16061608
def TestValidOp : TEST_Op<"valid", [Terminator]>,
16071609
Arguments<(ins Variadic<AnyType>)>;
16081610

mlir/test/lib/Dialect/Test/TestPatterns.cpp

+53
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,58 @@ struct TestTypeConversionDriver
11351135
};
11361136
} // namespace
11371137

1138+
//===----------------------------------------------------------------------===//
1139+
// Test Target Materialization With No Uses
1140+
//===----------------------------------------------------------------------===//
1141+
1142+
namespace {
1143+
struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1144+
using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1145+
1146+
LogicalResult
1147+
matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1148+
ConversionPatternRewriter &rewriter) const final {
1149+
rewriter.replaceOp(op, adaptor.getOperands());
1150+
return success();
1151+
}
1152+
};
1153+
1154+
struct TestTargetMaterializationWithNoUses
1155+
: public PassWrapper<TestTargetMaterializationWithNoUses,
1156+
OperationPass<ModuleOp>> {
1157+
StringRef getArgument() const final {
1158+
return "test-target-materialization-with-no-uses";
1159+
}
1160+
StringRef getDescription() const final {
1161+
return "Test a special case of target materialization in DialectConversion";
1162+
}
1163+
1164+
void runOnOperation() override {
1165+
TypeConverter converter;
1166+
converter.addConversion([](Type t) { return t; });
1167+
converter.addConversion([](IntegerType intTy) -> Type {
1168+
if (intTy.getWidth() == 16)
1169+
return IntegerType::get(intTy.getContext(), 64);
1170+
return intTy;
1171+
});
1172+
converter.addTargetMaterialization(
1173+
[](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1174+
return builder.create<TestCastOp>(loc, type, inputs).getResult();
1175+
});
1176+
1177+
ConversionTarget target(getContext());
1178+
target.addIllegalOp<TestTypeChangerOp>();
1179+
1180+
RewritePatternSet patterns(&getContext());
1181+
patterns.add<ForwardOperandPattern>(converter, &getContext());
1182+
1183+
if (failed(applyPartialConversion(getOperation(), target,
1184+
std::move(patterns))))
1185+
signalPassFailure();
1186+
}
1187+
};
1188+
} // namespace
1189+
11381190
//===----------------------------------------------------------------------===//
11391191
// Test Block Merging
11401192
//===----------------------------------------------------------------------===//
@@ -1317,6 +1369,7 @@ void registerPatternsTestPass() {
13171369
PassRegistration<TestUnknownRootOpDriver>();
13181370

13191371
PassRegistration<TestTypeConversionDriver>();
1372+
PassRegistration<TestTargetMaterializationWithNoUses>();
13201373

13211374
PassRegistration<TestMergeBlocksPatternDriver>();
13221375
PassRegistration<TestSelectiveReplacementPatternDriver>();

0 commit comments

Comments
 (0)