Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Omit the where predicate of pad when found safe to do so #3708

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
373 changes: 350 additions & 23 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2677,6 +2677,326 @@ void IndexLowering::allocateUniqueFusedReduction(
insertAtTopLevel(fused_reduction_alloc_reduction);
}

namespace {

// Check if the tensor is scheduled in such a way that the
// padded region is included in the loop domain. It should be
// sufficient if there's the mapped resize expr between the
// logical and loop domains of this tensor. Note that, however,
// preceding resizes may make this unsafe.
class PaddingConsistencyAnalysis {
public:
static bool paddedConsistently(
TensorView* dep_tv,
const ExprGroups& pad_resizes) {
PaddingConsistencyAnalysis analysis(dep_tv, pad_resizes);
return analysis.is_consistent_;
}

private:
PaddingConsistencyAnalysis(TensorView* dep_tv, const ExprGroups& pad_resizes)
: pad_resizes_(pad_resizes) {
const auto& tensor_indexer = GpuLower::current()->tensorIndexer();
const auto& index_traversal_graph = tensor_indexer.traversalGraph();

auto dep_tv_def = dep_tv->definition();

const auto logical_domain_indexing_path =
tensor_indexer.getIndexingPath(dep_tv_def, dep_tv->getLogicalDomain());

for (const auto& [path_expr_g, dir] : logical_domain_indexing_path) {
// Since the indexing traversal may use a local graph for resize,
// needs to explicitly find the group in the indexing
// traversal graph.
const auto& expr_g = index_traversal_graph.toGroup(path_expr_g->front());
const auto inputs = getInputsOfExpr(
expr_g,
dir,
ValGraphInputs(index_traversal_graph),
ValGraphOutputs(index_traversal_graph));

const auto outputs = getOutputsOfExpr(
expr_g,
dir,
ValGraphInputs(index_traversal_graph),
ValGraphOutputs(index_traversal_graph));

if (expr_g->front()->isA<Resize>()) {
if (!handleResize(expr_g, inputs, outputs)) {
return;
}
} else {
if (!handleNonResize(expr_g, inputs, outputs)) {
return;
}
}
}

// Traversed the path successfully. Should mean there's no
// conflicting different resize

// Need to make sure all of the pad resizes are found in the
// path
if (detected_resize_exprs_ != pad_resizes.set()) {
return;
}

is_consistent_ = true;
}

bool handleResize(
const ExprGroup& expr_g,
const ValGroups& inputs,
const ValGroups& outputs) {
const auto& resize_input = inputs.at(0);
if (no_more_resize_dep_set.count(resize_input)) {
// No resize allowed
return false;
}

// Check if this is a resize of the pad op.
auto resize_expr_it =
std::find(pad_resizes_.begin(), pad_resizes_.end(), expr_g);
if (resize_expr_it != pad_resizes_.end()) {
detected_resize_exprs_.insert(expr_g);
for (const auto& output_g : outputs) {
NVF_ERROR(pad_resize_dep_map_.emplace(output_g, expr_g).second);
}
return true;
}

const bool depends_on_pad_resize = std::any_of(
inputs.begin(), inputs.end(), [&](const ValGroup& input_group) {
return pad_resize_dep_map_.count(input_group);
});

if (!depends_on_pad_resize) {
return true;
}

// There's a dependency. Check if this is a valid expr. If
// it's another resize, it must have positive expansion
// factors. If not resize, no further resize will be allowed

// Different resize. For padded sides, the resize expand
// factor must be non-negative

const ExprGroup& original_pad_resize = pad_resize_dep_map_.at(inputs.at(0));

// Left expand factor
if (!original_pad_resize->front()->as<Resize>()->leftExpand()->isZero()) {
auto dep_resize_left = expr_g->front()->as<Resize>()->leftExpand();
if (!dep_resize_left->isConstInt() ||
dep_resize_left->evaluate().as<int64_t>() < 0) {
return false;
}
}

if (!original_pad_resize->front()->as<Resize>()->rightExpand()->isZero()) {
auto dep_resize_right = expr_g->front()->as<Resize>()->rightExpand();
if (!dep_resize_right->isConstInt() ||
dep_resize_right->evaluate().as<int64_t>() < 0) {
return false;
}
}

for (const auto& output_g : outputs) {
NVF_ERROR(
pad_resize_dep_map_.emplace(output_g, original_pad_resize).second);
}

return true;
}

bool handleNonResize(
const ExprGroup& expr_g,
const ValGroups& inputs,
const ValGroups& outputs) {
const bool depends_on_pad_resize = std::any_of(
inputs.begin(), inputs.end(), [&](const ValGroup& input_group) {
return pad_resize_dep_map_.count(input_group);
});

// If depends on a pad resize, no further resize is
// allowed. Propagate the no-resize info
if (depends_on_pad_resize ||
std::any_of(
inputs.begin(), inputs.end(), [&](const ValGroup& input_group) {
return no_more_resize_dep_set.count(input_group);
})) {
for (const auto& output_group : outputs) {
no_more_resize_dep_set.emplace(output_group);
}
}

return true;
}

private:
const ExprGroups& pad_resizes_;

bool is_consistent_ = false;

std::unordered_map<ValGroup, ExprGroup> pad_resize_dep_map_;
std::unordered_set<ValGroup> no_more_resize_dep_set;
std::unordered_set<ExprGroup> detected_resize_exprs_;
};

bool canOmitPadPredicate(const PadOp* pad) {
if (isOptionDisabled(DisableOption::PadPredicateElimination)) {
return false;
}

// TensorIndexer and PredicateElimination are required
if (!GpuLower::current()->isTensorIndexerEnabled() ||
isOptionDisabled(DisableOption::PredicateElimination)) {
return false;
}

std::cerr << "Checking " << pad->toString();

auto consumer_tv = pad->out()->as<TensorView>();

const auto& tensor_indexer = GpuLower::current()->tensorIndexer();
const auto& index_traversal_graph = tensor_indexer.traversalGraph();
const auto& pred_info = GpuLower::current()->predicateElimination();

auto resize_exprs = DependencyCheck::getAllExprsBetween(
{consumer_tv->getRootDomain().begin(),
consumer_tv->getRootDomain().end()},
{consumer_tv->getLogicalDomain().begin(),
consumer_tv->getLogicalDomain().end()});

NVF_ERROR(
!resize_exprs.empty() &&
std::all_of(resize_exprs.begin(), resize_exprs.end(), [](Expr* expr) {
return expr->isA<Resize>();
}));

// All resize expansion factors must be static and non-negative
for (auto expr : resize_exprs) {
for (auto expand_val :
{expr->as<Resize>()->leftExpand(),
expr->as<Resize>()->rightExpand()}) {
if (!expand_val->isConstInt()) {
return false;
}
auto expand_int = expand_val->evaluate().as<int64_t>();
if (expand_int < 0) {
return false;
}
}
}

ExprGroups resize_expr_groups = index_traversal_graph.toGroups(resize_exprs);

const auto pad_val = pad->value();

// Contains tensors to check. Starting with the consumer tv and its
// upward dependent tensors are added as necessary. Specifically,
// when an expr is predicated, check if it's initialized to the same
// value as the padding value. If yes and the consumer of the expr
// is found to have the padded region as part of its loop domain, it
// should be safe to omit the pad predicate. If the predicate of the
// expr is omitted, its preceding exprs need to be checked. Since
// reading fusion inputs should never omit predicates, this list
// should never include fusion inputs.

std::deque<TensorView*> tvs_to_check;
tvs_to_check.push_back(consumer_tv);

while (!tvs_to_check.empty()) {
auto tv_to_check = tvs_to_check.front();
tvs_to_check.pop_front();

// tvs_to_check should never include a fusion input
NVF_ERROR(
!tv_to_check->isFusionInput(),
"Not expected to have a fusion input: ",
tv_to_check->toString());

auto tv_expr = tv_to_check->definition();
NVF_ERROR(
tv_expr != nullptr,
"Unexpected to have no definition: ",
tv_to_check->toString());

if (pred_info.canOmitPredicate(tv_expr)) {
// If predicate is omitted and producer values are just
// propagated, check the producers

// Check if the producer value is just moved
if (tv_expr != pad) {
if (!tv_expr->isOneOf<
LoadStoreOp,
BroadcastOp,
ExpandOp,
SqueezeOp,
SliceOp,
CatOp,
ViewOp,
UnaryOp>()) {
std::cerr << "Unsupported op: " << tv_expr->toString();
return false;
}

// For unary op, only cast is allowed for now. Should be able to
// support, e.g., abs, neg, etc. Neg must be careful as the
// negative zero is different from the positive zero, which
// matters for bitwise-or based concat
if (auto uop = dynamic_cast<UnaryOp*>(tv_expr)) {
if (uop->getUnaryOpType() != UnaryOpType::Cast) {
std::cerr << "Unsupported op: " << tv_expr->toString();
return false;
}
}
}

// If there's no producer, i.e., a full op, and the predicate of
// the expr is omitted, can't guarantee anything about the
// padded region
auto producer_tvs = ir_utils::producerTvsOf(tv_to_check);
if (producer_tvs.empty()) {
return false;
}
for (auto producer_tv : producer_tvs) {
// If tv_expr has a fusion input as one of its input, its
// predicate should never be omitted, so producer_tv should
// not be a fusion input
NVF_ERROR(!producer_tv->isFusionInput());
tvs_to_check.push_back(producer_tv);
}
} else {
auto init_val = pred_info.getInitValue(tv_to_check);
if (init_val == nullptr) {
// Can't determine if it's safe to omit without an init value
return false;
}

// Note Val::sameAs may not work as the data types may be
// different (e.g., 0.0f vs 0L)
bool initialized_to_same_value = pad_val->value().hasValue() &&
init_val != nullptr && init_val->value().hasValue() &&
pad_val->value() == init_val->value();

if (!initialized_to_same_value) {
return false;
}

if (!PaddingConsistencyAnalysis::paddedConsistently(
tv_to_check, resize_expr_groups)) {
return false;
}
}
}

std::cerr << "Pad predicate can be safely removed: " << pad->toString();

return true;
}

} // namespace

void IndexLowering::handle(const PadOp* pad) {
// Convert to a where op as:
// consumer[consumer_idx] = (consumer_idx >= left_pad && consumer_idx <
Expand All @@ -2695,30 +3015,37 @@ void IndexLowering::handle(const PadOp* pad) {
const auto pad_val = pad->value();

// Build a predicate for where
auto consumer_root_indices = Index::getConsumerPerDimLogicalIndex(
consumer_tv, for_loops_, getRotatedLoop());
Val* pred = consumer_tv->fusion()->trueVal();
for (auto padded_axis : pad->getPaddedAxes()) {
auto consumer_idx = consumer_root_indices.at(padded_axis);
auto consumer_root_id = consumer_tv->getLogicalDomain().at(padded_axis);
NVF_ERROR(!consumer_root_id->maybePartial());
const auto& pad_widths = pad->getPadWidths(padded_axis);
pred = SimplifyingIrBuilder::logicalAndExpr(
pred,
// idx >= left_pad && idx < extent - right_pad
SimplifyingIrBuilder::logicalAndExpr(
SimplifyingIrBuilder::geExpr(consumer_idx, pad_widths.first),
SimplifyingIrBuilder::ltExpr(
consumer_idx,
SimplifyingIrBuilder::subExpr(
consumer_root_id->getMaybeExpandedExtent(),
pad_widths.second))));
}

pred = GpuLower::current()->commonScalarMap().hoistScalar(pred, for_loops_);
bool can_omit_where_predicate = canOmitPadPredicate(pad);

if (can_omit_where_predicate) {
pushBack(IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, in));
} else {
auto consumer_root_indices = Index::getConsumerPerDimLogicalIndex(
consumer_tv, for_loops_, getRotatedLoop());
Val* pred = consumer_tv->fusion()->trueVal();
for (auto padded_axis : pad->getPaddedAxes()) {
auto consumer_idx = consumer_root_indices.at(padded_axis);
auto consumer_root_id = consumer_tv->getLogicalDomain().at(padded_axis);
NVF_ERROR(!consumer_root_id->maybePartial());
const auto& pad_widths = pad->getPadWidths(padded_axis);
pred = SimplifyingIrBuilder::logicalAndExpr(
pred,
// idx >= left_pad && idx < extent - right_pad
SimplifyingIrBuilder::logicalAndExpr(
SimplifyingIrBuilder::geExpr(consumer_idx, pad_widths.first),
SimplifyingIrBuilder::ltExpr(
consumer_idx,
SimplifyingIrBuilder::subExpr(
consumer_root_id->getMaybeExpandedExtent(),
pad_widths.second))));
}

pred = GpuLower::current()->commonScalarMap().hoistScalar(pred, for_loops_);

pushBack(IrBuilder::create<TernaryOp>(
TernaryOpType::Where, out, pred, in, pad_val));
}

pushBack(IrBuilder::create<TernaryOp>(
TernaryOpType::Where, out, pred, in, pad_val));
GpuLower::current()->propagateExprInfo(pad, back());
}

Expand Down
Loading
Loading