Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def execute_model(

t0 = time.time()

self._update_states(scheduler_output)
# TODO: change to EMPTY_MODEL_RUNNER_OUTPUT, right now this
# will be a breaking change, or clumsy to make retrocompatible
# with conditional import
Expand All @@ -446,8 +447,6 @@ def execute_model(
prompt_logprobs_dict={},
)

self._update_states(scheduler_output)

model_input = self.prepare_model_input(scheduler_output)
self._mark_input_tensors(model_input)

Expand Down Expand Up @@ -916,41 +915,7 @@ def execute_model(
)

model_input = self.prepare_model_input(scheduler_output)

# Marking dimensions static/dynamic
if model_input.is_prompt:

# batch static (batch size 1)
torch._dynamo.mark_static(model_input.input_tokens, 0)
torch._dynamo.mark_static(model_input.slot_mapping, 0)
torch._dynamo.mark_static(model_input.input_positions, 0)
torch._dynamo.mark_static(model_input.input_masks, 0)

# sequence dynamic
torch._dynamo.mark_dynamic(model_input.input_tokens, 1)
torch._dynamo.mark_dynamic(model_input.slot_mapping, 1)
torch._dynamo.mark_dynamic(model_input.input_positions, 1)
torch._dynamo.mark_dynamic(model_input.input_masks, 2)
torch._dynamo.mark_dynamic(model_input.input_masks, 3)

# decode
else:
# mask is no longer used here

# batch dynamic
torch._dynamo.mark_dynamic(model_input.input_tokens, 0)
torch._dynamo.mark_dynamic(model_input.block_table, 0)
torch._dynamo.mark_dynamic(model_input.slot_mapping, 0)
torch._dynamo.mark_dynamic(model_input.input_positions, 0)
torch._dynamo.mark_dynamic(model_input.current_tkv_mask, 0)
torch._dynamo.mark_dynamic(model_input.left_padded_prompt_mask, 0)

# sequence
torch._dynamo.mark_static(model_input.input_tokens, 1) # always 1
torch._dynamo.mark_dynamic(model_input.block_table, 1)
torch._dynamo.mark_static(model_input.slot_mapping, 1) # always 1
torch._dynamo.mark_static(model_input.input_positions,
1) # always 1
self._mark_input_tensors(model_input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this function to be using self? It only interacts with the model_input param

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are totally right, I copied it from the static batching code, where it also uses self which can be removed. FYI: @wallashss


# Execute the model
hidden_states = self.model(
Expand Down Expand Up @@ -1001,3 +966,39 @@ def execute_model(
tkv=self.tkv,
)
return model_output

def _mark_input_tensors(self, model_input: ModelForwardInputs) -> None:
# Marking dimensions static/dynamic
if model_input.is_prompt:

# batch static (batch size 1)
torch._dynamo.mark_static(model_input.input_tokens, 0)
torch._dynamo.mark_static(model_input.slot_mapping, 0)
torch._dynamo.mark_static(model_input.input_positions, 0)
torch._dynamo.mark_static(model_input.input_masks, 0)

# sequence dynamic
torch._dynamo.mark_dynamic(model_input.input_tokens, 1)
torch._dynamo.mark_dynamic(model_input.slot_mapping, 1)
torch._dynamo.mark_dynamic(model_input.input_positions, 1)
torch._dynamo.mark_dynamic(model_input.input_masks, 2)
torch._dynamo.mark_dynamic(model_input.input_masks, 3)

# decode
else:
# mask is no longer used here

# batch dynamic
torch._dynamo.mark_dynamic(model_input.input_tokens, 0)
torch._dynamo.mark_dynamic(model_input.block_table, 0)
torch._dynamo.mark_dynamic(model_input.slot_mapping, 0)
torch._dynamo.mark_dynamic(model_input.input_positions, 0)
torch._dynamo.mark_dynamic(model_input.current_tkv_mask, 0)
torch._dynamo.mark_dynamic(model_input.left_padded_prompt_mask, 0)

# sequence
torch._dynamo.mark_static(model_input.input_tokens, 1) # always 1
torch._dynamo.mark_dynamic(model_input.block_table, 1)
torch._dynamo.mark_static(model_input.slot_mapping, 1) # always 1
torch._dynamo.mark_static(model_input.input_positions,
1) # always 1