Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Extending forwarding of fusion segmenter #3659

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 78 additions & 19 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4016,10 +4016,11 @@ void SegmentCandidateFinder::findSegments() {
}

// Decides whether we should forward an input (or a forwarded input) of a
// fusion. Currently, we forward an input only when its single use is a UnaryOp.
// Therefore, this function returns `v`'s single unary use or nullptr if it
// fusion. Currently, we forward an input only when its single use is
// a UnaryOp or a set-like op.
// Therefore, this function returns `v`'s single use or nullptr if it
// decides not to forward.
UnaryOp* shouldForward(Val* v) {
Expr* shouldForward(Val* v) {
const std::vector<Expr*>& uses = v->uses();
// Just allow stripping out input with single use.
// Stripping out multi-used inputs can lead to:
Expand All @@ -4029,23 +4030,74 @@ UnaryOp* shouldForward(Val* v) {
return nullptr;
}

auto* unary_use = dynamic_cast<UnaryOp*>(uses.front());
if (unary_use == nullptr) {
auto* use_of_v = uses.front();
if (!use_of_v
->isOneOf<UnaryOp, LoadStoreOp, BroadcastOp, ExpandOp, ViewOp>()) {
return nullptr;
}

// For LoadStoreOp, only allow trivial set
if (auto load_store = dynamic_cast<LoadStoreOp*>(use_of_v)) {
if (load_store->opType() != LoadStoreOpType::Set) {
return nullptr;
}
// Don't allow anything with root-logical transforms or reordering
if (auto out_tv = dynamic_cast<TensorView*>(load_store->out());
out_tv != nullptr && out_tv->hasRoot()) {
return nullptr;
}
}

if (auto reshape = dynamic_cast<ViewOp*>(use_of_v)) {
// Don't forward reshape with split since the fusion would not be
// able to see the connection between split output IDs, which
// might result in missing fusion opportunities. Note that
// merge should be fine, although merge after reduction may
// potentially result in an unschedulable fusion, since the
// condition is already enforced by all of the reduction-related
// schedulers. See NVFuserTest..ForwardReshapePostReduction for a
// concrete example.
auto reshape_out = reshape->out();
auto reshape_exprs = DependencyCheck::getAllExprsBetween(
{reshape_out->getRootDomain().begin(),
reshape_out->getRootDomain().end()},
{reshape_out->getLogicalDomain().begin(),
reshape_out->getLogicalDomain().end()});
if (std::any_of(reshape_exprs.begin(), reshape_exprs.end(), [](Expr* expr) {
return expr->isA<Split>();
})) {
return nullptr;
}
}

auto consumer_of_v = use_of_v->output(0);

// Don't forward if the input or the output is DID parallelized
if (auto input_tv = dynamic_cast<TensorView*>(v); input_tv != nullptr) {
auto output_tv = dynamic_cast<TensorView*>(consumer_of_v);
NVF_ERROR(output_tv != nullptr);
for (auto tv : {input_tv, output_tv}) {
if (std::any_of(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[](auto loop_id) { return loop_id->isDeviceDim(); })) {
return nullptr;
}
}
}

// Don't forward an input to an output yet. Doing that would lead to an empty
// group that ought to work in theory but doesn't work in practice with the
// downstream logic. See #1813 for an example.
if (unary_use->out()->isFusionOutput()) {
if (consumer_of_v->isFusionOutput()) {
return nullptr;
}

// prevent forward to a SegmenterSet, which could cause unary op forward to a
// no-op segment. See issue: https://github.com/NVIDIA/Fuser/issues/2658
if (std::any_of(
unary_use->out()->uses().begin(),
unary_use->out()->uses().end(),
consumer_of_v->uses().begin(),
consumer_of_v->uses().end(),
[](const Expr* next_use) {
if (const LoadStoreOp* use =
dynamic_cast<const LoadStoreOp*>(next_use)) {
Expand All @@ -4058,7 +4110,7 @@ UnaryOp* shouldForward(Val* v) {
return nullptr;
}

return unary_use;
return use_of_v;
}

void SegmentCandidateFinder::forwardInputs() {
Expand All @@ -4069,27 +4121,27 @@ void SegmentCandidateFinder::forwardInputs() {
// treated as complete fusion inputs.
VectorOfUniqueEntries<Val*> forwarded_inputs;
{
std::deque<UnaryOp*> to_visit;
std::deque<Expr*> to_visit;
for (Val* inp : completeFusion()->inputs()) {
if (UnaryOp* unary_use = shouldForward(inp)) {
to_visit.push_back(unary_use);
if (Expr* use_of_inp = shouldForward(inp)) {
to_visit.push_back(use_of_inp);
}
}

while (!to_visit.empty()) {
UnaryOp* uop = to_visit.front();
Expr* expr = to_visit.front();
to_visit.pop_front();

if (UnaryOp* unary_use = shouldForward(uop->out())) {
to_visit.push_back(unary_use);
if (Expr* use_of_out = shouldForward(expr->output(0))) {
to_visit.push_back(use_of_out);
} else {
// We cannot extend the chain of unary ops, so we finalize this chain by
// saving its output as a forwarded input.
forwarded_inputs.pushBack(uop->out());
forwarded_inputs.pushBack(expr->output(0));
}
// Either way, `uop` is excluded from merging until
// Either way, `expr` is excluded from merging until
// `resolveNonscalarForwardedInput` adds it back to one of the segments.
excluded_inp_unary_exprs_.pushBack(uop);
excluded_inp_unary_exprs_.pushBack(expr);
}
}

Expand Down Expand Up @@ -4346,7 +4398,14 @@ void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) {

SegmentedGroup* SegmentCandidateFinder::createInputGroup(Val* forwarded_input) {
SegmentedGroup* group = segmented_fusion_->newGroup();
group->input_vals = IterVisitor::getInputsTo({forwarded_input});
auto inputs = IterVisitor::getInputsTo({forwarded_input});
for (auto inp : inputs) {
// Don't add scalars here as they are added elsewhere
if (inp->isScalar()) {
continue;
}
group->input_vals.push_back(inp);
}
group->exprs_ = StmtSort::getExprsTo({forwarded_input});
return group;
}
Expand Down
54 changes: 54 additions & 0 deletions tests/cpp/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9349,6 +9349,60 @@ TEST_F(NVFuserTest, RepeatBroadcastAndNonBroadcast) {
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

// Testing the forwading of the fusion segmenter. The reshape of tv1,
// which merges two IDs, is forwarded. Those two IDs are mapped to a
// non-reduction ID and a reduction ID, respectively, so the reshape
// should be fused with the reduction. This condition is enforced by
// the reduction-related schedulers (in this case the inner
// normalization scheduler), so the fusion should be segmented before
// the reshape of tv4. For the remaining ops, forwarding the tv1
// reshape should not be a problem.
TEST_F(NVFuserTest, ForwardReshapePostReduction) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
auto tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);

auto tv2 = sum(tv0, {1});
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = add(tv3, tv0);
auto tv5 = reshape(tv4, {IrBuilder::create<Val>(-1)});
auto tv6 = reshape(tv1, {IrBuilder::create<Val>(-1)});
auto tv7 = add(tv5, tv6);

fusion.addOutput(tv7);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::ones({10, 20}, options);
at::Tensor t1 = at::ones({10, 20}, options);
std::vector<c10::IValue> inputs = {t0, t1};

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs(inputs);
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);

// The fusion should be segmented into two kernels, one of which
// should have both of the two view ops.
auto kernel_runtime = executor_cache.getMostRecentKernelRuntime();
const auto num_segments = kernel_runtime->fusionSegments()->groups().size();
EXPECT_EQ(num_segments, 2);
for (const auto i : c10::irange(num_segments)) {
const auto& exec = kernel_runtime->executors().at(i);
int64_t num_view_ops = 0;
for (auto expr : exec->as<KernelExecutor>()->fusion()->exprs()) {
if (expr->isA<ViewOp>()) {
++num_view_ops;
}
}
EXPECT_TRUE(num_view_ops == 0 || num_view_ops == 2)
<< "Unexpected number of ViewOp: " << num_view_ops;
}
}

// Test file size should be up to 10K LoC. Create a new file for more tests.

} // namespace nvfuser
22 changes: 13 additions & 9 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,18 +975,22 @@ TEST_F(PointwiseTest, DomainMapFactory) {
FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
SegmentedFusion* segmented_fusion = runtime->fusionSegments();
// This fusion currently cannot be scheduled as a single kernel. It is
// expected to be segmented as: g{(pointwise)
// inputs: tv0, tv1
// outputs: tv2, tv3
// expected to be segmented as:
//
// g{(pointwise)
// inputs: tv0
// outputs: tv6
// tv2 = broadcast(tv0)
// tv3 = add (tv2, broadcast(tv1))
// tv5 = full({4, 1, i0})
// tv6 = add (tv2, tv5)
// }
//
// g{(pointwise)
// inputs: tv2
// outputs: tv5
// tv4 = full({4, 1, i0})
// tv5 = mul(tv2, tv4)
// inputs: tv0, tv1
// outputs: tv4
// tv3 = broadcast(tv1)
// tv2 = broadcast(tv0)
// tv4 = add(tv2, tv3)
// }
EXPECT_EQ(segmented_fusion->groups().size(), 2);

Expand All @@ -998,7 +1002,7 @@ TEST_F(PointwiseTest, DomainMapFactory) {
});
if (num_full != 0) {
// this is the segment contains the factory op.
EXPECT_EQ(exprs.size(), 2);
EXPECT_EQ(exprs.size(), 3);
EXPECT_EQ(num_full, 1);
auto binary_op_iter =
std::find_if(exprs.begin(), exprs.end(), [](Expr* expr) {
Expand Down
Loading