33Run `python -m pytest tests/test_spyre_warmup_shapes.py`.
44"""
55
6- import os
7-
86import pytest
97from spyre_util import (compare_results , generate_hf_output ,
108 generate_spyre_vllm_output , get_spyre_backend_list ,
119 get_spyre_model_list )
1210from vllm import SamplingParams
1311
1412
15- # temporary for filtering until bug with caching gets fixed
16- @pytest .mark .skipif (
17- os .environ .get ("TORCH_SENDNN_CACHE_ENABLE" ) == "1" ,
18- reason = "torch_sendnn caching is currently broken with this configuration" )
1913@pytest .mark .parametrize ("model" , get_spyre_model_list ())
2014@pytest .mark .parametrize ("prompts" , [
2115 7 * [
@@ -40,6 +34,7 @@ def test_output(
4034 warmup_shapes : list [tuple [int , int , int ]],
4135 backend : str ,
4236 vllm_version : str ,
37+ monkeypatch : pytest .MonkeyPatch ,
4338) -> None :
4439 '''
4540 The warmup is based on two shapes, that 'overlap' each
@@ -60,6 +55,8 @@ def test_output(
6055 test using 'pytest --capture=no tests/spyre/test_spyre_warmup_shapes.py'
6156 After debugging, DISABLE_ASSERTS should be reset to 'False'.
6257 '''
58+ # temporary until bug with caching gets fixed
59+ monkeypatch .setenv ("TORCH_SENDNN_CACHE_ENABLE" , "0" )
6360
6461 max_new_tokens = max ([t [1 ] for t in warmup_shapes ])
6562
0 commit comments