Skip to content

Commit 971231c

Browse files
Bump IREE to 20240226.813 and adapt to API breaks. (#482)
This release pulls in iree-org/iree#16486 which makes substantial changes to torch imports: * Always generates async code (with a special `$async` suffixed entrypoint that the default entrypoint delegates to). * Internal structure of the generated code is different, invalidating some tests.
1 parent 18262d4 commit 971231c

File tree

6 files changed

+30
-63
lines changed

6 files changed

+30
-63
lines changed

core/iree-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
iree-compiler==20240215.802
2-
iree-runtime==20240215.802
1+
iree-compiler==20240226.813
2+
iree-runtime==20240226.813

core/shark_turbine/dynamo/backends/cpu.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,7 @@
3939
from torch._dynamo.backends.common import aot_autograd
4040
from ..passes import turbine_cpu_pass_pipeline
4141

42-
DEFAULT_COMPILER_FLAGS = (
43-
# Enable asynchronous calling convention.
44-
# TODO: Enable async execution mode.
45-
# "--iree-execution-model=async-external",
46-
"--iree-input-type=tm_tensor",
47-
)
42+
DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",)
4843

4944

5045
def _base_backend(gm: torch.fx.GraphModule, example_inputs):

core/shark_turbine/dynamo/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
self,
154154
user_module: VmModule,
155155
device_state: DeviceState,
156-
entry_name: str = "main",
156+
entry_name: str = "main$async",
157157
):
158158
self.user_module = user_module
159159
self.vm_context = VmContext(

core/shark_turbine/dynamo/tensor.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,7 @@
5353

5454
from ..importers.fx_importer import FxImporter
5555

56-
DEFAULT_COMPILER_FLAGS = (
57-
# Enable asynchronous calling convention.
58-
"--iree-execution-model=async-external",
59-
"--iree-input-type=torch",
60-
)
56+
DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",)
6157

6258
###############################################################################
6359
# Factories and device enablement

core/tests/aot/args_test.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,24 @@ def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)):
2424
module_str = str(CompiledModule.get_mlir_module(inst))
2525
print(module_str)
2626
self.assertIn(
27-
"func.func @foobar(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> (tensor<1x1xf32>, tensor<3x2xf32>)",
27+
"util.func public @foobar$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.fence, %arg3: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)",
2828
module_str,
2929
)
30-
self.assertIn("return %arg1, %arg0", module_str)
3130

3231
def testProcToJitArgs(self):
33-
class ProcArgsModule(CompiledModule):
32+
class testProcToJitArgs(CompiledModule):
3433
def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)):
3534
return self.compute(a, b)
3635

3736
@jittable
3837
def compute(a, b):
3938
return a + b
4039

41-
inst = ProcArgsModule(context=Context())
40+
inst = testProcToJitArgs(context=Context())
4241
module_str = str(CompiledModule.get_mlir_module(inst))
4342
print(module_str)
4443
self.assertIn(
45-
"func.func @foobar(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<3x2xf32>",
46-
module_str,
47-
)
48-
self.assertIn(
49-
"func.func private @compute(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<3x2xf32>",
50-
module_str,
51-
)
52-
self.assertIn(
53-
"%0 = call @compute(%arg0, %arg1)",
44+
"linalg.generic",
5445
module_str,
5546
)
5647

@@ -68,13 +59,10 @@ def compute(a, b):
6859
inst = ProcArgsModule(context=Context())
6960
module_str = str(CompiledModule.get_mlir_module(inst))
7061
print(module_str)
71-
self.assertIn(
72-
"%0 = call @compute(%arg0, %arg1)",
73-
module_str,
74-
)
75-
self.assertIn(
76-
"%1 = call @compute$1(%0, %arg0)",
77-
module_str,
62+
self.assertEqual(
63+
2,
64+
module_str.count("linalg.generic"),
65+
msg=f"Did not find two linalg.generics in module: module_str",
7866
)
7967

8068

core/tests/aot/globals_test.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def forward(self, x):
2626
return self.classifier(x)
2727

2828

29-
class ArgsTest(unittest.TestCase):
29+
class GlobalsTest(unittest.TestCase):
3030
def testGlobalParameters(self):
3131
m = SimpleParams()
3232

@@ -63,10 +63,6 @@ def read_params(self):
6363
"%_params.classifier.bias = util.global.load @_params.classifier.bias",
6464
module_str,
6565
)
66-
self.assertIn(
67-
"return %_params.classifier.weight, %_params.classifier.bias",
68-
module_str,
69-
)
7066

7167
def testGlobalLoadFromPyLeaf(self):
7268
m = SimpleParams()
@@ -84,7 +80,6 @@ def read_weight(self):
8480
"%_params.classifier.weight = util.global.load @_params.classifier.weight",
8581
module_str,
8682
)
87-
self.assertIn("return %_params.classifier.weight", module_str)
8883

8984
def testGlobalStoreFromPyTree(self):
9085
m = SimpleParams()
@@ -100,8 +95,10 @@ def update_params(me, updates=abstractify(params)):
10095
inst = GlobalModule(context=Context())
10196
module_str = str(CompiledModule.get_mlir_module(inst))
10297
print(module_str)
103-
self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str)
104-
self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str)
98+
self.assertRegex(
99+
module_str, "util.global.store %.*, @_params.classifier.weight"
100+
)
101+
self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias")
105102

106103
def testGlobalStoreFromLeaf(self):
107104
m = SimpleParams()
@@ -115,7 +112,7 @@ def update_bias(self, new_bias=abstractify(params["classifier.bias"])):
115112
inst = GlobalModule(context=Context())
116113
module_str = str(CompiledModule.get_mlir_module(inst))
117114
print(module_str)
118-
self.assertIn("util.global.store %arg0, @_params.classifier.bias", module_str)
115+
self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias")
119116

120117
def testExportSingleGlobalTensor(self):
121118
state_example = torch.randn(3, 11)
@@ -131,7 +128,6 @@ def read_state(self):
131128
print(module_str)
132129
self.assertIn("util.global private @_state0.global", module_str)
133130
self.assertIn("%_state0.global = util.global.load @_state0.global", module_str)
134-
self.assertIn("return %_state0.global", module_str)
135131

136132
def testExportTreeGlobalTensors(self):
137133
state_example = {
@@ -160,10 +156,6 @@ def read_state(self):
160156
self.assertIn("%_state0.seq.0 = util.global.load @_state0.seq.0", module_str)
161157
self.assertIn("%_state0.seq.1 = util.global.load @_state0.seq.1", module_str)
162158
self.assertIn("%_state0.seq.2 = util.global.load @_state0.seq.2", module_str)
163-
self.assertIn(
164-
"return %_state0.data, %_state0.seq.0, %_state0.seq.1, %_state0.seq.2",
165-
module_str,
166-
)
167159

168160
def testExportGlobalScalars(self):
169161
class ScalarState(CompiledModule):
@@ -210,9 +202,6 @@ class DerivedState(BaseState):
210202
print(module_str)
211203
self.assertIn("@_state_index.global {noinline} = 0 : index", module_str)
212204
self.assertIn("@_state_f32.global {noinline} = 0.000000e+00 : f32", module_str)
213-
self.assertIn(
214-
"return %_state_index.global, %_state_f32.global : index, f32", module_str
215-
)
216205

217206
def testInheritOverrideBase(self):
218207
class BaseState(CompiledModule):
@@ -252,8 +241,10 @@ class DerivedModule(BaseModule):
252241
inst = DerivedModule(context=Context())
253242
module_str = str(CompiledModule.get_mlir_module(inst))
254243
print(module_str)
255-
self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str)
256-
self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str)
244+
self.assertRegex(
245+
module_str, "util.global.store %.*, @_params.classifier.weight"
246+
)
247+
self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias")
257248

258249
def testUpdateGlobalStateTree(self):
259250
state_example = {
@@ -287,10 +278,10 @@ def read_state(self, updates=abstractify(state_example)):
287278
module_str,
288279
)
289280
self.assertIn("util.global private mutable @_state0.data", module_str)
290-
self.assertIn("util.global.store %arg0, @_state0.data", module_str)
291-
self.assertIn("util.global.store %arg1, @_state0.seq.0", module_str)
292-
self.assertIn("util.global.store %arg2, @_state0.seq.1", module_str)
293-
self.assertIn("util.global.store %arg3, @_state0.seq.2", module_str)
281+
self.assertRegex(module_str, "util.global.store %.*, @_state0.data")
282+
self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.0")
283+
self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.1")
284+
self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.2")
294285

295286
def testTensorUpdateGlobal(self):
296287
state_example = torch.randn(5, 20)
@@ -305,9 +296,9 @@ def tensor_update_state(self, update=abstractify(update_example)):
305296
inst = UpdateState(context=Context())
306297
module_str = str(CompiledModule.get_mlir_module(inst))
307298
print(module_str)
308-
self.assertIn(
309-
"flow.tensor.update %arg0, %_state0.global[%c0, %c0] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>",
299+
self.assertRegex(
310300
module_str,
301+
"flow.tensor.update %.*, %_state0.global\\[%c0, %c0\\] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>",
311302
)
312303

313304
def testTensorUpdateGlobalReturnNone(self):
@@ -325,10 +316,7 @@ def tensor_update_state(self, update=abstractify(update_example)):
325316
inst = UpdateState(context=Context())
326317
module_str = str(CompiledModule.get_mlir_module(inst))
327318
print(module_str)
328-
self.assertIn(
329-
"flow.tensor.update %arg0, %_state0.global[%c4, %c0, %c0] : tensor<1x1x4xf32> -> %_state0.global as tensor<5x20x4xf32>",
330-
module_str,
331-
)
319+
self.assertIn("flow.tensor.update", module_str)
332320

333321
def testExternalGlobalParametersDefaults(self):
334322
m = SimpleParams()

0 commit comments

Comments
 (0)