Skip to content

Commit 780fb09

Browse files
perf(mlx): rope supports batch offset (#379)
1 parent 39127ce commit 780fb09

11 files changed

Lines changed: 70 additions & 157 deletions

File tree

src/parallax/models/deepseek_v2.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,14 @@ def __call__(
6262
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
6363
k_nope = k_nope.transpose(0, 2, 1, 3)
6464

65-
# q_pe = self.rope(q_pe, offset=offset)
66-
# k_pe = self.rope(k_pe, offset=offset)
6765
key_cache_global, value_cache_global = cache.get_cache()
68-
q_pe_list = []
69-
k_pe_list = []
70-
for i in range(batch):
71-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
72-
q_slice = q_pe[i : i + 1]
73-
k_slice = k_pe[i : i + 1]
74-
q_rot = self.rope(q_slice, offset=current_pos)
75-
k_rot = self.rope(k_slice, offset=current_pos)
76-
q_pe_list.append(q_rot)
77-
k_pe_list.append(k_rot)
78-
q_pe = mx.concatenate(q_pe_list, axis=0)
79-
k_pe = mx.concatenate(k_pe_list, axis=0)
66+
67+
if target_len == 1:
68+
current_pos = context_lengths - 1
69+
else:
70+
current_pos = 0
71+
q_pe = self.rope(q_pe, offset=current_pos)
72+
k_pe = self.rope(k_pe, offset=current_pos)
8073

8174
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
8275
queries = mx.concatenate([q_nope, q_pe], axis=-1)

src/parallax/models/deepseek_v3.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,14 @@ def __call__(
6464
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
6565
k_nope = k_nope.transpose(0, 2, 1, 3)
6666

67-
# q_pe = self.rope(q_pe, offset=offset)
68-
# k_pe = self.rope(k_pe, offset=offset)
6967
key_cache_global, value_cache_global = cache.get_cache()
7068

71-
q_pe_list = []
72-
k_pe_list = []
73-
for i in range(batch):
74-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
75-
q_slice = q_pe[i : i + 1]
76-
k_slice = k_pe[i : i + 1]
77-
q_rot = self.rope(q_slice, offset=current_pos)
78-
k_rot = self.rope(k_slice, offset=current_pos)
79-
q_pe_list.append(q_rot)
80-
k_pe_list.append(k_rot)
81-
82-
q_pe = mx.concatenate(q_pe_list, axis=0)
83-
k_pe = mx.concatenate(k_pe_list, axis=0)
69+
if target_len == 1:
70+
current_pos = context_lengths - 1
71+
else:
72+
current_pos = 0
73+
q_pe = self.rope(q_pe, offset=current_pos)
74+
k_pe = self.rope(k_pe, offset=current_pos)
8475

8576
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
8677
queries = mx.concatenate([q_nope, q_pe], axis=-1)

src/parallax/models/deepseek_v32.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,13 @@ def __call__(
141141
k_nope = k_nope.transpose(0, 2, 1, 3)
142142
key_cache_global, value_cache_global = cache.get_cache()
143143
indexer_cache = cache.get_indexer_cache()
144-
q_pe_list = []
145-
k_pe_list = []
146-
for i in range(batch):
147-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
148-
q_slice = q_pe[i : i + 1]
149-
k_slice = k_pe[i : i + 1]
150-
q_rot = self.rope(q_slice, offset=current_pos)
151-
k_rot = self.rope(k_slice, offset=current_pos)
152-
q_pe_list.append(q_rot)
153-
k_pe_list.append(k_rot)
154-
q_pe = mx.concatenate(q_pe_list, axis=0)
155-
k_pe = mx.concatenate(k_pe_list, axis=0)
144+
145+
if target_len == 1:
146+
current_pos = context_lengths - 1
147+
else:
148+
current_pos = 0
149+
q_pe = self.rope(q_pe, offset=current_pos)
150+
k_pe = self.rope(k_pe, offset=current_pos)
156151

157152
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
158153
queries = mx.concatenate([q_nope, q_pe], axis=-1)

src/parallax/models/glm4_moe.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,12 @@ def __call__(
3838

3939
key_cache_global, value_cache_global = cache.get_cache()
4040

41-
queries_rotated_list = []
42-
keys_rotated_list = []
43-
44-
for i in range(batch):
45-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
46-
q_slice = queries_new[i : i + 1]
47-
k_slice = keys_new[i : i + 1]
48-
q_rot = self.rope(q_slice, offset=current_pos)
49-
k_rot = self.rope(k_slice, offset=current_pos)
50-
queries_rotated_list.append(q_rot)
51-
keys_rotated_list.append(k_rot)
52-
53-
queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
54-
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
41+
if target_len == 1:
42+
current_pos = context_lengths - 1
43+
else:
44+
current_pos = 0
45+
queries_rotated = self.rope(queries_new, offset=current_pos)
46+
keys_rotated = self.rope(keys_new, offset=current_pos)
5547

5648
block_size = key_cache_global.shape[3]
5749

src/parallax/models/gpt_oss.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,12 @@ def __call__(
5151

5252
key_cache_global, value_cache_global = cache.get_cache()
5353

54-
queries_rotated_list = []
55-
keys_rotated_list = []
56-
for i in range(batch):
57-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
58-
q_slice = queries_new[i : i + 1]
59-
k_slice = keys_new[i : i + 1]
60-
q_rot = self.rope(q_slice, offset=current_pos)
61-
k_rot = self.rope(k_slice, offset=current_pos)
62-
queries_rotated_list.append(q_rot)
63-
keys_rotated_list.append(k_rot)
64-
65-
queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
66-
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
54+
if target_len == 1:
55+
current_pos = context_lengths - 1
56+
else:
57+
current_pos = 0
58+
queries_rotated = self.rope(queries_new, offset=current_pos)
59+
keys_rotated = self.rope(keys_new, offset=current_pos)
6760

6861
# Update Paged Cache
6962
block_size = key_cache_global.shape[3]

src/parallax/models/llama.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,12 @@ def __call__(
6161

6262
key_cache_global, value_cache_global = cache.get_cache()
6363

64-
queries_rotated_list = []
65-
keys_rotated_list = []
66-
67-
for i in range(batch):
68-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
69-
q_slice = queries_new[i : i + 1]
70-
k_slice = keys_new[i : i + 1]
71-
q_rot = self.rope(q_slice, offset=current_pos)
72-
k_rot = self.rope(k_slice, offset=current_pos)
73-
queries_rotated_list.append(q_rot)
74-
keys_rotated_list.append(k_rot)
75-
76-
queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
77-
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
64+
if target_len == 1:
65+
current_pos = context_lengths - 1
66+
else:
67+
current_pos = 0
68+
queries_rotated = self.rope(queries_new, offset=current_pos)
69+
keys_rotated = self.rope(keys_new, offset=current_pos)
7870

7971
block_size = key_cache_global.shape[3]
8072

src/parallax/models/minimax.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,12 @@ def __call__(
4545

4646
key_cache_global, value_cache_global = cache.get_cache()
4747

48-
queries_rotated_list = []
49-
keys_rotated_list = []
50-
51-
for i in range(batch):
52-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
53-
q_slice = queries_new[i : i + 1]
54-
k_slice = keys_new[i : i + 1]
55-
q_rot = self.rope(q_slice, offset=current_pos)
56-
k_rot = self.rope(k_slice, offset=current_pos)
57-
queries_rotated_list.append(q_rot)
58-
keys_rotated_list.append(k_rot)
59-
60-
queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
61-
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
48+
if target_len == 1:
49+
current_pos = context_lengths - 1
50+
else:
51+
current_pos = 0
52+
queries_rotated = self.rope(queries_new, offset=current_pos)
53+
keys_rotated = self.rope(keys_new, offset=current_pos)
6254

6355
block_size = key_cache_global.shape[3]
6456

src/parallax/models/qwen2.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,12 @@ def __call__(
5656

5757
key_cache_global, value_cache_global = cache.get_cache()
5858

59-
queries_rotated_list = []
60-
keys_rotated_list = []
61-
62-
for i in range(batch):
63-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
64-
q_slice = queries_new[i : i + 1]
65-
k_slice = keys_new[i : i + 1]
66-
q_rot = self.rope(q_slice, offset=current_pos)
67-
k_rot = self.rope(k_slice, offset=current_pos)
68-
queries_rotated_list.append(q_rot)
69-
keys_rotated_list.append(k_rot)
70-
71-
queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
72-
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
59+
if target_len == 1:
60+
current_pos = context_lengths - 1
61+
else:
62+
current_pos = 0
63+
queries_rotated = self.rope(queries_new, offset=current_pos)
64+
keys_rotated = self.rope(keys_new, offset=current_pos)
7365

7466
block_size = key_cache_global.shape[3]
7567

src/parallax/models/qwen3.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,14 @@ def __call__(
6767

6868
key_cache_global, value_cache_global = cache.get_cache()
6969

70-
queries_rotated_list = []
71-
keys_rotated_list = []
72-
73-
for i in range(batch):
74-
# For decode phase: position is context_length - 1
75-
# For prefill phase: position starts at prefix_len (skip cached prefix tokens)
76-
if target_len == 1:
77-
# Decode phase
78-
current_pos = int(context_lengths[i]) - 1
79-
else:
80-
# Prefill phase - start from prefix_len if using prefix cache
81-
current_pos = int(prefix_lens[i]) if prefix_lens is not None else 0
82-
q_slice = queries_new[i : i + 1]
83-
k_slice = keys_new[i : i + 1]
84-
q_rot = self.rope(q_slice, offset=current_pos)
85-
k_rot = self.rope(k_slice, offset=current_pos)
86-
queries_rotated_list.append(q_rot)
87-
keys_rotated_list.append(k_rot)
88-
89-
queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
90-
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
70+
if target_len == 1:
71+
current_pos = context_lengths - 1
72+
elif prefix_lens is not None:
73+
current_pos = prefix_lens
74+
else:
75+
current_pos = 0
76+
queries_rotated = self.rope(queries_new, offset=current_pos)
77+
keys_rotated = self.rope(keys_new, offset=current_pos)
9178

9279
block_size = key_cache_global.shape[3]
9380

src/parallax/models/qwen3_moe.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,12 @@ def __call__(
6161

6262
key_cache_global, value_cache_global = cache.get_cache()
6363

64-
queries_rotated_list = []
65-
keys_rotated_list = []
66-
67-
for i in range(batch):
68-
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
69-
q_slice = queries_new[i : i + 1]
70-
k_slice = keys_new[i : i + 1]
71-
q_rot = self.rope(q_slice, offset=current_pos)
72-
k_rot = self.rope(k_slice, offset=current_pos)
73-
queries_rotated_list.append(q_rot)
74-
keys_rotated_list.append(k_rot)
75-
76-
queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
77-
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
64+
if target_len == 1:
65+
current_pos = context_lengths - 1
66+
else:
67+
current_pos = 0
68+
queries_rotated = self.rope(queries_new, offset=current_pos)
69+
keys_rotated = self.rope(keys_new, offset=current_pos)
7870

7971
block_size = key_cache_global.shape[3]
8072

0 commit comments

Comments
 (0)