Skip to content

Commit

Permalink
[MLIR][mlir-link] Make the linkage-a test pass (#18)
Browse files Browse the repository at this point in the history
Ensuring the previously passing tests still do. This introduces
the ValueMapper concept in a customized format for mlir-linking.
  • Loading branch information
hbrodin authored Feb 12, 2025
1 parent 2bf5173 commit 65939ae
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 42 deletions.
215 changes: 187 additions & 28 deletions mlir/lib/Linker/IRMover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,123 @@ using llvm::Expected;

namespace {

// TODO: Should this exist?
enum RemapFlags {
RF_None = 0,
RF_NullMapMissingGlobalValues = 8,
};

inline RemapFlags operator|(RemapFlags LHS, RemapFlags RHS) {
return RemapFlags(unsigned(LHS) | unsigned(RHS));
}

// NOTE: This is a simplified version of the LLVM IR one.
template <typename T>
class ValueMapper {
public:
// TODO: Consider creating an interface for the materializer
ValueMapper(IRMapping &valueMap, T &materializer)
: valueMap(valueMap), materializer(materializer) {}

// TODO: Maybe this should be called remapValue?
void scheduleRemapFunction(Operation *v) { worklist.push_back(v); }

Operation *mapSymbol(StringRef sym) {
// TODO: This is special as we only have a symbol ref and need to get a
// value currently implemented by asking the materializer. Might not be
// exactly how it shall be done, but works for now.
if (auto op = materializer.getSourceOperation(sym))
return mapValue(op);

return nullptr;
}

Operation *mapValue(Operation *v) {
Flusher f(*this);

// If the value already exists in the map, use it.
if (auto op = valueMap.lookupOrNull(v)) {
return op;
}

// If we have a materializer and it can materialize a value, use that.
if (auto newv = materializer.materialize(v, false)) {
valueMap.map(v, newv);
return newv;
}

// Global values do not need to be seeded into the VM if they
// are using the identity mapping.
if (auto gvl = dyn_cast<GlobalValueLinkageOpInterface>(v)) {
if (flags & RF_NullMapMissingGlobalValues)
return nullptr;
valueMap.map(v, v);
return v;
}
// TODO: potentially only value mapping

// TODO: Inline asm

// TODO: MetadataAsValue??

// TODO: Constants etc.
assert(false);
};

bool hasWorkToDo() const { return !worklist.empty(); }

void addFlags(RemapFlags additionalFlags) {
assert(!hasWorkToDo() && "Expected to have flushed the worklist");
flags = flags | additionalFlags;
}

void remapFunction(Operation *f) {
// TODO: Remap operands

f->walk([&](Operation *op) { remapInstruction(op); });

// Remap the metadata attachments.
// TODO: Do we need to do this?
}

void remapInstruction(Operation *op) {
Flusher f(*this);

// TODO: Operands?

if (auto symuser = dyn_cast<SymbolUserOpInterface>(op)) {
if (auto sym = symuser.getUserSymbol()) {
mapSymbol(*sym);
}
}
}

void flush() {
// Flush out the worklist of global values.
while (!worklist.empty()) {
auto e = worklist.pop_back_val();

// TODO: for now, we only handle functions
remapFunction(e);
}
}

private:
struct Flusher {
ValueMapper<T> &vm;
Flusher(ValueMapper<T> &vm) : vm(vm) {}
~Flusher() { vm.flush(); }
};

IRMapping &valueMap;
T &materializer;
RemapFlags flags;

// TODO: The worklist is just a function linkage opfor now.
// Once we add global value init we will need to extend it.
SmallVector<Operation *, 4> worklist;
};

class MLIRLinker {
Operation *composite;
OwningOpRef<Operation *> src;
Expand All @@ -29,12 +146,17 @@ class MLIRLinker {
// the equivalent in mlir.
IRMapping valueMap;

ValueMapper<MLIRLinker> mapper;

DenseSet<GlobalValueLinkageOpInterface> valuesToLink;
std::vector<GlobalValueLinkageOpInterface> worklist;
// Replace-all-uses-with worklist
std::vector<std::pair<Operation *, Operation *>> rauwWorklist;

bool doneLinkingBodies;
// NOTE: This is the ValueMapper flush
void flush();

bool doneLinkingBodies{false};

void maybeAdd(GlobalValueLinkageOpInterface val) {
if (valuesToLink.insert(val).second)
Expand Down Expand Up @@ -99,17 +221,42 @@ class MLIRLinker {
return dgv;
}

void insertUnique(Operation *op, Operation *dst) {
// LLVM does global value renaming automatically. This is a workaround to
// ensure we only insert unique values.
bool needsRename = false;
if (auto gv = dyn_cast<SymbolOpInterface>(op)) {
auto name = gv.getName();
SymbolTable syms(dst);
if (syms.lookup(name)) {
(void)syms.renameToUnique(op, {});
}
}

OpBuilder b(dst->getRegion(0));
b.insert(op);
}

public:
MLIRLinker(Operation *composite, OwningOpRef<Operation *> srcOp,
ArrayRef<GlobalValueLinkageOpInterface> valuesToLink)
: composite{composite}, src{std::move(srcOp)} {
: composite{composite}, src{std::move(srcOp)}, mapper{valueMap, *this} {
for (GlobalValueLinkageOpInterface gvl : valuesToLink)
maybeAdd(gvl);
}

// TODO: Helper function for the materializer to convert a symbol to a
// linkable value
GlobalValueLinkageOpInterface getSourceOperation(StringRef sym) {
SymbolTable syms(*src);
if (auto op = syms.lookup(sym))
return dyn_cast<GlobalValueLinkageOpInterface>(op);
return nullptr;
}

Error run();

Operation *materialize(GlobalValueLinkageOpInterface v,
bool forIndirectSymbol);
Operation *materialize(Operation *v, bool forIndirectSymbol);
};

bool MLIRLinker::shouldLink(Operation *dst, Operation *src) {
Expand All @@ -127,7 +274,7 @@ bool MLIRLinker::shouldLink(Operation *dst, Operation *src) {
return false;
}

if (sgv.isDeclarationForLinkage()) // TODO: DoneLinkingBodies??
if (sgv.isDeclarationForLinkage() || doneLinkingBodies)
return false;

// Callback to the client to give a chance to lazily add the Global to the
Expand All @@ -146,19 +293,17 @@ Operation *MLIRLinker::copyGlobalVariableProto(Operation *src) {
// No linking to be performed or linking from the source: simply create an
// identical version of the symbol over in the dest module... the
// initializer will be filled in later by LinkGlobalInits.
OpBuilder builder(composite->getRegion(0));
// OpBuilder builder(composite->getRegion(0));
Operation *newFunc = src->cloneWithoutRegions();
builder.insert(newFunc);
insertUnique(newFunc, composite);
return newFunc;
}

Operation *MLIRLinker::copyFunctionProto(Operation *src) {
OpBuilder builder(composite->getRegion(0));

// Clone the operation (without regions to ensure it becomes empty as is
// considered a decl)
Operation *newFunc = src->cloneWithoutRegions();
builder.insert(newFunc);
insertUnique(newFunc, composite);
return newFunc;
}
Operation *MLIRLinker::copyGlobalValueProto(Operation *src,
Expand Down Expand Up @@ -202,7 +347,11 @@ MLIRLinker::linkGlobalValueProto(GlobalValueLinkageOpInterface sgv,
if (dgv && !shouldLinkOps) {
newDst = dgv.getOperation();
} else {
// TODO: Done linking bodies?
// If we are done linking global value bodies (i.e. we are performing
// metadata linking), don't link in the global value due to this
// reference, simply map it to null.
if (doneLinkingBodies)
return nullptr;

newDst = copyGlobalValueProto(sgv.getOperation(),
shouldLinkOps); // TODO: || ForIndirectSymbol?
Expand Down Expand Up @@ -233,22 +382,23 @@ MLIRLinker::linkGlobalValueProto(GlobalValueLinkageOpInterface sgv,
return newDst;
}

Operation *MLIRLinker::materialize(GlobalValueLinkageOpInterface v,
bool forIndirectSymbol) {
Operation *op = v.getOperation();
Operation *MLIRLinker::materialize(Operation *v, bool forIndirectSymbol) {
auto sgv = dyn_cast<GlobalValueLinkageOpInterface>(v);
if (!sgv)
return nullptr;

// If v is from dest, it was already materialized when dest was loaded.
if (op->getParentOp() == composite)
if (v->getParentOp() == composite)
return nullptr;

// When linking a global from other modules than source & dest, skip
// materializing it because it would be mapped later when its containing
// module is linked. Linking it now would potentially pull in many types that
// may not be mapped properly.
if (op->getParentOp() != src.get())
if (v->getParentOp() != src.get())
return nullptr;

auto newProto = linkGlobalValueProto(v, false);
auto newProto = linkGlobalValueProto(sgv, false);
if (!newProto) {
setError(newProto.takeError());
return nullptr;
Expand All @@ -267,6 +417,10 @@ Operation *MLIRLinker::materialize(GlobalValueLinkageOpInterface v,
if (!f.isDeclarationForLinkage()) {
return *newProto;
}
} else if (auto var = dyn_cast<GlobalVariableLinkageOpInterface>(
newGvl.getOperation())) {
if (!var.isDeclarationForLinkage() || var.hasAppendingLinkage())
return *newProto;
}
// TODO: Lots of if cases for Function, global variable, global alias.
// for now, just check if it is a declaration, if so, not much more to do.
Expand All @@ -281,8 +435,8 @@ Operation *MLIRLinker::materialize(GlobalValueLinkageOpInterface v,
// new definition for the indirect symbol ("New" will be different).
// TODO: Some indirect symbol thing

if (forIndirectSymbol || shouldLink(newGvl.getOperation(), v.getOperation()))
setError(linkGlobalValueBody(newGvl.getOperation(), v));
if (forIndirectSymbol || shouldLink(newGvl.getOperation(), v))
setError(linkGlobalValueBody(newGvl.getOperation(), sgv));

// TODO: Update attributes
return newGvl.getOperation();
Expand All @@ -292,6 +446,7 @@ Operation *MLIRLinker::materialize(GlobalValueLinkageOpInterface v,
/// referenced are in Dest.
void MLIRLinker::linkGlobalVariable(Operation *dst,
GlobalVariableLinkageOpInterface src) {
// Figure out what the initializer looks like in the dest module.
// TODO: Schedule global init
// TODO: This will likely only need to happen for those that have an
// initializer, not for constants
Expand Down Expand Up @@ -319,6 +474,7 @@ Error MLIRLinker::linkFunctionBody(Operation *dst,

// TODO: several steps here, copy metadata, steal arg list and schedule
// remapfunction. What is needed?
mapper.scheduleRemapFunction(dst);

// auto target = src.getOperation()->clone(mapping);
// dst->replaceAllUsesWith(target);
Expand All @@ -340,10 +496,12 @@ Error MLIRLinker::linkGlobalValueBody(Operation *dst,
}

void MLIRLinker::flushRAUWorklist() {
SymbolTable syms(composite);
for (const auto &elem : rauwWorklist) {
Operation *oldOp, *newOp;
std::tie(oldOp, newOp) = elem;
oldOp->replaceAllUsesWith(newOp);
if (auto sym = dyn_cast<SymbolOpInterface>(newOp))
syms.replaceAllSymbolUses(oldOp, sym.getNameAttr(), composite);
oldOp->erase();
}
rauwWorklist.clear();
Expand Down Expand Up @@ -372,24 +530,25 @@ Error MLIRLinker::run() {

assert(!gvl.isDeclarationForLinkage());

// TODO: Is this the equivalent of Mapper.mapValue?
auto newGvl = materialize(gvl, false);
valueMap.map(gvl.getOperation(), newGvl);
mapper.mapValue(gvl);

if (foundError)
return std::move(*foundError);
flushRAUWorklist();
}

doneLinkingBodies = true;
mapper.addFlags(RF_NullMapMissingGlobalValues);

// Reorder the globals just added to the destination module to match their
// original order in the source module.
src->walk([&](GlobalValueLinkageOpInterface gv) {
src->walk([&](GlobalVariableLinkageOpInterface gv) {
if (gv.hasAppendingLinkage())
return WalkResult::skip();
if (auto op = valueMap.lookupOrNull(gv.getOperation())) {
if (auto newValue = dyn_cast<GlobalValueLinkageOpInterface>(op)) {
newValue->remove();
composite->getRegion(0).back().push_back(newValue);
if (auto newValue = mapper.mapValue(gv)) {
if (auto newGv = dyn_cast<GlobalVariableLinkageOpInterface>(newValue)) {
newGv->remove();
composite->getRegion(0).back().push_back(newGv);
}
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/mlir-link/functions-reverse.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: mlir-link -split-input-file %s | FileCheck %s

// CHECK: llvm.func @f1() {
// CHECK: llvm.func @f2() {
// CHECK-NEXT: llvm.call @f1() : () -> ()
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }
// CHECK-NEXT: llvm.func @f2() {
// CHECK-NEXT: llvm.call @f1() : () -> ()
// CHECK: llvm.func @f1() {
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/mlir-link/functions.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: mlir-link -split-input-file %s | FileCheck %s

// CHECK: llvm.func @f2() {
// CHECK-NEXT: llvm.call @f1() : () -> ()
// CHECK: llvm.func @f1() {
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }
// CHECK-NEXT: llvm.func @f1() {
// CHECK: llvm.func @f2() {
// CHECK-NEXT: llvm.call @f1() : () -> ()
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }

Expand Down
6 changes: 4 additions & 2 deletions mlir/test/mlir-link/linkage-a.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// RUN: mlir-link -split-input-file %s | FileCheck %s

// CHECK: llvm.mlir.global external @X() {addr_space = 0 : i32} : i32
// CHECK: llvm.func @foo() -> i32

// CHECK: llvm.func @bar()

// CHECK: llvm.func @foo() -> i32
// CHECK: llvm.mlir.global external @X() {addr_space = 0 : i32} : i32



llvm.mlir.global linkonce @X(5 : i32) {addr_space = 0 : i32} : i32
llvm.func linkonce @foo() -> i32 {
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-link/single-global-usage.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ llvm.mlir.global @number(7 : i32) : i32


// -----

llvm.mlir.global @number() : i32

Loading

0 comments on commit 65939ae

Please sign in to comment.