@@ -291,6 +291,19 @@ def backend_compile(
291
291
executable_devices : xc .DeviceList ,
292
292
options : xc .CompileOptions ,
293
293
host_callbacks : Sequence [Any ],
294
+ ) -> xc .LoadedExecutable :
295
+ return backend_compile_and_load (
296
+ backend , module , executable_devices , options , host_callbacks
297
+ )
298
+
299
+
300
+ @profiler .annotate_function
301
+ def backend_compile_and_load (
302
+ backend : xc .Client ,
303
+ module : ir .Module ,
304
+ executable_devices : xc .DeviceList ,
305
+ options : xc .CompileOptions ,
306
+ host_callbacks : Sequence [Any ],
294
307
) -> xc .LoadedExecutable :
295
308
sym_name = module .operation .attributes ['sym_name' ]
296
309
module_name = ir .StringAttr (sym_name ).value
@@ -315,18 +328,35 @@ def backend_compile(
315
328
try :
316
329
# we use a separate function call to ensure that XLA compilation appears
317
330
# separately in Python profiling results
318
- if host_callbacks :
331
+ elif jaxlib_extension_version < 342 or isinstance (backend , xc .CompileOnlyPyClient ):
332
+ if host_callbacks :
333
+ return backend .compile (
334
+ built_c ,
335
+ executable_devices = executable_devices , # type: ignore
336
+ compile_options = options ,
337
+ host_callbacks = host_callbacks ,
338
+ )
339
+ # Some backends don't have `host_callbacks` option yet
340
+ # TODO(sharadmv): remove this fallback when all backends allow `compile`
341
+ # to take in `host_callbacks`
319
342
return backend .compile (
343
+ built_c , executable_devices = executable_devices , compile_options = options ) # type: ignore
344
+ else :
345
+ if host_callbacks :
346
+ return backend .compile_and_load (
347
+ built_c ,
348
+ executable_devices = executable_devices ,
349
+ compile_options = options ,
350
+ host_callbacks = host_callbacks ,
351
+ )
352
+ # Some backends don't have `host_callbacks` option yet
353
+ # TODO(sharadmv): remove this fallback when all backends allow `compile`
354
+ # to take in `host_callbacks`
355
+ return backend .compile_and_load (
320
356
built_c ,
321
- executable_devices = executable_devices , # type: ignore
357
+ executable_devices = executable_devices ,
322
358
compile_options = options ,
323
- host_callbacks = host_callbacks ,
324
359
)
325
- # Some backends don't have `host_callbacks` option yet
326
- # TODO(sharadmv): remove this fallback when all backends allow `compile`
327
- # to take in `host_callbacks`
328
- return backend .compile (
329
- built_c , executable_devices = executable_devices , compile_options = options ) # type: ignore
330
360
except xc .XlaRuntimeError as e :
331
361
for error_handler in _XLA_RUNTIME_ERROR_HANDLERS :
332
362
handler_result = error_handler (e )
@@ -391,7 +421,7 @@ def compile_or_get_cached(
391
421
)
392
422
393
423
if cache_key is None :
394
- return backend_compile (
424
+ return backend_compile_and_load (
395
425
backend , computation , executable_devices , compile_options ,
396
426
host_callbacks )
397
427
@@ -419,7 +449,7 @@ def compile_or_get_cached(
419
449
config .share_binary_between_hosts .value
420
450
and is_multi_process
421
451
and distributed .global_state .client is not None
422
- # Host callbacks are currently baked into the HLO module so we cant share
452
+ # Host callbacks are currently baked into the HLO module so we can't share
423
453
# them.
424
454
and len (host_callbacks ) == 0
425
455
):
@@ -705,7 +735,7 @@ def _compile_and_write_cache(
705
735
cache_key : str ,
706
736
) -> xc .LoadedExecutable :
707
737
start_time = time .monotonic ()
708
- executable = backend_compile (
738
+ executable = backend_compile_and_load (
709
739
backend , computation , executable_devices , compile_options , host_callbacks
710
740
)
711
741
compile_time = time .monotonic () - start_time
0 commit comments