34
34
from jax ._src import profiler
35
35
from jax ._src import traceback_util
36
36
from jax ._src .interpreters import mlir
37
+ from jax ._src .lib import jaxlib_extension_version
37
38
from jax ._src .lib import xla_client as xc
39
+ from jax ._src .lib import _jax
38
40
from jax ._src .lib .mlir import ir
39
41
import numpy as np
40
42
@@ -291,6 +293,19 @@ def backend_compile(
291
293
executable_devices : xc .DeviceList ,
292
294
options : xc .CompileOptions ,
293
295
host_callbacks : Sequence [Any ],
296
+ ) -> xc .LoadedExecutable :
297
+ return backend_compile_and_load (
298
+ backend , module , executable_devices , options , host_callbacks
299
+ )
300
+
301
+
302
+ @profiler .annotate_function
303
+ def backend_compile_and_load (
304
+ backend : xc .Client ,
305
+ module : ir .Module ,
306
+ executable_devices : xc .DeviceList ,
307
+ options : xc .CompileOptions ,
308
+ host_callbacks : Sequence [Any ],
294
309
) -> xc .LoadedExecutable :
295
310
sym_name = module .operation .attributes ['sym_name' ]
296
311
module_name = ir .StringAttr (sym_name ).value
@@ -315,18 +330,40 @@ def backend_compile(
315
330
try :
316
331
# we use a separate function call to ensure that XLA compilation appears
317
332
# separately in Python profiling results
318
- if host_callbacks :
333
+ # TODO(dsuo): Simplify this logic once backend_compile actually returns an
334
+ # unloaded executable.
335
+ if jaxlib_extension_version < 343 or (
336
+ jaxlib_extension_version >= 343
337
+ and isinstance (backend , _jax .CompileOnlyPyClient )
338
+ ):
339
+ if host_callbacks :
340
+ return backend .compile (
341
+ built_c ,
342
+ executable_devices = executable_devices , # type: ignore
343
+ compile_options = options ,
344
+ host_callbacks = host_callbacks ,
345
+ )
346
+ # Some backends don't have `host_callbacks` option yet
347
+ # TODO(sharadmv): remove this fallback when all backends allow `compile`
348
+ # to take in `host_callbacks`
319
349
return backend .compile (
350
+ built_c , executable_devices = executable_devices , compile_options = options ) # type: ignore
351
+ else :
352
+ if host_callbacks :
353
+ return backend .compile_and_load (
354
+ built_c ,
355
+ executable_devices = executable_devices ,
356
+ compile_options = options ,
357
+ host_callbacks = host_callbacks ,
358
+ )
359
+ # Some backends don't have `host_callbacks` option yet
360
+ # TODO(sharadmv): remove this fallback when all backends allow `compile`
361
+ # to take in `host_callbacks`
362
+ return backend .compile_and_load (
320
363
built_c ,
321
- executable_devices = executable_devices , # type: ignore
364
+ executable_devices = executable_devices ,
322
365
compile_options = options ,
323
- host_callbacks = host_callbacks ,
324
366
)
325
- # Some backends don't have `host_callbacks` option yet
326
- # TODO(sharadmv): remove this fallback when all backends allow `compile`
327
- # to take in `host_callbacks`
328
- return backend .compile (
329
- built_c , executable_devices = executable_devices , compile_options = options ) # type: ignore
330
367
except xc .XlaRuntimeError as e :
331
368
for error_handler in _XLA_RUNTIME_ERROR_HANDLERS :
332
369
handler_result = error_handler (e )
@@ -391,7 +428,7 @@ def compile_or_get_cached(
391
428
)
392
429
393
430
if cache_key is None :
394
- return backend_compile (
431
+ return backend_compile_and_load (
395
432
backend , computation , executable_devices , compile_options ,
396
433
host_callbacks )
397
434
@@ -419,7 +456,7 @@ def compile_or_get_cached(
419
456
config .share_binary_between_hosts .value
420
457
and is_multi_process
421
458
and distributed .global_state .client is not None
422
- # Host callbacks are currently baked into the HLO module so we cant share
459
+ # Host callbacks are currently baked into the HLO module so we can't share
423
460
# them.
424
461
and len (host_callbacks ) == 0
425
462
):
@@ -705,7 +742,7 @@ def _compile_and_write_cache(
705
742
cache_key : str ,
706
743
) -> xc .LoadedExecutable :
707
744
start_time = time .monotonic ()
708
- executable = backend_compile (
745
+ executable = backend_compile_and_load (
709
746
backend , computation , executable_devices , compile_options , host_callbacks
710
747
)
711
748
compile_time = time .monotonic () - start_time
0 commit comments