Skip to content

Commit 7d13c56

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[jaxlib] Add CompileOnlyPyClient to xla_client.
We have users of CompileOnlyPyClient that use `backend.compile` as we eventually intend it (i.e., return `ExecutableRef`, possibly `PyExecutable` eventually, instead of `PyLoadedExectuable`). PiperOrigin-RevId: 762440439
1 parent e989e23 commit 7d13c56

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

jaxlib/_jax/__init__.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,17 @@ class Client:
553553
) -> PjRtLayout: ...
554554
def __getattr__(self, name: str) -> Any: ...
555555

556+
557+
class CompileOnlyPyClient(Client):
558+
def compile(
559+
self,
560+
computation: str | bytes,
561+
executable_devices: DeviceList | Sequence[Device],
562+
compile_options: CompileOptions = ...,
563+
host_callbacks: Sequence[Any] = ...,
564+
) -> LoadedExecutable: ...
565+
566+
556567
class CpuCollectives: ...
557568

558569
def make_gloo_tcp_collectives(

jaxlib/xla_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
# Just an internal arbitrary increasing number to help with backward-compatible
4545
# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version.
46-
_version = 344
46+
_version = 345
4747

4848
# An internal increasing version number for protecting jaxlib code against
4949
# ifrt changes.

0 commit comments

Comments
 (0)