From bdcc33a56b6049615301f97e68b26f7cd005c71b Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 22 Jan 2025 18:50:44 -0500 Subject: [PATCH] [nnx] improve Module docs --- .../haiku_migration_guide.rst | 1 + examples/gemma/transformer_test.py | 2 +- flax/linen/linear.py | 22 +- flax/linen/module.py | 12 +- flax/linen/stochastic.py | 4 +- flax/nnx/nn/attention.py | 2 +- flax/nnx/nn/linear.py | 34 +- flax/nnx/nn/lora.py | 4 +- flax/nnx/nn/normalization.py | 8 +- flax/nnx/nn/recurrent.py | 384 +++++++++--------- flax/nnx/nn/stochastic.py | 6 +- flax/nnx/object.py | 8 +- flax/nnx/training/metrics.py | 8 +- flax/nnx/training/optimizer.py | 10 +- flax/training/train_state.py | 4 +- tests/linen/linen_test.py | 2 +- tests/nnx/metrics_test.py | 22 +- tests/nnx/transforms_test.py | 12 +- uv.lock | 84 ++-- 19 files changed, 323 insertions(+), 306 deletions(-) diff --git a/docs/guides/converting_and_upgrading/haiku_migration_guide.rst b/docs/guides/converting_and_upgrading/haiku_migration_guide.rst index ce93f62e..cb01ad38 100644 --- a/docs/guides/converting_and_upgrading/haiku_migration_guide.rst +++ b/docs/guides/converting_and_upgrading/haiku_migration_guide.rst @@ -12,6 +12,7 @@ and highlight the differences between the two libraries. from jax import random import optax import flax.linen as nn + import haiku as hk Basic Example ----------------- diff --git a/examples/gemma/transformer_test.py b/examples/gemma/transformer_test.py index b2c852af..efb0cf22 100644 --- a/examples/gemma/transformer_test.py +++ b/examples/gemma/transformer_test.py @@ -150,7 +150,7 @@ def test_logit_softcap( all_outputs = [] for config in [config_soft_cap, config_no_soft_cap]: transformer = transformer_lib.Transformer( - config=config, rngs=nnx.Rngs(params=0) + config=config, rngs=nnx.Rngs(params=1) ) cache = transformer.init_cache( cache_size=cache_size, diff --git a/flax/linen/linear.py b/flax/linen/linear.py index babe809a..e5b5c3ba 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -1068,20 +1068,20 @@ class Embed(Module): >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> variables = layer.init(jax.random.key(0), indices_input) >>> variables - {'params': {'embedding': Array([[-0.28884724, 0.19018005, -0.414205 ], - [-0.11768015, -0.54618824, -0.3789283 ], - [ 0.30428642, 0.49511626, 0.01706631], - [-0.0982546 , -0.43055868, 0.20654906], - [-0.688412 , -0.46882293, 0.26723292]], dtype=float32)}} + {'params': {'embedding': Array([[ 0.04396089, -0.9328513 , -0.97328115], + [ 0.41147125, 0.66334754, 0.49469155], + [ 0.09719624, 0.49861377, 0.49519277], + [-0.13316602, 0.6697022 , 0.3710195 ], + [-0.5039532 , 0.287319 , 1.4369922 ]], dtype=float32)}} >>> # get the first three and last three embeddings >>> layer.apply(variables, indices_input) - Array([[[-0.28884724, 0.19018005, -0.414205 ], - [-0.11768015, -0.54618824, -0.3789283 ], - [ 0.30428642, 0.49511626, 0.01706631]], + Array([[[ 0.04396089, -0.9328513 , -0.97328115], + [ 0.41147125, 0.66334754, 0.49469155], + [ 0.09719624, 0.49861377, 0.49519277]], - [[-0.688412 , -0.46882293, 0.26723292], - [-0.0982546 , -0.43055868, 0.20654906], - [ 0.30428642, 0.49511626, 0.01706631]]], dtype=float32) + [[-0.5039532 , 0.287319 , 1.4369922 ], + [-0.13316602, 0.6697022 , 0.3710195 ], + [ 0.09719624, 0.49861377, 0.49519277]]], dtype=float32) Attributes: num_embeddings: number of embeddings / vocab size. diff --git a/flax/linen/module.py b/flax/linen/module.py index 52e5a059..4c78244a 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -2684,18 +2684,18 @@ def perturb( >>> variables = model.init(jax.random.key(0), x) >>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y) >>> print(intm_grads['perturbations']['dense3']) - [[-1.456924 -0.44332537 0.02422847] - [-1.456924 -0.44332537 0.02422847]] + [[-0.04684732 0.06573904 -0.3194327 ] + [-0.04684732 0.06573904 -0.3194327 ]] If perturbations are not passed to ``apply``, ``perturb`` behaves like a no-op so you can easily disable the behavior when not needed:: >>> model.apply(variables, x) # works as expected - Array([[-1.0980128 , -0.67961735], - [-1.0980128 , -0.67961735]], dtype=float32) + Array([[-0.04579116, 0.50412744], + [-0.04579116, 0.50412744]], dtype=float32) >>> model.apply({'params': variables['params']}, x) # behaves like a no-op - Array([[-1.0980128 , -0.67961735], - [-1.0980128 , -0.67961735]], dtype=float32) + Array([[-0.04579116, 0.50412744], + [-0.04579116, 0.50412744]], dtype=float32) >>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y) >>> 'perturbations' not in intm_grads True diff --git a/flax/linen/stochastic.py b/flax/linen/stochastic.py index 629b5dcf..427d5951 100644 --- a/flax/linen/stochastic.py +++ b/flax/linen/stochastic.py @@ -47,9 +47,9 @@ class Dropout(Module): >>> x = jnp.ones((1, 3)) >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout >>> model.apply(variables, x, train=False) # don't use dropout - Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32) + Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout - Array([[ 0. , -1.1856356, -1.0369378, 0. ]], dtype=float32) + Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) Attributes: rate: the dropout probability. (_not_ the keep rate!) diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 185e0bd9..5bc54ad7 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -244,7 +244,7 @@ class MultiHeadAttention(Module): >>> assert (layer(q) == layer(q, q)).all() >>> assert (layer(q) == layer(q, q, q)).all() - Attributes: + Args: num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. in_features: int or tuple with number of input features. diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index 230f1d35..71bf9313 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -119,7 +119,7 @@ class LinearGeneral(Module): >>> y.shape (16, 4, 5) - Attributes: + Args: in_features: int or tuple with number of input features. out_features: int or tuple with number of output features. axis: int or tuple with axes to apply the transformation on. For instance, @@ -301,7 +301,7 @@ class Linear(Module): ) }) - Attributes: + Args: in_features: the number of input features. out_features: the number of output features. use_bias: whether to add a bias to the output (default: True). @@ -393,7 +393,7 @@ class Einsum(Module): >>> y.shape (16, 11, 8, 4) - Attributes: + Args: einsum_str: a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of ``einsum_str`` @@ -572,7 +572,7 @@ class Conv(Module): ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x) - Attributes: + Args: in_features: int or tuple with number of input features. out_features: int or tuple with number of output features. kernel_size: shape of the convolutional kernel. For 1D convolution, @@ -823,7 +823,7 @@ class ConvTranspose(Module): ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x) - Attributes: + Args: in_features: int or tuple with number of input features. out_features: int or tuple with number of output features. kernel_size: shape of the convolutional kernel. For 1D convolution, @@ -1065,23 +1065,23 @@ class Embed(Module): State({ 'embedding': VariableState( # 15 (60 B) type=Param, - value=Array([[-0.90411377, -0.3648777 , -1.1083648 ], - [ 0.01070483, 0.27923733, 1.7487359 ], - [ 0.59161806, 0.8660184 , 1.2838588 ], - [-0.748139 , -0.15856352, 0.06061118], - [-0.4769059 , -0.6607095 , 0.46697947]], dtype=float32) + value=Array([[ 0.57966787, -0.523274 , -0.43195742], + [-0.676289 , -0.50300646, 0.33996582], + [ 0.41796115, -0.59212935, 0.95934135], + [-1.0917838 , -0.7441663 , 0.07713798], + [-0.66570747, 0.13815777, 1.007365 ]], dtype=float32) ) }) >>> # get the first three and last three embeddings >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> layer(indices_input) - Array([[[-0.90411377, -0.3648777 , -1.1083648 ], - [ 0.01070483, 0.27923733, 1.7487359 ], - [ 0.59161806, 0.8660184 , 1.2838588 ]], + Array([[[ 0.57966787, -0.523274 , -0.43195742], + [-0.676289 , -0.50300646, 0.33996582], + [ 0.41796115, -0.59212935, 0.95934135]], - [[-0.4769059 , -0.6607095 , 0.46697947], - [-0.748139 , -0.15856352, 0.06061118], - [ 0.59161806, 0.8660184 , 1.2838588 ]]], dtype=float32) + [[-0.66570747, 0.13815777, 1.007365 ], + [-1.0917838 , -0.7441663 , 0.07713798], + [ 0.41796115, -0.59212935, 0.95934135]]], dtype=float32) A parameterized function from integers [0, ``num_embeddings``) to ``features``-dimensional vectors. This ``Module`` will create an ``embedding`` @@ -1092,7 +1092,7 @@ class Embed(Module): broadcast the ``embedding`` matrix to input shape with ``features`` dimension appended. - Attributes: + Args: num_embeddings: number of embeddings / vocab size. features: number of feature dimensions for each embedding. dtype: the dtype of the embedding vectors (default: same as embedding). diff --git a/flax/nnx/nn/lora.py b/flax/nnx/nn/lora.py index dbba23fd..3b13951d 100644 --- a/flax/nnx/nn/lora.py +++ b/flax/nnx/nn/lora.py @@ -61,7 +61,7 @@ class LoRA(Module): >>> y.shape (16, 4) - Attributes: + Args: in_features: the number of input features. lora_rank: the rank of the LoRA dimension. out_features: the number of output features. @@ -133,7 +133,7 @@ class LoRALinear(Linear): >>> y.shape (16, 4) - Attributes: + Args: in_features: the number of input features. out_features: the number of output features. lora_rank: the rank of the LoRA dimension. diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 928d9cf2..921a030f 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -236,7 +236,7 @@ class BatchNorm(Module): >>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all() >>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all() - Attributes: + Args: num_features: the number of input features. use_running_average: if True, the stored batch statistics will be used instead of computing the batch statistics on the input. @@ -407,7 +407,7 @@ class LayerNorm(Module): >>> y = layer(x) - Attributes: + Args: num_features: the number of input features. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). @@ -539,7 +539,7 @@ class RMSNorm(Module): >>> y = layer(x) - Attributes: + Args: num_features: the number of input features. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). @@ -670,7 +670,7 @@ class GroupNorm(Module): >>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y2) - Attributes: + Args: num_features: the number of input features/channels. num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper. diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index 6ce039c5..0a96aaa6 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -31,7 +31,7 @@ from flax.nnx.nn.linear import Linear from flax.nnx.nn.activations import sigmoid from flax.nnx.nn.activations import tanh -from flax.nnx.transforms.iteration import Carry, StateAxes +from flax.nnx.transforms import iteration from flax.typing import ( Dtype, Initializer, @@ -44,7 +44,7 @@ A = TypeVar("A") Array = jax.Array Output = Any - +Carry = Any class RNNCellBase(Module): """RNN cell base class.""" @@ -213,7 +213,7 @@ def num_feature_axes(self) -> int: class OptimizedLSTMCell(RNNCellBase): - r"""More efficient LSTM Cell that concatenates state components before matmul. + r"""More efficient LSTM Cell that concatenates state components before matmul. The parameters are compatible with ``LSTMCell``. Note that this cell is often faster than ``LSTMCell`` as long as the hidden size is roughly <= 2048 units. @@ -235,7 +235,7 @@ class OptimizedLSTMCell(RNNCellBase): where x is the input, h is the output of the previous time step, and c is the memory. - Attributes: + Args: gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). @@ -248,107 +248,111 @@ class OptimizedLSTMCell(RNNCellBase): param_dtype: the dtype passed to parameter initializers (default: float32). """ - def __init__( - self, - in_features: int, - hidden_features: int, - *, - gate_fn: Callable[..., Any] = sigmoid, - activation_fn: Callable[..., Any] = tanh, - kernel_init: Initializer = default_kernel_init, - recurrent_kernel_init: Initializer = initializers.orthogonal(), - bias_init: Initializer = initializers.zeros_init(), - dtype: Dtype | None = None, - param_dtype: Dtype = jnp.float32, - carry_init: Initializer = initializers.zeros_init(), - rngs: rnglib.Rngs, - ): - self.in_features = in_features - self.hidden_features = hidden_features - self.gate_fn = gate_fn - self.activation_fn = activation_fn - self.kernel_init = kernel_init - self.recurrent_kernel_init = recurrent_kernel_init - self.bias_init = bias_init - self.dtype = dtype - self.param_dtype = param_dtype - self.carry_init = carry_init - self.rngs = rngs + def __init__( + self, + in_features: int, + hidden_features: int, + *, + gate_fn: Callable[..., Any] = sigmoid, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = default_kernel_init, + recurrent_kernel_init: Initializer = initializers.orthogonal(), + bias_init: Initializer = initializers.zeros_init(), + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.rngs = rngs - # input and recurrent layers are summed so only one needs a bias. - self.dense_i = Linear( - in_features=in_features, - out_features=4 * hidden_features, - use_bias=False, - kernel_init=self.kernel_init, - dtype=self.dtype, - param_dtype=self.param_dtype, - rngs=rngs, - ) + # input and recurrent layers are summed so only one needs a bias. + self.dense_i = Linear( + in_features=in_features, + out_features=4 * hidden_features, + use_bias=False, + kernel_init=self.kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) - self.dense_h = Linear( - in_features=hidden_features, - out_features=4 * hidden_features, - use_bias=True, - kernel_init=self.recurrent_kernel_init, - bias_init=self.bias_init, - dtype=self.dtype, - param_dtype=self.param_dtype, - rngs=rngs, - ) + self.dense_h = Linear( + in_features=hidden_features, + out_features=4 * hidden_features, + use_bias=True, + kernel_init=self.recurrent_kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) - def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] - r"""An optimized long short-term memory (LSTM) cell. + def __call__( + self, carry: tuple[Array, Array], inputs: Array + ) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] + r"""An optimized long short-term memory (LSTM) cell. - Args: - carry: the hidden state of the LSTM cell, initialized using - ``LSTMCell.initialize_carry``. - inputs: an ndarray with the input for the current time step. - All dimensions except the final are considered batch dimensions. + Args: + carry: the hidden state of the LSTM cell, initialized using + ``LSTMCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. - Returns: - A tuple with the new carry and the output. - """ - c, h = carry + Returns: + A tuple with the new carry and the output. + """ + c, h = carry - # Compute combined transformations for inputs and hidden state - y = self.dense_i(inputs) + self.dense_h(h) + # Compute combined transformations for inputs and hidden state + y = self.dense_i(inputs) + self.dense_h(h) - # Split the combined transformations into individual gates - i, f, g, o = jnp.split(y, indices_or_sections=4, axis=-1) + # Split the combined transformations into individual gates + i, f, g, o = jnp.split(y, indices_or_sections=4, axis=-1) - # Apply gate activations - i = self.gate_fn(i) - f = self.gate_fn(f) - g = self.activation_fn(g) - o = self.gate_fn(o) + # Apply gate activations + i = self.gate_fn(i) + f = self.gate_fn(f) + g = self.activation_fn(g) + o = self.gate_fn(o) - # Update cell state and hidden state - new_c = f * c + i * g - new_h = o * self.activation_fn(new_c) - return (new_c, new_h), new_h + # Update cell state and hidden state + new_c = f * c + i * g + new_h = o * self.activation_fn(new_c) + return (new_c, new_h), new_h - def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> tuple[Array, Array]: # type: ignore[override] - """Initialize the RNN cell carry. + def initialize_carry( + self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None + ) -> tuple[Array, Array]: # type: ignore[override] + """Initialize the RNN cell carry. - Args: - rngs: random number generator passed to the init_fn. - input_shape: a tuple providing the shape of the input to the cell. + Args: + rngs: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. - Returns: - An initialized carry for the given RNN cell. - """ - batch_dims = input_shape[:-1] - if rngs is None: - rngs = self.rngs - mem_shape = batch_dims + (self.hidden_features,) - c = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) - h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) - return (c, h) + Returns: + An initialized carry for the given RNN cell. + """ + batch_dims = input_shape[:-1] + if rngs is None: + rngs = self.rngs + mem_shape = batch_dims + (self.hidden_features,) + c = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + return (c, h) - @property - def num_feature_axes(self) -> int: - return 1 + @property + def num_feature_axes(self) -> int: + return 1 class SimpleCell(RNNCellBase): @@ -451,7 +455,7 @@ def num_feature_axes(self) -> int: class GRUCell(RNNCellBase): - r"""GRU cell. + r"""GRU cell. The mathematical definition of the cell is as follows @@ -466,7 +470,7 @@ class GRUCell(RNNCellBase): where x is the input and h is the output of the previous time step. - Attributes: + Args: in_features: number of input features. hidden_features: number of output features. gate_fn: activation function used for gates (default: sigmoid). @@ -481,108 +485,110 @@ class GRUCell(RNNCellBase): param_dtype: the dtype passed to parameter initializers (default: float32). """ - def __init__( - self, - in_features: int, - hidden_features: int, - *, - gate_fn: Callable[..., Any] = sigmoid, - activation_fn: Callable[..., Any] = tanh, - kernel_init: Initializer = default_kernel_init, - recurrent_kernel_init: Initializer = initializers.orthogonal(), - bias_init: Initializer = initializers.zeros_init(), - dtype: Dtype | None = None, - param_dtype: Dtype = jnp.float32, - carry_init: Initializer = initializers.zeros_init(), - rngs: rnglib.Rngs, - ): - self.in_features = in_features - self.hidden_features = hidden_features - self.gate_fn = gate_fn - self.activation_fn = activation_fn - self.kernel_init = kernel_init - self.recurrent_kernel_init = recurrent_kernel_init - self.bias_init = bias_init - self.dtype = dtype - self.param_dtype = param_dtype - self.carry_init = carry_init - self.rngs = rngs + def __init__( + self, + in_features: int, + hidden_features: int, + *, + gate_fn: Callable[..., Any] = sigmoid, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = default_kernel_init, + recurrent_kernel_init: Initializer = initializers.orthogonal(), + bias_init: Initializer = initializers.zeros_init(), + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.rngs = rngs - # Combine input transformations into a single linear layer - self.dense_i = Linear( - in_features=in_features, - out_features=3 * hidden_features, # r, z, n - use_bias=True, - kernel_init=self.kernel_init, - bias_init=self.bias_init, - dtype=self.dtype, - param_dtype=self.param_dtype, - rngs=rngs, - ) + # Combine input transformations into a single linear layer + self.dense_i = Linear( + in_features=in_features, + out_features=3 * hidden_features, # r, z, n + use_bias=True, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) - self.dense_h = Linear( - in_features=hidden_features, - out_features=3 * hidden_features, # r, z, n - use_bias=False, - kernel_init=self.recurrent_kernel_init, - dtype=self.dtype, - param_dtype=self.param_dtype, - rngs=rngs, - ) + self.dense_h = Linear( + in_features=hidden_features, + out_features=3 * hidden_features, # r, z, n + use_bias=False, + kernel_init=self.recurrent_kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) - def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] - """Gated recurrent unit (GRU) cell. + def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] + """Gated recurrent unit (GRU) cell. - Args: - carry: the hidden state of the GRU cell, - initialized using ``GRUCell.initialize_carry``. - inputs: an ndarray with the input for the current time step. - All dimensions except the final are considered batch dimensions. + Args: + carry: the hidden state of the GRU cell, + initialized using ``GRUCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. - Returns: - A tuple with the new carry and the output. - """ - h = carry + Returns: + A tuple with the new carry and the output. + """ + h = carry - # Compute combined transformations for inputs and hidden state - x_transformed = self.dense_i(inputs) - h_transformed = self.dense_h(h) + # Compute combined transformations for inputs and hidden state + x_transformed = self.dense_i(inputs) + h_transformed = self.dense_h(h) - # Split the combined transformations into individual components - xi_r, xi_z, xi_n = jnp.split(x_transformed, 3, axis=-1) - hh_r, hh_z, hh_n = jnp.split(h_transformed, 3, axis=-1) + # Split the combined transformations into individual components + xi_r, xi_z, xi_n = jnp.split(x_transformed, 3, axis=-1) + hh_r, hh_z, hh_n = jnp.split(h_transformed, 3, axis=-1) - # Compute gates - r = self.gate_fn(xi_r + hh_r) - z = self.gate_fn(xi_z + hh_z) + # Compute gates + r = self.gate_fn(xi_r + hh_r) + z = self.gate_fn(xi_z + hh_z) - # Compute n with an additional linear transformation on h - n = self.activation_fn(xi_n + r * hh_n) + # Compute n with an additional linear transformation on h + n = self.activation_fn(xi_n + r * hh_n) - # Update hidden state - new_h = (1.0 - z) * n + z * h - return new_h, new_h + # Update hidden state + new_h = (1.0 - z) * n + z * h + return new_h, new_h - def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> Array: # type: ignore[override] - """Initialize the RNN cell carry. + def initialize_carry( + self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None + ) -> Array: # type: ignore[override] + """Initialize the RNN cell carry. - Args: - rngs: random number generator passed to the init_fn. - input_shape: a tuple providing the shape of the input to the cell. + Args: + rngs: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. - Returns: - An initialized carry for the given RNN cell. - """ - batch_dims = input_shape[:-1] - if rngs is None: - rngs = self.rngs - mem_shape = batch_dims + (self.hidden_features,) - h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) - return h + Returns: + An initialized carry for the given RNN cell. + """ + batch_dims = input_shape[:-1] + if rngs is None: + rngs = self.rngs + mem_shape = batch_dims + (self.hidden_features,) + h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + return h - @property - def num_feature_axes(self) -> int: - return 1 + @property + def num_feature_axes(self) -> int: + return 1 class RNN(Module): @@ -591,7 +597,7 @@ class RNN(Module): using :func:`flax.nnx.scan`. """ - state_axes: Mapping[str, int | type[Carry] | None] + state_axes: dict[str, int | type[iteration.Carry] | None] def __init__( self, @@ -602,7 +608,7 @@ def __init__( keep_order: bool = False, unroll: int = 1, rngs: rnglib.Rngs | None = None, - state_axes: Mapping[str, int | type[Carry] | None] | None = None, + state_axes: Mapping[str, int | type[iteration.Carry] | None] | None = None, broadcast_rngs: filterlib.Filter = None, ): self.cell = cell @@ -614,7 +620,7 @@ def __init__( if rngs is None: rngs = rnglib.Rngs(0) self.rngs = rngs - self.state_axes = state_axes or {...: Carry} # type: ignore + self.state_axes = state_axes or {...: iteration.Carry} # type: ignore self.broadcast_rngs = broadcast_rngs def __call__( @@ -675,14 +681,16 @@ def __call__( slice_carry = seq_lengths is not None and return_carry broadcast_rngs = nnx.All(nnx.RngState, self.broadcast_rngs) - state_axes = StateAxes({broadcast_rngs: None, **self.state_axes}) # type: ignore + state_axes = iteration.StateAxes({broadcast_rngs: None, **self.state_axes}) # type: ignore[misc] # we use split_rngs with splits=1 and squeeze=True to get unique rngs # every time RNN is called @nnx.split_rngs(splits=1, only=self.broadcast_rngs, squeeze=True) @nnx.scan( - in_axes=(state_axes, Carry, time_axis), - out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis), + in_axes=(state_axes, iteration.Carry, time_axis), + out_axes=(iteration.Carry, (0, time_axis)) + if slice_carry + else (iteration.Carry, time_axis), unroll=self.unroll, ) def scan_fn( diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index add54563..4785149b 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -53,13 +53,13 @@ class Dropout(Module): >>> model.train() # use dropout >>> model(x) - Array([[-0.9353421, 0. , 1.434417 , 0. ]], dtype=float32) + Array([[ 0. , 0. , -1.592019 , -2.5238838]], dtype=float32) >>> model.eval() # don't use dropout >>> model(x) - Array([[-0.46767104, -0.7213411 , 0.7172085 , -0.31562346]], dtype=float32) + Array([[ 1.0533503, -1.2679932, -0.7960095, -1.2619419]], dtype=float32) - Attributes: + Args: rate: the dropout probability. (_not_ the keep rate!) broadcast_dims: dimensions that will share the same dropout mask deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and diff --git a/flax/nnx/object.py b/flax/nnx/object.py index b1f7478e..b8f35ba7 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -15,6 +15,7 @@ from __future__ import annotations import dataclasses +import inspect import threading import typing as tp from abc import ABCMeta @@ -24,15 +25,14 @@ import numpy as np import treescope # type: ignore[import-untyped] from treescope import rendering_parts -from flax.nnx import visualization -from flax import errors +from flax import errors, nnx from flax.nnx import ( graph, reprlib, tracers, + visualization, ) -from flax import nnx from flax.nnx.variablelib import Variable, VariableState from flax.typing import SizeBytes, value_stats @@ -157,6 +157,8 @@ def __init_subclass__(cls) -> None: init=cls._graph_node_init, # type: ignore ) + cls.__signature__ = inspect.signature(cls.__init__) + if not tp.TYPE_CHECKING: def __setattr__(self, name: str, value: Any) -> None: diff --git a/flax/nnx/training/metrics.py b/flax/nnx/training/metrics.py index 18044902..eca41bd4 100644 --- a/flax/nnx/training/metrics.py +++ b/flax/nnx/training/metrics.py @@ -224,7 +224,7 @@ class Accuracy(Average): >>> import jax, jax.numpy as jnp >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) - >>> labels = jnp.array([1, 1, 0, 1, 0]) + >>> labels = jnp.array([0, 1, 1, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) @@ -236,7 +236,7 @@ class Accuracy(Average): Array(0.6, dtype=float32) >>> metrics.update(logits=logits2, labels=labels2) >>> metrics.compute() - Array(0.7, dtype=float32) + Array(0.4, dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32) @@ -320,7 +320,7 @@ class MultiMetric(Metric): ) >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) - >>> labels = jnp.array([1, 1, 0, 1, 0]) + >>> labels = jnp.array([0, 1, 1, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) @@ -334,7 +334,7 @@ class MultiMetric(Metric): {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)} >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2) >>> metrics.compute() - {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)} + {'accuracy': Array(0.4, dtype=float32), 'loss': Array(2., dtype=float32)} >>> metrics.reset() >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 4b85d5a3..5c017dae 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -126,11 +126,11 @@ class Optimizer(Object): ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) - Array(1.7055722, dtype=float32) + Array(2.3359995, dtype=float32) >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads) >>> loss_fn(model) - Array(1.6925814, dtype=float32) + Array(2.310461, dtype=float32) Note that you can easily extend this class by subclassing it for storing additional data (e.g. adding metrics). @@ -151,15 +151,15 @@ class Optimizer(Object): >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() - Array(1.6925814, dtype=float32) + Array(2.310461, dtype=float32) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() - Array(1.68612, dtype=float32) + Array(2.2978127, dtype=float32) For more exotic usecases (e.g. multiple optimizers) it's probably best to fork the class and modify it. - Attributes: + Args: step: An ``OptState`` :class:`Variable` that tracks the step count. model: The wrapped :class:`Module`. tx: An Optax gradient transformation. diff --git a/flax/training/train_state.py b/flax/training/train_state.py index bbce765c..65d28d5e 100644 --- a/flax/training/train_state.py +++ b/flax/training/train_state.py @@ -48,12 +48,12 @@ class TrainState(struct.PyTreeNode): ... loss = optax.l2_loss(predictions=predictions, targets=y).mean() ... return loss >>> loss_fn(state.params, x, y) - Array(3.3514676, dtype=float32) + Array(1.8136346, dtype=float32) >>> grads = jax.grad(loss_fn)(state.params, x, y) >>> state = state.apply_gradients(grads=grads) >>> loss_fn(state.params, x, y) - Array(3.343844, dtype=float32) + Array(1.8079796, dtype=float32) Note that you can easily extend this dataclass by subclassing it for storing additional data (e.g. additional variable collections). diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 96539d89..2556146f 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -428,7 +428,7 @@ def __call__(self, x): model = Foo() x = random.normal(random.key(1), (2, 4)) (y1, y2), _ = model.init_with_output(key, x) - np.testing.assert_allclose(y1, y2, rtol=1e-4) + np.testing.assert_allclose(y1, y2, rtol=0.005) @parameterized.parameters( {'feature_axes': -1}, diff --git a/tests/nnx/metrics_test.py b/tests/nnx/metrics_test.py index 25833e57..951734f9 100644 --- a/tests/nnx/metrics_test.py +++ b/tests/nnx/metrics_test.py @@ -22,10 +22,14 @@ class TestMetrics(parameterized.TestCase): def test_split_merge(self): - logits = jax.random.normal(jax.random.key(0), (5, 2)) - labels = jnp.array([1, 1, 0, 1, 0]) - logits2 = jax.random.normal(jax.random.key(1), (5, 2)) - labels2 = jnp.array([0, 1, 1, 1, 1]) + logits = jnp.array( + [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] + ) + labels = jnp.array([1, 1, 1, 1, 1]) + logits2 = jnp.array( + [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] + ) + labels2 = jnp.array([1, 1, 1, 1, 0]) accuracy = nnx.metrics.Accuracy() accuracy.update(logits=logits, labels=labels) @@ -87,9 +91,13 @@ def test_welford_many(self): self.assertAlmostEqual(computed.standard_deviation, 1.0, places=2) def test_multimetric(self): - logits = jax.random.normal(jax.random.key(0), (5, 2)) + logits = jnp.array( + [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] + ) labels = jnp.array([1, 1, 0, 1, 0]) - logits2 = jax.random.normal(jax.random.key(1), (5, 2)) + logits2 = jnp.array( + [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] + ) labels2 = jnp.array([0, 1, 1, 1, 1]) batch_loss = jnp.array([1, 2, 3, 4]) batch_loss2 = jnp.array([3, 2, 1, 0]) @@ -108,7 +116,7 @@ def test_multimetric(self): metrics.update(logits=logits2, labels=labels2, values=batch_loss2) values = metrics.compute() - self.assertEqual(values['accuracy'], 0.7) + self.assertEqual(values['accuracy'], 0.5) self.assertEqual(values['loss'], 2.0) metrics.reset() diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index bfa461be..28abea77 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -2578,8 +2578,8 @@ def forward_block(module, x): def test_basic_demo_single(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): - self.linear = nnx.Linear(3, 3, rngs=rngs) - self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + self.linear = nnx.Linear(20, 20, rngs=rngs) + self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: return self.dropout(nnx.relu(self.linear(x))) @@ -2598,14 +2598,14 @@ def forward_block(module: Block, x): module = create_block(rngs) assert rngs.default.count.value == 1 - assert module.linear.kernel.value.shape == (1, 3, 3) - assert module.linear.bias.value.shape == (1, 3) + assert module.linear.kernel.value.shape == (1, 20, 20) + assert module.linear.bias.value.shape == (1, 20) - x = jnp.ones((1, 10, 3)) + x = jnp.ones((1, 10, 20)) y = forward_block(module, x) - assert y.shape == (1, 10, 3) + assert y.shape == (1, 10, 20) assert rngs.default.count.value == 2 y2 = forward_block(module, x) diff --git a/uv.lock b/uv.lock index 48bda4f7..2857313f 100644 --- a/uv.lock +++ b/uv.lock @@ -3,13 +3,13 @@ requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] [[package]] @@ -641,7 +641,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/99/bc/cfb52b9e8531526604afe8666185d207e4f0cb9c6d90bc76f62fb8746804/etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350", size = 95695 } wheels = [ @@ -676,10 +676,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/ba/49/d480aeb4fc441d933acce97261bea002234a45fb847599c9a93c31e51b2e/etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379", size = 101506 } wheels = [ @@ -1202,7 +1202,7 @@ name = "ipython" version = "8.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "decorator" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "jedi" }, @@ -1246,7 +1246,7 @@ wheels = [ [[package]] name = "jax" -version = "0.4.38" +version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, @@ -1255,14 +1255,14 @@ dependencies = [ { name = "opt-einsum" }, { name = "scipy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fb/e5/c4aa9644bb96b7f6747bd7c9f8cda7665ca5e194fa2542b2dea3ff730701/jax-0.4.38.tar.gz", hash = "sha256:43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8", size = 1930034 } +sdist = { url = "https://files.pythonhosted.org/packages/4a/cb/22d62b26284f08e62d6eb64603d3b010004cfdb7a97ce6cca5c6cf86edab/jax-0.5.0.tar.gz", hash = "sha256:49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8", size = 1959707 } wheels = [ - { url = "https://files.pythonhosted.org/packages/22/49/b4418a7a892c0dd64442bbbeef54e1cdfe722dfc5a7bf0d611d3f5f90e99/jax-0.4.38-py3-none-any.whl", hash = "sha256:78987306f7041ea8500d99df1a17c33ed92620c2268c4c3677fb24e06712be64", size = 2236864 }, + { url = "https://files.pythonhosted.org/packages/f4/58/cc0721a1030fcbab0984beea0bf3c4610ec103f738423cdfa9c4ceb40598/jax-0.5.0-py3-none-any.whl", hash = "sha256:b3907aa87ae2c340b39cdbf80c07a74550369cafcaf7398fb60ba58d167345ab", size = 2270365 }, ] [[package]] name = "jaxlib" -version = "0.4.38" +version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, @@ -1270,26 +1270,22 @@ dependencies = [ { name = "scipy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/d4/e6a0881a88b8f17491c2ee271fd77c348b0221d9e2ec92dad23a2c9e41bc/jaxlib-0.4.38-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:55c19b9d3f33a6fc59f644aa5a21fba02639ccdd776cb4a9b5526625f57839ff", size = 99663603 }, - { url = "https://files.pythonhosted.org/packages/b6/6d/11569ce873f04c82ec22e58d822f4187dccae1d400c0d6dd05ed314d5328/jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:30b2f52cb50d74734af2f477c2533a7a583e3bb7b2c8acdeb361ee77d940577a", size = 79475708 }, - { url = "https://files.pythonhosted.org/packages/72/61/1de2405d13089c83b1ad87ec0266479c9d00080659dae2474892ae356306/jaxlib-0.4.38-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ee19c163a8fdf0839d4c18b88a5fbfb4e731ba7c437416d3e5483e570bb764e4", size = 93219045 }, - { url = "https://files.pythonhosted.org/packages/9c/24/0829decf233c6af9efe7c53888ae8ac72395e0979869cd9cee487e35dac3/jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:61aeccb9a27c67fdb8450f6357240019cd4511cb9d62a44e4764756d384853ad", size = 101732107 }, - { url = "https://files.pythonhosted.org/packages/0d/04/120c4caac6151f7297fedf9dd776362aa2d417d3f87bda826050b4da45e8/jaxlib-0.4.38-cp310-cp310-win_amd64.whl", hash = "sha256:d6ab745a89d0fb737a36fe1d8b86659e3fffe6ee8303b20651b26193d5edc0ef", size = 64223924 }, - { url = "https://files.pythonhosted.org/packages/b0/6a/b9fba73eb5e758e40a514919e096a039d27dc0ab4776a6cc977f5153a55f/jaxlib-0.4.38-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:b67fdeabd6dfed08b7768f3bdffb521160085f8305669bd197beef61d08de08b", size = 99679916 }, - { url = "https://files.pythonhosted.org/packages/44/2a/3458130d44d44038fd6974e7c43948f68408f685063203b82229b9b72c1a/jaxlib-0.4.38-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb0eaae7369157afecbead50aaf29e73ffddfa77a2335d721bd9794f3c510e4", size = 79488377 }, - { url = "https://files.pythonhosted.org/packages/94/96/7d9a0b9f35af4727df44b68ade4c6f15163840727d1cb47251b1ea515e30/jaxlib-0.4.38-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:43db58c4c427627296366a56c10318e1f00f503690e17f94bb4344293e1995e0", size = 93241543 }, - { url = "https://files.pythonhosted.org/packages/a3/2d/68f85037e60c981b37b18b23ace458c677199dea4722ddce541b48ddfc63/jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2751ff7037d6a997d0be0e77cc4be381c5a9f9bb8b314edb755c13a6fd969f45", size = 101751923 }, - { url = "https://files.pythonhosted.org/packages/cc/24/a9c571c8a189f58e0b54b14d53fc7f5a0a06e4f1d7ab9edcf8d1d91d07e7/jaxlib-0.4.38-cp311-cp311-win_amd64.whl", hash = "sha256:35226968fc9de6873d1571670eac4117f5ed80e955f7a1775204d1044abe16c6", size = 64255189 }, - { url = "https://files.pythonhosted.org/packages/49/df/08b94c593c0867c7eaa334592807ba74495de4be90580f360db8b96221dc/jaxlib-0.4.38-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:3fefea985f0415816f3bbafd3f03a437050275ef9bac9a72c1314e1644ac57c1", size = 99737849 }, - { url = "https://files.pythonhosted.org/packages/ab/b1/c9d2a7ba9ebeabb7ac37082f4c466364f475dc7550a79358c0f0aa89fdf2/jaxlib-0.4.38-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f33bcafe32c97a562ecf6894d7c41674c80c0acdedfa5423d49af51147149874", size = 79509242 }, - { url = "https://files.pythonhosted.org/packages/53/25/dd670d8bdf3799ece76d12cfe6a6a250ea256057aa4b0fcace4753a99d2d/jaxlib-0.4.38-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:496f45b0e001a2341309cd0c74af0b670537dced79c168cb230cfcc773f0aa86", size = 93251503 }, - { url = "https://files.pythonhosted.org/packages/f9/cc/37fce5162f6b9070203fd76cc0f298d9b3bfdf01939a78935a6078d63621/jaxlib-0.4.38-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:dad6c0a96567c06d083c0469fec40f201210b099365bd698be31a6d2ec88fd59", size = 101792792 }, - { url = "https://files.pythonhosted.org/packages/6f/7a/8515950a60a4ea5b13cc98fc0a42e36553b2db5a6eedc00d3bd7836f77b5/jaxlib-0.4.38-cp312-cp312-win_amd64.whl", hash = "sha256:966cdec36cfa978f5b4582bcb4147fe511725b94c1a752dac3a5f52ce46b6fa3", size = 64288223 }, - { url = "https://files.pythonhosted.org/packages/91/03/aee503c7077c6dbbd568842303426c6ec1cef9bff330c418c9e71906cccd/jaxlib-0.4.38-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:41e55ae5818a882e5789e848f6f16687ac132bcfbb5a5fa114a5d18b78d05f2d", size = 99739026 }, - { url = "https://files.pythonhosted.org/packages/cb/bf/fbbf61da319611d88e11c691d5a2077039208ded05e1731dea940f824a59/jaxlib-0.4.38-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6fe326b8af366387dd47ccf312583b2b17fed12712c9b74a648b18a13cbdbabf", size = 79508735 }, - { url = "https://files.pythonhosted.org/packages/e4/0b/8cbff0b6d62a4694351c49baf53b7ed8deb8a6854d129408c38158e11676/jaxlib-0.4.38-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:248cca3771ebf24b070f49701364ceada33e6139445b06c782cca5ac5ad92bf4", size = 93251882 }, - { url = "https://files.pythonhosted.org/packages/15/57/7f0283273b69c417071bcd2f4c2ed076479ec5ffc22a647f13c21da8d071/jaxlib-0.4.38-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:2ce77ba8cda9259a4bca97afc1c722e4291a6c463a63f8d372c6edc85117d625", size = 101791137 }, - { url = "https://files.pythonhosted.org/packages/de/de/d6c4d234cd426b97459cb070af90792b48643967a0d28641379ee9e10fc9/jaxlib-0.4.38-cp313-cp313-win_amd64.whl", hash = "sha256:4103db0b3a38a5dc132741237453c24d8547290a22079ba1b577d6c88c95300a", size = 64288459 }, + { url = "https://files.pythonhosted.org/packages/c8/41/3e4ac64df72c4da126df3fd66a2214025a46b6263f7be266728e7b8e473e/jaxlib-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1b8a6c4345f137f387650de2dbc488c20251b7412b55dd648e1a4f13bcf507fb", size = 79248968 }, + { url = "https://files.pythonhosted.org/packages/1e/5f/2a16e61f1d54ae5f55fbf3cb3e22ef5bb01bf9d7d6474e0d34fedba19c4d/jaxlib-0.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5b2efe3dfebf18a84c451d3803ac884ee242021c1113b279c13f4bbc378c3dc0", size = 93181077 }, + { url = "https://files.pythonhosted.org/packages/08/c3/573e2f01b99f1247e8fbe1aa46b95a0faa68ef208f9a8e8ef775d607b3e6/jaxlib-0.5.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:74440b632107336400d4f97a16481d767f13ea914c53ba14e544c6fda54819b3", size = 101969119 }, + { url = "https://files.pythonhosted.org/packages/6e/38/512f61ea13da41ca47f2411d7c05af0cf74a37f225e16725ed0e6fb58893/jaxlib-0.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:53478a28eee6c2ef01759b05a9491702daef9268c3ed013d6f8e2e5f5cae0887", size = 63883394 }, + { url = "https://files.pythonhosted.org/packages/92/4b/8875870ff52ad3fbea876c905228f691f05c8dc8556b226cbfaf0fba7f62/jaxlib-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6cd762ed1623132499fa701c4203446102e0a9c82ca23194b87288f746d12a29", size = 79242870 }, + { url = "https://files.pythonhosted.org/packages/a0/0f/00cdfa411d7218e4696c10c5867f7d3c396219adbcaeb02e95108ca802de/jaxlib-0.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:63088dbfaa85bb56cd521a925a3472fd7328b18ec93c2d8ffa85af331095c995", size = 93181807 }, + { url = "https://files.pythonhosted.org/packages/58/8e/a5c29db03d5a93b0326e297b556d0e0a9805e9c9c1ae5f82f69557273faa/jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:09113ef1582ba34d7cbc440fedb318f4855b59b776711a8aba2473c9727d3025", size = 101969212 }, + { url = "https://files.pythonhosted.org/packages/70/86/ceae20e4f37fa07f1cc95551cc0f49170d0db46d2e82fdf511d26bffd801/jaxlib-0.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:78289fc3ddc1e4e9510de2536a6375df9fe1c50de0ac60826c286b7a5c5090fe", size = 63881994 }, + { url = "https://files.pythonhosted.org/packages/57/d6/d971b40cb156e0637aa3c1522a1e803b641142e9a8f3ade6a574711bb073/jaxlib-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73e335715760c56e635109d61426435a5d7f46f3363a115daea09427d5cd0efd", size = 79246087 }, + { url = "https://files.pythonhosted.org/packages/41/2e/ba9770330077c3e4082cd0353e6a61419f79bff3e2197f904ce70167b9ad/jaxlib-0.5.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:4b4b01afb0ddec96c08356bff2bb685ddbe97fdffe4ed6e2d834b30aba972f22", size = 93179593 }, + { url = "https://files.pythonhosted.org/packages/66/e9/211ba3e46ec22c722c4d61a739cfccf79b0618006d6f5fa53eb4eb93ed6d/jaxlib-0.5.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:f980c733e98c998a8da87c9a8cc61b6726d0be667a58bd664c1d717b4b4eae75", size = 101984785 }, + { url = "https://files.pythonhosted.org/packages/2d/cb/11bb92324afb6ba678f388e10b78d6b02196bc8887eb5aa0d85ce398edf9/jaxlib-0.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:5baedbeeb60fa493c7528783254f04c6e986a2826266b198ed37e9336af2ef8c", size = 63899871 }, + { url = "https://files.pythonhosted.org/packages/22/ac/e400473e6a2f405fd6e4dc40a713bb9a3868a3f76a8ffc5eb66f6e686002/jaxlib-0.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed18ea7161d03aa8fd4d1b55494882f21420efdfea68e5f298c4aebcf2ac3f34", size = 79245359 }, + { url = "https://files.pythonhosted.org/packages/44/2d/c210abf4a9b2ce2e0858fcd3567c8773a739114e37d751af6c228901af57/jaxlib-0.5.0-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7d9b17a7ea19355d45ecdb2ff0db5d707a86f0c5a862d94b89b4568d6c45311a", size = 93180025 }, + { url = "https://files.pythonhosted.org/packages/30/f8/316f7b4797c5eb50c6d70e461724a7cbe08b4505ca4da1bfd260c135895a/jaxlib-0.5.0-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:11eef01d37c0f1c5306265b76f207f1002d13480ded2e31fd63ec76912c93ca2", size = 101982281 }, + { url = "https://files.pythonhosted.org/packages/4c/f2/cfa012a0417c9b13b44c8e1d3ebf5fd04e8bb738b5c93e20c9fc97919880/jaxlib-0.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:61b4d26cd6a0c49ba0b1e4340c7d29198913ee2dc70b65ee90752717d22305bb", size = 63900219 }, ] [[package]] @@ -1431,7 +1427,7 @@ version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, { name = "traitlets" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } @@ -2095,7 +2091,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -2122,9 +2118,9 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, @@ -2135,7 +2131,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, @@ -2436,7 +2432,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 } wheels = [ @@ -2454,10 +2450,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/e8/ab/cb61a4b87b2e7e6c312dce33602bd5884797fd054e0e53205f1c27cf0f66/protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d", size = 380283 } wheels = [ @@ -2475,6 +2471,8 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, + { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, @@ -2606,7 +2604,7 @@ name = "pytest" version = "8.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "iniconfig" }, { name = "packaging" }, @@ -3195,7 +3193,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "docutils" }, { name = "imagesize" }, { name = "jinja2" }, @@ -3684,7 +3682,7 @@ name = "triton" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },