Skip to content

Commit

Permalink
Migrate from jax.core to jax.extend.core for several deprecated symbols
Browse files Browse the repository at this point in the history
A number of symbols from jax.core are deprecated as of recent JAX releases; some of them are newly available in jax.extend.core.

PiperOrigin-RevId: 707926331
  • Loading branch information
Jake VanderPlas authored and KfacJaxDev committed Dec 19, 2024
1 parent aaf3064 commit 729bb56
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 74 deletions.
7 changes: 4 additions & 3 deletions kfac_jax/_src/curvature_blocks/curvature_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Sequence

import jax
import jax.extend as jex
import jax.numpy as jnp
import jax.scipy
from kfac_jax._src import layers_and_loss_tags as tags
Expand Down Expand Up @@ -81,22 +82,22 @@ def name(self) -> str:

@property
def layer_tag_primitive(self) -> tags.LayerTag:
"""The :class:`jax.core.Primitive` corresponding to the block's tag equation."""
"""The :class:`jex.core.Primitive` corresponding to the block's tag equation."""

primitive = self._layer_tag_eq.primitive
assert isinstance(primitive, tgm.tags.LayerTag)

return primitive

@property
def parameter_variables(self) -> tuple[jax.core.Var, ...]:
def parameter_variables(self) -> tuple[jex.core.Var, ...]:
"""The parameter variables of the underlying Jax equation."""

param_vars = []

for p in tags.layer_eqn_data(self._layer_tag_eq).params:

assert isinstance(p, jax.core.Var)
assert isinstance(p, jex.core.Var)
param_vars.append(p)

return tuple(param_vars)
Expand Down
26 changes: 13 additions & 13 deletions kfac_jax/_src/layers_and_loss_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Generic, Sequence, TypeVar

import jax
from jax import core
import jax.extend as jex


# Types for annotation
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_loss_outputs(
return tuple(kwargs[name] for name in meta.parameter_dependants)


class LossTag(core.Primitive):
class LossTag(jex.core.Primitive):
"""A Jax primitive for tagging K-FAC losses.
The primitive is no-op at runtime, however its goal is to tag (annotate) the
Expand All @@ -103,7 +103,7 @@ class LossTag(core.Primitive):
curvature matrix.
"""

# Whether the primitive returns multiple outputs (from core.Primitive)
# Whether the primitive returns multiple outputs (from jex.core.Primitive)
multiple_results = True

def __init__(self):
Expand Down Expand Up @@ -175,9 +175,9 @@ def _batching(


def loss_eqn_parameter_dependants(
eqn: jax.core.JaxprEqn,
eqn: jex.core.JaxprEqn,
raise_an_error: bool = True,
) -> list[jax.core.Var]:
) -> list[jex.core.Var]:
"""Returns the parameter dependants variables from the give loss equation."""
if not isinstance(eqn.primitive, LossTag):
if raise_an_error:
Expand All @@ -192,7 +192,7 @@ def loss_eqn_parameter_dependants(


def loss_eqn_construct_loss(
eqn: jax.core.JaxprEqn,
eqn: jex.core.JaxprEqn,
*args: Array,
) -> Any:
"""Constructs an instance of the corresponding :class:`~LossFunction` class."""
Expand All @@ -206,7 +206,7 @@ def loss_eqn_construct_loss(
return meta.loss_class(**kwargs)


def loss_eqn_class_name(eqn: jax.core.JaxprEqn) -> str:
def loss_eqn_class_name(eqn: jex.core.JaxprEqn) -> str:
"""The name of the underlying `~LossFunction` class."""

if not isinstance(eqn.primitive, LossTag):
Expand Down Expand Up @@ -253,7 +253,7 @@ def get_and_verify_layer_meta(
return meta


class LayerTag(core.Primitive):
class LayerTag(jex.core.Primitive):
"""A Jax primitive for tagging K-FAC layers.
The primitive is no-op at runtime, however its goal is to tag (annotate) the
Expand Down Expand Up @@ -347,9 +347,9 @@ def _batching(


def layer_eqn_data( # pytype: disable=invalid-annotation
eqn: jax.core.JaxprEqn,
eqn: jex.core.JaxprEqn,
raise_an_error: bool = True,
) -> LayerData[jax.core.Var]:
) -> LayerData[jex.core.Var]:

if isinstance(eqn.primitive, LayerTag):
return eqn.primitive.layer_data(eqn.invars, eqn.params, str(eqn))
Expand All @@ -360,7 +360,7 @@ def layer_eqn_data( # pytype: disable=invalid-annotation
return LayerData(inputs=(), outputs=(), params=())


def layer_eqn_name(eqn: jax.core.JaxprEqn) -> str:
def layer_eqn_name(eqn: jex.core.JaxprEqn) -> str:
meta = get_and_verify_layer_meta(eqn.invars, eqn.params)
if meta.name is None:
raise ValueError("Layer name must be provided at this stage.")
Expand Down Expand Up @@ -460,11 +460,11 @@ def register_scale_and_shift(
)


class LossTagEqn(core.JaxprEqn):
class LossTagEqn(jex.core.JaxprEqn):
"""A class used only for annotation purposes."""
primitive: LossTag


class LayerTagEqn(core.JaxprEqn):
class LayerTagEqn(jex.core.JaxprEqn):
"""A class used only for annotation purposes."""
primitive: LayerTag
61 changes: 26 additions & 35 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,7 @@
from absl import logging
import immutabledict
import jax

jax_version = (
jax.__version_info__ if hasattr(jax, "__version_info__")
else tuple(map(int, jax.__version__.split("."))))

if jax_version > (0, 4, 11):
import jax.extend as jax_extend # pylint: disable=g-import-not-at-top
import jax.extend as jex

import jax.numpy as jnp # pylint: disable=g-import-not-at-top
from kfac_jax._src import layers_and_loss_tags as tags
Expand All @@ -42,11 +36,11 @@
# Types for annotation
Array = utils.Array
PyTreeDef = utils.PyTreeDef
Var = jax.core.Var
Var = jex.core.Var
Vars = Sequence[Var]
Jaxpr = jax.core.Jaxpr
ClosedJaxpr = jax.core.ClosedJaxpr
JaxprEqn = jax.core.JaxprEqn
Jaxpr = jex.core.Jaxpr
ClosedJaxpr = jex.core.ClosedJaxpr
JaxprEqn = jex.core.JaxprEqn
JaxprEqns = Sequence[JaxprEqn]
T = TypeVar("T")
J = TypeVar("J", Jaxpr, ClosedJaxpr)
Expand All @@ -64,10 +58,7 @@ def eval_jaxpr_eqn(eqn: JaxprEqn, in_values: list[T]) -> list[T]:

subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)

if jax_version > (0, 4, 11):
user_context = jax_extend.source_info_util.user_context
else:
user_context = jax.core.source_info_util.user_context # pytype: disable=module-attr
user_context = jex.source_info_util.user_context

with user_context(eqn.source_info.traceback):
output = eqn.primitive.bind(*subfuns, *in_values, **bind_params)
Expand Down Expand Up @@ -245,9 +236,9 @@ class JaxprGraph:
it.
manual_registrations: Any layer tag equations that have been manually
registered.
jaxpr: The underlying :class:`jax.core.Jaxpr` part of ``self.closed_jaxpr``.
jaxpr: The underlying :class:`Jaxpr` part of ``self.closed_jaxpr``.
consts: The underlying constants part ``self.closed_jaxpr``.
outvars: The output variables of the underlying :class:`jax.core.Jaxpr` part
outvars: The output variables of the underlying :class:`Jaxpr` part
of ``self.closed_jaxpr``.
"""
name: str
Expand Down Expand Up @@ -294,7 +285,7 @@ def sub_graph_eqns(
eqns.append(next_eqn)

for v in next_eqn.invars:
if (not isinstance(v, jax.core.Literal) and v not in root_vars and
if (not isinstance(v, jex.core.Literal) and v not in root_vars and
v not in processed_vars and v in self.var_to_creation_op):
to_process_eqns.append(self.var_to_creation_op[v])
processed_vars.add(v)
Expand Down Expand Up @@ -383,7 +374,7 @@ def make_jax_graph(
eqns.append(eqn)

sub_graph_vars.update(
v for v in eqn.invars if not isinstance(v, jax.core.Literal)
v for v in eqn.invars if not isinstance(v, jex.core.Literal)
)

consts_i = [
Expand Down Expand Up @@ -461,8 +452,8 @@ class GraphPattern:
in_values_preprocessor: A function that can optionally modify the in_vals
passed to the tag_primitive, from those that are usually the input to
the jaxpr.
jaxpr: The underlying :class:`jax.core.Jaxpr` represented by the pattern.
param_vars: The list of :class:`jax.core.Var` that correspond to parameters
jaxpr: The underlying :class:`Jaxpr` represented by the pattern.
param_vars: The list of :class:`Var` that correspond to parameters
in the pattern.
graph: A :class:`JaxprGraph` representation of the pattern.
"""
Expand Down Expand Up @@ -633,7 +624,7 @@ def add_vars_if_possible(
If at least one of the pattern variables is a parameter, but the
corresponding graph variable is not or vise-versa, the method does not
update the current variables map and returns ``False``. Similarly, if at
least one of the graph variables is a :class:`~jax.core.Literal` (meaning a
least one of the graph variables is a :class:`iteral` (meaning a
constant, independent of the function inputs) and the corresponding
pattern variable is not an input to the pattern, it returns ``False``. In
all other cases it updates the map and returns ``True``.
Expand All @@ -648,12 +639,12 @@ def add_vars_if_possible(
"""
for var1, var2 in zip(eqn_vars, graph_vars):

var2_matchable = isinstance(var2, jax.core.Var) and (
var2_matchable = isinstance(var2, jex.core.Var) and (
var2 in matchable_graph_params)

if (var1 in param_variables and not var2_matchable or
var1 not in param_variables and var2_matchable or
(isinstance(var2, jax.core.Literal) and var1 not in input_vars)):
(isinstance(var2, jex.core.Literal) and var1 not in input_vars)):
return False

current_variables_map.update(zip(eqn_vars, graph_vars))
Expand Down Expand Up @@ -788,7 +779,7 @@ def match_pattern(
for k, v in match_variables_map.items():

if (k not in pattern.graph.jaxpr.invars and
not isinstance(v, jax.core.Literal)):
not isinstance(v, jex.core.Literal)):

creation_op = graph.var_to_creation_op[v]

Expand Down Expand Up @@ -883,14 +874,14 @@ def find_layer_tags_and_patterns(


def read_env(
env: dict[jax.core.Var, T],
env: dict[jex.core.Var, T],
variables: list[jax.core.Atom],
) -> list[T]:
"""Reads from the variable-to-array environment during tracing."""
result = []
assert isinstance(variables, list)
for v in variables:
if isinstance(v, jax.core.Literal):
if isinstance(v, jex.core.Literal):
# Literals are values baked into the Jaxpr
result.append(v.val)
elif isinstance(v, jax.core.DropVar):
Expand All @@ -901,8 +892,8 @@ def read_env(


def write_env(
env: dict[jax.core.Var, T],
variables: list[jax.core.Var],
env: dict[jex.core.Var, T],
variables: list[jex.core.Var],
values: list[T],
) -> None:
"""Writes to the variable-to-array environment during tracing."""
Expand Down Expand Up @@ -979,7 +970,7 @@ def clean_jaxpr(

final_outvars.append(var)

if not isinstance(var, jax.core.Literal):
if not isinstance(var, jex.core.Literal):
dependants.add(var)

for eqn in reversed(closed_jaxpr.jaxpr.eqns):
Expand Down Expand Up @@ -1035,7 +1026,7 @@ def clean_jaxpr(
if check:
eqns.append(eqn)
new_dependants = set(v for v in eqn.invars
if not isinstance(v, jax.core.Literal))
if not isinstance(v, jex.core.Literal))
dependants = dependants.union(new_dependants)

# Dependants should only be invars
Expand Down Expand Up @@ -1112,7 +1103,7 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J:

# We ignore broadcasting of constants
if (eqn.primitive.name == "broadcast_in_dim" and
not all(isinstance(v, jax.core.Literal) for v in eqn.invars)):
not all(isinstance(v, jex.core.Literal) for v in eqn.invars)):

if eqn.invars[0] in broadcasts_outputs:
# Construct a merged equation from the previous and current one
Expand All @@ -1139,7 +1130,7 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J:

else:
for v in eqn.invars:
if not isinstance(v, jax.core.Literal) and v in broadcasts_outputs:
if not isinstance(v, jex.core.Literal) and v in broadcasts_outputs:
eqns.append(broadcasts_outputs[v])

eqns.append(eqn)
Expand Down Expand Up @@ -1688,7 +1679,7 @@ def __init__(
):
self._func_graph = func_graph
self._tag_locations = tag_locations
self._flat_func = jax.core.jaxpr_as_fun(func_graph.closed_jaxpr)
self._flat_func = jex.core.jaxpr_as_fun(func_graph.closed_jaxpr)
self._param_labels = self._compute_parameter_labels()

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -1770,7 +1761,7 @@ def _auto_register_tags(

eqns_for_registration.append(eqn)
sub_graph_vars.update(
v for v in eqn.invars if not isinstance(v, jax.core.Literal))
v for v in eqn.invars if not isinstance(v, jex.core.Literal))

eqns_for_registration = eqns_for_registration[::-1]

Expand Down
15 changes: 8 additions & 7 deletions kfac_jax/_src/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from absl import logging
import jax
import jax.extend as jex
import jax.numpy as jnp
from kfac_jax._src import layers_and_loss_tags as tags
from kfac_jax._src import loss_functions
Expand All @@ -32,7 +33,7 @@
Params = utils.Params
FuncArgs = utils.FuncArgs
FuncOuts = utils.FuncOuts
Var = jax.core.Var
Var = jex.core.Var
LossFunction = loss_functions.LossFunction
LossFunctionInputs = loss_functions.LossFunctionInputs

Expand Down Expand Up @@ -80,7 +81,7 @@ def tree_unflatten(cls, aux_data, children):
tuple[LayerVjpData[Array], ...], # pytype: disable=invalid-annotation
],
]
JaxprOrClosedJaxpr = jax.core.Jaxpr | jax.core.ClosedJaxpr
JaxprOrClosedJaxpr = jex.core.Jaxpr | jex.core.ClosedJaxpr


def shape_and_type(x: Array) -> tuple[Shape, jnp.dtype]:
Expand All @@ -99,7 +100,7 @@ def make_cache_key(


def extract_tags(
jaxpr: jax.core.Jaxpr,
jaxpr: jex.core.Jaxpr,
) -> tuple[tuple[tags.LayerTagEqn, ...], tuple[tags.LossTagEqn, ...]]:
"""Extracts the layer and the loss tags from the given Jaxpr."""

Expand Down Expand Up @@ -199,7 +200,7 @@ class ProcessedJaxpr(utils.Finalizable):

def __init__(
self,
jaxpr: jax.core.Jaxpr,
jaxpr: jex.core.Jaxpr,
consts: list[Any],
in_tree: utils.PyTreeDef,
params_index: int,
Expand Down Expand Up @@ -819,16 +820,16 @@ def forward_aux(
own_func_args = primal_func_args

# Mapping from variable -> value
env: dict[jax.core.Var, Array] = {}
env: dict[jex.core.Var, Array] = {}
read = functools.partial(tgm.read_env, env)

def write(variables: list[jax.core.Var], values: list[Array]) -> None:
def write(variables: list[jex.core.Var], values: list[Array]) -> None:
# if not isinstance(variables, list):
# variables = [variables]
tgm.write_env(env, variables, values)

for v in variables:
if not isinstance(v, jax.core.Literal) and v in aux:
if not isinstance(v, jex.core.Literal) and v in aux:
env[v] = env[v] + aux[v]

# Bind args and consts to environment
Expand Down
Loading

0 comments on commit 729bb56

Please sign in to comment.