Skip to content

Commit e328124

Browse files
authored
Add ability to include a plugin when creating driver (#740)
`vmfbRunner` now takes an optional `extra_plugin` argument to load an executable plugin while the driver is getting created. This option might be used, for example, when loading a vmfb that has an external dependency on a native shared library. The implementation of this new feature takes advantage of the pre-existing `iree.runtime.flags` feature and a new IREE python API function. Normally, drivers are managed in a cache. However, setting a flag to specify the plugin has no effect on existing drivers. The API now has a function for creating a driver independent of the cache, to guarantee that any flags are sure to take effect. This PR also includes a fix for the problem of the CI using old cached wheels for iree, as recommended by @monorimet. --------- Signed-off-by: Dave Liddell <[email protected]> Signed-off-by: daveliddell <[email protected]>
1 parent 815c857 commit e328124

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

.github/workflows/test_models.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ jobs:
5252
pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt
5353
pip install --no-compile --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt
5454
pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing]
55+
pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html
5556
pip install --no-compile --pre --upgrade -e models -r models/requirements.txt
5657
5758
- name: Show current free memory

models/turbine_models/model_runner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import argparse
22
import sys
33
from iree import runtime as ireert
4+
from iree.runtime._binding import create_hal_driver
45

56

67
class vmfbRunner:
7-
def __init__(self, device, vmfb_path, external_weight_path=None):
8+
def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None):
89
flags = []
9-
haldriver = ireert.get_driver(device)
10+
11+
# If an extra plugin is requested, add a global flag to load the plugin
12+
# and create the driver using the non-caching creation function, as
13+
# the caching creation function may ignore the flag.
14+
if extra_plugin:
15+
ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}")
16+
haldriver = create_hal_driver(device)
17+
18+
# No plugin requested: create the driver with the caching create
19+
# function.
20+
else:
21+
haldriver = ireert.get_driver(device)
1022
if "://" in device:
1123
try:
1224
device_idx = int(device.split("://")[-1])

0 commit comments

Comments
 (0)