Skip to content

Conversation

@zhangYiIntel
Copy link
Contributor

@zhangYiIntel zhangYiIntel commented Nov 18, 2025

Details:

  • Support GPT-OSS style Rope fusion
  • Add Support in CPU kernel
  • Add Support in GPU OCL kernel

Tickets:

@zhangYiIntel zhangYiIntel requested review from a team as code owners November 18, 2025 14:42
@zhangYiIntel zhangYiIntel requested review from mryzhov and removed request for a team November 18, 2025 14:42
@github-actions github-actions bot added category: IE Tests OpenVINO Test: plugins and common category: CPU OpenVINO CPU plugin category: transformations OpenVINO Runtime library - Transformations labels Nov 18, 2025
@zhangYiIntel zhangYiIntel requested review from a team as code owners November 19, 2025 05:44
@github-actions github-actions bot added the category: GPU OpenVINO GPU plugin label Nov 19, 2025
@zhangYiIntel zhangYiIntel changed the title [CPU]Support Rope for GPT-OSS [CPU][GPU]Support Rope for GPT-OSS Nov 19, 2025
Copy link
Contributor

@e-ddykim e-ddykim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me for GPU part

this->register_matcher(m, callback);
}

ov::pass::RoPEFusionGPTOSS::RoPEFusionGPTOSS() {
Copy link
Contributor

@mryzhov mryzhov Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why the RoPE fusion for gpt-oss should be done as a separate transformation? Why can’t we extend the existing one?”

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Rope of GPT-OSS cannot share the graph with others.
Current Rope, cos/sin table has same size with input x

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    rotate_x = torch.cat((-x2, x1), dim=-1)
   (x * cos) + (rotate_x  * sin)

Gpt-Oss Rope, cos/sin table is [B, 1, L, S/2], the layout of x is [B, H, L, S]

    first_half, second_half = torch.chunk(x, 2, dim=-1)
    first_ = first_half * cos - second_half * sin
    second_ = second_half * cos + first_half * sin
    return torch.cat((first_, second_), dim=-1)

From graph side, no current implementation could be extended to contain this pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mryzhov I think this is ok for the given transformation. The RoPE patterns are indeed complicated so we have different RoPEFusion for different models

auto x = pattern::any_input(pattern::rank_equals(4));
auto t_cos = pattern::any_input(pattern::shape_matches("[?, 1, ?, half_ndims]"));
auto t_sin = pattern::any_input(pattern::shape_matches("[?, 1, ?, half_ndims]"));
auto varsplit = pattern::wrap_type<v1::VariadicSplit>({x, -1, {"half_ndims", "?"}});
Copy link
Contributor

@CuriousPanCake CuriousPanCake Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure {"half_ndims", "?"} is gonna work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other pass uses this style too like RoPEFusionGPTJ. I am not sure if this is correct before.

ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
    using namespace ov::op::util;
    MATCHER_SCOPE(RoPEFusionGPTJ);

    auto gather_sin_cos = pattern::any_input(pattern::type_matches(ov::element::f32));
    auto varsplit = pattern::wrap_type<opset1::VariadicSplit>({gather_sin_cos, -1, {"ndims/2", "-1"}});

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, nvm, my bad. I didn't get what you're trying to express using {"ndims/2", "-1"}, but this is a Constant. Got lost in the API a little bit :)

Comment on lines 1131 to 1134
auto first_half_mul_cos =
pattern::wrap_type<v1::Multiply>({varsplit->output(0), t_cos}, {{"auto_broadcast", "numpy"}});
auto second_half_mul_sin =
pattern::wrap_type<v1::Multiply>({varsplit->output(1), t_sin}, {{"auto_broadcast", "numpy"}});
Copy link
Contributor

@CuriousPanCake CuriousPanCake Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use predicates for checking output indices. pattern::output_index_matches(0) and pattern::output_index_matches(1)

@evkotov, please advise on how to use it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it is changed to use output_index_matches and shape_match.

    auto vsplit_out0 = pattern::wrap_type<op::v1::VariadicSplit>(
            {x, -1, {"half_ndims", "?"}},
            pattern::output_index_matches(0) && pattern::shape_matches("[?, ?, ?, half_ndims]"));

Comment on lines 1138 to 1141
auto second_half_mul_cos =
pattern::wrap_type<v1::Multiply>({varsplit->output(1), t_cos}, {{"auto_broadcast", "numpy"}});
auto first_half_mul_sin =
pattern::wrap_type<v1::Multiply>({varsplit->output(0), t_sin}, {{"auto_broadcast", "numpy"}});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@evkotov same question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will fix it.

@maxnick maxnick added this to the 2026.0 milestone Nov 19, 2025
@maxnick maxnick self-assigned this Nov 19, 2025
@zhangYiIntel zhangYiIntel force-pushed the yi3/gpt-oss-rope branch 2 times, most recently from b87afe3 to eb0a573 Compare November 20, 2025 02:39
new_node);
ov::replace_node(old_node, new_node);

// this new node may match following additional matchers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used now, will remove it.

new_args.push_back(x_val);
new_args.push_back(v_cos);
new_args.push_back(pattern_map.at(t_sin));
auto old_node = root;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for having the old_node variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

auto seq_len = t_src.size(2);
auto feature_size = t_src.size(3);

auto half_rotary_dims = rotary_dims / 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto half_rotary_dims = rotary_dims / 2;
const auto half_rotary_dims = rotary_dims / 2;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines 92 to 101
// cos[i + sin_cos_offset]
// if con/sin table is not same size with input, it is reused for both halves.
const size_t sin_cos_offset = m_jcp.cos_sin_ndims == half_rotary_ndims ? 0 : half_rotary_ndims;
if (sin_cos_offset) {
load(vmm_cos, reg_cos, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
}
// sin[i + sin_cos_offset]
if (sin_cos_offset) {
load(vmm_sin, reg_sin, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sin_cos_offset is used as a bool flag. Isn't it more logical to use bool type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: CPU OpenVINO CPU plugin category: GPU OpenVINO GPU plugin category: IE Tests OpenVINO Test: plugins and common category: transformations OpenVINO Runtime library - Transformations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants