Skip to content

Commit

Permalink
Add WriteAfterWriteElimination pass (#2572)
Browse files Browse the repository at this point in the history
* Add RemoveUselessStores pass

Signed-off-by: Anna Gringauze <[email protected]>

* Address some CR comments

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

---------

Signed-off-by: Anna Gringauze <[email protected]>
  • Loading branch information
annagrin authored Feb 4, 2025
1 parent 10933d5 commit 0a8c67e
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 0 deletions.
25 changes: 25 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1038,4 +1038,29 @@ def UpdateRegisterNames : Pass<"update-register-names"> {
}];
}

def WriteAfterWriteElimination : Pass<"write-after-write-elimination"> {
let summary = "Remove stores that are overridden by subsequent store";
let description = [{
Remove stores to a location on the stack that have a subsequent store
to the same location without a use between them:

Example:
```mlir
%1 = cc.alloca !cc.array<i64 x 1>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
cc.store %c0_i64, %2 : !cc.ptr<i64>
// nothing using %2 until the next instruction
cc.store %c1_i64, %2 : !cc.ptr<i64>
```

would be converted to

```mlir
%1 = cc.alloca !cc.array<i64 x 1>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
cc.store %c1_i64, %2 : !cc.ptr<i64>
```
}];
}

#endif // CUDAQ_OPT_OPTIMIZER_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ add_cudaq_library(OptTransforms
StatePreparation.cpp
UnitarySynthesis.cpp
WiresToWiresets.cpp
WriteAfterWriteElimination.cpp

DEPENDS
OptTransformsPassIncGen
Expand Down
155 changes: 155 additions & 0 deletions lib/Optimizer/Transforms/WriteAfterWriteElimination.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*******************************************************************************
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

namespace cudaq::opt {
#define GEN_PASS_DEF_WRITEAFTERWRITEELIMINATION
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt

#define DEBUG_TYPE "write-after-write-elimination"

using namespace mlir;

namespace {
/// Remove stores followed by a store to the same pointer
/// if the pointer is not used in between.
/// ```
/// cc.store %c0_i64, %1 : !cc.ptr<i64>
/// // no use of %1 until next line
/// cc.store %0, %1 : !cc.ptr<i64>
/// ───────────────────────────────────────────
/// cc.store %0, %1 : !cc.ptr<i64>
/// ```
class SimplifyWritesAnalysis {
public:
SimplifyWritesAnalysis(DominanceInfo &di, Operation *op) : dom(di) {
for (auto &region : op->getRegions())
for (auto &b : region)
collectBlockInfo(&b);
}

/// Remove stores followed by a store to the same pointer if the pointer is
/// not used in between, using collected block info.
void removeOverriddenStores() {
SmallVector<Operation *> toErase;

for (const auto &[block, ptrToStores] : blockInfo) {
for (const auto &[ptr, stores] : ptrToStores) {
if (stores.size() > 1) {
auto replacement = stores.back();
for (auto it = stores.rend(); it != stores.rbegin(); it++) {
auto store = *it;
if (isReplacement(ptr, *store, *replacement)) {
LLVM_DEBUG(llvm::dbgs() << "replacing store " << store
<< " by: " << replacement << '\n');
toErase.push_back(store->getOperation());
}
}
}
}
}

for (auto *op : toErase)
op->erase();
}

private:
/// Detect if value is used in the op or its nested blocks.
bool isReplacement(Value ptr, cudaq::cc::StoreOp store,
cudaq::cc::StoreOp replacement) const {
// Check that there are no stores dominated by the store and not dominated
// by the replacement (i.e. used in between the store and the replacement)
for (auto *user : ptr.getUsers()) {
if (user != store && user != replacement) {
if (dom.dominates(store, user) && !dom.dominates(replacement, user)) {
LLVM_DEBUG(llvm::dbgs() << "store " << replacement
<< " is used before: " << store << '\n');
return false;
}
}
}
return true;
}

/// Collect all stores to a pointer for a block.
void collectBlockInfo(Block *block) {
for (auto &op : *block) {
for (auto &region : op.getRegions())
for (auto &b : region)
collectBlockInfo(&b);

if (auto store = dyn_cast<cudaq::cc::StoreOp>(&op)) {
auto ptr = store.getPtrvalue();
if (isStoreToStack(store)) {
auto ptrToStores = blockInfo.FindAndConstruct(block).second;
auto stores = ptrToStores.FindAndConstruct(ptr).second;
stores.push_back(&store);
}
}
}
}

/// Detect stores to stack locations, for example:
/// ```
/// %1 = cc.alloca !cc.array<i64 x 2>
///
/// %2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
/// cc.store %c0_i64, %2 : !cc.ptr<i64>
///
/// %3 = cc.compute_ptr %1[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
/// cc.store %c0_i64, %3 : !cc.ptr<i64>
/// ```
static bool isStoreToStack(cudaq::cc::StoreOp store) {
auto ptrOp = store.getPtrvalue();
if (auto cast = ptrOp.getDefiningOp<cudaq::cc::CastOp>())
ptrOp = cast.getOperand();

if (auto computePtr = ptrOp.getDefiningOp<cudaq::cc::ComputePtrOp>())
ptrOp = computePtr.getBase();

return isa_and_present<cudaq::cc::AllocaOp>(ptrOp.getDefiningOp());
}

DominanceInfo &dom;
DenseMap<Block *, DenseMap<Value, SmallVector<cudaq::cc::StoreOp *>>>
blockInfo;
};

class WriteAfterWriteEliminationPass
: public cudaq::opt::impl::WriteAfterWriteEliminationBase<
WriteAfterWriteEliminationPass> {
public:
using WriteAfterWriteEliminationBase::WriteAfterWriteEliminationBase;

void runOnOperation() override {
auto op = getOperation();
DominanceInfo domInfo(op);

LLVM_DEBUG(llvm::dbgs()
<< "Before write after write elimination: " << *op << '\n');

auto analysis = SimplifyWritesAnalysis(domInfo, op);
analysis.removeOverriddenStores();

LLVM_DEBUG(llvm::dbgs()
<< "After write after write elimination: " << *op << '\n');
}
};
} // namespace
85 changes: 85 additions & 0 deletions test/Quake/write_after_write_elimination.qke
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// ========================================================================== //
// Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

// RUN: cudaq-opt -write-after-write-elimination %s | FileCheck %s

func.func @test_two_stores_same_pointer() {
%c0_i64 = arith.constant 0 : i64
%0 = quake.alloca !quake.veq<2>
%1 = cc.const_array [1] : !cc.array<i64 x 1>
%2 = cc.extract_value %1[0] : (!cc.array<i64 x 1>) -> i64
%3 = cc.alloca !cc.array<i64 x 1>
%4 = cc.cast %3 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
cc.store %c0_i64, %4 : !cc.ptr<i64>
cc.store %2, %4 : !cc.ptr<i64>
%5 = cc.load %4 : !cc.ptr<i64>
%6 = quake.extract_ref %0[%5] : (!quake.veq<2>, i64) -> !quake.ref
quake.x %6 : (!quake.ref) -> ()
return
}

// CHECK-LABEL: func.func @test_two_stores_same_pointer() {
// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2>
// CHECK: %[[VAL_2:.*]] = cc.const_array [1] : !cc.array<i64 x 1>
// CHECK: %[[VAL_3:.*]] = cc.extract_value %[[VAL_2]][0] : (!cc.array<i64 x 1>) -> i64
// CHECK: %[[VAL_4:.*]] = cc.alloca !cc.array<i64 x 1>
// CHECK: %[[VAL_5:.*]] = cc.cast %[[VAL_4]] : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
// CHECK: cc.store %[[VAL_3]], %[[VAL_5]] : !cc.ptr<i64>
// CHECK: %[[VAL_6:.*]] = cc.load %[[VAL_5]] : !cc.ptr<i64>
// CHECK: %[[VAL_7:.*]] = quake.extract_ref %[[VAL_1]][%[[VAL_6]]] : (!quake.veq<2>, i64) -> !quake.ref
// CHECK: quake.x %[[VAL_7]] : (!quake.ref) -> ()
// CHECK: return
// CHECK: }

func.func @test_two_stores_different_pointers() {
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%0 = quake.alloca !quake.veq<2>
%1 = cc.alloca !cc.array<i64 x 1>
%2 = cc.alloca i64
cc.store %c0_i64, %2 : !cc.ptr<i64>
%3 = cc.alloca i64
cc.store %c1_i64, %3 : !cc.ptr<i64>
return
}

// CHECK-LABEL: func.func @test_two_stores_different_pointers() {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_2:.*]] = quake.alloca !quake.veq<2>
// CHECK: %[[VAL_3:.*]] = cc.alloca !cc.array<i64 x 1>
// CHECK: %[[VAL_4:.*]] = cc.alloca i64
// CHECK: cc.store %[[VAL_0]], %[[VAL_4]] : !cc.ptr<i64>
// CHECK: %[[VAL_5:.*]] = cc.alloca i64
// CHECK: cc.store %[[VAL_1]], %[[VAL_5]] : !cc.ptr<i64>
// CHECK: return
// CHECK: }

func.func @test_two_stores_same_pointer_interleaving() {
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%1 = cc.alloca !cc.array<i64 x 2>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c0_i64, %2 : !cc.ptr<i64>
%3 = cc.compute_ptr %1[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c0_i64, %3 : !cc.ptr<i64>
cc.store %c1_i64, %2 : !cc.ptr<i64>
cc.store %c1_i64, %3 : !cc.ptr<i64>
return
}

// CHECK-LABEL: func.func @test_two_stores_same_pointer_interleaving() {
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_1:.*]] = cc.alloca !cc.array<i64 x 2>
// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
// CHECK: %[[VAL_3:.*]] = cc.compute_ptr %[[VAL_1]][1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_2]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_3]] : !cc.ptr<i64>
// CHECK: return
// CHECK: }

0 comments on commit 0a8c67e

Please sign in to comment.