Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions src/parallax/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,14 @@ def __call__(
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_nope = k_nope.transpose(0, 2, 1, 3)

# q_pe = self.rope(q_pe, offset=offset)
# k_pe = self.rope(k_pe, offset=offset)
key_cache_global, value_cache_global = cache.get_cache()
q_pe_list = []
k_pe_list = []
for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = q_pe[i : i + 1]
k_slice = k_pe[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
q_pe_list.append(q_rot)
k_pe_list.append(k_rot)
q_pe = mx.concatenate(q_pe_list, axis=0)
k_pe = mx.concatenate(k_pe_list, axis=0)

if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
q_pe = self.rope(q_pe, offset=current_pos)
k_pe = self.rope(k_pe, offset=current_pos)

k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
Expand Down
21 changes: 6 additions & 15 deletions src/parallax/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,14 @@ def __call__(
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_nope = k_nope.transpose(0, 2, 1, 3)

# q_pe = self.rope(q_pe, offset=offset)
# k_pe = self.rope(k_pe, offset=offset)
key_cache_global, value_cache_global = cache.get_cache()

q_pe_list = []
k_pe_list = []
for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = q_pe[i : i + 1]
k_slice = k_pe[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
q_pe_list.append(q_rot)
k_pe_list.append(k_rot)

q_pe = mx.concatenate(q_pe_list, axis=0)
k_pe = mx.concatenate(k_pe_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
q_pe = self.rope(q_pe, offset=current_pos)
k_pe = self.rope(k_pe, offset=current_pos)

k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
Expand Down
19 changes: 7 additions & 12 deletions src/parallax/models/deepseek_v32.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,13 @@ def __call__(
k_nope = k_nope.transpose(0, 2, 1, 3)
key_cache_global, value_cache_global = cache.get_cache()
indexer_cache = cache.get_indexer_cache()
q_pe_list = []
k_pe_list = []
for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = q_pe[i : i + 1]
k_slice = k_pe[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
q_pe_list.append(q_rot)
k_pe_list.append(k_rot)
q_pe = mx.concatenate(q_pe_list, axis=0)
k_pe = mx.concatenate(k_pe_list, axis=0)

if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
q_pe = self.rope(q_pe, offset=current_pos)
k_pe = self.rope(k_pe, offset=current_pos)

k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
Expand Down
20 changes: 6 additions & 14 deletions src/parallax/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,12 @@ def __call__(

key_cache_global, value_cache_global = cache.get_cache()

queries_rotated_list = []
keys_rotated_list = []

for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = queries_new[i : i + 1]
k_slice = keys_new[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
queries_rotated_list.append(q_rot)
keys_rotated_list.append(k_rot)

queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
queries_rotated = self.rope(queries_new, offset=current_pos)
keys_rotated = self.rope(keys_new, offset=current_pos)

block_size = key_cache_global.shape[3]

Expand Down
19 changes: 6 additions & 13 deletions src/parallax/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,12 @@ def __call__(

key_cache_global, value_cache_global = cache.get_cache()

queries_rotated_list = []
keys_rotated_list = []
for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = queries_new[i : i + 1]
k_slice = keys_new[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
queries_rotated_list.append(q_rot)
keys_rotated_list.append(k_rot)

queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
queries_rotated = self.rope(queries_new, offset=current_pos)
keys_rotated = self.rope(keys_new, offset=current_pos)

# Update Paged Cache
block_size = key_cache_global.shape[3]
Expand Down
20 changes: 6 additions & 14 deletions src/parallax/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,12 @@ def __call__(

key_cache_global, value_cache_global = cache.get_cache()

queries_rotated_list = []
keys_rotated_list = []

for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = queries_new[i : i + 1]
k_slice = keys_new[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
queries_rotated_list.append(q_rot)
keys_rotated_list.append(k_rot)

queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
queries_rotated = self.rope(queries_new, offset=current_pos)
keys_rotated = self.rope(keys_new, offset=current_pos)

block_size = key_cache_global.shape[3]

Expand Down
20 changes: 6 additions & 14 deletions src/parallax/models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,12 @@ def __call__(

key_cache_global, value_cache_global = cache.get_cache()

queries_rotated_list = []
keys_rotated_list = []

for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = queries_new[i : i + 1]
k_slice = keys_new[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
queries_rotated_list.append(q_rot)
keys_rotated_list.append(k_rot)

queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
queries_rotated = self.rope(queries_new, offset=current_pos)
keys_rotated = self.rope(keys_new, offset=current_pos)

block_size = key_cache_global.shape[3]

Expand Down
20 changes: 6 additions & 14 deletions src/parallax/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,12 @@ def __call__(

key_cache_global, value_cache_global = cache.get_cache()

queries_rotated_list = []
keys_rotated_list = []

for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = queries_new[i : i + 1]
k_slice = keys_new[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
queries_rotated_list.append(q_rot)
keys_rotated_list.append(k_rot)

queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
queries_rotated = self.rope(queries_new, offset=current_pos)
keys_rotated = self.rope(keys_new, offset=current_pos)

block_size = key_cache_global.shape[3]

Expand Down
29 changes: 8 additions & 21 deletions src/parallax/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,14 @@ def __call__(

key_cache_global, value_cache_global = cache.get_cache()

queries_rotated_list = []
keys_rotated_list = []

for i in range(batch):
# For decode phase: position is context_length - 1
# For prefill phase: position starts at prefix_len (skip cached prefix tokens)
if target_len == 1:
# Decode phase
current_pos = int(context_lengths[i]) - 1
else:
# Prefill phase - start from prefix_len if using prefix cache
current_pos = int(prefix_lens[i]) if prefix_lens is not None else 0
q_slice = queries_new[i : i + 1]
k_slice = keys_new[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
queries_rotated_list.append(q_rot)
keys_rotated_list.append(k_rot)

queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
elif prefix_lens is not None:
current_pos = prefix_lens
else:
current_pos = 0
queries_rotated = self.rope(queries_new, offset=current_pos)
keys_rotated = self.rope(keys_new, offset=current_pos)

block_size = key_cache_global.shape[3]

Expand Down
20 changes: 6 additions & 14 deletions src/parallax/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,12 @@ def __call__(

key_cache_global, value_cache_global = cache.get_cache()

queries_rotated_list = []
keys_rotated_list = []

for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = queries_new[i : i + 1]
k_slice = keys_new[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
queries_rotated_list.append(q_rot)
keys_rotated_list.append(k_rot)

queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
queries_rotated = self.rope(queries_new, offset=current_pos)
keys_rotated = self.rope(keys_new, offset=current_pos)

block_size = key_cache_global.shape[3]

Expand Down
18 changes: 6 additions & 12 deletions src/parallax/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,12 @@ def __call__(

key_cache_global, value_cache_global = cache.get_cache()

queries_rotated_list = []
keys_rotated_list = []
for i in range(batch):
current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0
q_slice = queries_new[i : i + 1]
k_slice = keys_new[i : i + 1]
q_rot = self.rope(q_slice, offset=current_pos)
k_rot = self.rope(k_slice, offset=current_pos)
queries_rotated_list.append(q_rot)
keys_rotated_list.append(k_rot)
queries_rotated = mx.concatenate(queries_rotated_list, axis=0)
keys_rotated = mx.concatenate(keys_rotated_list, axis=0)
if target_len == 1:
current_pos = context_lengths - 1
else:
current_pos = 0
queries_rotated = self.rope(queries_new, offset=current_pos)
keys_rotated = self.rope(keys_new, offset=current_pos)

block_size = key_cache_global.shape[3]
reshape_and_cache(
Expand Down