Skip to content

Commit

Permalink
Use the reference finder of pointwise scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 10, 2024
1 parent 77ea084 commit 7e7db61
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 19 deletions.
12 changes: 2 additions & 10 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,19 +403,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
return params;
}

// Return reference tensor view.
TensorView* getReferenceTensorView(Fusion* fusion) {
FusionGuard fg(fusion);
pointwise_utils::DomainMap domain_map(fusion);
auto reference_tv = domain_map.findReferenceTensorView();
return reference_tv;
}

//! Utility for canSchedule interface to check if this fusion has
//! a fully broadcasted reference tensor, which is necessary for
//! the pointwise scheduler.
bool hasReferenceTensorView(Fusion* fusion) {
return getReferenceTensorView(fusion) != nullptr;
return pointwise_utils::getReferenceTensor(fusion) != nullptr;
}

bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) {
Expand Down Expand Up @@ -512,7 +504,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
return;
}

TensorView* reference_tv = getReferenceTensorView(fusion);
TensorView* reference_tv = pointwise_utils::getReferenceTensor(fusion);

NVF_ERROR(
reference_tv != nullptr,
Expand Down
8 changes: 8 additions & 0 deletions csrc/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,13 @@ class DomainMap : public scheduler_tools::DomainMap {
}
};

// Return reference tensor view.
inline TensorView* getReferenceTensor(Fusion* fusion) {
FusionGuard fg(fusion);
DomainMap domain_map(fusion);
auto reference_tv = domain_map.findReferenceTensorView();
return reference_tv;
}

} // namespace pointwise_utils
} // namespace nvfuser
11 changes: 2 additions & 9 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,6 @@ std::unique_ptr<HeuristicParams> ResizeScheduler::computeHeuristics(
return params;
}

namespace {

TensorView* getReferenceTensor(Fusion* fusion) {
return nullptr;
}

} // namespace

void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
FUSER_PERF_SCOPE("ResizeScheduler::schedule");

Expand All @@ -126,7 +118,8 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
scheduler_tools::propagateResizeToInputs(expr);
}

auto ref_tv = getReferenceTensor(fusion);
// Just use the pointwise version for now
auto ref_tv = pointwise_utils::getReferenceTensor(fusion);

std::cerr << "Reference: " << ref_tv->toString() << "\n";

Expand Down

0 comments on commit 7e7db61

Please sign in to comment.