-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[CPU][GPU]Support Rope for GPT-OSS #32910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[CPU][GPU]Support Rope for GPT-OSS #32910
Conversation
b036cfe to
587239a
Compare
e-ddykim
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me for GPU part
| this->register_matcher(m, callback); | ||
| } | ||
|
|
||
| ov::pass::RoPEFusionGPTOSS::RoPEFusionGPTOSS() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain why the RoPE fusion for gpt-oss should be done as a separate transformation? Why can’t we extend the existing one?”
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Rope of GPT-OSS cannot share the graph with others.
Current Rope, cos/sin table has same size with input x
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
rotate_x = torch.cat((-x2, x1), dim=-1)
(x * cos) + (rotate_x * sin)
Gpt-Oss Rope, cos/sin table is [B, 1, L, S/2], the layout of x is [B, H, L, S]
first_half, second_half = torch.chunk(x, 2, dim=-1)
first_ = first_half * cos - second_half * sin
second_ = second_half * cos + first_half * sin
return torch.cat((first_, second_), dim=-1)
From graph side, no current implementation could be extended to contain this pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mryzhov I think this is ok for the given transformation. The RoPE patterns are indeed complicated so we have different RoPEFusion for different models
| auto x = pattern::any_input(pattern::rank_equals(4)); | ||
| auto t_cos = pattern::any_input(pattern::shape_matches("[?, 1, ?, half_ndims]")); | ||
| auto t_sin = pattern::any_input(pattern::shape_matches("[?, 1, ?, half_ndims]")); | ||
| auto varsplit = pattern::wrap_type<v1::VariadicSplit>({x, -1, {"half_ndims", "?"}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure {"half_ndims", "?"} is gonna work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other pass uses this style too like RoPEFusionGPTJ. I am not sure if this is correct before.
ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
using namespace ov::op::util;
MATCHER_SCOPE(RoPEFusionGPTJ);
auto gather_sin_cos = pattern::any_input(pattern::type_matches(ov::element::f32));
auto varsplit = pattern::wrap_type<opset1::VariadicSplit>({gather_sin_cos, -1, {"ndims/2", "-1"}});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, nvm, my bad. I didn't get what you're trying to express using {"ndims/2", "-1"}, but this is a Constant. Got lost in the API a little bit :)
| auto first_half_mul_cos = | ||
| pattern::wrap_type<v1::Multiply>({varsplit->output(0), t_cos}, {{"auto_broadcast", "numpy"}}); | ||
| auto second_half_mul_sin = | ||
| pattern::wrap_type<v1::Multiply>({varsplit->output(1), t_sin}, {{"auto_broadcast", "numpy"}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use predicates for checking output indices. pattern::output_index_matches(0) and pattern::output_index_matches(1)
@evkotov, please advise on how to use it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it is changed to use output_index_matches and shape_match.
auto vsplit_out0 = pattern::wrap_type<op::v1::VariadicSplit>(
{x, -1, {"half_ndims", "?"}},
pattern::output_index_matches(0) && pattern::shape_matches("[?, ?, ?, half_ndims]"));
| auto second_half_mul_cos = | ||
| pattern::wrap_type<v1::Multiply>({varsplit->output(1), t_cos}, {{"auto_broadcast", "numpy"}}); | ||
| auto first_half_mul_sin = | ||
| pattern::wrap_type<v1::Multiply>({varsplit->output(0), t_sin}, {{"auto_broadcast", "numpy"}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@evkotov same question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will fix it.
b87afe3 to
eb0a573
Compare
| new_node); | ||
| ov::replace_node(old_node, new_node); | ||
|
|
||
| // this new node may match following additional matchers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by this comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not used now, will remove it.
| new_args.push_back(x_val); | ||
| new_args.push_back(v_cos); | ||
| new_args.push_back(pattern_map.at(t_sin)); | ||
| auto old_node = root; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for having the old_node variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
| auto seq_len = t_src.size(2); | ||
| auto feature_size = t_src.size(3); | ||
|
|
||
| auto half_rotary_dims = rotary_dims / 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| auto half_rotary_dims = rotary_dims / 2; | |
| const auto half_rotary_dims = rotary_dims / 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
| // cos[i + sin_cos_offset] | ||
| // if con/sin table is not same size with input, it is reused for both halves. | ||
| const size_t sin_cos_offset = m_jcp.cos_sin_ndims == half_rotary_ndims ? 0 : half_rotary_ndims; | ||
| if (sin_cos_offset) { | ||
| load(vmm_cos, reg_cos, ov::element::f32, step, false, half_rotary_ndims * sizeof(float)); | ||
| } | ||
| // sin[i + sin_cos_offset] | ||
| if (sin_cos_offset) { | ||
| load(vmm_sin, reg_sin, ov::element::f32, step, false, half_rotary_ndims * sizeof(float)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sin_cos_offset is used as a bool flag. Isn't it more logical to use bool type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
Details:
Tickets: