diff --git a/src/parallax/models/deepseek_v2.py b/src/parallax/models/deepseek_v2.py index fe55ab60..9287d95a 100644 --- a/src/parallax/models/deepseek_v2.py +++ b/src/parallax/models/deepseek_v2.py @@ -62,21 +62,14 @@ def __call__( k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) k_nope = k_nope.transpose(0, 2, 1, 3) - # q_pe = self.rope(q_pe, offset=offset) - # k_pe = self.rope(k_pe, offset=offset) key_cache_global, value_cache_global = cache.get_cache() - q_pe_list = [] - k_pe_list = [] - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = q_pe[i : i + 1] - k_slice = k_pe[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - q_pe_list.append(q_rot) - k_pe_list.append(k_rot) - q_pe = mx.concatenate(q_pe_list, axis=0) - k_pe = mx.concatenate(k_pe_list, axis=0) + + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + q_pe = self.rope(q_pe, offset=current_pos) + k_pe = self.rope(k_pe, offset=current_pos) k_pe = mx.repeat(k_pe, self.num_heads, axis=1) queries = mx.concatenate([q_nope, q_pe], axis=-1) diff --git a/src/parallax/models/deepseek_v3.py b/src/parallax/models/deepseek_v3.py index 01049ab9..f228238c 100644 --- a/src/parallax/models/deepseek_v3.py +++ b/src/parallax/models/deepseek_v3.py @@ -64,23 +64,14 @@ def __call__( k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) k_nope = k_nope.transpose(0, 2, 1, 3) - # q_pe = self.rope(q_pe, offset=offset) - # k_pe = self.rope(k_pe, offset=offset) key_cache_global, value_cache_global = cache.get_cache() - q_pe_list = [] - k_pe_list = [] - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = q_pe[i : i + 1] - k_slice = k_pe[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - q_pe_list.append(q_rot) - k_pe_list.append(k_rot) - - q_pe = mx.concatenate(q_pe_list, axis=0) - k_pe = mx.concatenate(k_pe_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + q_pe = self.rope(q_pe, offset=current_pos) + k_pe = self.rope(k_pe, offset=current_pos) k_pe = mx.repeat(k_pe, self.num_heads, axis=1) queries = mx.concatenate([q_nope, q_pe], axis=-1) diff --git a/src/parallax/models/deepseek_v32.py b/src/parallax/models/deepseek_v32.py index 7982647e..440f12ba 100644 --- a/src/parallax/models/deepseek_v32.py +++ b/src/parallax/models/deepseek_v32.py @@ -141,18 +141,13 @@ def __call__( k_nope = k_nope.transpose(0, 2, 1, 3) key_cache_global, value_cache_global = cache.get_cache() indexer_cache = cache.get_indexer_cache() - q_pe_list = [] - k_pe_list = [] - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = q_pe[i : i + 1] - k_slice = k_pe[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - q_pe_list.append(q_rot) - k_pe_list.append(k_rot) - q_pe = mx.concatenate(q_pe_list, axis=0) - k_pe = mx.concatenate(k_pe_list, axis=0) + + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + q_pe = self.rope(q_pe, offset=current_pos) + k_pe = self.rope(k_pe, offset=current_pos) k_pe = mx.repeat(k_pe, self.num_heads, axis=1) queries = mx.concatenate([q_nope, q_pe], axis=-1) diff --git a/src/parallax/models/glm4_moe.py b/src/parallax/models/glm4_moe.py index 46a69a38..8002f440 100644 --- a/src/parallax/models/glm4_moe.py +++ b/src/parallax/models/glm4_moe.py @@ -38,20 +38,12 @@ def __call__( key_cache_global, value_cache_global = cache.get_cache() - queries_rotated_list = [] - keys_rotated_list = [] - - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = queries_new[i : i + 1] - k_slice = keys_new[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - queries_rotated_list.append(q_rot) - keys_rotated_list.append(k_rot) - - queries_rotated = mx.concatenate(queries_rotated_list, axis=0) - keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + queries_rotated = self.rope(queries_new, offset=current_pos) + keys_rotated = self.rope(keys_new, offset=current_pos) block_size = key_cache_global.shape[3] diff --git a/src/parallax/models/gpt_oss.py b/src/parallax/models/gpt_oss.py index 00d8de59..43723b1e 100644 --- a/src/parallax/models/gpt_oss.py +++ b/src/parallax/models/gpt_oss.py @@ -51,19 +51,12 @@ def __call__( key_cache_global, value_cache_global = cache.get_cache() - queries_rotated_list = [] - keys_rotated_list = [] - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = queries_new[i : i + 1] - k_slice = keys_new[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - queries_rotated_list.append(q_rot) - keys_rotated_list.append(k_rot) - - queries_rotated = mx.concatenate(queries_rotated_list, axis=0) - keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + queries_rotated = self.rope(queries_new, offset=current_pos) + keys_rotated = self.rope(keys_new, offset=current_pos) # Update Paged Cache block_size = key_cache_global.shape[3] diff --git a/src/parallax/models/llama.py b/src/parallax/models/llama.py index a4127a14..352e9811 100644 --- a/src/parallax/models/llama.py +++ b/src/parallax/models/llama.py @@ -61,20 +61,12 @@ def __call__( key_cache_global, value_cache_global = cache.get_cache() - queries_rotated_list = [] - keys_rotated_list = [] - - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = queries_new[i : i + 1] - k_slice = keys_new[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - queries_rotated_list.append(q_rot) - keys_rotated_list.append(k_rot) - - queries_rotated = mx.concatenate(queries_rotated_list, axis=0) - keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + queries_rotated = self.rope(queries_new, offset=current_pos) + keys_rotated = self.rope(keys_new, offset=current_pos) block_size = key_cache_global.shape[3] diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index 06737402..4e909d14 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -45,20 +45,12 @@ def __call__( key_cache_global, value_cache_global = cache.get_cache() - queries_rotated_list = [] - keys_rotated_list = [] - - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = queries_new[i : i + 1] - k_slice = keys_new[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - queries_rotated_list.append(q_rot) - keys_rotated_list.append(k_rot) - - queries_rotated = mx.concatenate(queries_rotated_list, axis=0) - keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + queries_rotated = self.rope(queries_new, offset=current_pos) + keys_rotated = self.rope(keys_new, offset=current_pos) block_size = key_cache_global.shape[3] diff --git a/src/parallax/models/qwen2.py b/src/parallax/models/qwen2.py index 0b70fd6b..b770339d 100644 --- a/src/parallax/models/qwen2.py +++ b/src/parallax/models/qwen2.py @@ -56,20 +56,12 @@ def __call__( key_cache_global, value_cache_global = cache.get_cache() - queries_rotated_list = [] - keys_rotated_list = [] - - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = queries_new[i : i + 1] - k_slice = keys_new[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - queries_rotated_list.append(q_rot) - keys_rotated_list.append(k_rot) - - queries_rotated = mx.concatenate(queries_rotated_list, axis=0) - keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + queries_rotated = self.rope(queries_new, offset=current_pos) + keys_rotated = self.rope(keys_new, offset=current_pos) block_size = key_cache_global.shape[3] diff --git a/src/parallax/models/qwen3.py b/src/parallax/models/qwen3.py index 806221af..ab24b1f6 100644 --- a/src/parallax/models/qwen3.py +++ b/src/parallax/models/qwen3.py @@ -67,27 +67,14 @@ def __call__( key_cache_global, value_cache_global = cache.get_cache() - queries_rotated_list = [] - keys_rotated_list = [] - - for i in range(batch): - # For decode phase: position is context_length - 1 - # For prefill phase: position starts at prefix_len (skip cached prefix tokens) - if target_len == 1: - # Decode phase - current_pos = int(context_lengths[i]) - 1 - else: - # Prefill phase - start from prefix_len if using prefix cache - current_pos = int(prefix_lens[i]) if prefix_lens is not None else 0 - q_slice = queries_new[i : i + 1] - k_slice = keys_new[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - queries_rotated_list.append(q_rot) - keys_rotated_list.append(k_rot) - - queries_rotated = mx.concatenate(queries_rotated_list, axis=0) - keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + elif prefix_lens is not None: + current_pos = prefix_lens + else: + current_pos = 0 + queries_rotated = self.rope(queries_new, offset=current_pos) + keys_rotated = self.rope(keys_new, offset=current_pos) block_size = key_cache_global.shape[3] diff --git a/src/parallax/models/qwen3_moe.py b/src/parallax/models/qwen3_moe.py index eba35314..024c69e9 100644 --- a/src/parallax/models/qwen3_moe.py +++ b/src/parallax/models/qwen3_moe.py @@ -61,20 +61,12 @@ def __call__( key_cache_global, value_cache_global = cache.get_cache() - queries_rotated_list = [] - keys_rotated_list = [] - - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = queries_new[i : i + 1] - k_slice = keys_new[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - queries_rotated_list.append(q_rot) - keys_rotated_list.append(k_rot) - - queries_rotated = mx.concatenate(queries_rotated_list, axis=0) - keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + queries_rotated = self.rope(queries_new, offset=current_pos) + keys_rotated = self.rope(keys_new, offset=current_pos) block_size = key_cache_global.shape[3] diff --git a/src/parallax/models/qwen3_next.py b/src/parallax/models/qwen3_next.py index 3bdb5e8d..356f77ee 100644 --- a/src/parallax/models/qwen3_next.py +++ b/src/parallax/models/qwen3_next.py @@ -47,18 +47,12 @@ def __call__( key_cache_global, value_cache_global = cache.get_cache() - queries_rotated_list = [] - keys_rotated_list = [] - for i in range(batch): - current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 - q_slice = queries_new[i : i + 1] - k_slice = keys_new[i : i + 1] - q_rot = self.rope(q_slice, offset=current_pos) - k_rot = self.rope(k_slice, offset=current_pos) - queries_rotated_list.append(q_rot) - keys_rotated_list.append(k_rot) - queries_rotated = mx.concatenate(queries_rotated_list, axis=0) - keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + if target_len == 1: + current_pos = context_lengths - 1 + else: + current_pos = 0 + queries_rotated = self.rope(queries_new, offset=current_pos) + keys_rotated = self.rope(keys_new, offset=current_pos) block_size = key_cache_global.shape[3] reshape_and_cache(