Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][mlir-link] Make the linkage-a test pass #18

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 187 additions & 28 deletions mlir/lib/Linker/IRMover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,123 @@ using llvm::Expected;

namespace {

// TODO: Should this exist?
enum RemapFlags {
RF_None = 0,
RF_NullMapMissingGlobalValues = 8,
};

inline RemapFlags operator|(RemapFlags LHS, RemapFlags RHS) {
return RemapFlags(unsigned(LHS) | unsigned(RHS));
}

// NOTE: This is a simplified version of the LLVM IR one.
template <typename T>
class ValueMapper {
public:
// TODO: Consider creating an interface for the materializer
ValueMapper(IRMapping &valueMap, T &materializer)
: valueMap(valueMap), materializer(materializer) {}

// TODO: Maybe this should be called remapValue?
void scheduleRemapFunction(Operation *v) { worklist.push_back(v); }

Operation *mapSymbol(StringRef sym) {
// TODO: This is special as we only have a symbol ref and need to get a
// value currently implemented by asking the materializer. Might not be
// exactly how it shall be done, but works for now.
if (auto op = materializer.getSourceOperation(sym))
return mapValue(op);

return nullptr;
}

Operation *mapValue(Operation *v) {
Flusher f(*this);

// If the value already exists in the map, use it.
if (auto op = valueMap.lookupOrNull(v)) {
return op;
}

// If we have a materializer and it can materialize a value, use that.
if (auto newv = materializer.materialize(v, false)) {
valueMap.map(v, newv);
return newv;
}

// Global values do not need to be seeded into the VM if they
// are using the identity mapping.
if (auto gvl = dyn_cast<GlobalValueLinkageOpInterface>(v)) {
if (flags & RF_NullMapMissingGlobalValues)
return nullptr;
valueMap.map(v, v);
return v;
}
// TODO: potentially only value mapping

// TODO: Inline asm

// TODO: MetadataAsValue??

// TODO: Constants etc.
assert(false);
};

bool hasWorkToDo() const { return !worklist.empty(); }

void addFlags(RemapFlags additionalFlags) {
assert(!hasWorkToDo() && "Expected to have flushed the worklist");
flags = flags | additionalFlags;
}

void remapFunction(Operation *f) {
// TODO: Remap operands

f->walk([&](Operation *op) { remapInstruction(op); });

// Remap the metadata attachments.
// TODO: Do we need to do this?
}

void remapInstruction(Operation *op) {
Flusher f(*this);

// TODO: Operands?

if (auto symuser = dyn_cast<SymbolUserOpInterface>(op)) {
if (auto sym = symuser.getUserSymbol()) {
mapSymbol(*sym);
}
}
}

void flush() {
// Flush out the worklist of global values.
while (!worklist.empty()) {
auto e = worklist.pop_back_val();

// TODO: for now, we only handle functions
remapFunction(e);
}
}

private:
struct Flusher {
ValueMapper<T> &vm;
Flusher(ValueMapper<T> &vm) : vm(vm) {}
~Flusher() { vm.flush(); }
};

IRMapping &valueMap;
T &materializer;
RemapFlags flags;

// TODO: The worklist is just a function linkage opfor now.
// Once we add global value init we will need to extend it.
SmallVector<Operation *, 4> worklist;
};

class MLIRLinker {
Operation *composite;
OwningOpRef<Operation *> src;
Expand All @@ -29,12 +146,17 @@ class MLIRLinker {
// the equivalent in mlir.
IRMapping valueMap;

ValueMapper<MLIRLinker> mapper;

DenseSet<GlobalValueLinkageOpInterface> valuesToLink;
std::vector<GlobalValueLinkageOpInterface> worklist;
// Replace-all-uses-with worklist
std::vector<std::pair<Operation *, Operation *>> rauwWorklist;

bool doneLinkingBodies;
// NOTE: This is the ValueMapper flush
void flush();

bool doneLinkingBodies{false};

void maybeAdd(GlobalValueLinkageOpInterface val) {
if (valuesToLink.insert(val).second)
Expand Down Expand Up @@ -99,17 +221,42 @@ class MLIRLinker {
return dgv;
}

void insertUnique(Operation *op, Operation *dst) {
// LLVM does global value renaming automatically. This is a workaround to
// ensure we only insert unique values.
bool needsRename = false;
if (auto gv = dyn_cast<SymbolOpInterface>(op)) {
auto name = gv.getName();
SymbolTable syms(dst);
if (syms.lookup(name)) {
(void)syms.renameToUnique(op, {});
}
}

OpBuilder b(dst->getRegion(0));
b.insert(op);
}

public:
MLIRLinker(Operation *composite, OwningOpRef<Operation *> srcOp,
ArrayRef<GlobalValueLinkageOpInterface> valuesToLink)
: composite{composite}, src{std::move(srcOp)} {
: composite{composite}, src{std::move(srcOp)}, mapper{valueMap, *this} {
for (GlobalValueLinkageOpInterface gvl : valuesToLink)
maybeAdd(gvl);
}

// TODO: Helper function for the materializer to convert a symbol to a
// linkable value
GlobalValueLinkageOpInterface getSourceOperation(StringRef sym) {
SymbolTable syms(*src);
if (auto op = syms.lookup(sym))
return dyn_cast<GlobalValueLinkageOpInterface>(op);
return nullptr;
}

Error run();

Operation *materialize(GlobalValueLinkageOpInterface v,
bool forIndirectSymbol);
Operation *materialize(Operation *v, bool forIndirectSymbol);
};

bool MLIRLinker::shouldLink(Operation *dst, Operation *src) {
Expand All @@ -127,7 +274,7 @@ bool MLIRLinker::shouldLink(Operation *dst, Operation *src) {
return false;
}

if (sgv.isDeclarationForLinkage()) // TODO: DoneLinkingBodies??
if (sgv.isDeclarationForLinkage() || doneLinkingBodies)
return false;

// Callback to the client to give a chance to lazily add the Global to the
Expand All @@ -146,19 +293,17 @@ Operation *MLIRLinker::copyGlobalVariableProto(Operation *src) {
// No linking to be performed or linking from the source: simply create an
// identical version of the symbol over in the dest module... the
// initializer will be filled in later by LinkGlobalInits.
OpBuilder builder(composite->getRegion(0));
// OpBuilder builder(composite->getRegion(0));
Operation *newFunc = src->cloneWithoutRegions();
builder.insert(newFunc);
insertUnique(newFunc, composite);
return newFunc;
}

Operation *MLIRLinker::copyFunctionProto(Operation *src) {
OpBuilder builder(composite->getRegion(0));

// Clone the operation (without regions to ensure it becomes empty as is
// considered a decl)
Operation *newFunc = src->cloneWithoutRegions();
builder.insert(newFunc);
insertUnique(newFunc, composite);
return newFunc;
}
Operation *MLIRLinker::copyGlobalValueProto(Operation *src,
Expand Down Expand Up @@ -202,7 +347,11 @@ MLIRLinker::linkGlobalValueProto(GlobalValueLinkageOpInterface sgv,
if (dgv && !shouldLinkOps) {
newDst = dgv.getOperation();
} else {
// TODO: Done linking bodies?
// If we are done linking global value bodies (i.e. we are performing
// metadata linking), don't link in the global value due to this
// reference, simply map it to null.
if (doneLinkingBodies)
return nullptr;

newDst = copyGlobalValueProto(sgv.getOperation(),
shouldLinkOps); // TODO: || ForIndirectSymbol?
Expand Down Expand Up @@ -233,22 +382,23 @@ MLIRLinker::linkGlobalValueProto(GlobalValueLinkageOpInterface sgv,
return newDst;
}

Operation *MLIRLinker::materialize(GlobalValueLinkageOpInterface v,
bool forIndirectSymbol) {
Operation *op = v.getOperation();
Operation *MLIRLinker::materialize(Operation *v, bool forIndirectSymbol) {
auto sgv = dyn_cast<GlobalValueLinkageOpInterface>(v);
if (!sgv)
return nullptr;

// If v is from dest, it was already materialized when dest was loaded.
if (op->getParentOp() == composite)
if (v->getParentOp() == composite)
return nullptr;

// When linking a global from other modules than source & dest, skip
// materializing it because it would be mapped later when its containing
// module is linked. Linking it now would potentially pull in many types that
// may not be mapped properly.
if (op->getParentOp() != src.get())
if (v->getParentOp() != src.get())
return nullptr;

auto newProto = linkGlobalValueProto(v, false);
auto newProto = linkGlobalValueProto(sgv, false);
if (!newProto) {
setError(newProto.takeError());
return nullptr;
Expand All @@ -267,6 +417,10 @@ Operation *MLIRLinker::materialize(GlobalValueLinkageOpInterface v,
if (!f.isDeclarationForLinkage()) {
return *newProto;
}
} else if (auto var = dyn_cast<GlobalVariableLinkageOpInterface>(
newGvl.getOperation())) {
if (!var.isDeclarationForLinkage() || var.hasAppendingLinkage())
return *newProto;
}
// TODO: Lots of if cases for Function, global variable, global alias.
// for now, just check if it is a declaration, if so, not much more to do.
Expand All @@ -281,8 +435,8 @@ Operation *MLIRLinker::materialize(GlobalValueLinkageOpInterface v,
// new definition for the indirect symbol ("New" will be different).
// TODO: Some indirect symbol thing

if (forIndirectSymbol || shouldLink(newGvl.getOperation(), v.getOperation()))
setError(linkGlobalValueBody(newGvl.getOperation(), v));
if (forIndirectSymbol || shouldLink(newGvl.getOperation(), v))
setError(linkGlobalValueBody(newGvl.getOperation(), sgv));

// TODO: Update attributes
return newGvl.getOperation();
Expand All @@ -292,6 +446,7 @@ Operation *MLIRLinker::materialize(GlobalValueLinkageOpInterface v,
/// referenced are in Dest.
void MLIRLinker::linkGlobalVariable(Operation *dst,
GlobalVariableLinkageOpInterface src) {
// Figure out what the initializer looks like in the dest module.
// TODO: Schedule global init
// TODO: This will likely only need to happen for those that have an
// initializer, not for constants
Expand Down Expand Up @@ -319,6 +474,7 @@ Error MLIRLinker::linkFunctionBody(Operation *dst,

// TODO: several steps here, copy metadata, steal arg list and schedule
// remapfunction. What is needed?
mapper.scheduleRemapFunction(dst);

// auto target = src.getOperation()->clone(mapping);
// dst->replaceAllUsesWith(target);
Expand All @@ -340,10 +496,12 @@ Error MLIRLinker::linkGlobalValueBody(Operation *dst,
}

void MLIRLinker::flushRAUWorklist() {
SymbolTable syms(composite);
for (const auto &elem : rauwWorklist) {
Operation *oldOp, *newOp;
std::tie(oldOp, newOp) = elem;
oldOp->replaceAllUsesWith(newOp);
if (auto sym = dyn_cast<SymbolOpInterface>(newOp))
syms.replaceAllSymbolUses(oldOp, sym.getNameAttr(), composite);
oldOp->erase();
}
rauwWorklist.clear();
Expand Down Expand Up @@ -372,24 +530,25 @@ Error MLIRLinker::run() {

assert(!gvl.isDeclarationForLinkage());

// TODO: Is this the equivalent of Mapper.mapValue?
auto newGvl = materialize(gvl, false);
valueMap.map(gvl.getOperation(), newGvl);
mapper.mapValue(gvl);

if (foundError)
return std::move(*foundError);
flushRAUWorklist();
}

doneLinkingBodies = true;
mapper.addFlags(RF_NullMapMissingGlobalValues);

// Reorder the globals just added to the destination module to match their
// original order in the source module.
src->walk([&](GlobalValueLinkageOpInterface gv) {
src->walk([&](GlobalVariableLinkageOpInterface gv) {
if (gv.hasAppendingLinkage())
return WalkResult::skip();
if (auto op = valueMap.lookupOrNull(gv.getOperation())) {
if (auto newValue = dyn_cast<GlobalValueLinkageOpInterface>(op)) {
newValue->remove();
composite->getRegion(0).back().push_back(newValue);
if (auto newValue = mapper.mapValue(gv)) {
if (auto newGv = dyn_cast<GlobalVariableLinkageOpInterface>(newValue)) {
newGv->remove();
composite->getRegion(0).back().push_back(newGv);
}
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/mlir-link/functions-reverse.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: mlir-link -split-input-file %s | FileCheck %s

// CHECK: llvm.func @f1() {
// CHECK: llvm.func @f2() {
// CHECK-NEXT: llvm.call @f1() : () -> ()
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }
// CHECK-NEXT: llvm.func @f2() {
// CHECK-NEXT: llvm.call @f1() : () -> ()
// CHECK: llvm.func @f1() {
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/mlir-link/functions.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: mlir-link -split-input-file %s | FileCheck %s

// CHECK: llvm.func @f2() {
// CHECK-NEXT: llvm.call @f1() : () -> ()
// CHECK: llvm.func @f1() {
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }
// CHECK-NEXT: llvm.func @f1() {
// CHECK: llvm.func @f2() {
// CHECK-NEXT: llvm.call @f1() : () -> ()
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }

Expand Down
6 changes: 4 additions & 2 deletions mlir/test/mlir-link/linkage-a.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// RUN: mlir-link -split-input-file %s | FileCheck %s

// CHECK: llvm.mlir.global external @X() {addr_space = 0 : i32} : i32
// CHECK: llvm.func @foo() -> i32

// CHECK: llvm.func @bar()

// CHECK: llvm.func @foo() -> i32
// CHECK: llvm.mlir.global external @X() {addr_space = 0 : i32} : i32



llvm.mlir.global linkonce @X(5 : i32) {addr_space = 0 : i32} : i32
llvm.func linkonce @foo() -> i32 {
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-link/single-global-usage.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ llvm.mlir.global @number(7 : i32) : i32


// -----

llvm.mlir.global @number() : i32

Loading