Skip to content

Commit

Permalink
EmbeddingFwdOp node with same functionality as F.embedding (#3649)
Browse files Browse the repository at this point in the history
This PR adds an `EmbeddingFwdOp` with same functionality as
`F.embedding`.
1. I am not using `take_along_axis`. `F.embedding` allows optional
parameters like `max_norm, padding_idx` which would require further
processing if implemented using `take_along_axis`. So I defaulted to
creating a new node to guarantee performance parity.
2. Thunder uses `prims.EMBEDDING` if the optional parameters
`padding_idx/max_norm` are specified, else it uses `prims.TAKE`. This
prevents nvfuser from consuming embedding operator in the other cases.
Hence, in Thunder, nvfuser will also directly execute
`ltorch.embedding`. This will require a separate backward API to consume
`ltorch.embedding_backward` and cannot reuse grad rules for
`prims.EMBEDDING`. Hence, the `EmbeddingFwdOp` naming instead of
`EmbeddingOp`.
3. I first plan to plumb the fwd only embedding support in Thunder while
I draft the backward node which should be very similar. Thunder reviews
may bring up another way of implementing this support.
  • Loading branch information
Priya2698 authored Jan 22, 2025
1 parent 2bbdf26 commit 56d1ec5
Show file tree
Hide file tree
Showing 16 changed files with 511 additions and 2 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_circular_buffering.cpp
${NVFUSER_ROOT}/tests/cpp/test_abstract_tensor.cpp
${NVFUSER_ROOT}/tests/cpp/test_dynamic_transform.cpp
${NVFUSER_ROOT}/tests/cpp/test_embedding_node.cpp
${NVFUSER_ROOT}/tests/cpp/test_evaluator.cpp
${NVFUSER_ROOT}/tests/cpp/test_exceptions.cpp
${NVFUSER_ROOT}/tests/cpp/test_expr_simplifier.cpp
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ bool isTvOp(const Expr* expr) {
LinearOp,
SdpaFwdOp,
SdpaBwdOp,
EmbeddingFwdOp,
BroadcastOp,
SqueezeOp,
ExpandOp,
Expand Down
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Val;
f(LinearOp); \
f(SdpaFwdOp); \
f(SdpaBwdOp); \
f(EmbeddingFwdOp); \
f(Communication); \
f(ForLoop); \
f(P2PCommunication);
Expand Down
75 changes: 75 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2714,4 +2714,79 @@ class SdpaBwdOp : public Expr {
const std::vector<PolymorphicValue>& inputs) const override;
};

class EmbeddingFwdOp : public Expr {
public:
using Expr::Expr;

EmbeddingFwdOp(
IrBuilderPasskey,
TensorView* output,
TensorView* input,
TensorView* weight,
Val* padding_idx,
Val* max_norm,
Val* norm_type,
Val* scale_grad_by_freq,
Val* sparse);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "EmbeddingFwdOp";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

TensorView* out() const {
return output(0)->as<TensorView>();
}

TensorView* in() const {
return input(0)->as<TensorView>();
}

TensorView* weight() const {
return input(1)->as<TensorView>();
}

Val* norm_type() const {
return input(2);
}

Val* scale_grad_by_freq() const {
return input(3);
}

Val* sparse() const {
return input(4);
}

Val* padding_idx() const {
if (has_padding_idx()) {
return input(5);
}
return nullptr;
}

Val* max_norm() const {
if (has_max_norm()) {
return input(5 + has_padding_idx());
}
return nullptr;
}

bool has_padding_idx() const {
return attribute<bool>(0);
}

bool has_max_norm() const {
return attribute<bool>(1);
}

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
};

} // namespace nvfuser
91 changes: 91 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <type.h>

#include <c10/util/irange.h>
#include <torch/nn/options/embedding.h>

#include <complex>
#include <iterator>
Expand Down Expand Up @@ -5304,4 +5305,94 @@ std::vector<PolymorphicValue> SdpaBwdOp::evaluate(
slice_last_dim(grad_value)};
}

EmbeddingFwdOp::EmbeddingFwdOp(
IrBuilderPasskey passkey,
TensorView* output,
TensorView* input,
TensorView* weight,
Val* padding_idx,
Val* max_norm,
Val* norm_type,
Val* scale_grad_by_freq,
Val* sparse)
: Expr(passkey) {
addOutput(output);

addInput(input);
addInput(weight);
addInput(norm_type);
addInput(scale_grad_by_freq);
addInput(sparse);
if (padding_idx != nullptr) {
addInput(padding_idx);
addDataAttribute(true);
} else {
addDataAttribute(false);
}
if (max_norm != nullptr) {
addInput(max_norm);
addDataAttribute(true);
} else {
addDataAttribute(false);
}
}

NVFUSER_DEFINE_CLONE_AND_CREATE(EmbeddingFwdOp)

std::string EmbeddingFwdOp::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << out()->toString() << ",\n";
indent(ss, indent_size + 1) << " = embedding(" << in()->toString() << ",\n";
indent(ss, indent_size + 1) << " " << weight()->toString() << ",\n";
if (padding_idx() != nullptr) {
indent(ss, indent_size + 1)
<< " padding_idx = " << padding_idx()->toString() << ",\n";
}
if (max_norm() != nullptr) {
indent(ss, indent_size + 1)
<< " max_norm = " << max_norm()->toString() << ",\n";
}
indent(ss, indent_size + 1)
<< " norm_type = " << norm_type()->toString() << ",\n";
indent(ss, indent_size + 1)
<< " scale_grad_by_freq = "
<< scale_grad_by_freq()->toInlineString() << ",\n";
indent(ss, indent_size + 1)
<< " sparse = " << sparse()->toInlineString() << ")\n";
return ss.str();
}

std::string EmbeddingFwdOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

std::vector<PolymorphicValue> EmbeddingFwdOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
auto input = inputs.at(0).as<at::Tensor>();
auto weight = inputs.at(1).as<at::Tensor>();
auto norm_type = inputs.at(2).as<double>();
auto scale_grad_by_freq = inputs.at(3).as<bool>();
auto sparse = inputs.at(4).as<bool>();
std::optional<int64_t> padding_idx = std::nullopt;
if (has_padding_idx()) {
padding_idx = inputs.at(5).as<int64_t>();
}
std::optional<double> max_norm = std::nullopt;
if (has_max_norm()) {
auto idx = 5 + has_padding_idx();
max_norm = inputs.at(idx).as<double>();
}

namespace F = torch::nn::functional;
return {F::embedding(
input,
weight,
F::EmbeddingFuncOptions()
.padding_idx(padding_idx)
.max_norm(max_norm)
.norm_type(norm_type)
.scale_grad_by_freq(scale_grad_by_freq)
.sparse(sparse))};
}
} // namespace nvfuser
21 changes: 21 additions & 0 deletions csrc/logical_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,27 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseLogicalDomainMap::map(
return dom_map;
}

if (EmbeddingFwdOp* op =
dynamic_cast<EmbeddingFwdOp*>(consumer_tv_->definition())) {
// Producers:
// input = [*]
// weight = [V, embedding_dim]
// Consumers:
// output = [*, embedding_dim]
auto ndims_out = consumer_root.size();
if (producer_tv_->sameAs(op->in())) {
for (auto idx : c10::irange(ndims_out - 1)) {
updatePairwiseLogicalDomainMap(
producer_logical.at(idx), consumer_root.at(idx));
}
}
if (producer_tv_->sameAs(op->weight())) {
updatePairwiseLogicalDomainMap(
producer_logical.back(), consumer_root.back());
}
return dom_map;
}

size_t itc = 0, itp = 0;
while (itc < consumer_root.size() && itp < producer_logical.size()) {
IterDomain* producer_id = producer_logical.at(itp);
Expand Down
68 changes: 68 additions & 0 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,4 +662,72 @@ SdpfaBwdResult sdpfa_bwd(
return {grad_query, grad_key, grad_value};
}

TensorView* embedding_fwd(
TensorView* input,
TensorView* weight,
Val* padding_idx,
Val* max_norm,
Val* norm_type,
Val* scale_grad_by_freq,
Val* sparse) {
auto input_domain = TensorDomain::noReductions(input->getLogicalDomain());
auto weight_domain = TensorDomain::noReductions(weight->getLogicalDomain());
NVF_CHECK(
!input_domain.empty(),
"Expected input to be atleast 1D, got: ",
input_domain.size());
NVF_CHECK(
weight_domain.size() == 2,
"Expected weight to be 2D, got: ",
weight_domain.size());

NVF_CHECK(
!padding_idx || padding_idx->isScalar(),
"Expected padding_idx to be a scalar int.");
NVF_CHECK(
!max_norm || max_norm->isScalar(),
"Expected max_norm to be a scalar double.");
NVF_CHECK(
!norm_type || norm_type->isScalar(),
"Expected scale to be a scalar double.");
NVF_CHECK(
!scale_grad_by_freq || scale_grad_by_freq->isScalar(),
"Expected scale to be a scalar bool.");
NVF_CHECK(
!sparse || sparse->isScalar(), "Expected scale to be a scalar bool.");

auto ndims_out = input_domain.size() + 1;
std::vector<IterDomain*> out_domain(ndims_out, nullptr);

for (auto idx : c10::irange(ndims_out - 1)) {
out_domain[idx] = ops::newOutputIterDomain({input_domain[idx]});
}
out_domain[ndims_out - 1] = ops::newOutputIterDomain({weight_domain.back()});
TensorDomain* out_td = IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true));
TensorView* output = IrBuilder::create<TensorView>(out_td, weight->dtype());

if (norm_type == nullptr) {
norm_type = IrBuilder::create<Val>(2.0, DataType::Double);
}

if (scale_grad_by_freq == nullptr) {
scale_grad_by_freq = input->fusion()->falseVal();
}
if (sparse == nullptr) {
sparse = input->fusion()->falseVal();
}
IrBuilder::create<EmbeddingFwdOp>(
output,
input,
weight,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse);

return output;
}

} // namespace nvfuser
9 changes: 9 additions & 0 deletions csrc/ops/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,13 @@ SdpfaBwdResult sdpfa_bwd(
TensorView* philox_offset,
Val* scale);

TensorView* embedding_fwd(
TensorView* input,
TensorView* weight,
Val* padding_idx,
Val* max_norm,
Val* norm_type,
Val* scale_grad_by_freq,
Val* sparse);

} // namespace nvfuser
43 changes: 43 additions & 0 deletions csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -3052,6 +3052,49 @@ struct SdpaBwdOpRecord : RecordFunctor {
}
};

struct EmbeddingFwdOpRecord : RecordFunctor {
EmbeddingFwdOpRecord(std::vector<State> args, std::vector<State> outputs)
: RecordFunctor(
std::move(args),
std::move(outputs),
"ops.embedding_fwd",
serde::RecordType::EmbeddingFwdOp) {}
~EmbeddingFwdOpRecord() override = default;
RecordFunctor* clone() final {
return new EmbeddingFwdOpRecord(*this);
}

void operator()(FusionState& fd) final {
auto input = fd.getFusionState(args_.at(0).index)->as<TensorView>();
auto weight = fd.getFusionState(args_.at(1).index)->as<TensorView>();
auto padding_idx = (args_.at(2).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(2).index)->as<Val>()
: nullptr;
auto max_norm = (args_.at(3).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(3).index)->as<Val>()
: nullptr;
auto norm_type = (args_.at(4).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(4).index)->as<Val>()
: nullptr;
auto scale_grad_by_freq = (args_.at(5).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(5).index)->as<Val>()
: nullptr;
auto sparse = (args_.at(6).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(6).index)->as<Val>()
: nullptr;

auto output = embedding_fwd(
input,
weight,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse);
fd.setFusionState(outputs_.at(0).index, output);
}
};

} // namespace nvfuser::python_frontend

//! Creating the template specialized hash and equal_to functions for a
Expand Down
Loading

0 comments on commit 56d1ec5

Please sign in to comment.