Skip to content

Commit

Permalink
Add ResizeHeuristic (#3674)
Browse files Browse the repository at this point in the history
Currently only has one parameter. Also added some minor tweaks. 

- Previously gridDim.x was static, which is now symbolic. 
- Rejects transpose-like patterns for now as they would need scheduling
like what the transpose does.
  • Loading branch information
naoyam committed Jan 14, 2025
1 parent c3c2093 commit d6cef9a
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 12 deletions.
74 changes: 62 additions & 12 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <scheduler/pointwise_utils.h>
#include <scheduler/registry_utils.h>
#include <scheduler/resize.h>
#include <scheduler/resize_heuristic.h>
#include <scheduler/runtime_info.h>
#include <scheduler/tools/inlining.h>
#include <scheduler/tools/loop_domain_scheduler.h>
Expand All @@ -30,6 +31,31 @@ TensorView* getReferenceTensor(Fusion* fusion) {
return pointwise_utils::getReferenceTensor(fusion);
}

// Returns the largest tensor with its number of elements
std::pair<TensorView*, int64_t> getLargestTensor(
const std::vector<Val*>& vals,
SchedulerRuntimeInfo& runtime_info) {
int64_t max_num_elms = -1;
TensorView* largest_tv = nullptr;
for (auto tv : ir_utils::filterByType<TensorView>(vals)) {
int64_t num_elms = 1;
for (auto logical_id : tv->getLogicalDomain()) {
auto inferred_val =
runtime_info.expressionEvaluator().evaluate(logical_id->extent());
NVF_ERROR(
inferred_val.hasValue(),
"Error inferring extent of: ",
logical_id->toString());
num_elms *= inferred_val.as<int64_t>();
}
if (num_elms > max_num_elms) {
largest_tv = tv;
max_num_elms = num_elms;
}
}
return std::make_pair(largest_tv, max_num_elms);
}

} // namespace

bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
Expand Down Expand Up @@ -111,12 +137,10 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
}
}

// This doesn't work yet due to issue #3571
auto ref_tv = getReferenceTensor(fusion);
if (std::any_of(
ref_tv->getLogicalDomain().begin(),
ref_tv->getLogicalDomain().end(),
[](IterDomain* logical_id) { return logical_id->isBroadcast(); })) {
if (ref_tv == nullptr) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "No reference found");
return false;
}

Expand Down Expand Up @@ -158,10 +182,12 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
}
}

// Disable the scheduler if there's a squeeze op. The loop option
// may also need to be enabled in that case, but that option is not
// turned on automatically yet.
if (ir_utils::hasOpsOfType<SqueezeOp>(fusion)) {
// Skip transpose-like patterns for now
scheduler_tools::TransposeDomainMap domain_map(fusion);
auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim();
if (grouped_inputs_outputs.size() >= 2) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Transpose-like patterns not supported.");
return false;
}

Expand All @@ -173,15 +199,27 @@ std::unique_ptr<HeuristicParams> ResizeScheduler::computeHeuristics(
SchedulerRuntimeInfo& runtime_info,
HeuristicDataCache* data_cache) {
FUSER_PERF_SCOPE("ResizeScheduler::computeHeuristics");
auto params = std::make_unique<HeuristicParams>(SchedulerType::Resize);
auto params = std::make_unique<ResizeParams>(SchedulerType::Resize);
params->tag = "Resize heuristics";
params->cparams.index_type = runtime_info.getIndexType();

const int64_t bdimx = 128;

const auto& [largest_output, max_num_elms] =
getLargestTensor(fusion->outputs(), runtime_info);

params->split_grid_x_dim =
ceilDiv(max_num_elms, bdimx) > ResizeParams::max_gdimx;

return params;
}

void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
FUSER_PERF_SCOPE("ResizeScheduler::schedule");

FusionGuard fg(fusion);
const auto resize_params = dynamic_cast<const ResizeParams*>(params);
NVF_ERROR(resize_params != nullptr);

scheduler_utils::clearMemorySpace(fusion);

Expand Down Expand Up @@ -222,19 +260,31 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
}

auto ref_tv = getReferenceTensor(fusion);
NVF_ERROR(ref_tv != nullptr);

// Just simple scheduling for now.
// TODO: Do something smarter. Can just use the pointwise scheduler?

const int64_t bdimx = 128;

// Make sure the DID ID located at the outermost position
const auto outermost_pos = scheduler_utils::reorderDevicesToOuter(ref_tv);

// Schedule only the remaining IDs
ref_tv->flatten(outermost_pos);
ref_tv->split(outermost_pos, 128);
ref_tv->split(outermost_pos, 1 << 14);
// [..., I0]

ref_tv->split(-1, bdimx);
ref_tv->axis(-1)->parallelize(ParallelType::TIDx);
// [..., I0/bdimx, bdimx(TIDx)]

if (resize_params->split_grid_x_dim) {
ref_tv->split(-2, ResizeParams::max_gdimx);
// [..., I0/bdimx/max_gdimx, max_gdimx, bdimx(TIDx)]
}
ref_tv->axis(-2)->parallelize(ParallelType::BIDx);
// [..., I0/bdimx/max_gdimx, max_gdimx(BIDx), bdimx(TIDx)] or
// [..., I0/bdimx(BIDx), bdimx(TIDx)]

// Propagate the reference to the other tensors. Note that the
// update flag is enabled so to workaround the resize propagation
Expand Down
59 changes: 59 additions & 0 deletions csrc/scheduler/resize_heuristic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <c10/util/hash.h>
#include <ir/interface_nodes.h>
#include <scheduler/heuristic.h>
#include <utils.h>

#include <sstream>

namespace nvfuser {

class ResizeParams : public HeuristicParams {
public:
ResizeParams() : HeuristicParams(SchedulerType::Resize) {};

// Split grid x dimension
bool split_grid_x_dim = false;

static constexpr int64_t max_gdimx = (1L << 31) - 1L;

using HeuristicParams::HeuristicParams;

// Warning: Does not check launch parameters!
bool sameAs(const HeuristicParams* other_base) const override {
auto other = dynamic_cast<const ResizeParams*>(other_base);
if (other == nullptr) {
return false;
}
bool attr_equal = other->cparams == cparams &&
other->split_grid_x_dim == split_grid_x_dim;
return attr_equal;
}

std::string toString() const override {
std::stringstream ss;
ss << "\n===== Resize Parameters ========\n"
<< (tag.empty() ? "" : "Tag: ") << tag << " Resize Characteristics:\n"
<< " split grid x dim: " << split_grid_x_dim << "\n";
ss << "====================================\n";
return ss.str();
}

size_t hash() const override {
return c10::get_hash(split_grid_x_dim);
}

std::unique_ptr<HeuristicParams> clone() const override {
return std::make_unique<ResizeParams>(*this);
}
};

} // namespace nvfuser

0 comments on commit d6cef9a

Please sign in to comment.