Skip to content

Commit

Permalink
Further changes to use jax.extend.core instead of jax.core.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705429546
  • Loading branch information
tomhennigan authored and copybara-github committed Dec 12, 2024
1 parent cfe8480 commit b3541fa
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
8 changes: 5 additions & 3 deletions haiku/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,13 @@ hk_py_library(
name = "dot",
srcs = ["dot.py"],
deps = [
":config",
":data_structures",
":module",
":utils",
# pip: jax
# pip: jax:extend
# pip: tree
# pip: jax/extend:core
# pip: tree # build_cleaner: keep
],
)

Expand Down Expand Up @@ -579,14 +579,14 @@ hk_py_test(
name = "jaxpr_info_test",
srcs = ["jaxpr_info_test.py"],
deps = [
":config",
":conv",
":jaxpr_info",
":module",
":transform",
# pip: absl/logging
# pip: absl/testing:absltest
# pip: jax
# pip: jax/extend:core
# pip: numpy
],
)
Expand Down Expand Up @@ -870,6 +870,7 @@ hk_py_test(
# pip: absl/testing:absltest
# pip: absl/testing:parameterized
# pip: jax
# pip: jax/extend:core
# pip: numpy
],
)
Expand All @@ -894,6 +895,7 @@ hk_py_test(
# pip: absl/testing:absltest
# pip: absl/testing:parameterized
# pip: jax
# pip: jax/extend:core
# pip: numpy
],
)
Expand Down
3 changes: 2 additions & 1 deletion haiku/_src/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import jax
import jax.core
from jax.experimental import pjit
from jax.extend import core as jax_core
from jax.extend import linear_util as lu


Expand Down Expand Up @@ -204,7 +205,7 @@ def process_primitive(self, primitive, tracers, params):
vals = [self.to_val(t) for t in tracers]
val_out = primitive.bind_with_trace(self.parent_trace, vals, params)
if primitive is pjit.pjit_p:
f = jax.core.jaxpr_as_fun(params['jaxpr'])
f = jax_core.jaxpr_as_fun(params['jaxpr'])
f.__name__ = params['name']
fun = lu.wrap_init(f)
return self.process_call(primitive, fun, tracers, params)
Expand Down
3 changes: 2 additions & 1 deletion haiku/_src/jaxpr_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from haiku._src import module
from haiku._src import transform
import jax
from jax.extend import core as jax_core
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -62,7 +63,7 @@ def add(x, y):

def test_compute_flops(self):

def _compute_flops(eqn: jax.core.JaxprEqn,
def _compute_flops(eqn: jax_core.JaxprEqn,
expression: jaxpr_info.Expression) -> int:
del expression
return max(np.prod(var.aval.shape) for var in eqn.invars) # pytype: disable=attribute-error
Expand Down
3 changes: 2 additions & 1 deletion haiku/_src/stateful_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from haiku._src import transform

import jax
from jax.extend import core as jax_core
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -876,7 +877,7 @@ def b_impl(x):
backward()
return (x,)

prim = jax.core.Primitive("hk_callback")
prim = jax_core.Primitive("hk_callback")
prim.def_impl(f_impl)
prim.def_abstract_eval(f_impl)
jax.interpreters.ad.deflinear(prim, b_impl)
Expand Down

0 comments on commit b3541fa

Please sign in to comment.