diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h index 313cb5171e75e..eafcc03d091f9 100644 --- a/mlir/include/mlir/Transforms/LoopFusionUtils.h +++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h @@ -50,11 +50,9 @@ struct FusionStrategy { None, // Generic fusion. No assumtions are made. ProducerConsumer, // Producer-consumer fusion from AffineLoopFusion pass. Sibling // Sibling fusion from AffineLoopFusion pass. - } strategy; + } value; - Value memref; - FusionStrategy(StrategyEnum strategy, Value memref) - : strategy(strategy), memref(memref) {} + FusionStrategy(StrategyEnum value) : value(value) {} }; /// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the @@ -65,10 +63,10 @@ struct FusionStrategy { /// NOTE: This function is not feature complete and should only be used in /// testing. /// TODO: Update comments when this function is fully implemented. -FusionResult -canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, - ComputationSliceState *srcSlice, - FusionStrategy fusionStrategy = {FusionStrategy::None, Value()}); +FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, + unsigned dstLoopDepth, + ComputationSliceState *srcSlice, + FusionStrategy fusionStrategy = FusionStrategy::None); /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point /// and source slice loop bounds specified in 'srcSlice'. @@ -112,6 +110,11 @@ bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, const ComputationSliceState &slice, int64_t *computeCost); +// TODO. +void gatherProducerConsumerMemrefs(ArrayRef srcOps, + ArrayRef dstOps, + DenseSet &producerConsumerMemrefs); + } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1075486378467..cc39ae308efe0 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -30,6 +30,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include +#include #include #define DEBUG_TYPE "affine-loop-fusion" @@ -270,45 +271,6 @@ struct MemRefDependenceGraph { return false; } - // Returns the unique AffineWriteOpInterface in `node` that meets all the - // following: - // *) store is the only one that writes to a function-local memref live out - // of `node`, - // *) store is not the source of a self-dependence on `node`. - // Otherwise, returns a null AffineWriteOpInterface. - AffineWriteOpInterface getUniqueOutgoingStore(Node *node) { - AffineWriteOpInterface uniqueStore; - - // Return null if `node` doesn't have any outgoing edges. - auto outEdgeIt = outEdges.find(node->id); - if (outEdgeIt == outEdges.end()) - return nullptr; - - const auto &nodeOutEdges = outEdgeIt->second; - for (auto *op : node->stores) { - auto storeOp = cast(op); - auto memref = storeOp.getMemRef(); - // Skip this store if there are no dependences on its memref. This means - // that store either: - // *) writes to a memref that is only read within the same loop nest - // (self-dependence edges are not represented in graph at the moment), - // *) writes to a function live out memref (function parameter), or - // *) is dead. - if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) { - return (edge.value != memref); - })) - continue; - - if (uniqueStore) - // Found multiple stores to function-local live-out memrefs. - return nullptr; - // Found first store to function-local live-out memref. - uniqueStore = storeOp; - } - - return uniqueStore; - } - // Returns true if node 'id' can be removed from the graph. Returns false // otherwise. A node can be removed from the graph iff the following // conditions are met: @@ -495,42 +457,49 @@ struct MemRefDependenceGraph { return dstNodeInst; } - // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' - // has been replaced in node at 'dstId' by a private memref depending - // on the value of 'createPrivateMemRef'. - void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef, - bool createPrivateMemRef) { + // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, + // taking into account that: + // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, + // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a + // private memref. + void updateEdges(unsigned srcId, unsigned dstId, + const DenseSet &privateMemRefs, bool removeSrcId) { // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. if (inEdges.count(srcId) > 0) { SmallVector oldInEdges = inEdges[srcId]; for (auto &inEdge : oldInEdges) { - // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'. - if (inEdge.value != oldMemRef) + // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. + if (privateMemRefs.count(inEdge.value) == 0) addEdge(inEdge.id, dstId, inEdge.value); } } // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. + // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. if (outEdges.count(srcId) > 0) { SmallVector oldOutEdges = outEdges[srcId]; for (auto &outEdge : oldOutEdges) { // Remove any out edges from 'srcId' to 'dstId' across memrefs. if (outEdge.id == dstId) removeEdge(srcId, outEdge.id, outEdge.value); + else if (removeSrcId) { + addEdge(dstId, outEdge.id, outEdge.value); + removeEdge(srcId, outEdge.id, outEdge.value); + } } } // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being // replaced by a private memref). These edges could come from nodes // other than 'srcId' which were removed in the previous step. - if (inEdges.count(dstId) > 0 && createPrivateMemRef) { + if (inEdges.count(dstId) > 0 && privateMemRefs.size() > 0) { SmallVector oldInEdges = inEdges[dstId]; for (auto &inEdge : oldInEdges) - if (inEdge.value == oldMemRef) + if (privateMemRefs.count(inEdge.value) > 0) removeEdge(inEdge.id, dstId, inEdge.value); } } // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion - // of sibling node 'sidId' into node 'dstId'. + // of sibling node 'sibId' into node 'dstId'. void updateEdges(unsigned sibId, unsigned dstId) { // For each edge in 'inEdges[sibId]': // *) Add new edge from source node 'inEdge.id' to 'dstNode'. @@ -624,6 +593,101 @@ struct MemRefDependenceGraph { void dump() const { print(llvm::errs()); } }; +// TODO +bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, + Operation *fusedLoopInsPoint, + MemRefDependenceGraph *mdg) { + if (mdg->writesToLiveInOrEscapingMemrefs(srcId)) + return false; + + // TODO: Use domination information for this analysis when more complex + // scenarios are needed. + Operation *dstNodeOp = mdg->getNode(dstId)->op; + for (auto &outEdge : mdg->outEdges[srcId]) { + Operation *depNodeOp = mdg->getNode(outEdge.id)->op; + // Skip dependence with dstOp since it will be removed after fusion. + if (depNodeOp == dstNodeOp) + continue; + + // Only fusion within the same block is supported. Be conservative for now. + // Use domination when needed. + if (depNodeOp->getBlock() != dstNodeOp->getBlock()) + return false; + + // Check if the insertion point of the fused loop dominates the dependency. + // Otherwise, the src loop can't be removed. + if (fusedLoopInsPoint != depNodeOp && + !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) + return false; + } + + return true; +} + +// Return in 'srcIdCandidates' the fusion candidates for 'dstId'. 'std::set' +// keeps candidates sorted by node id, which also corresponds to the program +// order of the nodes before fusion. However, this property may not hold after +// fusing any pair of loops. +void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, + std::set &srcIdCandidates) { + // Skip if no input edges along which to fuse. + if (mdg->inEdges.count(dstId) == 0) + return; + + // Gather memrefs from loads in 'dstId'. + auto *dstNode = mdg->getNode(dstId); + DenseSet consumedMemrefs; + for (Operation *load : dstNode->loads) + consumedMemrefs.insert(cast(load).getMemRef()); + + // TODO. + for (auto &srcEdge : mdg->inEdges[dstId]) { + auto *srcNode = mdg->getNode(srcEdge.id); + // Skip if 'srcNode' is not a loop nest. + if (!isa(srcNode->op)) + continue; + + // Add srcNode candidate if it contains a store to one of the consumed + // memrefs. + if (any_of(srcNode->stores, [&](Operation *op) { + auto storeOp = cast(op); + return consumedMemrefs.count(storeOp.getMemRef()) > 0; + })) { + srcIdCandidates.insert(srcNode->id); + } + } +} + +// TODO +void gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, + MemRefDependenceGraph *mdg, + DenseSet &producerConsumerMemrefs) { + auto *dstNode = mdg->getNode(dstId); + auto *srcNode = mdg->getNode(srcId); + gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, + producerConsumerMemrefs); +} + +// TODO +void gatherLiveInOrEscapingMemrefs(unsigned id, + MemRefDependenceGraph *mdg, + DenseSet &liveInOrEscapingMemRefs) { + auto *node = mdg->getNode(id); + for (auto *storeOpInst : node->stores) { + auto memref = cast(storeOpInst).getMemRef(); + if (liveInOrEscapingMemRefs.count(memref)) + continue; + auto *op = memref.getDefiningOp(); + // Return true if 'memref' is a block argument. + if (!op) + liveInOrEscapingMemRefs.insert(memref); + // Return true if any use of 'memref' escapes the function. + for (auto *user : memref.getUsers()) + if (!isMemRefDereferencingOp(*user)) + liveInOrEscapingMemRefs.insert(memref); + } +} + } // end anonymous namespace // Initializes the data dependence graph by walking operations in 'f'. @@ -631,6 +695,7 @@ struct MemRefDependenceGraph { // TODO: Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(FuncOp f) { + LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); DenseMap> memrefAccesses; // TODO: support multi-block functions. @@ -686,6 +751,10 @@ bool MemRefDependenceGraph::init(FuncOp f) { } } + for (auto &idAndNode : nodes) + LLVM_DEBUG(llvm::dbgs() << "Created node " << idAndNode.first << " for:\n" + << *(idAndNode.second.op) << "\n"); + // Add dependence edges between nodes which produce SSA values and their // users. for (auto &idAndNode : nodes) { @@ -725,22 +794,6 @@ bool MemRefDependenceGraph::init(FuncOp f) { return true; } -// Removes load operations from 'srcLoads' which operate on 'memref', and -// adds them to 'dstLoads'. -static void moveLoadsAccessingMemrefTo(Value memref, - SmallVectorImpl *srcLoads, - SmallVectorImpl *dstLoads) { - dstLoads->clear(); - SmallVector srcLoadsToKeep; - for (auto *load : *srcLoads) { - if (cast(load).getMemRef() == memref) - dstLoads->push_back(load); - else - srcLoadsToKeep.push_back(load); - } - srcLoads->swap(srcLoadsToKeep); -} - // Sinks all sequential loops to the innermost levels (while preserving // relative order among them) and moves all parallel loops to the // outermost (while again preserving relative order among them). @@ -1039,6 +1092,7 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. +// TODO: Extend this to support multiple stores. static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, ArrayRef dstLoadOpInsts, ArrayRef depthSliceUnions, @@ -1321,8 +1375,6 @@ struct GreedyFusion { MemRefDependenceGraph *mdg; // Worklist of graph nodes visited during the fusion pass. SmallVector worklist; - // Set of graph nodes which are present on the worklist. - llvm::SmallDenseSet worklistSet; // Parameter for local buffer size threshold. unsigned localBufSizeThreshold; // Parameter for fast memory space. @@ -1343,16 +1395,15 @@ struct GreedyFusion { fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), computeToleranceThreshold(computeToleranceThreshold) {} - // Initializes 'worklist' with nodes from 'mdg' + /// Initializes 'worklist' with nodes from 'mdg' in reverse topological order + /// so that they are visited in topological order. void init() { // TODO: Add a priority queue for prioritizing nodes by different // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). worklist.clear(); - worklistSet.clear(); for (auto &idAndNode : mdg->nodes) { const Node &node = idAndNode.second; worklist.push_back(node.id); - worklistSet.insert(node.id); } } @@ -1371,11 +1422,11 @@ struct GreedyFusion { } void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { + LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); init(); while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); - worklistSet.erase(dstId); // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) @@ -1385,114 +1436,131 @@ struct GreedyFusion { // Skip if 'dstNode' is not a loop nest. if (!isa(dstNode->op)) continue; + + LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); + // Sink sequential loops in 'dstNode' (and thus raise parallel loops) // while preserving relative order. This can increase the maximum loop // depth at which we can fuse a slice of a producer loop nest into a // consumer loop nest. sinkSequentialLoops(dstNode); - - SmallVector loads = dstNode->loads; - SmallVector dstLoadOpInsts; - DenseSet visitedMemrefs; - while (!loads.empty()) { - // Get memref of load on top of the stack. - auto memref = cast(loads.back()).getMemRef(); - if (visitedMemrefs.count(memref) > 0) - continue; - visitedMemrefs.insert(memref); - // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'. - moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts); - // Skip if no input edges along which to fuse. - if (mdg->inEdges.count(dstId) == 0) - continue; - // Iterate through in-edges for 'dstId' and src node id for any - // edges on 'memref'. - SmallVector srcNodeIds; - for (auto &srcEdge : mdg->inEdges[dstId]) { - // Skip 'srcEdge' if not for 'memref'. - if (srcEdge.value != memref) - continue; - srcNodeIds.push_back(srcEdge.id); - } - for (unsigned srcId : srcNodeIds) { - // Skip if this node was removed (fused into another node). - if (mdg->nodes.count(srcId) == 0) - continue; + auto dstAffineForOp = cast(dstNode->op); + + // Try to fuse 'dstNode' with candidate producer loops until a fixed point + // is reached. Fusing two loops may expose new fusion opportunities. + bool dstNodeChanged; + do { + // Gather src loop candidates for 'dstNode' and visit them in "quasi" + // reverse program order to minimize the number of iterations needed to + // reach the fixed point. Note that this is a best effort approach since + // 'getProducerCandidates' does not always guarantee that program order + // in 'srcIdCandidats'. + dstNodeChanged = false; + std::set srcIdCandidates; + getProducerCandidates(dstId, mdg, srcIdCandidates); + + for (unsigned srcId : llvm::reverse(srcIdCandidates)) { // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); - // Skip if 'srcNode' is not a loop nest. - if (!isa(srcNode->op)) - continue; - // Skip if 'srcNode' has more than one live-out store to a - // function-local memref. - // TODO: Support more generic multi-output src loop nests - // fusion. - auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode); - if (!srcStoreOp) { - // Get the src store op at the deepest loop depth. - // We will use 'LoopFusionUtils::canFuseLoops' to check fusion - // feasibility for loops with multiple stores. - unsigned maxLoopDepth = 0; - for (auto *op : srcNode->stores) { - auto storeOp = cast(op); - if (storeOp.getMemRef() != memref) { - srcStoreOp = nullptr; - break; - } - unsigned loopDepth = getNestingDepth(storeOp); - if (loopDepth > maxLoopDepth) { - maxLoopDepth = loopDepth; - srcStoreOp = storeOp; - } - } - if (!srcStoreOp) - continue; - } + auto srcAffineForOp = cast(srcNode->op); + LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId + << " for dst loop " << dstId << "\n"); - // Unique outgoing store found must write to 'memref' since 'memref' - // is the one that established the producer-consumer relationship - // between 'srcNode' and 'dstNode'. - assert(srcStoreOp.getMemRef() == memref && - "Found store to unexpected memref"); + // TODO. + DenseSet producerConsumerMemrefs; + gatherProducerConsumerMemrefs(srcId, dstId, mdg, + producerConsumerMemrefs); + // TODO: Check if this is still needed. // Skip if 'srcNode' writes to any live in or escaping memrefs, // and cannot be fused. - bool writesToLiveInOrOut = - mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); - if (writesToLiveInOrOut && - !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) - continue; + DenseSet liveInOrOutMemrefs; + gatherLiveInOrEscapingMemrefs(srcNode->id, mdg, liveInOrOutMemrefs); + bool writesToLiveInOrOut = liveInOrOutMemrefs.size() > 0; + if (writesToLiveInOrOut) { + // Skip if 'srcNode' writes to a live in or escaping memref that is + // not involved the the producer-consumer relationship. + if (any_of(liveInOrOutMemrefs, [&](Value memref) { + return !producerConsumerMemrefs.count(memref); + })) + continue; - // Don't create a private memref if 'writesToLiveInOrOut'. - bool createPrivateMemref = !writesToLiveInOrOut; - // Don't create a private memref if 'srcNode' has in edges on - // 'memref', or if 'dstNode' has out edges on 'memref'. - if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 || - mdg->getOutEdgeCount(dstNode->id, memref) > 0) { - createPrivateMemref = false; + // TODO: Extend this check to properly support multiple stores, if + // needed. + // REVIEW: Not sure I understand the problematic fusion cases that + // we are trying to avoid with the check below since we are only + // fusing loops within the same Block. Any examples of this, please? + if (any_of(srcNode->stores, [&](Operation *op) { + auto storeOp = cast(op); + return !canFuseSrcWhichWritesToLiveOut(srcId, dstId, storeOp, + mdg); + })) + continue; } - // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. - if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) + // Skip if 'srcNode' out edge count on any memref is greater than + // 'maxSrcUserCount'. + if (any_of(producerConsumerMemrefs, [&](Value memref) { + return mdg->getOutEdgeCount(srcNode->id, memref) > + maxSrcUserCount; + })) continue; // Compute an operation list insertion point for the fused loop // nest which preserves dependences. - Operation *insertPointInst = + Operation *fusedLoopInsPoint = mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); - if (insertPointInst == nullptr) + if (fusedLoopInsPoint == nullptr) continue; - auto srcAffineForOp = cast(srcNode->op); - auto dstAffineForOp = cast(dstNode->op); + // TODO + // Remove old src loop nest if it no longer has outgoing dependence + // edges, and if it does not write to a memref which escapes the + // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has + // been fused into 'dstNode' and write region of 'dstNode' covers + // the write region of 'srcNode', and 'srcNode' has no other users + // so it is safe to remove. + bool removeSrcNode = + writesToLiveInOrOut || + canRemoveSrcNodeAfterFusion(srcId, dstId, fusedLoopInsPoint, mdg); + + DenseSet privateMemrefs; + for (Value memref : producerConsumerMemrefs) { + // Don't create a private memref if 'writesToLiveInOrOut'. + if (liveInOrOutMemrefs.count(memref) > 0) + continue; + + // Don't create a private memref if 'srcNode' has in edges on + // 'memref' or 'dstNode' has out edges on 'memref'. + if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 || + mdg->getOutEdgeCount(dstId, memref) > 0) + continue; + + // If 'srcNode' can be removed but it has out edges on 'memref' to + // nodes besides 'dstNode', creating a private memref would prevent + // the removal since stores to the old memref still need to happen. + // We prioritize the 'srcNode' removal and do not create a private + // memref since the original memref won't be eliminated anyways. + if (removeSrcNode && + any_of(mdg->outEdges[srcId], [&](const auto &edge) { + return edge.value == memref && edge.id != dstId; + })) + continue; + + // Create a private version of this memref. + privateMemrefs.insert(memref); + } - // Compute the innermost common loop depth for dstNode loads/stores. + // Compute the innermost common loop depth for dstNode + // producer-consumer loads/stores. SmallVector dstMemrefOps; for (Operation *op : dstNode->loads) - if (cast(op).getMemRef() == memref) + if (producerConsumerMemrefs.count( + cast(op).getMemRef()) > 0) dstMemrefOps.push_back(op); for (Operation *op : dstNode->stores) - if (cast(op).getMemRef() == memref) + if (producerConsumerMemrefs.count( + cast(op).getMemRef())) dstMemrefOps.push_back(op); unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); @@ -1501,7 +1569,7 @@ struct GreedyFusion { unsigned maxLegalFusionDepth = 0; SmallVector depthSliceUnions; depthSliceUnions.resize(dstLoopDepthTest); - FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref); + FusionStrategy strategy(FusionStrategy::ProducerConsumer); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { FusionResult result = mlir::canFuseLoops( srcAffineForOp, dstAffineForOp, @@ -1511,19 +1579,17 @@ struct GreedyFusion { maxLegalFusionDepth = i; } - // Skip if fusion is not feasible at any loop depths. - if (maxLegalFusionDepth == 0) - continue; - // Check if fusion would be profitable. We skip profitability analysis // for maximal fusion since we already know the maximal legal depth to // fuse. unsigned bestDstLoopDepth = maxLegalFusionDepth; - if (!maximalFusion && - !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, - depthSliceUnions, maxLegalFusionDepth, - &bestDstLoopDepth, computeToleranceThreshold)) - continue; + // TODO +// if (!maximalFusion && +// !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, +// depthSliceUnions, +// maxLegalFusionDepth, &bestDstLoopDepth, +// computeToleranceThreshold)) +// continue; assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && @@ -1532,6 +1598,7 @@ struct GreedyFusion { // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. fuseLoops(srcAffineForOp, dstAffineForOp, depthSliceUnions[bestDstLoopDepth - 1]); + dstNodeChanged = true; LLVM_DEBUG(llvm::dbgs() << "Fused src loop " << srcId << " into dst loop " << dstId @@ -1539,18 +1606,20 @@ struct GreedyFusion { << dstAffineForOp << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. - if (insertPointInst != dstAffineForOp.getOperation()) - dstAffineForOp.getOperation()->moveBefore(insertPointInst); + if (fusedLoopInsPoint != dstAffineForOp.getOperation()) + dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint); // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, memref, - createPrivateMemref); + mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, + removeSrcNode); // Collect slice loop stats. LoopNestStateCollector dstForCollector; dstForCollector.collect(dstAffineForOp); - if (createPrivateMemref) { + for (Value memref : privateMemrefs) { // Create private memref for 'memref' in 'dstAffineForOp'. + // TODO: remove storesForMemref and move the code below to the + // loop-if. SmallVector storesForMemref; for (auto *storeOpInst : dstForCollector.storeOpInsts) { if (cast(storeOpInst).getMemRef() == @@ -1562,7 +1631,6 @@ struct GreedyFusion { auto newMemRef = createPrivateMemRef( dstAffineForOp, storesForMemref[0], bestDstLoopDepth, fastMemorySpace, localBufSizeThreshold); - visitedMemrefs.insert(newMemRef); // Create new node in dependence graph for 'newMemRef' alloc op. unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); // Add edge from 'newMemRef' node to dstNode. @@ -1574,58 +1642,21 @@ struct GreedyFusion { LoopNestStateCollector dstLoopCollector; dstLoopCollector.collect(dstAffineForOp.getOperation()); - // Add new load ops to current Node load op list 'loads' to - // continue fusing based on new operands. - for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - // NOTE: Change 'loads' to a hash set in case efficiency is an - // issue. We still use a vector since it's expected to be small. - if (!llvm::is_contained(loads, loadOpInst)) - loads.push_back(loadOpInst); - } - // Clear visited memrefs after fusion so that previously visited - // src nodes are considered for fusion again in the context of the - // new fused node. - // TODO: This shouldn't be necessary if we visited candidates in - // the dependence graph in post-order or once we fully support - // multi-store producers. Currently, in a multi-store producer - // scenario such as A->B, A->C, B->C, we fail to fuse A+B due to - // the multiple outgoing edges. However, after fusing B+C, A has a - // single outgoing edge and can be fused if we revisit it in the - // context of the new fused B+C node. - visitedMemrefs.clear(); - // Clear and add back loads and stores. mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts); - // Remove old src loop nest if it no longer has outgoing - // dependence edges, and if it does not write to a memref which - // escapes the function. If 'writesToLiveInOrOut' is true, then - // 'srcNode' has been fused into 'dstNode' and write region of - // 'dstNode' covers the write region of 'srcNode', and 'srcNode' - // has no other users so it is safe to remove. - if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { - mdg->removeNode(srcNode->id); + + if (removeSrcNode) { + LLVM_DEBUG(llvm::dbgs() + << "Removing src loop " << srcId << " after fusion\n"); + // srcNode is no longer valid after removing it from mdg. srcNode->op->erase(); - } else { - // Add remaining users of 'oldMemRef' back on the worklist (if - // not already there), as its replacement with a local/private - // memref has reduced dependences on 'oldMemRef' which may have - // created new fusion opportunities. - if (mdg->outEdges.count(srcNode->id) > 0) { - SmallVector oldOutEdges = - mdg->outEdges[srcNode->id]; - for (auto &outEdge : oldOutEdges) { - if (outEdge.value == memref && - worklistSet.count(outEdge.id) == 0) { - worklist.push_back(outEdge.id); - worklistSet.insert(outEdge.id); - } - } - } + mdg->removeNode(srcId); + srcNode = nullptr; } } - } + } while (dstNodeChanged); } } @@ -1636,7 +1667,6 @@ struct GreedyFusion { while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); - worklistSet.erase(dstId); // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) @@ -1698,7 +1728,7 @@ struct GreedyFusion { SmallVector depthSliceUnions; depthSliceUnions.resize(dstLoopDepthTest); unsigned maxLegalFusionDepth = 0; - FusionStrategy strategy(FusionStrategy::Sibling, memref); + FusionStrategy strategy(FusionStrategy::Sibling); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { FusionResult result = mlir::canFuseLoops( sibAffineForOp, dstAffineForOp, diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 87f6bd7055cc9..bb34af004eb11 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -191,11 +191,8 @@ gatherLoadsAndStores(AffineForOp forOp, // 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. // TODO: Generalize this check for sibling and more generic fusion scenarios. // TODO: Support forward slice fusion. -static unsigned getMaxLoopDepth(ArrayRef dstOps, - FusionStrategy fusionStrategy) { - assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer && - "Fusion strategy not supported"); - +static unsigned getMaxLoopDepth(ArrayRef srcOps, + ArrayRef dstOps) { if (dstOps.empty()) // Expected at least one memory operation. // TODO: Revisit this case with a specific example. @@ -203,15 +200,14 @@ static unsigned getMaxLoopDepth(ArrayRef dstOps, // Filter out ops in 'dstOps' that do not use the producer-consumer memref so // that they are not considered for analysis. - // TODO: Currently, we pass the producer-consumer memref through - // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we - // generalize the algorithm. + DenseSet producerConsumerMemrefs; + gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); SmallVector targetDstOps; for (Operation *dstOp : dstOps) { auto loadOp = dyn_cast(dstOp); Value memref = loadOp ? loadOp.getMemRef() : cast(dstOp).getMemRef(); - if (memref == fusionStrategy.memref) + if (producerConsumerMemrefs.count(memref) > 0) targetDstOps.push_back(dstOp); } @@ -308,10 +304,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, // loop dependences. // TODO: Enable this check for sibling and more generic loop fusion // strategies. - if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) { + if (fusionStrategy.value == FusionStrategy::ProducerConsumer) { // TODO: 'getMaxLoopDepth' does not support forward slice fusion. assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); - if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) { + if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); return FusionResult::FailFusionDependence; } @@ -324,7 +320,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, // Filter out ops in 'opsA' to compute the slice union based on the // assumptions expected by the fusion strategy. SmallVector strategyOpsA; - switch (fusionStrategy.strategy) { + switch (fusionStrategy.value) { case FusionStrategy::None: // Generic fusion. Take into account all the memory operations to compute // the slice union. @@ -332,19 +328,17 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, break; case FusionStrategy::ProducerConsumer: // Producer-consumer fusion (AffineLoopFusion pass) only takes into - // account stores to 'memref' in 'srcForOp' to compute the slice union. + // account stores in 'srcForOp' to compute the slice union. for (Operation *op : opsA) { - auto store = dyn_cast(op); - if (store && store.getMemRef() == fusionStrategy.memref) + if (isa(op)) strategyOpsA.push_back(op); } break; case FusionStrategy::Sibling: - // Sibling fusion (AffineLoopFusion pass) only takes into account the loads - // to 'memref' in 'srcForOp' to compute the slice union. + // Sibling fusion (AffineLoopFusion pass) only takes into account loads + // in 'srcForOp' to compute the slice union. for (Operation *op : opsA) { - auto load = dyn_cast(op); - if (load && load.getMemRef() == fusionStrategy.memref) + if (isa(op)) strategyOpsA.push_back(op); } break; @@ -628,3 +622,22 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); return true; } + +// TODO +void mlir::gatherProducerConsumerMemrefs( + ArrayRef srcOps, ArrayRef dstOps, + DenseSet &producerConsumerMemrefs) { + // Gather memrefs from stores in 'srcOps'. + DenseSet srcStoreMemRefs; + for (Operation *op : srcOps) + if (auto storeOp = dyn_cast(op)) + srcStoreMemRefs.insert(storeOp.getMemRef()); + + // Compute the intersection between memrefs from stores in 'srcOps' and + // memrefs from loads in 'dstOps'. + for (Operation *op : dstOps) + if (auto loadOp = dyn_cast(op)) + if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) + producerConsumerMemrefs.insert(loadOp.getMemRef()); +} + diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 1da4ae0566742..d99f956bf6b2c 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -378,14 +378,16 @@ func @should_fuse_with_private_memref_if_top_level_access() { %c0 = constant 4 : index %v1 = affine.load %m[%c0] : memref<10xf32> - // Top-level load to '%{{.*}}' should prevent fusion. - // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // Top-level load to '%{{.*}}' should prevent creating a private memref but + // loop nests should be fused and 'i0' should be removed. + // CHECK: %[[alloc:.*]] = alloc() : memref<10xf32> + // CHECK-NOT: alloc + + // CHECK: affine.for %[[i1:.*]] = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %[[alloc]][%[[i1]]] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%[[i1]]] : memref<10xf32> // CHECK-NEXT: } + // CHECK-NOT: affine.for return }