33Run `python -m pytest tests/test_spyre_warmup_shapes.py`.
44"""
55
6+ import os
7+
68import pytest
79from spyre_util import (compare_results , generate_hf_output ,
810 generate_spyre_vllm_output , get_spyre_backend_list ,
911 get_spyre_model_list )
1012from vllm import SamplingParams
1113
1214
15+ # temporary for filtering until bug with caching gets fixed
16+ @pytest .mark .skipif (
17+ os .environ .get ("TORCH_SENDNN_CACHE_ENABLE" ) == "1" ,
18+ reason = "torch_sendnn caching is currently broken with this configuration" )
1319@pytest .mark .parametrize ("model" , get_spyre_model_list ())
1420@pytest .mark .parametrize ("prompts" , [
1521 7 * [
@@ -34,7 +40,6 @@ def test_output(
3440 warmup_shapes : list [tuple [int , int , int ]],
3541 backend : str ,
3642 vllm_version : str ,
37- monkeypatch : pytest .MonkeyPatch ,
3843) -> None :
3944 '''
4045 The warmup is based on two shapes, that 'overlap' each
@@ -55,8 +60,6 @@ def test_output(
5560 test using 'pytest --capture=no tests/spyre/test_spyre_warmup_shapes.py'
5661 After debugging, DISABLE_ASSERTS should be reset to 'False'.
5762 '''
58- # temporary until bug with caching gets fixed
59- monkeypatch .setenv ("TORCH_SENDNN_CACHE_ENABLE" , "0" )
6063
6164 max_new_tokens = max ([t [1 ] for t in warmup_shapes ])
6265
0 commit comments