Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions examples/graph/gqa_training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,33 @@ bool bench_gqa_backward(engine::kind ekind, logical_tensor::data_type dt,
bmm_do_v.add_inputs({doutput, value});
bmm_do_v.add_outputs({dprobs});

// compute dmasked_score = dsoftmax(dprobs)
// compute dmasked_score = P * (dprobs - ReduceSum(O * dO))
// decomposed softmax backward: dS = P * (dP - rowsum(O * dO))
auto o_do_out
= logical_tensor(id++, dt_inter, output_sz, layout_type::strided);
auto o_do_mul = op(id++, op::kind::Multiply, "mul_o_do");
o_do_mul.add_inputs({output, doutput});
o_do_mul.add_outputs({o_do_out});

auto correction_out
= logical_tensor(id++, dt_inter, stats_sz, layout_type::strided);
auto correction = op(id++, op::kind::ReduceSum, "reducesum_correction");
correction.set_attr<std::vector<int64_t>>(op::attr::axes, {4});
correction.set_attr<bool>(op::attr::keep_dims, true);
correction.add_inputs({o_do_out});
correction.add_outputs({correction_out});

auto dp_corrected_out
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto dp_corrected_op = op(id++, op::kind::Subtract, "sub_dp_corrected");
dp_corrected_op.add_inputs({dprobs, correction_out});
dp_corrected_op.add_outputs({dp_corrected_out});

auto dmasked_score
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto softmax_grad = op(id++, op::kind::SoftMaxBackward, "softmax_bwd");
softmax_grad.set_attr<int64_t>(op::attr::axis, -1);
softmax_grad.add_inputs({dprobs, probs});
softmax_grad.add_outputs({dmasked_score});
auto softmax_bwd_mul = op(id++, op::kind::Multiply, "mul_softmax_bwd");
softmax_bwd_mul.add_inputs({probs, dp_corrected_out});
softmax_bwd_mul.add_outputs({dmasked_score});

// compute dscored_score = dmasked_score / scale
auto dscaled_score
Expand Down Expand Up @@ -372,10 +392,13 @@ bool bench_gqa_backward(engine::kind ekind, logical_tensor::data_type dt,
gqa_bwd.add_op(exp);
gqa_bwd.add_op(bmm_p_do);
gqa_bwd.add_op(bmm_do_v);
gqa_bwd.add_op(softmax_grad);
gqa_bwd.add_op(o_do_mul);
gqa_bwd.add_op(correction);
gqa_bwd.add_op(dp_corrected_op);
gqa_bwd.add_op(softmax_bwd_mul);
gqa_bwd.add_op(scale_div2);
gqa_bwd.add_op(bmm_dscaled_score_k);
gqa_bwd.add_op(bmm_dscaled_score_q);
gqa_bwd.add_op(bmm_dscaled_score_k);
gqa_bwd.add_op(reduce_dv);
gqa_bwd.add_op(reduce_dk);
if (dt != dt_inter) {
Expand Down
46 changes: 45 additions & 1 deletion src/graph/backend/dnnl/dnnl_op_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1009,11 +1009,14 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_matmul, 1,
DNNL_GRAPH_OP_SCHEMA(dnnl_softmax, 1,
op_schema_t()
.set_inputs_option(op_schema_t::param_num_option::variadic)
.set_outputs_option(op_schema_t::param_num_option::optional)
.set_num_inputs(std::set<size_t>({1, 32}))
.set_num_outputs(2)
.set_num_outputs(std::set<size_t>({2, 3}))
.set_input(0, "input")
.set_output(0, "output")
.set_output(1, "scratchpad")
.set_output(2, "stats") // optional
// Attributes inherited from SoftMax
.set_attr(op_attr::axis, false, attribute_kind::i, (int64_t)1)
.set_attr(op_attr::mode, false, attribute_kind::s, "none",
Expand All @@ -1023,7 +1026,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_softmax, 1,
.set_attr(op_attr::fusion_info, false,
attribute_kind::fusion_info)
// Analysis rules
.set_shape_inference_function(infer_identity_output_shape)
.set_shape_inference_function(infer_dnnl_softmax_output_shape)
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_softmax)
.SET_EXECUTABLE_CREATOR(
executable_creator<softmax_executable_t>)
Expand Down Expand Up @@ -1177,20 +1180,25 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_mask, 1,
DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa, 1,
op_schema_t()
.set_inputs_option(op_schema_t::param_num_option::variadic)
.set_outputs_option(op_schema_t::param_num_option::optional)
.set_num_inputs(std::set<size_t>({3, 32}))
.set_num_outputs(2)
.set_num_outputs(std::set<size_t>({2, 3}))
.set_input(0, "query")
.set_input(1, "key")
.set_input(2, "value")
.set_input(3, "scale") // optional
.set_input(4, "mask") // optional
.set_output(0, "output")
.set_output(1, "scratchpad")
.set_output(2,
"softmax_stats") // optional, only used for sdpa training
.set_attr(op_attr::fusion_info, false,
attribute_kind::fusion_info)
.set_attr(op_attr::with_scale, true, attribute_kind::b)
.set_attr(op_attr::is_invert_scale, false, attribute_kind::b,
false)
.set_attr(op_attr::is_training, false, attribute_kind::b)
// mask_type attribute indicates existence of explicit mask,
// top-left implicit causal mask or bottm-right implicit causal mask
.set_attr(op_attr::mask_type, true, attribute_kind::i)
Expand All @@ -1202,6 +1210,42 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa, 1,
.SET_EXECUTABLE_CREATOR(executable_creator<sdpa_executable_t>)
.SET_ARG_INDICES_GETTER(sdpa_executable_t))

// Backward op for SDPA
DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa_bwd, 1,
op_schema_t()
.set_inputs_option(op_schema_t::param_num_option::variadic)
.set_outputs_option(op_schema_t::param_num_option::optional)
// Inputs: query, key, value, dst, diff_dst, [dS], [scale], [mask]
.set_num_inputs(std::set<size_t>({5, 32}))
.set_num_outputs(std::set<size_t>({4, 5}))
.set_input(0, "query")
.set_input(1, "key")
.set_input(2, "value")
.set_input(3, "dst")
.set_input(4, "stats")
.set_input(5, "diff_dst")
.set_input(6, "scale") // optional
.set_input(7, "mask") // optional
// Outputs: diff_query, diff_key, diff_value, scratchpad, diff_mask
.set_output(0, "diff_query")
.set_output(1, "diff_key")
.set_output(2, "diff_value")
.set_output(3, "scratchpad")
.set_output(4, "diff_mask") // optional
.set_attr(op_attr::fusion_info, false,
attribute_kind::fusion_info)
.set_attr(op_attr::with_scale, true, attribute_kind::b)
.set_attr(op_attr::is_invert_scale, false, attribute_kind::b,
false)
.set_attr(op_attr::mask_type, true, attribute_kind::i)
.set_attr(op_attr::qk_acc_mode, true, attribute_kind::s)
.set_attr(op_attr::vs_acc_mode, true, attribute_kind::s)
.set_shape_inference_function(infer_dnnl_sdpa_bwd_output_shape)
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_sdpa_bwd)
.SET_EXECUTABLE_CREATOR(
executable_creator<sdpa_bwd_executable_t>)
.SET_ARG_INDICES_GETTER(sdpa_bwd_executable_t))

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
1 change: 1 addition & 0 deletions src/graph/backend/dnnl/dnnl_opset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class dnnl_opset_t {
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_reorder, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_groupnorm, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_sdpa, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_sdpa_bwd, 1)>());
}
};

Expand Down
106 changes: 106 additions & 0 deletions src/graph/backend/dnnl/dnnl_shape_infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,21 @@ status_t infer_dnnl_sdpa_output_shape(op_t *n,
}

set_shape_and_strides(*outputs[0], inferred_output_shape);

if (outputs.size() > 2) {
auto out1 = logical_tensor_wrapper_t(outputs[2]);
dims inferred_stats_shape
= {query_dims[0], query_dims[1], query_dims[2], 1};

if (out1.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(inferred_stats_shape, out1.vdims()),
"%s, given stats shape is not compatible with inferred",
op_t::kind2str(n->get_kind()).c_str());
}

set_shape_and_strides(*outputs[2], inferred_stats_shape);
}

return status::success;
}

Expand Down Expand Up @@ -617,6 +632,97 @@ status_t infer_dnnl_layernorm_output_shape(op_t *n,
return status::success;
}

status_t infer_dnnl_softmax_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
auto out0 = logical_tensor_wrapper_t(outputs[0]);
auto in0 = logical_tensor_wrapper_t(inputs[0]);

// check if partial set shape aligns with inferred shape
if (out0.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(in0.vdims(), out0.vdims()),
"%s, input and output shapes are not compatible",
op_t::kind2str(n->get_kind()).c_str());
}

// We should compute output dense strides instead of directly copying input
// strides to it
set_shape_and_strides(*outputs[0], in0.vdims());
if (outputs.size() == 2) return status::success;

// infer stats output shape
auto out1 = logical_tensor_wrapper_t(outputs[2]);
dims out1_dims = in0.vdims();
int64_t axis = n->get_attr<int64_t>(op_attr::axis);
if (axis < 0) { axis += in0.ndims(); }
out1_dims[axis] = 1;

if (out1.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(out1_dims, out1.vdims()),
"%s, given stats shape is not compatible with inferred",
op_t::kind2str(n->get_kind()).c_str());
}

set_shape_and_strides(*outputs[2], out1_dims);

return status::success;
}

status_t infer_dnnl_sdpa_bwd_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
// [batch_size, num_heads_q, seq_len_q, head_size_qk]
auto query = ltw(inputs[0]);
// [batch_size, num_heads_q, head_size_qk, seq_len_kv,]
auto key = ltw(inputs[1]);
// [batch_size, num_heads_v, seq_len_kv, head_size_v]
auto value = ltw(inputs[2]);

auto dquery = ltw(outputs[0]);
auto dkey = ltw(outputs[1]);
auto dvalue = ltw(outputs[2]);

if (dquery.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(dquery.vdims(), query.vdims()),
"%s, inferred out shape and output shape are not compatible",
op_t::kind2str(n->get_kind()).c_str());
}
set_shape_and_strides(*outputs[0], query.vdims());

if (dkey.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(dkey.vdims(), key.vdims()),
"%s, inferred out shape and output shape are not compatible",
op_t::kind2str(n->get_kind()).c_str());
}
set_shape_and_strides(*outputs[1], key.vdims());

if (dvalue.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(dvalue.vdims(), value.vdims()),
"%s, inferred out shape and output shape are not compatible",
op_t::kind2str(n->get_kind()).c_str());
}
set_shape_and_strides(*outputs[2], value.vdims());

if (outputs.size() > 4) {
// dmask exists
auto dmask = ltw(outputs[4]);
dims inferred_dmask_shape = query.vdims();
size_t ndims = query.ndims();
// [batch_size, num_heads_q, seq_len_q, seq_len_kv]
inferred_dmask_shape[ndims - 1] = value.vdims()[ndims - 1];

if (dmask.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(inferred_dmask_shape, dmask.vdims()),
"%s, given dmask shape is not compatible with inferred",
op_t::kind2str(n->get_kind()).c_str());
}

set_shape_and_strides(*outputs[4], inferred_dmask_shape);
}

return status::success;
}

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
9 changes: 9 additions & 0 deletions src/graph/backend/dnnl/dnnl_shape_infer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ status_t infer_dnnl_host_scalar_output_shape(op_t *n,
status_t infer_dnnl_layernorm_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

status_t infer_dnnl_softmax_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

status_t infer_dnnl_sdpa_bwd_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
Loading
Loading