Skip to content

Commit b036cfe

Browse files
committed
[CPU]Support Rope for GPT-OSS
1 parent c18a511 commit b036cfe

File tree

10 files changed

+196
-4
lines changed

10 files changed

+196
-4
lines changed

src/common/transformations/include/ov_ops/rotary_positional_embeddings.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class TRANSFORMATIONS_API RoPE : public Op {
3333
bool is_qwen = false; // Qwen is special which overrides other setting
3434
bool use_rope_cache = false; // use precomputed RoPE cache for trigonometric values (cosine and sine)
3535
bool support_3d_rope = false; // use same logic as RoPEFusionGPTNEOX(4), used by gpu plugin
36+
size_t cos_sin_ndims = 0; // last dimension of con/sin table
3637
size_t head_cnt = 0;
3738
size_t head_size = 0;
3839
int gather_position_arg_id =

src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class TRANSFORMATIONS_API RoPEFusionIOSlicing;
2121
class TRANSFORMATIONS_API RoPEFusionPreprocess;
2222
class TRANSFORMATIONS_API RoPEFusionCosSinPreprocess;
2323
class TRANSFORMATIONS_API RoPEShareCosSin;
24+
class TRANSFORMATIONS_API RoPEFusionGPTOSS;
2425

2526
} // namespace pass
2627
} // namespace ov
@@ -91,6 +92,12 @@ class ov::pass::RoPEShareCosSin : public ov::pass::MatcherPass {
9192
std::vector<std::shared_ptr<ov::Node>> m_shared_inputs{2, nullptr};
9293
};
9394

95+
class ov::pass::RoPEFusionGPTOSS : public ov::pass::MatcherPass {
96+
public:
97+
OPENVINO_MATCHER_PASS_RTTI("RoPEFusionGPTOSS");
98+
RoPEFusionGPTOSS();
99+
};
100+
94101
/**
95102
* @ingroup ov_transformation_common_api
96103
* @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation

src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr<ov::Model>& model)
6161
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTNEOX>(4);
6262
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTNEOX>(3);
6363
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTJ>();
64+
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTOSS>();
6465
// optional heads & tails are fused in separate matcher pass,
6566
// after RoPENode has been created.
6667
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionCosSinPreprocess>();
@@ -1111,3 +1112,74 @@ ov::pass::RoPEShareCosSin::RoPEShareCosSin() {
11111112
auto m = std::make_shared<pattern::Matcher>(result, matcher_name);
11121113
this->register_matcher(m, callback);
11131114
}
1115+
1116+
ov::pass::RoPEFusionGPTOSS::RoPEFusionGPTOSS() {
1117+
using namespace ov::op::util;
1118+
MATCHER_SCOPE(RoPEFusionGPTOSS);
1119+
1120+
// gpt-oss style
1121+
// first_half, second_half = torch.chunk(x, 2, dim=-1)
1122+
// first_ = first_half * cos - second_half * sin
1123+
// second_ = second_half * cos + first_half * sin
1124+
// return torch.cat((first_, second_), dim=-1)
1125+
auto x = pattern::any_input(pattern::rank_equals(4));
1126+
auto t_cos = pattern::any_input(pattern::shape_matches("[?, 1, ?, half_ndims]"));
1127+
auto t_sin = pattern::any_input(pattern::shape_matches("[?, 1, ?, half_ndims]"));
1128+
auto varsplit = pattern::wrap_type<v1::VariadicSplit>({x, -1, {"half_ndims", "?"}});
1129+
varsplit->set_output_size(2);
1130+
1131+
auto first_half_mul_cos = pattern::wrap_type<v1::Multiply>({varsplit->output(0), t_cos}, {{"auto_broadcast", "numpy"}});
1132+
auto second_half_mul_sin = pattern::wrap_type<v1::Multiply>({varsplit->output(1), t_sin}, {{"auto_broadcast", "numpy"}});
1133+
auto neg = pattern::wrap_type<v1::Multiply>({second_half_mul_sin, -1.0f}, {{"auto_broadcast", "numpy"}});
1134+
auto sub_Subtract = pattern::wrap_type<v1::Add>({first_half_mul_cos, neg}, {{"auto_broadcast", "numpy"}});
1135+
1136+
auto second_half_mul_cos = pattern::wrap_type<v1::Multiply>({varsplit->output(1), t_cos}, {{"auto_broadcast", "numpy"}});
1137+
auto first_half_mul_sin = pattern::wrap_type<v1::Multiply>({varsplit->output(0), t_sin}, {{"auto_broadcast", "numpy"}});
1138+
auto add_Add = pattern::wrap_type<v1::Add>({second_half_mul_cos, first_half_mul_sin}, {{"auto_broadcast", "numpy"}});
1139+
auto concat_result = pattern::wrap_type<opset1::Concat>({sub_Subtract, add_Add}, {{"axis", -1}});
1140+
1141+
auto result = concat_result;
1142+
1143+
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
1144+
const auto& pattern_map = m.get_pattern_value_map();
1145+
auto root = m.get_match_root();
1146+
const auto& x_val = pattern_map.at(x);
1147+
const auto& v_cos = pattern_map.at(t_cos);
1148+
1149+
auto symbols = m.get_symbols();
1150+
const auto& half_ndims = symbols["half_ndims"];
1151+
if (!half_ndims.is_integer()) {
1152+
return false;
1153+
}
1154+
1155+
op::internal::RoPE::Config config;
1156+
OutputVector new_args;
1157+
config.rotary_ndims = 2ul * static_cast<size_t>(half_ndims.i());
1158+
config.cos_sin_ndims = static_cast<size_t>(half_ndims.i());
1159+
1160+
new_args.push_back(x_val);
1161+
new_args.push_back(v_cos);
1162+
new_args.push_back(pattern_map.at(t_sin));
1163+
auto old_node = root;
1164+
auto new_node = std::make_shared<internal::RoPE>(new_args, config);
1165+
new_node->set_friendly_name(old_node->get_friendly_name());
1166+
ov::copy_runtime_info({
1167+
pattern_map.at(neg).get_node_shared_ptr(),
1168+
pattern_map.at(sub_Subtract).get_node_shared_ptr(),
1169+
pattern_map.at(first_half_mul_cos).get_node_shared_ptr(),
1170+
pattern_map.at(first_half_mul_sin).get_node_shared_ptr(),
1171+
pattern_map.at(second_half_mul_cos).get_node_shared_ptr(),
1172+
pattern_map.at(second_half_mul_sin).get_node_shared_ptr(),
1173+
pattern_map.at(add_Add).get_node_shared_ptr(),
1174+
pattern_map.at(result).get_node_shared_ptr()},
1175+
new_node);
1176+
ov::replace_node(old_node, new_node);
1177+
1178+
// this new node may match following additional matchers
1179+
register_new_node(new_node);
1180+
return true;
1181+
};
1182+
1183+
auto m = std::make_shared<ov::pass::pattern::Matcher>(result, matcher_name);
1184+
this->register_matcher(m, callback);
1185+
}

src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,12 @@ void jit_rotary_kernel<isa>::rotary_half(size_t step) {
9090
store(reg_dst, vmm_dst0, m_jcp.dst_prc, step);
9191

9292
// cos[i + halfRotaryNdims]
93-
load(vmm_cos, reg_cos, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
93+
// if con/sin table is not the same size, it is reused for both halves.
94+
if (!m_jcp.cos_sin_ndims)
95+
load(vmm_cos, reg_cos, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
9496
// sin[i + halfRotaryNdims]
95-
load(vmm_sin, reg_sin, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
97+
if (!m_jcp.cos_sin_ndims)
98+
load(vmm_sin, reg_sin, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
9699
// cos[i + half_rotary_dims] * src1
97100
uni_vmulps(vmm_dst0, vmm_cos, vmm_src1);
98101
// cos[i + half_rotary_dims] * src1 + sin[i + half_rotary_dims] * src0

src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct jit_rotary_compile_params {
2727
ov::element::Type src_prc;
2828
ov::element::Type dst_prc;
2929
size_t rotary_ndims = 0UL;
30+
size_t cos_sin_ndims = 0UL;
3031
bool interleave = false;
3132
bool mix_cos_sin = false;
3233
};

src/plugins/intel_cpu/src/nodes/rope.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
111111
jcp.dst_prc = precision_of<T>::value;
112112
jcp.rotary_ndims = config.rotary_ndims;
113113
jcp.interleave = false;
114+
jcp.cos_sin_ndims = config.cos_sin_ndims;
114115
m_rotaryKernel = createJitKernel(jcp);
115116
}
116117

@@ -157,7 +158,8 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
157158
auto head_cnt = t_src.size(1);
158159
auto seq_len = t_src.size(2);
159160
auto feature_size = t_src.size(3);
160-
161+
auto half_rotary_dims = rotary_dims / 2;
162+
const size_t cos_sin_offset = (m_config.cos_sin_ndims == half_rotary_dims) ? 0 : half_rotary_dims;
161163
parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
162164
auto cos_pos = p;
163165
if (gather) {
@@ -181,7 +183,7 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
181183
auto src0 = src[i];
182184
auto src1 = src[i + half_rotary_dims];
183185
dst[i] = cos[i] * src0 - sin[i] * src1;
184-
dst[i + half_rotary_dims] = cos[i + half_rotary_dims] * src1 + sin[i + half_rotary_dims] * src0;
186+
dst[i + half_rotary_dims] = cos[i + cos_sin_offset] * src1 + sin[i + cos_sin_offset] * src0;
185187
}
186188
}
187189
if (!can_inplace) {

src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,12 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwenVL,
7979
::testing::ValuesIn(vit_param)),
8080
RoPETestQwenVL::getTestCaseName);
8181

82+
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTOSS,
83+
RoPETestGPTOSS,
84+
::testing::Combine(
85+
::testing::Values(ov::element::f32),
86+
::testing::Values(ov::test::utils::DEVICE_CPU)),
87+
RoPETestGPTOSS::getTestCaseName);
88+
8289
} // namespace test
8390
} // namespace ov

src/tests/functional/plugin/shared/include/shared_test_classes/subgraph/rotary_pos_emb.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,5 +198,19 @@ class RoPETestChatGLMHF : public SubgraphBaseTest, public testing::WithParamInte
198198
static std::string getTestCaseName(const testing::TestParamInfo<rope_params>& obj);
199199
};
200200

201+
class RoPETestGPTOSS : public SubgraphBaseTest, public testing::WithParamInterface<rope_params> {
202+
private:
203+
std::shared_ptr<ov::Model> buildROPE_GPTOSS(int num_head,
204+
int rotary_dims,
205+
ov::element::Type element_type);
206+
207+
protected:
208+
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
209+
void SetUp() override;
210+
211+
public:
212+
static std::string getTestCaseName(const testing::TestParamInfo<rope_params>& obj);
213+
};
214+
201215
} // namespace test
202216
} // namespace ov

src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,12 @@ TEST_P(RoPETestChatGLMHF, CompareWithRefs) {
115115
CheckNumberOfNodesWithType(function, {"RoPE"}, 1);
116116
};
117117

118+
TEST_P(RoPETestGPTOSS, CompareWithRefs) {
119+
SKIP_IF_CURRENT_TEST_IS_DISABLED();
120+
run();
121+
auto function = compiledModel.get_runtime_model();
122+
CheckNumberOfNodesWithType(function, {"RoPE"}, 1);
123+
};
124+
118125
} // namespace test
119126
} // namespace ov

src/tests/functional/plugin/shared/src/subgraph/rotary_pos_emb.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,5 +1457,83 @@ std::string RoPETestChatGLMHF::getTestCaseName(const testing::TestParamInfo<rope
14571457
return result.str();
14581458
}
14591459

1460+
std::shared_ptr<ov::Model> RoPETestGPTOSS::buildROPE_GPTOSS(int num_head,
1461+
int rotary_dims,
1462+
ov::element::Type element_type) {
1463+
auto int32_max = std::numeric_limits<std::int32_t>::max();
1464+
auto input = std::make_shared<ov::opset1::Parameter>(element_type, PartialShape{-1, -1, num_head, rotary_dims});
1465+
input->set_friendly_name("input");
1466+
auto permute_Transpose = makeOP<opset1::Transpose>({input, {0, 2, 1, 3}}, {});
1467+
auto cos = std::make_shared<ov::opset1::Parameter>(element_type, PartialShape{-1, 1, -1, rotary_dims/2});
1468+
cos->set_friendly_name("cos");
1469+
auto sin = std::make_shared<ov::opset1::Parameter>(element_type, PartialShape{-1, 1, -1, rotary_dims/2});
1470+
sin->set_friendly_name("sin");
1471+
auto variadicSplit = makeOP<opset1::VariadicSplit>({permute_Transpose, -1, {rotary_dims/2, -1}});
1472+
auto first_half_mul_cos = makeOP<opset1::Multiply>({variadicSplit->output(0), cos}, {{"auto_broadcast", "numpy"}});
1473+
1474+
auto second_half_mul_sin = makeOP<opset1::Multiply>({variadicSplit->output(1), sin}, {{"auto_broadcast", "numpy"}});
1475+
auto neg = makeOP<opset1::Multiply>({second_half_mul_sin, -1.000000f}, {{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,64,?,32]> Multiply_9680(__module.model.layers.0.self_attn/aten::mul/Multiply_1, Constant_9679)
1476+
auto sub_Subtract = makeOP<opset1::Add>({first_half_mul_cos, neg}, {{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,64,?,32]> __module.model.layers.0.self_attn/aten::sub/Subtract(__module.model.layers.0.self_attn/aten::mul/Multiply, Multiply_9680)
1477+
1478+
1479+
auto second_half_mul_cos = makeOP<opset1::Multiply>({variadicSplit->output(1), cos}, {{"auto_broadcast", "numpy"}});
1480+
auto first_half_mul_sin = makeOP<opset1::Multiply>({variadicSplit->output(0), sin}, {{"auto_broadcast", "numpy"}});
1481+
auto add_Add = makeOP<opset1::Add>({second_half_mul_cos, first_half_mul_sin}, {{"auto_broadcast", "numpy"}});
1482+
auto cat_Concat = makeOP<opset1::Concat>({sub_Subtract, add_Add}, {{"axis", -1}});
1483+
return std::make_shared<ov::Model>(cat_Concat, ov::ParameterVector{input, cos, sin});
1484+
}
1485+
1486+
void RoPETestGPTOSS::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
1487+
const auto& funcInputs = function->inputs();
1488+
1489+
auto& input_shape = targetInputStaticShapes[0];
1490+
auto& cos_shape = targetInputStaticShapes[1];
1491+
auto& sin_shape = targetInputStaticShapes[2];
1492+
1493+
ov::test::utils::InputGenerateData in_data;
1494+
in_data.start_from = -1;
1495+
in_data.range = 2;
1496+
in_data.resolution = 32768;
1497+
auto cos_data = in_data;
1498+
cos_data.seed = 10;
1499+
1500+
auto sin_data = in_data;
1501+
sin_data.seed = 20;
1502+
ov::Tensor t_input = utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, in_data);
1503+
ov::Tensor t_cos_cache =
1504+
utils::create_and_fill_tensor(funcInputs[1].get_element_type(), cos_shape, cos_data);
1505+
ov::Tensor t_sin_cache =
1506+
utils::create_and_fill_tensor(funcInputs[2].get_element_type(), sin_shape, sin_data);
1507+
1508+
inputs.clear();
1509+
inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input});
1510+
inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_cache});
1511+
inputs.insert({funcInputs[2].get_node_shared_ptr(), t_sin_cache});
1512+
}
1513+
1514+
std::string RoPETestGPTOSS::getTestCaseName(const testing::TestParamInfo<rope_params>& obj) {
1515+
const auto& [element_type, targetDevice] = obj.param;
1516+
std::ostringstream result;
1517+
result << "type=" << element_type << "_"
1518+
<< "targetDevice=" << targetDevice;
1519+
return result.str();
1520+
}
1521+
1522+
void RoPETestGPTOSS::SetUp() {
1523+
const auto& [element_type, _targetDevice] = this->GetParam();
1524+
targetDevice = _targetDevice;
1525+
1526+
const int batch = 5;
1527+
const int seq_length = 7;
1528+
const int num_head = 8;
1529+
const int rotary_dims = 64;
1530+
1531+
InputShape input = {{batch, seq_length, num_head, rotary_dims}, {{batch, seq_length, num_head, rotary_dims}}};
1532+
InputShape cos = {{batch, 1, seq_length, rotary_dims/2}, {{batch, 1, seq_length, rotary_dims/2}}};
1533+
InputShape sin = {{batch, 1, seq_length, rotary_dims/2}, {{batch, 1, seq_length, rotary_dims/2}}};
1534+
init_input_shapes({input, cos, sin});
1535+
function = buildROPE_GPTOSS(num_head, rotary_dims, element_type);
1536+
}
1537+
14601538
} // namespace test
14611539
} // namespace ov

0 commit comments

Comments
 (0)