@@ -26,7 +26,7 @@ def forward(self, x):
26
26
return self .classifier (x )
27
27
28
28
29
- class ArgsTest (unittest .TestCase ):
29
+ class GlobalsTest (unittest .TestCase ):
30
30
def testGlobalParameters (self ):
31
31
m = SimpleParams ()
32
32
@@ -63,10 +63,6 @@ def read_params(self):
63
63
"%_params.classifier.bias = util.global.load @_params.classifier.bias" ,
64
64
module_str ,
65
65
)
66
- self .assertIn (
67
- "return %_params.classifier.weight, %_params.classifier.bias" ,
68
- module_str ,
69
- )
70
66
71
67
def testGlobalLoadFromPyLeaf (self ):
72
68
m = SimpleParams ()
@@ -84,7 +80,6 @@ def read_weight(self):
84
80
"%_params.classifier.weight = util.global.load @_params.classifier.weight" ,
85
81
module_str ,
86
82
)
87
- self .assertIn ("return %_params.classifier.weight" , module_str )
88
83
89
84
def testGlobalStoreFromPyTree (self ):
90
85
m = SimpleParams ()
@@ -100,8 +95,10 @@ def update_params(me, updates=abstractify(params)):
100
95
inst = GlobalModule (context = Context ())
101
96
module_str = str (CompiledModule .get_mlir_module (inst ))
102
97
print (module_str )
103
- self .assertIn ("util.global.store %arg0, @_params.classifier.weight" , module_str )
104
- self .assertIn ("util.global.store %arg1, @_params.classifier.bias" , module_str )
98
+ self .assertRegex (
99
+ module_str , "util.global.store %.*, @_params.classifier.weight"
100
+ )
101
+ self .assertRegex (module_str , "util.global.store %.*, @_params.classifier.bias" )
105
102
106
103
def testGlobalStoreFromLeaf (self ):
107
104
m = SimpleParams ()
@@ -115,7 +112,7 @@ def update_bias(self, new_bias=abstractify(params["classifier.bias"])):
115
112
inst = GlobalModule (context = Context ())
116
113
module_str = str (CompiledModule .get_mlir_module (inst ))
117
114
print (module_str )
118
- self .assertIn ( "util.global.store %arg0 , @_params.classifier.bias" , module_str )
115
+ self .assertRegex ( module_str , "util.global.store %.* , @_params.classifier.bias" )
119
116
120
117
def testExportSingleGlobalTensor (self ):
121
118
state_example = torch .randn (3 , 11 )
@@ -131,7 +128,6 @@ def read_state(self):
131
128
print (module_str )
132
129
self .assertIn ("util.global private @_state0.global" , module_str )
133
130
self .assertIn ("%_state0.global = util.global.load @_state0.global" , module_str )
134
- self .assertIn ("return %_state0.global" , module_str )
135
131
136
132
def testExportTreeGlobalTensors (self ):
137
133
state_example = {
@@ -160,10 +156,6 @@ def read_state(self):
160
156
self .assertIn ("%_state0.seq.0 = util.global.load @_state0.seq.0" , module_str )
161
157
self .assertIn ("%_state0.seq.1 = util.global.load @_state0.seq.1" , module_str )
162
158
self .assertIn ("%_state0.seq.2 = util.global.load @_state0.seq.2" , module_str )
163
- self .assertIn (
164
- "return %_state0.data, %_state0.seq.0, %_state0.seq.1, %_state0.seq.2" ,
165
- module_str ,
166
- )
167
159
168
160
def testExportGlobalScalars (self ):
169
161
class ScalarState (CompiledModule ):
@@ -210,9 +202,6 @@ class DerivedState(BaseState):
210
202
print (module_str )
211
203
self .assertIn ("@_state_index.global {noinline} = 0 : index" , module_str )
212
204
self .assertIn ("@_state_f32.global {noinline} = 0.000000e+00 : f32" , module_str )
213
- self .assertIn (
214
- "return %_state_index.global, %_state_f32.global : index, f32" , module_str
215
- )
216
205
217
206
def testInheritOverrideBase (self ):
218
207
class BaseState (CompiledModule ):
@@ -252,8 +241,10 @@ class DerivedModule(BaseModule):
252
241
inst = DerivedModule (context = Context ())
253
242
module_str = str (CompiledModule .get_mlir_module (inst ))
254
243
print (module_str )
255
- self .assertIn ("util.global.store %arg0, @_params.classifier.weight" , module_str )
256
- self .assertIn ("util.global.store %arg1, @_params.classifier.bias" , module_str )
244
+ self .assertRegex (
245
+ module_str , "util.global.store %.*, @_params.classifier.weight"
246
+ )
247
+ self .assertRegex (module_str , "util.global.store %.*, @_params.classifier.bias" )
257
248
258
249
def testUpdateGlobalStateTree (self ):
259
250
state_example = {
@@ -287,10 +278,10 @@ def read_state(self, updates=abstractify(state_example)):
287
278
module_str ,
288
279
)
289
280
self .assertIn ("util.global private mutable @_state0.data" , module_str )
290
- self .assertIn ( "util.global.store %arg0 , @_state0.data" , module_str )
291
- self .assertIn ( "util.global.store %arg1 , @_state0.seq.0" , module_str )
292
- self .assertIn ( "util.global.store %arg2 , @_state0.seq.1" , module_str )
293
- self .assertIn ( "util.global.store %arg3 , @_state0.seq.2" , module_str )
281
+ self .assertRegex ( module_str , "util.global.store %.* , @_state0.data" )
282
+ self .assertRegex ( module_str , "util.global.store %.* , @_state0.seq.0" )
283
+ self .assertRegex ( module_str , "util.global.store %.* , @_state0.seq.1" )
284
+ self .assertRegex ( module_str , "util.global.store %.* , @_state0.seq.2" )
294
285
295
286
def testTensorUpdateGlobal (self ):
296
287
state_example = torch .randn (5 , 20 )
@@ -305,9 +296,9 @@ def tensor_update_state(self, update=abstractify(update_example)):
305
296
inst = UpdateState (context = Context ())
306
297
module_str = str (CompiledModule .get_mlir_module (inst ))
307
298
print (module_str )
308
- self .assertIn (
309
- "flow.tensor.update %arg0, %_state0.global[%c0, %c0] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>" ,
299
+ self .assertRegex (
310
300
module_str ,
301
+ "flow.tensor.update %.*, %_state0.global\\ [%c0, %c0\\ ] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>" ,
311
302
)
312
303
313
304
def testTensorUpdateGlobalReturnNone (self ):
@@ -325,10 +316,7 @@ def tensor_update_state(self, update=abstractify(update_example)):
325
316
inst = UpdateState (context = Context ())
326
317
module_str = str (CompiledModule .get_mlir_module (inst ))
327
318
print (module_str )
328
- self .assertIn (
329
- "flow.tensor.update %arg0, %_state0.global[%c4, %c0, %c0] : tensor<1x1x4xf32> -> %_state0.global as tensor<5x20x4xf32>" ,
330
- module_str ,
331
- )
319
+ self .assertIn ("flow.tensor.update" , module_str )
332
320
333
321
def testExternalGlobalParametersDefaults (self ):
334
322
m = SimpleParams ()
0 commit comments