Skip to content

Commit

Permalink
lower pow(x,1) to x (#3774)
Browse files Browse the repository at this point in the history
Partially solve #3629

Currently, `genPowerWithMul()` only handles case with factor of 2 & 3,
for `y = pow(x, 1.0)`, it should be lowered to `y = x` which is much
faster than the power version.
  • Loading branch information
liqiangxl authored Jan 30, 2025
1 parent 149c163 commit 765a165
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -995,32 +995,38 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
return false;
}

// Only **2 and **3 are considered
if (!(exponent == 2 || exponent == 3)) {
// Only **1, **2 and **3 are considered
if (!(exponent == 1 || exponent == 2 || exponent == 3)) {
return false;
}

auto lhs = gen(bop->lhs());

if (print_inline_) {
code_ << lhs << " * " << lhs;
if (exponent == 3) {
code_ << " * " << lhs;
for (int i = 0; i < exponent; ++i) {
if (i != 0) {
code_ << " * ";
}
code_ << lhs;
}
} else {
indent() << gen(bop->out());
if (bop->out()->isScalar()) {
code_ << " = " << lhs << " * " << lhs;
if (exponent == 3) {
code_ << " * " << lhs;
for (int i = 0; i < exponent; ++i) {
if (i == 0) {
code_ << " = " << lhs;
} else {
code_ << " * " << lhs;
}
}
} else {
code_ << "\n";
indent() << kTab << "= " << lhs << "\n";
indent() << kTab << "* " << lhs;
if (exponent == 3) {
code_ << "\n";
indent() << kTab << "* " << lhs;
for (int i = 0; i < exponent; ++i) {
if (i == 0) {
code_ << "\n";
indent() << kTab << "= " << lhs;
} else {
indent() << "\n" << kTab << "* " << lhs;
}
}
}
}
Expand Down

0 comments on commit 765a165

Please sign in to comment.