Skip to content

Commit cee3bab

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
Reverts 12e07c9
PiperOrigin-RevId: 761946361
1 parent 7014bde commit cee3bab

File tree

4 files changed

+54
-11
lines changed

4 files changed

+54
-11
lines changed

jax/_src/compiler.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,19 @@ def backend_compile(
291291
executable_devices: xc.DeviceList,
292292
options: xc.CompileOptions,
293293
host_callbacks: Sequence[Any],
294+
) -> xc.LoadedExecutable:
295+
return backend_compile_and_load(
296+
backend, module, executable_devices, options, host_callbacks
297+
)
298+
299+
300+
@profiler.annotate_function
301+
def backend_compile_and_load(
302+
backend: xc.Client,
303+
module: ir.Module,
304+
executable_devices: xc.DeviceList,
305+
options: xc.CompileOptions,
306+
host_callbacks: Sequence[Any],
294307
) -> xc.LoadedExecutable:
295308
sym_name = module.operation.attributes['sym_name']
296309
module_name = ir.StringAttr(sym_name).value
@@ -315,18 +328,35 @@ def backend_compile(
315328
try:
316329
# we use a separate function call to ensure that XLA compilation appears
317330
# separately in Python profiling results
318-
if host_callbacks:
331+
elif jaxlib_extension_version < 342 or isinstance(backend, xc.CompileOnlyPyClient):
332+
if host_callbacks:
333+
return backend.compile(
334+
built_c,
335+
executable_devices=executable_devices, # type: ignore
336+
compile_options=options,
337+
host_callbacks=host_callbacks,
338+
)
339+
# Some backends don't have `host_callbacks` option yet
340+
# TODO(sharadmv): remove this fallback when all backends allow `compile`
341+
# to take in `host_callbacks`
319342
return backend.compile(
343+
built_c, executable_devices=executable_devices, compile_options=options) # type: ignore
344+
else:
345+
if host_callbacks:
346+
return backend.compile_and_load(
347+
built_c,
348+
executable_devices=executable_devices,
349+
compile_options=options,
350+
host_callbacks=host_callbacks,
351+
)
352+
# Some backends don't have `host_callbacks` option yet
353+
# TODO(sharadmv): remove this fallback when all backends allow `compile`
354+
# to take in `host_callbacks`
355+
return backend.compile_and_load(
320356
built_c,
321-
executable_devices=executable_devices, # type: ignore
357+
executable_devices=executable_devices,
322358
compile_options=options,
323-
host_callbacks=host_callbacks,
324359
)
325-
# Some backends don't have `host_callbacks` option yet
326-
# TODO(sharadmv): remove this fallback when all backends allow `compile`
327-
# to take in `host_callbacks`
328-
return backend.compile(
329-
built_c, executable_devices=executable_devices, compile_options=options) # type: ignore
330360
except xc.XlaRuntimeError as e:
331361
for error_handler in _XLA_RUNTIME_ERROR_HANDLERS:
332362
handler_result = error_handler(e)
@@ -391,7 +421,7 @@ def compile_or_get_cached(
391421
)
392422

393423
if cache_key is None:
394-
return backend_compile(
424+
return backend_compile_and_load(
395425
backend, computation, executable_devices, compile_options,
396426
host_callbacks)
397427

@@ -419,7 +449,7 @@ def compile_or_get_cached(
419449
config.share_binary_between_hosts.value
420450
and is_multi_process
421451
and distributed.global_state.client is not None
422-
# Host callbacks are currently baked into the HLO module so we cant share
452+
# Host callbacks are currently baked into the HLO module so we can't share
423453
# them.
424454
and len(host_callbacks) == 0
425455
):
@@ -705,7 +735,7 @@ def _compile_and_write_cache(
705735
cache_key: str,
706736
) -> xc.LoadedExecutable:
707737
start_time = time.monotonic()
708-
executable = backend_compile(
738+
executable = backend_compile_and_load(
709739
backend, computation, executable_devices, compile_options, host_callbacks
710740
)
711741
compile_time = time.monotonic() - start_time

jaxlib/_jax/__init__.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,17 @@ class Client:
551551
) -> PjRtLayout: ...
552552
def __getattr__(self, name: str) -> Any: ...
553553

554+
555+
class CompileOnlyPyClient(Client):
556+
def compile(
557+
self,
558+
computation: str | bytes,
559+
executable_devices: DeviceList | Sequence[Device],
560+
compile_options: CompileOptions = ...,
561+
host_callbacks: Sequence[Any] = ...,
562+
) -> LoadedExecutable: ...
563+
564+
554565
class CpuCollectives: ...
555566

556567
def make_gloo_tcp_collectives(

jaxlib/xla_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def computation_count():
304304

305305
XlaComputation = _xla.XlaComputation
306306
Client = _xla.Client
307+
CompileOnlyPyClient = _xla.CompileOnlyPyClient
307308
Memory = _xla.Memory
308309
Array = _xla.Array
309310
ArrayImpl = _xla.ArrayImpl

jaxlib/xla_client.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ from jaxlib._jax import ArrayCopySemantics as ArrayCopySemantics
2424
from jaxlib._jax import ArrayImpl as ArrayImpl
2525
from jaxlib._jax import AutotuneCacheMode as AutotuneCacheMode
2626
from jaxlib._jax import Client as Client
27+
from jaxlib._jax import CompileOnlyPyClient as CompileOnlyPyClient
2728
from jaxlib._jax import CompileOptions as CompileOptions
2829
from jaxlib._jax import Device as Device
2930
from jaxlib._jax import DeviceAssignment as DeviceAssignment

0 commit comments

Comments
 (0)