From 6cf5b7d3d7ca851f18e0b1f088dc362c9d56bbb6 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 20 Dec 2024 22:04:45 -0500 Subject: [PATCH] [nnx] fast jit --- benchmarks/nnx_simple_training.py | 57 +++++++++----- flax/nnx/extract.py | 14 ++-- flax/nnx/graph.py | 115 +++++++++++++++++++++++++---- flax/nnx/reprlib.py | 8 ++ flax/nnx/statelib.py | 91 +++++++++++++++++++++-- flax/nnx/transforms/compilation.py | 18 ++++- flax/nnx/variablelib.py | 1 + tests/nnx/graph_utils_test.py | 9 ++- uv.lock | 92 +++++++++++------------ 9 files changed, 309 insertions(+), 96 deletions(-) diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py index 0cb08066fe..6c040dee51 100644 --- a/benchmarks/nnx_simple_training.py +++ b/benchmarks/nnx_simple_training.py @@ -25,7 +25,9 @@ from absl import app FLAGS = flags.FLAGS -flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in') +flags.DEFINE_enum( + 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' +) flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') flags.DEFINE_integer('batch_size', 32, 'Batch size') flags.DEFINE_integer('width', 32, 'Hidden layer size') @@ -46,6 +48,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): def __call__(self, x): return x @ self.w + self.b +class Block(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.linear = Linear(din, dout, rngs=rngs) + self.bn = nnx.BatchNorm(dout, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.bn(self.linear(x))) class Count(nnx.Variable): pass @@ -54,11 +63,11 @@ class Count(nnx.Variable): class MLP(nnx.Module): def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) - self.linear_in = Linear(din, dhidden, rngs=rngs) + self.linear_in = Block(din, dhidden, rngs=rngs) self.intermediates = [ - Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) + Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) ] - self.linear_out = Linear(dhidden, dout, rngs=rngs) + self.linear_out = Block(dhidden, dout, rngs=rngs) def __call__(self, x): self.count.value += 1 @@ -79,18 +88,14 @@ def main(argv): print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') - if mode not in ['nnx', 'jax']: - raise ValueError(f'Invalid mode: {mode}') - X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) - model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) - tx = optax.sgd(1e-3) - optimizer = nnx.Optimizer(model, tx) - t0 = time() - - if mode == 'nnx': + if mode == 'nnx' or mode == 'all': + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() @nnx.jit def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch): @@ -115,11 +120,22 @@ def test_step_nnx(model: MLP, batch): if step % 1000 == 0: logs = test_step_nnx(model, (X, Y)) - print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break - else: + + print('### NNX ###') + print(f"final loss: {logs['loss']}") + total_time = time() - t0 + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + print('times called:', model.count.value) + + if mode == 'jax' or mode == 'all': + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() @jax.jit def train_step_jax(graphdef, state, batch): @@ -151,17 +167,18 @@ def test_step_jax(graphdef, state, batch): if step % 1000 == 0: state, logs = test_step_jax(graphdef, state, (X, Y)) - print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break model, optimizer = nnx.merge(graphdef, state) - total_time = time() - t0 - print('total time:', total_time) - print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') - print('times called:', model.count.value) + print('### JAX ###') + print(f"final loss: {logs['loss']}") + total_time = time() - t0 + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + print('times called:', model.count.value) if __name__ == '__main__': diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 191a0c195a..e5662e104a 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -254,7 +254,7 @@ class GraphDefState(struct.PyTreeNode): class NodeStates(struct.PyTreeNode): _graphdef: graph.GraphDef[tp.Any] | None - states: tuple[graph.GraphState, ...] + states: tuple[graph.GraphState | graph.GraphFlatState, ...] metadata: tp.Any = struct.field(pytree_node=False) @property @@ -264,7 +264,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]: return self._graphdef @property - def state(self) -> graph.GraphState: + def state(self) -> graph.GraphState | graph.GraphFlatState: if len(self.states) != 1: raise ValueError( f'Expected exactly one GraphDefState, got {len(self.states)}' @@ -275,15 +275,19 @@ def state(self) -> graph.GraphState: def from_split( cls, graphdef: graph.GraphDef[tp.Any], - state: graph.GraphState, + state: graph.GraphState | graph.GraphFlatState, /, - *states: graph.GraphState, + *states: graph.GraphState | graph.GraphFlatState, metadata: tp.Any = None, ): return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata) @classmethod - def from_states(cls, state: graph.GraphState, *states: graph.GraphState): + def from_states( + cls, + state: graph.GraphState | graph.GraphFlatState, + *states: graph.GraphState | graph.GraphFlatState, + ): return cls(_graphdef=None, states=(state, *states), metadata=None) @classmethod diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index a29999d34f..c18a710b30 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -30,7 +30,7 @@ CallableProxy, DelayedAccessor, ) -from flax.nnx.statelib import State +from flax.nnx.statelib import FlatState, State from flax.nnx import variablelib from flax.nnx.variablelib import Variable, VariableState from flax.typing import Key, PathParts, is_key_like @@ -53,6 +53,7 @@ StateLeaf = VariableState[tp.Any] NodeLeaf = Variable[tp.Any] GraphState = State[Key, StateLeaf] +GraphFlatState = FlatState[StateLeaf] def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: @@ -377,7 +378,9 @@ def _apply( module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) - return out, flatten(module) + graphdef, flat_state = flatten(module) + state_ = State.from_flat_path(flat_state) + return out, (graphdef, state_) return CallableProxy(_apply, accessor) # type: ignore @@ -389,7 +392,7 @@ def _apply( def flatten( node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None -) -> tuple[GraphDef[Node], GraphState]: +) -> tuple[GraphDef[Node], FlatState[tp.Any]]: """Flattens a graph node into a (graphdef, state) pair. Args: @@ -402,7 +405,7 @@ def flatten( ref_index = RefMap() flat_state: list[tuple[PathParts, StateLeaf]] = [] graphdef = _graph_flatten((), ref_index, flat_state, node) - return graphdef, GraphState.from_flat_path(flat_state) + return graphdef, FlatState(flat_state) def _graph_flatten( @@ -811,8 +814,11 @@ def split( ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) - graphdef, state = flatten(node, self.ref_index) - states = _split_state(state, filters) + graphdef, flat_state = flatten(node, self.ref_index) + flat_states = _split_state(flat_state, filters) + states = tuple( + State.from_flat_path(flat_state) for flat_state in flat_states + ) if ctx is not None: if ctx.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(ctx.index_ref, self.ref_index) @@ -822,6 +828,47 @@ def split( return graphdef, *states + @tp.overload + def flatten( + self, graph_node: A, / + ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... + @tp.overload + def flatten( + self, graph_node: A, first: filterlib.Filter, / + ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... + @tp.overload + def flatten( + self, + graph_node: A, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[ + GraphDef[A], + FlatState[VariableState[tp.Any]], + tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], + ]: ... + def flatten( + self, node: A, *filters: filterlib.Filter + ) -> tuple[ + GraphDef[A], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]] + ]: + ctx = ( + current_update_context(self.ctxtag) if self.ctxtag is not None else None + ) + graphdef, flat_state = flatten(node, self.ref_index) + flat_states = _split_state(flat_state, filters) + + if ctx is not None: + if ctx.index_ref is not None and isinstance(graphdef, NodeDef): + index_to_index = compose_mapping(ctx.index_ref, self.ref_index) + graphdef = dataclasses.replace( + graphdef, index_mapping=HashableMapping(index_to_index, copy=False) + ) + + return graphdef, *flat_states + @contextlib.contextmanager def split_context(ctxtag: str | None = None): @@ -874,6 +921,39 @@ def merge( ) return node + def unflatten( + self, + graphdef: GraphDef[A], + flat_state: GraphFlatState, + /, + *flat_states: GraphFlatState, + ) -> A: + ctx = ( + current_update_context(self.ctxtag) if self.ctxtag is not None else None + ) + if ( + ctx is not None + and isinstance(graphdef, NodeDef) + and graphdef.index_mapping is not None + ): + # outer merge (4), create index_ref_cache + assert ctx.ref_index is not None + index_ref_cache = compose_mapping_reversed( + ctx.ref_index, graphdef.index_mapping + ) + else: + # inner merge (2) + index_ref_cache = None + + state = FlatState.merge(flat_state, *flat_states).to_nested_state() + node = unflatten( + graphdef, + state, + index_ref=self.index_ref, + index_ref_cache=index_ref_cache, + ) + return node + @contextlib.contextmanager def merge_context(ctxtag: str | None = None): @@ -1001,9 +1081,11 @@ def split( filters are passed, a single :class:`State` is returned. """ ref_index: RefMap[tp.Any, Index] = RefMap() - graphdef, state = flatten(node, ref_index) - states = _split_state(state, filters) - + graphdef, flat_state = flatten(node, ref_index) + states = tuple( + State.from_flat_path(flat_state) + for flat_state in _split_state(flat_state, filters) + ) if self.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(self.index_ref, ref_index) graphdef = dataclasses.replace( @@ -1195,13 +1277,13 @@ def current_update_context(tag: str) -> UpdateContext: # -------------------------------------------------------- def _split_state( - state: GraphState, + state: FlatState[tp.Any], filters: tuple[filterlib.Filter, ...], -) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]: +) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]: if not filters: return (state,) states = state.split(*filters) - if isinstance(states, State): + if not isinstance(states, tuple): return (states,) assert len(states) > 0 return states # type: ignore[return-value] @@ -1292,9 +1374,11 @@ def split( ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no filters are passed, a single ``State`` is returned. """ - graphdef, state = flatten(node) - states = _split_state(state, filters) - return graphdef, *states + graphdef, flat_state = flatten(node) + flat_states = _split_state(flat_state, filters) + states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + return graphdef, *states # type: ignore[return-value] + def merge( graphdef: GraphDef[A], @@ -1486,6 +1570,7 @@ def state( One or more :class:`State` mappings. """ _, state = flatten(node) + state = state.to_nested_state() states: GraphState | tuple[GraphState, ...] if len(filters) == 0: diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py index 6ed7660cdf..9a36c38651 100644 --- a/flax/nnx/reprlib.py +++ b/flax/nnx/reprlib.py @@ -111,6 +111,14 @@ def __nnx_repr__(self): for key, value in self.items(): yield Attr(repr(key), value) +class SequenceReprMixin(tp.Sequence[A], Representable): + def __nnx_repr__(self): + yield Object(type='', value_sep='', start='[', end=']') + + for value in self: + yield Attr('', value) + + @dataclasses.dataclass(repr=False) class PrettyMapping(Representable): mapping: tp.Mapping diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 42a2604042..1c1b1b5127 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -54,7 +54,7 @@ def __treescope_repr__(self, path, subtree_renderer): # Render as the dictionary itself at the same path. return subtree_renderer(children, path=path) -class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.PrettySequence): +class FlatState(reprlib.SequenceReprMixin[tuple[PathParts, V]]): _keys: tuple[PathParts, ...] _values: list[V] @@ -83,6 +83,85 @@ def __len__(self) -> int: def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]: return iter(zip(self._keys, self._values)) + def to_nested_state(self) -> State[PathParts, V]: + return State.from_flat_path(self) + + @tp.overload + def split(self, first: filterlib.Filter, /) -> FlatState[V]: ... + + @tp.overload + def split( + self, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[FlatState[V], ...]: ... + + @tp.overload + def split( + self, /, *filters: filterlib.Filter + ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: ... + + def split( # type: ignore[misc] + self, first: filterlib.Filter, /, *filters: filterlib.Filter + ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: + filters = (first, *filters) + *flat_states_, rest = _split_state(self, *filters) + + if rest: + raise ValueError( + 'Non-exhaustive filters, got a non-empty remainder: ' + f'{rest}.\nUse `...` to match all remaining elements.' + ) + + flat_states: FlatState[V] | tuple[FlatState[V], ...] + if len(flat_states_) == 1: + flat_states = flat_states_[0] + else: + flat_states = tuple(flat_states_) + return flat_states # type: ignore + + @tp.overload + def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ... + + @tp.overload + def filter( + self, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[FlatState[V], ...]: ... + + def filter( + self, + first: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: + *flat_states_, _rest = _split_state(self, first, *filters) + + assert len(flat_states_) == len(filters) + 1 + + flat_states: FlatState[V] | tuple[FlatState[V], ...] + if len(flat_states_) == 1: + flat_states = flat_states_[0] + else: + flat_states = tuple(flat_states_) + + return flat_states # type: ignore + + @staticmethod + def merge( + flat_state: tp.Iterable[tuple[PathParts, V]], + /, + *flat_states: tp.Iterable[tuple[PathParts, V]], + ) -> FlatState[V]: + flat_states = (flat_state, *flat_states) + + return FlatState(elem for flat_state in flat_states for elem in flat_state) + def _flat_state_pytree_flatten(x: FlatState[V]): return x._values, x._keys @@ -291,7 +370,8 @@ def split( # type: ignore[misc] One or more ``States`` equal to the number of filters passed. """ filters = (first, *filters) - *states_, rest = _split_state(self.flat_state(), *filters) + flat_states = _split_state(self.flat_state(), *filters) + *states_, rest = (state.to_nested_state() for state in flat_states) if rest: raise ValueError( @@ -356,7 +436,8 @@ def filter( Returns: One or more ``States`` equal to the number of filters passed. """ - *states_, _rest = _split_state(self.flat_state(), first, *filters) + flat_states = _split_state(self.flat_state(), first, *filters) + *states_, _rest = (state.to_nested_state() for state in flat_states) assert len(states_) == len(filters) + 1 @@ -456,7 +537,7 @@ def _state_unflatten( def _split_state( flat_state: FlatState[V], *filters: filterlib.Filter, -) -> tuple[State[PathParts, V], ...]: +) -> tuple[FlatState[V], ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] @@ -482,7 +563,7 @@ def _split_state( # if we didn't break, set leaf to last state flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here? - return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + return tuple(FlatState(flat_state) for flat_state in flat_states) def create_path_filters(state: State): diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index e5ce20f8e3..d3420dd438 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -91,9 +91,15 @@ def __hash__(self): def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x): if isinstance(prefix, StateSharding): return extract.NodeStates.from_split( - *ctx.split(x, *prefix.filters), metadata=prefix + *ctx.flatten(x, *prefix.filters), metadata=prefix ) - return extract.NodeStates.from_split(*ctx.split(x)) + return extract.NodeStates.from_split(*ctx.flatten(x)) + + +def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any: + if not isinstance(leaf, extract.NodeStates): + raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') + return ctx.unflatten(leaf.graphdef, *leaf.states) # type: ignore @dataclasses.dataclass(eq=False) @@ -107,7 +113,9 @@ def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args, **pure_kwargs): - args, kwargs = extract.from_tree((pure_args, pure_kwargs), ctxtag='jit') + args, kwargs = extract.from_tree( + (pure_args, pure_kwargs), merge_fn=_jit_merge_fn, ctxtag='jit' + ) out = self.f(*args, **kwargs) @@ -346,7 +354,9 @@ def jit_wrapper(*args, **kwargs): *pure_args, **pure_kwargs ) _args_out, _kwargs_out, out = extract.from_tree( - (pure_args_out, pure_kwargs_out, pure_out), ctxtag='jit' + (pure_args_out, pure_kwargs_out, pure_out), + merge_fn=_jit_merge_fn, + ctxtag='jit', ) return out diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 4752a9b7bd..fb3a276e9a 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -808,6 +808,7 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_remove_axis' in self._var_metadata: self._var_metadata['on_remove_axis'](self, axis_index, axis_name) +GraphVariableState = VariableState[VariableState[tp.Any]] def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): metadata = tuple(x.get_metadata().items()) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index a7bbf178cb..85c4f2a4c8 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -64,7 +64,8 @@ def test_flatten(self): g = [a, 3, a, nnx.Param(4)] refmap = nnx.graph.RefMap() - graphdef, state = nnx.graph.flatten(g, ref_index=refmap) + graphdef, flat_state = nnx.graph.flatten(g, ref_index=refmap) + state = flat_state.to_nested_state() state[0]['b'].raw_value = 2 state[3].raw_value = 4 @@ -329,6 +330,7 @@ def f(m: Foo): ref_out_idx_out = nnx.graph.RefMap() graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) + state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): @@ -337,6 +339,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): f(m) ref_in_idx_in = nnx.graph.RefMap[Any, int]() graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) + state = state.to_nested_state() idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out @@ -369,6 +372,7 @@ def f(m: Foo): ref_out_idx_out = nnx.graph.RefMap[Any, int]() graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) + state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): @@ -377,6 +381,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): f(m) ref_in_idx_in = nnx.graph.RefMap[Any, int]() graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) + state = state.to_nested_state() idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out @@ -406,6 +411,7 @@ def f(m: Foo): ref_out_idx_out = nnx.graph.RefMap() graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) + state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): @@ -414,6 +420,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): f(m) ref_in_idx_in = nnx.graph.RefMap[Any, int]() graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) + state = state.to_nested_state() idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out diff --git a/uv.lock b/uv.lock index e08e2dbf53..fb61c0e0e7 100644 --- a/uv.lock +++ b/uv.lock @@ -3,13 +3,13 @@ requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] [[package]] @@ -641,7 +641,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/99/bc/cfb52b9e8531526604afe8666185d207e4f0cb9c6d90bc76f62fb8746804/etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350", size = 95695 } wheels = [ @@ -676,10 +676,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/ba/49/d480aeb4fc441d933acce97261bea002234a45fb847599c9a93c31e51b2e/etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379", size = 101506 } wheels = [ @@ -1202,7 +1202,7 @@ name = "ipython" version = "8.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "decorator" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "jedi" }, @@ -1246,7 +1246,7 @@ wheels = [ [[package]] name = "jax" -version = "0.4.37" +version = "0.4.38" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, @@ -1255,14 +1255,14 @@ dependencies = [ { name = "opt-einsum" }, { name = "scipy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/30/ad7617a960c86782587540a179cef676962322d1e5411415b1aa24f02ce0/jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b", size = 1915966 } +sdist = { url = "https://files.pythonhosted.org/packages/fb/e5/c4aa9644bb96b7f6747bd7c9f8cda7665ca5e194fa2542b2dea3ff730701/jax-0.4.38.tar.gz", hash = "sha256:43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8", size = 1930034 } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/3f/6c5553baaa7faa3fa8bae8279b1e46cb54c7ce52360139eae53498786ea5/jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e", size = 2221192 }, + { url = "https://files.pythonhosted.org/packages/22/49/b4418a7a892c0dd64442bbbeef54e1cdfe722dfc5a7bf0d611d3f5f90e99/jax-0.4.38-py3-none-any.whl", hash = "sha256:78987306f7041ea8500d99df1a17c33ed92620c2268c4c3677fb24e06712be64", size = 2236864 }, ] [[package]] name = "jaxlib" -version = "0.4.36" +version = "0.4.38" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, @@ -1270,26 +1270,26 @@ dependencies = [ { name = "scipy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/23/8d/8a44618f3493f29d769b2b40778d24075689cc8697b98e2c43bafbe50edf/jaxlib-0.4.36-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d69f991833b6dca794767049843462805936c89553b136a8ebb8485334204457", size = 98648230 }, - { url = "https://files.pythonhosted.org/packages/78/b8/207485eab566dcfbc29bb833714ac1ca47a1665ca605b1ff7d3d5dd2afbe/jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:807814c1ba3ec69cffaa93d3f90651c694a9b8a750b43832cc167ed590c821dd", size = 78553787 }, - { url = "https://files.pythonhosted.org/packages/26/42/3c2b0dc86a17aafd8f46ba0e4388f39f55706ee25f6c463c3dadea7a71e2/jaxlib-0.4.36-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1bc27d9ae09549d7652eafe1fdb10c21546cd2fd02bb24a49a7e6208b69163b0", size = 84008742 }, - { url = "https://files.pythonhosted.org/packages/b9/b2/29be712098342df10075fe085c0b39d783a579bd3325fb0d69c22712cf27/jaxlib-0.4.36-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3379f03a794d6a30b75765d2786f6e31052f364196fcd49aaae292a3c16f12ec", size = 100263041 }, - { url = "https://files.pythonhosted.org/packages/63/a9/93404a2f1d59647749d4d6dbab7bee9f5a7bfaeb9ade25b7e66c0ca0949a/jaxlib-0.4.36-cp310-cp310-win_amd64.whl", hash = "sha256:63e575ac8a515dee8171dd4a88c460d538bbcc9d959cabc9781e961763678f84", size = 63270658 }, - { url = "https://files.pythonhosted.org/packages/e4/7d/9394ff39af5c23bb98a241c33742a328df5a43c21d569855ea7e096aaf5e/jaxlib-0.4.36-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:213792db3b876206b45f6a9fbea15e4dd22a9e80be25b03136f20c94784fecfa", size = 98669744 }, - { url = "https://files.pythonhosted.org/packages/34/5a/9f3c9e5cec23e60f78bb3c3da108a5ef664601862dbc4e84fc4be3654f5d/jaxlib-0.4.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d7a89adf4c9d3cddd20482931dedc7a9e2669e904196a9599d9a605b3d9e552", size = 78574312 }, - { url = "https://files.pythonhosted.org/packages/ff/5c/bf78ed9b8d0f174a562f6496049a4872e14a3bb3a80de09c4292d04be5f0/jaxlib-0.4.36-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c395fe8cc5bd6558dd2fbce78e24172b6f27762e17628720ae03d693001283f3", size = 84038323 }, - { url = "https://files.pythonhosted.org/packages/67/af/6a9dd26e8a6bedd4c9fe702059767256b0d9ed18c29a180a4598d5795bb4/jaxlib-0.4.36-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc324c6b1c64fe68400934c653e4e622f12576120dcdb451c3b4ea4dcaba2ae9", size = 100285487 }, - { url = "https://files.pythonhosted.org/packages/b7/46/31c3a519a94e84c672ca264c4151998e3e3fd11c481d8fa5af5885b91a1e/jaxlib-0.4.36-cp311-cp311-win_amd64.whl", hash = "sha256:c9e0c45a79e63aea65447f82bd0fa21c17b9afe884aa18dd5362b9965abe9d72", size = 63308064 }, - { url = "https://files.pythonhosted.org/packages/e3/0e/3b4a99c09431ee5820624d4dcf4efa7becd3c83b56ff0f09a078f4c421a2/jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d", size = 98718357 }, - { url = "https://files.pythonhosted.org/packages/d3/46/05e70a1236ec3782333b3e9469f971c9d45af2aa0aebf602acd9d76292eb/jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e", size = 78596060 }, - { url = "https://files.pythonhosted.org/packages/8e/76/6b969cbf197b8c53c84c2642069722e84a3a260af084a8acbbf90ca444ea/jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442", size = 84053202 }, - { url = "https://files.pythonhosted.org/packages/fe/f2/7624a304426daa7b135b85caf1b8eccf879e7cb10bc074656ce628309cb0/jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357", size = 100325610 }, - { url = "https://files.pythonhosted.org/packages/bb/8b/ded8420cd9198eb677869ffd557d9880af5833c7bf39e604e80b56550e09/jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3", size = 63338518 }, - { url = "https://files.pythonhosted.org/packages/5d/22/b72811c61e8b594951d3ee03245cb0932c723ac35e75569005c3c976eec2/jaxlib-0.4.36-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:02df9c0e1323dde01e966c22eb12432905d2d4de8aac7b603cad2083101b0e6b", size = 98719384 }, - { url = "https://files.pythonhosted.org/packages/f1/66/3f4a97097983914899100db9e5312493fe1d6adc924e47a0e47e15c553f5/jaxlib-0.4.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ec980e85983f41999c4dc84137dec70507d958e23d7eefa104da93053d135f", size = 78596150 }, - { url = "https://files.pythonhosted.org/packages/3a/6f/cf02f56d1532962d8ca77a6548acab8204294b96b5a153ca4a2caf4971fc/jaxlib-0.4.36-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7ce9368515348d869d6c59d9904c3cb3c81f22ff3e9e969eae0e3563fe472080", size = 84055851 }, - { url = "https://files.pythonhosted.org/packages/28/10/4fc4e9719c065c6455491730011e87fe4b5120a9a008161cc32663feb9ce/jaxlib-0.4.36-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:93f1c502d08e517f842fe7b18428bb086cfd077db0ea9a2418fb21e5b4e06d3d", size = 100325986 }, - { url = "https://files.pythonhosted.org/packages/ba/28/fece5385e736ef2f1b5bed133f8001f0fc66dd0104707381343e047b341a/jaxlib-0.4.36-cp313-cp313-win_amd64.whl", hash = "sha256:bddf436a243e83ec6bc16bcbb74d15b1960a69318c9ea796fb2109492bc52575", size = 63338694 }, + { url = "https://files.pythonhosted.org/packages/ee/d4/e6a0881a88b8f17491c2ee271fd77c348b0221d9e2ec92dad23a2c9e41bc/jaxlib-0.4.38-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:55c19b9d3f33a6fc59f644aa5a21fba02639ccdd776cb4a9b5526625f57839ff", size = 99663603 }, + { url = "https://files.pythonhosted.org/packages/b6/6d/11569ce873f04c82ec22e58d822f4187dccae1d400c0d6dd05ed314d5328/jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:30b2f52cb50d74734af2f477c2533a7a583e3bb7b2c8acdeb361ee77d940577a", size = 79475708 }, + { url = "https://files.pythonhosted.org/packages/72/61/1de2405d13089c83b1ad87ec0266479c9d00080659dae2474892ae356306/jaxlib-0.4.38-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ee19c163a8fdf0839d4c18b88a5fbfb4e731ba7c437416d3e5483e570bb764e4", size = 93219045 }, + { url = "https://files.pythonhosted.org/packages/9c/24/0829decf233c6af9efe7c53888ae8ac72395e0979869cd9cee487e35dac3/jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:61aeccb9a27c67fdb8450f6357240019cd4511cb9d62a44e4764756d384853ad", size = 101732107 }, + { url = "https://files.pythonhosted.org/packages/0d/04/120c4caac6151f7297fedf9dd776362aa2d417d3f87bda826050b4da45e8/jaxlib-0.4.38-cp310-cp310-win_amd64.whl", hash = "sha256:d6ab745a89d0fb737a36fe1d8b86659e3fffe6ee8303b20651b26193d5edc0ef", size = 64223924 }, + { url = "https://files.pythonhosted.org/packages/b0/6a/b9fba73eb5e758e40a514919e096a039d27dc0ab4776a6cc977f5153a55f/jaxlib-0.4.38-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:b67fdeabd6dfed08b7768f3bdffb521160085f8305669bd197beef61d08de08b", size = 99679916 }, + { url = "https://files.pythonhosted.org/packages/44/2a/3458130d44d44038fd6974e7c43948f68408f685063203b82229b9b72c1a/jaxlib-0.4.38-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb0eaae7369157afecbead50aaf29e73ffddfa77a2335d721bd9794f3c510e4", size = 79488377 }, + { url = "https://files.pythonhosted.org/packages/94/96/7d9a0b9f35af4727df44b68ade4c6f15163840727d1cb47251b1ea515e30/jaxlib-0.4.38-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:43db58c4c427627296366a56c10318e1f00f503690e17f94bb4344293e1995e0", size = 93241543 }, + { url = "https://files.pythonhosted.org/packages/a3/2d/68f85037e60c981b37b18b23ace458c677199dea4722ddce541b48ddfc63/jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2751ff7037d6a997d0be0e77cc4be381c5a9f9bb8b314edb755c13a6fd969f45", size = 101751923 }, + { url = "https://files.pythonhosted.org/packages/cc/24/a9c571c8a189f58e0b54b14d53fc7f5a0a06e4f1d7ab9edcf8d1d91d07e7/jaxlib-0.4.38-cp311-cp311-win_amd64.whl", hash = "sha256:35226968fc9de6873d1571670eac4117f5ed80e955f7a1775204d1044abe16c6", size = 64255189 }, + { url = "https://files.pythonhosted.org/packages/49/df/08b94c593c0867c7eaa334592807ba74495de4be90580f360db8b96221dc/jaxlib-0.4.38-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:3fefea985f0415816f3bbafd3f03a437050275ef9bac9a72c1314e1644ac57c1", size = 99737849 }, + { url = "https://files.pythonhosted.org/packages/ab/b1/c9d2a7ba9ebeabb7ac37082f4c466364f475dc7550a79358c0f0aa89fdf2/jaxlib-0.4.38-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f33bcafe32c97a562ecf6894d7c41674c80c0acdedfa5423d49af51147149874", size = 79509242 }, + { url = "https://files.pythonhosted.org/packages/53/25/dd670d8bdf3799ece76d12cfe6a6a250ea256057aa4b0fcace4753a99d2d/jaxlib-0.4.38-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:496f45b0e001a2341309cd0c74af0b670537dced79c168cb230cfcc773f0aa86", size = 93251503 }, + { url = "https://files.pythonhosted.org/packages/f9/cc/37fce5162f6b9070203fd76cc0f298d9b3bfdf01939a78935a6078d63621/jaxlib-0.4.38-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:dad6c0a96567c06d083c0469fec40f201210b099365bd698be31a6d2ec88fd59", size = 101792792 }, + { url = "https://files.pythonhosted.org/packages/6f/7a/8515950a60a4ea5b13cc98fc0a42e36553b2db5a6eedc00d3bd7836f77b5/jaxlib-0.4.38-cp312-cp312-win_amd64.whl", hash = "sha256:966cdec36cfa978f5b4582bcb4147fe511725b94c1a752dac3a5f52ce46b6fa3", size = 64288223 }, + { url = "https://files.pythonhosted.org/packages/91/03/aee503c7077c6dbbd568842303426c6ec1cef9bff330c418c9e71906cccd/jaxlib-0.4.38-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:41e55ae5818a882e5789e848f6f16687ac132bcfbb5a5fa114a5d18b78d05f2d", size = 99739026 }, + { url = "https://files.pythonhosted.org/packages/cb/bf/fbbf61da319611d88e11c691d5a2077039208ded05e1731dea940f824a59/jaxlib-0.4.38-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6fe326b8af366387dd47ccf312583b2b17fed12712c9b74a648b18a13cbdbabf", size = 79508735 }, + { url = "https://files.pythonhosted.org/packages/e4/0b/8cbff0b6d62a4694351c49baf53b7ed8deb8a6854d129408c38158e11676/jaxlib-0.4.38-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:248cca3771ebf24b070f49701364ceada33e6139445b06c782cca5ac5ad92bf4", size = 93251882 }, + { url = "https://files.pythonhosted.org/packages/15/57/7f0283273b69c417071bcd2f4c2ed076479ec5ffc22a647f13c21da8d071/jaxlib-0.4.38-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:2ce77ba8cda9259a4bca97afc1c722e4291a6c463a63f8d372c6edc85117d625", size = 101791137 }, + { url = "https://files.pythonhosted.org/packages/de/de/d6c4d234cd426b97459cb070af90792b48643967a0d28641379ee9e10fc9/jaxlib-0.4.38-cp313-cp313-win_amd64.whl", hash = "sha256:4103db0b3a38a5dc132741237453c24d8547290a22079ba1b577d6c88c95300a", size = 64288459 }, ] [[package]] @@ -1431,7 +1431,7 @@ version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, + { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "traitlets" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } @@ -2095,7 +2095,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -2122,9 +2122,9 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, @@ -2135,7 +2135,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, @@ -2262,7 +2262,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.10.2" +version = "0.10.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2280,9 +2280,9 @@ dependencies = [ { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/06/c42e2f1563dbaaf5ed1464d7b634324fb9a2da04021073c45777e61af78d/orbax_checkpoint-0.10.2.tar.gz", hash = "sha256:e575ebe1f94e5cb6353ab8c9df81de0ca7cddc118645c3bfc17b8344f19d42f1", size = 248170 } +sdist = { url = "https://files.pythonhosted.org/packages/87/fd/36b22046aecf155e50494fd7901ecd3e97e0db3ac103d3a0ffd0cafd2d9e/orbax_checkpoint-0.10.3.tar.gz", hash = "sha256:71e3ea47e38d571f27146ee55c8727d7e7c242cf3df31dc499f9b2cb1d67ac8a", size = 252556 } wheels = [ - { url = "https://files.pythonhosted.org/packages/61/19/ed366f8894923f3c8db0370e4bdd57ef843d68011dafa00d8175f4a66e1a/orbax_checkpoint-0.10.2-py3-none-any.whl", hash = "sha256:dcfc425674bd8d4934986143bd22a37cd634d034652c5d30d83c539ef8587941", size = 354306 }, + { url = "https://files.pythonhosted.org/packages/6d/45/12a80b3704ec7d46fb0f79d193f4a089aa4a8297a61e6db183d97d108a4b/orbax_checkpoint-0.10.3-py3-none-any.whl", hash = "sha256:df7fd5f327dfe9c477533f33c20076ae11ba6a15767c5117881b328dece14c7d", size = 359825 }, ] [[package]] @@ -2436,7 +2436,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 } wheels = [ @@ -2454,10 +2454,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/e8/ab/cb61a4b87b2e7e6c312dce33602bd5884797fd054e0e53205f1c27cf0f66/protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d", size = 380283 } wheels = [ @@ -2606,7 +2606,7 @@ name = "pytest" version = "8.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "iniconfig" }, { name = "packaging" }, @@ -3195,7 +3195,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "docutils" }, { name = "imagesize" }, { name = "jinja2" }, @@ -3684,7 +3684,7 @@ name = "triton" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },