Skip to content

Commit 10c30ed

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
Rename backend.compile to backend.compile_and_load.
Part of a larger refactor. Today, `compile` returns a loaded executable i.e., fuses the compile and load functions. Eventually, `compile` should return an unloaded executable and `load` should return a loaded exectuable; the default jit path will still return a loaded executable. PiperOrigin-RevId: 761951858
1 parent a827a27 commit 10c30ed

File tree

3 files changed

+60
-12
lines changed

3 files changed

+60
-12
lines changed

jax/_src/compiler.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from jax._src import profiler
3535
from jax._src import traceback_util
3636
from jax._src.interpreters import mlir
37+
from jax._src.lib import jaxlib_extension_version
3738
from jax._src.lib import xla_client as xc
39+
from jax._src.lib import _jax
3840
from jax._src.lib.mlir import ir
3941
import numpy as np
4042

@@ -291,6 +293,19 @@ def backend_compile(
291293
executable_devices: xc.DeviceList,
292294
options: xc.CompileOptions,
293295
host_callbacks: Sequence[Any],
296+
) -> xc.LoadedExecutable:
297+
return backend_compile_and_load(
298+
backend, module, executable_devices, options, host_callbacks
299+
)
300+
301+
302+
@profiler.annotate_function
303+
def backend_compile_and_load(
304+
backend: xc.Client,
305+
module: ir.Module,
306+
executable_devices: xc.DeviceList,
307+
options: xc.CompileOptions,
308+
host_callbacks: Sequence[Any],
294309
) -> xc.LoadedExecutable:
295310
sym_name = module.operation.attributes['sym_name']
296311
module_name = ir.StringAttr(sym_name).value
@@ -315,18 +330,40 @@ def backend_compile(
315330
try:
316331
# we use a separate function call to ensure that XLA compilation appears
317332
# separately in Python profiling results
318-
if host_callbacks:
333+
# TODO(dsuo): Simplify this logic once backend_compile actually returns an
334+
# unloaded executable.
335+
if jaxlib_extension_version < 343 or (
336+
jaxlib_extension_version >= 343
337+
and isinstance(backend, _jax.CompileOnlyPyClient)
338+
):
339+
if host_callbacks:
340+
return backend.compile(
341+
built_c,
342+
executable_devices=executable_devices, # type: ignore
343+
compile_options=options,
344+
host_callbacks=host_callbacks,
345+
)
346+
# Some backends don't have `host_callbacks` option yet
347+
# TODO(sharadmv): remove this fallback when all backends allow `compile`
348+
# to take in `host_callbacks`
319349
return backend.compile(
350+
built_c, executable_devices=executable_devices, compile_options=options) # type: ignore
351+
else:
352+
if host_callbacks:
353+
return backend.compile_and_load(
354+
built_c,
355+
executable_devices=executable_devices,
356+
compile_options=options,
357+
host_callbacks=host_callbacks,
358+
)
359+
# Some backends don't have `host_callbacks` option yet
360+
# TODO(sharadmv): remove this fallback when all backends allow `compile`
361+
# to take in `host_callbacks`
362+
return backend.compile_and_load(
320363
built_c,
321-
executable_devices=executable_devices, # type: ignore
364+
executable_devices=executable_devices,
322365
compile_options=options,
323-
host_callbacks=host_callbacks,
324366
)
325-
# Some backends don't have `host_callbacks` option yet
326-
# TODO(sharadmv): remove this fallback when all backends allow `compile`
327-
# to take in `host_callbacks`
328-
return backend.compile(
329-
built_c, executable_devices=executable_devices, compile_options=options) # type: ignore
330367
except xc.XlaRuntimeError as e:
331368
for error_handler in _XLA_RUNTIME_ERROR_HANDLERS:
332369
handler_result = error_handler(e)
@@ -391,7 +428,7 @@ def compile_or_get_cached(
391428
)
392429

393430
if cache_key is None:
394-
return backend_compile(
431+
return backend_compile_and_load(
395432
backend, computation, executable_devices, compile_options,
396433
host_callbacks)
397434

@@ -419,7 +456,7 @@ def compile_or_get_cached(
419456
config.share_binary_between_hosts.value
420457
and is_multi_process
421458
and distributed.global_state.client is not None
422-
# Host callbacks are currently baked into the HLO module so we cant share
459+
# Host callbacks are currently baked into the HLO module so we can't share
423460
# them.
424461
and len(host_callbacks) == 0
425462
):
@@ -705,7 +742,7 @@ def _compile_and_write_cache(
705742
cache_key: str,
706743
) -> xc.LoadedExecutable:
707744
start_time = time.monotonic()
708-
executable = backend_compile(
745+
executable = backend_compile_and_load(
709746
backend, computation, executable_devices, compile_options, host_callbacks
710747
)
711748
compile_time = time.monotonic() - start_time

jaxlib/_jax/__init__.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,17 @@ class Client:
551551
) -> PjRtLayout: ...
552552
def __getattr__(self, name: str) -> Any: ...
553553

554+
555+
class CompileOnlyPyClient(Client):
556+
def compile(
557+
self,
558+
computation: str | bytes,
559+
executable_devices: DeviceList | Sequence[Device],
560+
compile_options: CompileOptions = ...,
561+
host_callbacks: Sequence[Any] = ...,
562+
) -> LoadedExecutable: ...
563+
564+
554565
class CpuCollectives: ...
555566

556567
def make_gloo_tcp_collectives(

jaxlib/xla_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
# Just an internal arbitrary increasing number to help with backward-compatible
4545
# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version.
46-
_version = 342
46+
_version = 343
4747

4848
# An internal increasing version number for protecting jaxlib code against
4949
# ifrt changes.

0 commit comments

Comments
 (0)