Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] consider relevant code in compilation cache #11614

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Dec 30, 2024

example output:

$ vllm serve meta-llama/Meta-Llama-3-8B -O3
...
DEBUG 12-29 21:04:52 backends.py:495] Traced files (to be considered for compilation cache):
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/container.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/attention/layer.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/distributed/communication_op.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/distributed/parallel_state.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/model_executor/custom_op.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/model_executor/layers/activation.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/model_executor/layers/layernorm.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/model_executor/layers/linear.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/model_executor/layers/rotary_embedding.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/model_executor/layers/vocab_parallel_embedding.py
DEBUG 12-29 21:04:52 backends.py:495] /data/youkaichao/vllm/vllm/model_executor/models/llama.py
...

Looks pretty good.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member Author

/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/container.py

this line is relevant because we use ModuleList, and the forward pass will iterate over the module list.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool

Comment on lines +1099 to +1104
def __init__(self, tensors):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something that we'll need to do for every dataclass that's used during model execution?

Comment on lines +508 to +509
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not take the whole hash?

Comment on lines +201 to +213
self.vllm_config.compilation_config.traced_files.add(
self.original_code_object.co_filename)
inline_call = InliningInstructionTranslator.inline_call

def patched_inline_call(parent, func, args, kwargs):
code = func.get_code()
self.vllm_config.compilation_config.traced_files.add(
code.co_filename)
return inline_call(parent, func, args, kwargs)

with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call):
output = self.compiled_callable(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining how and why we are adding the file names here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants