-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[MLIR] Fix duplicated attribute nodes in MLIR bytecode deserialization #151267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-core Author: Hank (hankluo6) ChangesFixes #150163 MLIR bytecode does not preserve alias definitions, so each attribute encountered during deserialization is treated as a new one. This can generate duplicate The patch adds a Full diff: https://github.com/llvm/llvm-project/pull/151267.diff 4 Files Affected:
diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h
index 33daf7ca26f49..f39b3bd853a2a 100644
--- a/mlir/include/mlir/AsmParser/AsmParser.h
+++ b/mlir/include/mlir/AsmParser/AsmParser.h
@@ -53,7 +53,8 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
/// null terminated.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
Type type = {}, size_t *numRead = nullptr,
- bool isKnownNullTerminated = false);
+ bool isKnownNullTerminated = false,
+ llvm::StringMap<Attribute> *attributesCache = nullptr);
/// This parses a single MLIR type to an MLIR context if it was valid. If not,
/// an error diagnostic is emitted to the context.
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 8b14e71118c3a..de8e3c1fc1e72 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -245,6 +245,14 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
return nullptr;
}
+ if constexpr (std::is_same_v<Symbol, Attribute>) {
+ auto &cache = p.getState().symbols.attributesCache;
+
+ auto cacheIt = cache.find(symbolData);
+ if (cacheIt != cache.end()) {
+ return cacheIt->second;
+ }
+ }
return createSymbol(dialectName, symbolData, loc);
}
@@ -337,6 +345,7 @@ Type Parser::parseExtendedType() {
template <typename T, typename ParserFn>
static T parseSymbol(StringRef inputStr, MLIRContext *context,
size_t *numReadOut, bool isKnownNullTerminated,
+ llvm::StringMap<Attribute> *attributesCache,
ParserFn &&parserFn) {
// Set the buffer name to the string being parsed, so that it appears in error
// diagnostics.
@@ -348,6 +357,9 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState aliasState;
+ if (attributesCache)
+ aliasState.attributesCache = *attributesCache;
+
ParserConfig config(context);
ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
/*codeCompleteContext=*/nullptr);
@@ -358,6 +370,13 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
if (!symbol)
return T();
+ if constexpr (std::is_same_v<T, Attribute>) {
+ // Cache key is the symbol data without the dialect prefix.
+ StringRef cacheKey = inputStr.split('.').second;
+ if (attributesCache && !cacheKey.empty()) {
+ (*attributesCache)[cacheKey] = symbol;
+ }
+ }
// Provide the number of bytes that were read.
Token endTok = parser.getToken();
size_t numRead =
@@ -374,13 +393,15 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
Type type, size_t *numRead,
- bool isKnownNullTerminated) {
+ bool isKnownNullTerminated,
+ llvm::StringMap<Attribute> *attributesCache) {
return parseSymbol<Attribute>(
- attrStr, context, numRead, isKnownNullTerminated,
+ attrStr, context, numRead, isKnownNullTerminated, attributesCache,
[type](Parser &parser) { return parser.parseAttribute(type); });
}
Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
bool isKnownNullTerminated) {
return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
+ /*attributesCache=*/nullptr,
[](Parser &parser) { return parser.parseType(); });
}
diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h
index 159058a18fa4e..aa53032107cbf 100644
--- a/mlir/lib/AsmParser/ParserState.h
+++ b/mlir/lib/AsmParser/ParserState.h
@@ -40,6 +40,9 @@ struct SymbolState {
/// A map from unique integer identifier to DistinctAttr.
DenseMap<uint64_t, DistinctAttr> distinctAttributes;
+
+ /// A map from unique string identifier to Attribute.
+ llvm::StringMap<Attribute> attributesCache;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 44458d010c6c8..0f97443433774 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -895,6 +895,10 @@ class AttrTypeReader {
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
+ /// The map of cached attributes, used to avoid re-parsing the same
+ /// attribute multiple times.
+ llvm::StringMap<Attribute> attributesCache;
+
/// A location used for error emission.
Location fileLoc;
@@ -1235,7 +1239,7 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
else
result = ::parseAttribute(asmStr, context, Type(), &numRead,
- /*isKnownNullTerminated=*/true);
+ /*isKnownNullTerminated=*/true, &attributesCache);
if (!result)
return failure();
|
@llvm/pr-subscribers-mlir Author: Hank (hankluo6) ChangesFixes #150163 MLIR bytecode does not preserve alias definitions, so each attribute encountered during deserialization is treated as a new one. This can generate duplicate The patch adds a Full diff: https://github.com/llvm/llvm-project/pull/151267.diff 4 Files Affected:
diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h
index 33daf7ca26f49..f39b3bd853a2a 100644
--- a/mlir/include/mlir/AsmParser/AsmParser.h
+++ b/mlir/include/mlir/AsmParser/AsmParser.h
@@ -53,7 +53,8 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
/// null terminated.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
Type type = {}, size_t *numRead = nullptr,
- bool isKnownNullTerminated = false);
+ bool isKnownNullTerminated = false,
+ llvm::StringMap<Attribute> *attributesCache = nullptr);
/// This parses a single MLIR type to an MLIR context if it was valid. If not,
/// an error diagnostic is emitted to the context.
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 8b14e71118c3a..de8e3c1fc1e72 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -245,6 +245,14 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
return nullptr;
}
+ if constexpr (std::is_same_v<Symbol, Attribute>) {
+ auto &cache = p.getState().symbols.attributesCache;
+
+ auto cacheIt = cache.find(symbolData);
+ if (cacheIt != cache.end()) {
+ return cacheIt->second;
+ }
+ }
return createSymbol(dialectName, symbolData, loc);
}
@@ -337,6 +345,7 @@ Type Parser::parseExtendedType() {
template <typename T, typename ParserFn>
static T parseSymbol(StringRef inputStr, MLIRContext *context,
size_t *numReadOut, bool isKnownNullTerminated,
+ llvm::StringMap<Attribute> *attributesCache,
ParserFn &&parserFn) {
// Set the buffer name to the string being parsed, so that it appears in error
// diagnostics.
@@ -348,6 +357,9 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState aliasState;
+ if (attributesCache)
+ aliasState.attributesCache = *attributesCache;
+
ParserConfig config(context);
ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
/*codeCompleteContext=*/nullptr);
@@ -358,6 +370,13 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
if (!symbol)
return T();
+ if constexpr (std::is_same_v<T, Attribute>) {
+ // Cache key is the symbol data without the dialect prefix.
+ StringRef cacheKey = inputStr.split('.').second;
+ if (attributesCache && !cacheKey.empty()) {
+ (*attributesCache)[cacheKey] = symbol;
+ }
+ }
// Provide the number of bytes that were read.
Token endTok = parser.getToken();
size_t numRead =
@@ -374,13 +393,15 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
Type type, size_t *numRead,
- bool isKnownNullTerminated) {
+ bool isKnownNullTerminated,
+ llvm::StringMap<Attribute> *attributesCache) {
return parseSymbol<Attribute>(
- attrStr, context, numRead, isKnownNullTerminated,
+ attrStr, context, numRead, isKnownNullTerminated, attributesCache,
[type](Parser &parser) { return parser.parseAttribute(type); });
}
Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
bool isKnownNullTerminated) {
return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
+ /*attributesCache=*/nullptr,
[](Parser &parser) { return parser.parseType(); });
}
diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h
index 159058a18fa4e..aa53032107cbf 100644
--- a/mlir/lib/AsmParser/ParserState.h
+++ b/mlir/lib/AsmParser/ParserState.h
@@ -40,6 +40,9 @@ struct SymbolState {
/// A map from unique integer identifier to DistinctAttr.
DenseMap<uint64_t, DistinctAttr> distinctAttributes;
+
+ /// A map from unique string identifier to Attribute.
+ llvm::StringMap<Attribute> attributesCache;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 44458d010c6c8..0f97443433774 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -895,6 +895,10 @@ class AttrTypeReader {
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
+ /// The map of cached attributes, used to avoid re-parsing the same
+ /// attribute multiple times.
+ llvm::StringMap<Attribute> attributesCache;
+
/// A location used for error emission.
Location fileLoc;
@@ -1235,7 +1239,7 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
else
result = ::parseAttribute(asmStr, context, Type(), &numRead,
- /*isKnownNullTerminated=*/true);
+ /*isKnownNullTerminated=*/true, &attributesCache);
if (!result)
return failure();
|
Thanks for the fit! |
Hi @joker-eph, thanks for reviewing! I've added a test. |
#di_subprogram1 = #llvm.di_subprogram<recId = distinct[0]<>, id = distinct[2]<>, compileUnit = #di_compile_unit, scope = #di_file1, name = "main", file = #di_file1, line = 1, scopeLine = 1, subprogramFlags = "Definition|Optimized", type = #di_subroutine_type, retainedNodes = #di_local_variable> | ||
#di_local_variable1 = #llvm.di_local_variable<scope = #di_subprogram1, name = "a", file = #di_file1, line = 2, type = #di_basic_type> | ||
|
||
module attributes {dlti.dl_spec = #dlti.dl_spec<i64 = dense<64> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>>, llvm.ident = "MLIR", llvm.target_triple = "x86_64-unknown-linux-gnu"} { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check this as a round-tripping to MLIR targeted kind of test (with minimal attribute to show the discrepancy) instead of involving a translation to LLVM IR and relying on the specific of DI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a minimal test I believe:
// RUN: mlir-opt -emit-bytecode %s | mlir-opt --mlir-print-debuginfo | FileCheck %s
// CHECK: llvm.di_subprogram
// CHECK-NOT: llvm.di_subprogram
#di_file = #llvm.di_file<"foo.c" in "/mlir/">
#di_subprogram = #llvm.di_subprogram<recId = distinct[0]<>, isRecSelf = true>
#di_basic_type = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "int", sizeInBits = 32, encoding = DW_ATE_signed>
#di_local_variable = #llvm.di_local_variable<scope = #di_subprogram, name = "a", file = #di_file, line = 2, type = #di_basic_type>
module attributes {test.alias = #di_local_variable} {
}loc(fused<#di_subprogram>[])
However I'm not sure I understand exactly what triggers the issue, is this specific to the implementation of the LLVM attributes? Can we reproduce this with one of the test dialect attributes and simplify this further?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check this as a round-tripping to MLIR targeted kind of test (with minimal attribute to show the discrepancy) instead of involving a translation to LLVM IR and relying on the specific of DI.
The issue is only reproducible during mlir-translate
from MLIR bytecode to LLVM IR, because that’s the only place where the LLVM IR verifier checks that two DISubprogram
attributes refer to the same object:
llvm-project/llvm/lib/IR/Verifier.cpp
Lines 6936 to 6945 in 68b9bb5
// The scopes for variables and !dbg attachments must agree. | |
DISubprogram *VarSP = getSubprogram(Var->getRawScope()); | |
DISubprogram *LocSP = getSubprogram(Loc->getRawScope()); | |
if (!VarSP || !LocSP) | |
return; // Broken scope chains are checked elsewhere. | |
CheckDI(VarSP == LocSP, | |
"mismatched subprogram between #dbg record variable and DILocation", | |
&DVR, BB, F, Var, Var->getScope()->getSubprogram(), Loc, | |
Loc->getScope()->getSubprogram(), BB, F); |
mlir-opt
will create multiple identical attribute objects with the same content, and since it doesn't check whether they're the exact same object, no error is triggered.
However I'm not sure I understand exactly what triggers the issue, is this specific to the implementation of the LLVM attributes? Can we reproduce this with one of the test dialect attributes and simplify this further?
This is a potential issue for attributes in general: when parsing MLIR bytecode, the parser creates separate attribute instances even if they are logically equal. That isn't a problem usually if we care only about the content, but when we expect the objects should be the same as in the above case, it can lead to issues. We can't reproduce it with other dialects using mlir-translate
or mlir-opt
since there is no such checking.
Since the problem only happens in mlir-translate
when translating to LLVM IR, I’m not sure what the test should be. Should I put it under mlir/test/mlir-translate
instead? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is only reproducible during mlir-translate from MLIR bytecode to LLVM IR
I'm slightly confused by your claim here, because I provided above a test that fails before your patch, and passes afterward, using only mlir-opt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You’re right—I misunderstood earlier. I thought we needed to trigger an assertion in the test.
We can test this with an mlir-opt
round trip. I'll update the minimal test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However I'm not sure I understand exactly what triggers the issue, is this specific to the implementation of the LLVM attributes? Can we reproduce this with one of the test dialect attributes and simplify this further?
Do you under why we need a location to trigger the issue here?
@gysit for help :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand but I am not fully there yet.
So the IR above seems not correct since there should always be a "normal" subprogram before the self recursive reference when walking from the location down the attribute tree.
However, it seems more like the distinct attribute is the issue here? They are indeed special in the sense that creating a distinct attribute every time produces a new attribute. In the reproducer there is only one of them though, which again makes me wonder if this can be the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, it seems more like the distinct attribute is the issue here?
Yes somehow. Ideally I'd like to see it in isolated the test dialect somehow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we attach the the same distinct attribute to multiple test ops and then check they roundtrip correctly. Or does it only work if they show up in an alias?
Something like:
test_op attributes { distinct[0]<>, distinct[1]<>}
test_op attributes { distinct[1]<>, distinct[0]<>}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it seems from the minimal test that an alias that is referenced from two places is necessary. I am not aware of attributes outside of LLVM dialect that has a distinct attribute and that prints as alias. There would be others in LLVM dialect such as the alias_scope attribute that could be used.
Fixes #150163
MLIR bytecode does not preserve alias definitions, so each attribute encountered during deserialization is treated as a new one. This can generate duplicate
DISubprogram
nodes during deserialization.The patch adds a
StringMap
cache that records attributes and fetches them when encountered again.