1212from torch .testing import FileCheck
1313from torch .testing ._internal .common_utils import (
1414 instantiate_parametrized_tests ,
15- < << << << HEAD
16- patch_test_members ,
17- NAVI3_ARCH ,
18- is_arch ,
19- == == == =
2015 is_navi3_arch ,
21- > >> >> >> upstream / main
2216 parametrize ,
2317 patch_test_members ,
2418 TEST_XPU ,
@@ -79,11 +73,7 @@ def forward(
7973)
8074@instantiate_parametrized_tests
8175class TestDecomposeMemMM (TestCase ):
82- < << << << HEAD
83- def __init__ (self , method_name = 'runTest' , methodName = 'runTest' ):
84- == == == =
8576 def __init__ (self , method_name = "runTest" , methodName = "runTest" ):
86- >> >> >> > upstream / main
8777 super ().__init__ (method_name , methodName )
8878 self .atol = 1e-3
8979 self .rtol = 1e-3
@@ -92,13 +82,9 @@ def setup_tolerance(self, rtol=None, atol=None):
9282 if rtol is None :
9383 rtol = self .rtol
9484 if atol is None :
95- < << << << HEAD
96- atol = self .rtol
97- == == == =
9885 atol = self .atol
9986 self .rtol = rtol
10087 self .atol = atol
101- >> >> >> > upstream / main
10288
10389 def compare_dict_tensors (self , ref_dict , res_dict , rtol = None , atol = None ):
10490 self .setup_tolerance (rtol , atol )
@@ -107,13 +93,9 @@ def compare_dict_tensors(self, ref_dict, res_dict, rtol=None, atol=None):
10793 for key1 in ref_dict .keys ():
10894 key2 = "_orig_mod." + key1
10995 assert key2 in res_dict , f"{ key1 } does not exist in traced module"
110- < << << << HEAD
111- if not torch .allclose (ref_dict [key1 ], res_dict [key2 ], rtol = self .rtol , atol = self .atol ):
112- == == == =
11396 if not torch .allclose (
11497 ref_dict [key1 ], res_dict [key2 ], rtol = self .rtol , atol = self .atol
11598 ):
116- > >> >> >> upstream / main
11799 return False
118100 return True
119101
@@ -127,28 +109,20 @@ def compare_parameters(self, module, traced, rtol=None, atol=None):
127109 self .setup_tolerance (rtol , atol )
128110 ref_params = dict (module .named_parameters ())
129111 res_params = dict (traced .named_parameters ())
130- < << << << HEAD
131- self .assertTrue (self .compare_dict_tensors (ref_params , res_params , rtol = self .rtol , atol = self .atol ))
132- == == == =
133112 self .assertTrue (
134113 self .compare_dict_tensors (
135114 ref_params , res_params , rtol = self .rtol , atol = self .atol
136115 )
137116 )
138- >> >> >> > upstream / main
139117
140118 def compare_gradients (self , module , traced , rtol = None , atol = None ):
141119 self .setup_tolerance (rtol , atol )
142120 ref_grad = {key : param .grad for key , param in module .named_parameters ()}
143121 res_grad = {key : param .grad for key , param in traced .named_parameters ()}
144122 self .assertTrue (
145- << << << < HEAD
146- self .compare_dict_tensors (ref_grad , res_grad , rtol = self .rtol , atol = self .atol )
147- == == == =
148123 self .compare_dict_tensors (
149124 ref_grad , res_grad , rtol = self .rtol , atol = self .atol
150125 )
151- >> >> >> > upstream / main
152126 )
153127
154128 @parametrize (
@@ -257,19 +231,12 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose):
257231
258232 # We have to increase tolerance for navi3 because all fp16, bf16
259233 # GEMMs operations have an accuracy issue caused by hardware limitation
260- << << << < HEAD
261- @patch_test_members ({
262- "atol" : 2e-3 if is_arch (NAVI3_ARCH ) else 1e-3 ,
263- "rtol" : 2e-3 if is_arch (NAVI3_ARCH ) else 1e-3
264- })
265- == == == =
266234 @patch_test_members (
267235 {
268236 "atol" : 2e-3 if is_navi3_arch () else 1e-3 ,
269237 "rtol" : 2e-3 if is_navi3_arch () else 1e-3 ,
270238 }
271239 )
272- >> >> >> > upstream / main
273240 @parametrize (
274241 "m,k,n, should_decompose" ,
275242 [(20480 , 5 , 2 , True ), (20480 , 32 , 2 , False ), (2048 , 2 , 2 , False )],
@@ -380,19 +347,12 @@ def test_decompose_mm_cpu(self, m, n, k, should_decompose):
380347
381348 # We have to increase tolerance for navi3 because all fp16, bf16
382349 # GEMMs operations have an accuracy issue caused by hardware limitation
383- << << << < HEAD
384- @patch_test_members ({
385- "atol" : 3e-3 if is_arch (NAVI3_ARCH ) else 1e-3 ,
386- "rtol" : 4e-3 if is_arch (NAVI3_ARCH ) else 1e-3
387- })
388- == == == =
389350 @patch_test_members (
390351 {
391352 "atol" : 3e-3 if is_navi3_arch () else 1e-3 ,
392353 "rtol" : 4e-3 if is_navi3_arch () else 1e-3 ,
393354 }
394355 )
395- >> >> >> > upstream / main
396356 @parametrize (
397357 "m,k,n, should_decompose" ,
398358 [(20480 , 5 , 2 , True ), (20480 , 32 , 2 , False ), (2048 , 2 , 2 , False )],
0 commit comments