diff --git a/docs_nnx/guides/gemma.ipynb b/docs_nnx/guides/gemma.ipynb
index 1ac5dd12f4..1c59c951df 100644
--- a/docs_nnx/guides/gemma.ipynb
+++ b/docs_nnx/guides/gemma.ipynb
@@ -4,22 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Copyright 2024 The Flax Authors.\n",
- "\n",
- "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
- "\n",
- "http://www.apache.org/licenses/LICENSE-2.0\n",
- "\n",
- "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n",
- "\n",
- "---"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide\n",
+ "# Example: Using Pretrained Gemma\n",
"\n",
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it."
]
@@ -33,7 +18,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -58,19 +43,22 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
- "ename": "ModuleNotFoundError",
- "evalue": "No module named 'kagglehub'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mkagglehub\u001b[39;00m\n\u001b[1;32m 2\u001b[0m kagglehub\u001b[38;5;241m.\u001b[39mlogin()\n",
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'kagglehub'"
- ]
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2e7cf9f0345845f1a3edc72fa4411eb4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(HTML(value='
(()=>{ if (customElements.get('treescope-container') === undefined) { class TreescopeContainer extends HTMLElement { constructor() { super(); this.attachShadow({mode: \"open\"}); this.defns = {}; this.state = {}; } } customElements.define(\"treescope-container\", TreescopeContainer); } if (customElements.get('treescope-run-here') === undefined) { class RunHere extends HTMLElement { constructor() { super() } connectedCallback() { const run = child => { const fn = new Function(child.textContent); child.textContent = \"\"; fn.call(this); this.remove(); }; const child = this.querySelector(\"script\"); if (child) { run(child); } else { new MutationObserver(()=>{ run(this.querySelector(\"script\")); }).observe(this, {childList: true}); } } } customElements.define(\"treescope-run-here\", RunHere); } })();
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"transformer = transformer_lib.Transformer.from_params(params)\n",
"nnx.display(transformer)"
@@ -217,7 +255,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {
"cellView": "form"
},
@@ -227,7 +265,6 @@
"sampler = sampler_lib.Sampler(\n",
" transformer=transformer,\n",
" vocab=vocab,\n",
- " params=params['transformer'],\n",
")"
]
},
@@ -240,16 +277,73 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"metadata": {
"cellView": "form"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Prompt:\n",
+ "\n",
+ "# Python program for implementation of Bubble Sort\n",
+ "\n",
+ "def bubbleSort(arr):\n",
+ "Output:\n",
+ "\n",
+ " for i in range(len(arr)):\n",
+ " for j in range(len(arr) - i - 1):\n",
+ " if arr[j] > arr[j + 1]:\n",
+ " swap(arr, j, j + 1)\n",
+ "\n",
+ "\n",
+ "def swap(arr, i, j):\n",
+ " temp = arr[i]\n",
+ " arr[i] = arr[j]\n",
+ " arr[j] = temp\n",
+ "\n",
+ "\n",
+ "# Driver code\n",
+ "arr = [5, 2, 8, 3, 1, 9]\n",
+ "print(\"Unsorted array:\")\n",
+ "print(arr)\n",
+ "bubbleSort(arr)\n",
+ "print(\"Sorted array:\")\n",
+ "print(arr)\n",
+ "\n",
+ "\n",
+ "# Time complexity of Bubble sort O(n^2)\n",
+ "# where n is the length of the array\n",
+ "\n",
+ "\n",
+ "# Space complexity of Bubble sort O(1)\n",
+ "# as it only requires constant extra space for the swap operation\n",
+ "\n",
+ "\n",
+ "# This program uses the bubble sort algorithm to sort the given array in ascending order.\n",
+ "\n",
+ "```python\n",
+ "# This program uses the bubble sort algorithm to sort the given array in ascending order.\n",
+ "\n",
+ "def bubbleSort(arr):\n",
+ " for i in range(len(arr)):\n",
+ " for j in range(len(arr) - i - 1):\n",
+ " if arr[j] > arr[j + 1]:\n",
+ " swap(arr, j, j + 1)\n",
+ "\n",
+ "\n",
+ "def swap(\n",
+ "\n",
+ "##########\n"
+ ]
+ }
+ ],
"source": [
"input_batch = [\n",
- " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
- " \"What are the planets of the solar system?\",\n",
- " ]\n",
+ " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
+ "]\n",
"\n",
"out_data = sampler(\n",
" input_strings=input_batch,\n",
@@ -266,7 +360,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "You should get an implementation of bubble sort and a description of the solar system."
+ "You should get an implementation of bubble sort."
]
}
],
diff --git a/docs_nnx/guides/gemma.md b/docs_nnx/guides/gemma.md
index 1fa07ea3b2..e479201af0 100644
--- a/docs_nnx/guides/gemma.md
+++ b/docs_nnx/guides/gemma.md
@@ -8,19 +8,7 @@ jupytext:
jupytext_version: 1.13.8
---
-Copyright 2024 The Flax Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
-
-http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
-
----
-
-+++
-
-# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide
+# Example: Using Pretrained Gemma
You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it.
@@ -57,10 +45,14 @@ Kaggle credentials successfully validated.
Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models.
```{code-cell} ipython3
+from IPython.display import clear_output
+
VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'
+
+clear_output()
```
## Python imports
@@ -72,18 +64,19 @@ import sentencepiece as spm
Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example.
-```{code-cell} ipython3
-! git clone https://github.com/google/flax.git flax_examples
-```
-
```{code-cell} ipython3
import sys
-
-sys.path.append("./flax_examples/flax/nnx/examples/gemma")
-import params as params_lib
-import sampler as sampler_lib
-import transformer as transformer_lib
-sys.path.pop();
+import tempfile
+
+with tempfile.TemporaryDirectory() as tmp:
+ # Here we create a temporary directory and clone the flax repo
+ # Then we append the examples/gemma folder to the path to load the gemma modules
+ ! git clone https://github.com/google/flax.git {tmp}/flax
+ sys.path.append(f"{tmp}/flax/examples/gemma")
+ import params as params_lib
+ import sampler as sampler_lib
+ import transformer as transformer_lib
+ sys.path.pop();
```
## Start Generating with Your Model
@@ -122,7 +115,6 @@ Finally, build a sampler on top of your model and your tokenizer.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
- params=params['transformer'],
)
```
@@ -132,9 +124,8 @@ You're ready to start sampling ! This sampler uses just-in-time compilation, so
:cellView: form
input_batch = [
- "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
- "What are the planets of the solar system?",
- ]
+ "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
+]
out_data = sampler(
input_strings=input_batch,
@@ -147,4 +138,4 @@ for input_string, out_string in zip(input_batch, out_data.text):
print(10*'#')
```
-You should get an implementation of bubble sort and a description of the solar system.
+You should get an implementation of bubble sort.
diff --git a/pyproject.toml b/pyproject.toml
index 8381caaabc..baab1da052 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -76,11 +76,9 @@ docs = [
"sphinx-design",
"jupytext==1.13.8",
"dm-haiku",
-
# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
"docutils==0.16",
-
# The next packages are for notebooks.
"matplotlib",
"scikit-learn",
@@ -88,6 +86,8 @@ docs = [
"ml_collections",
# notebooks
"einops",
+ "kagglehub>=0.3.3",
+ "ipywidgets>=8.1.5",
]
dev = [
"pre-commit>=3.8.0",
diff --git a/uv.lock b/uv.lock
index bb7c89e2db..6c9e67edfa 100644
--- a/uv.lock
+++ b/uv.lock
@@ -504,7 +504,7 @@ wheels = [
[package.optional-dependencies]
toml = [
- { name = "tomli", marker = "python_full_version <= '3.11'" },
+ { name = "tomli", marker = "python_full_version == '3.11'" },
]
[[package]]
@@ -794,7 +794,9 @@ docs = [
{ name = "einops" },
{ name = "ipykernel" },
{ name = "ipython-genutils" },
+ { name = "ipywidgets" },
{ name = "jupytext" },
+ { name = "kagglehub" },
{ name = "matplotlib" },
{ name = "ml-collections" },
{ name = "myst-nb" },
@@ -842,11 +844,13 @@ requires-dist = [
{ name = "gymnasium", extras = ["accept-rom-license", "atari"], marker = "extra == 'testing'" },
{ name = "ipykernel", marker = "extra == 'docs'" },
{ name = "ipython-genutils", marker = "extra == 'docs'" },
+ { name = "ipywidgets", marker = "extra == 'docs'", specifier = ">=8.1.5" },
{ name = "jax", specifier = ">=0.4.27" },
{ name = "jaxlib", marker = "extra == 'testing'" },
{ name = "jaxtyping", marker = "extra == 'testing'" },
{ name = "jraph", marker = "extra == 'testing'", specifier = ">=0.0.6.dev0" },
{ name = "jupytext", marker = "extra == 'docs'", specifier = "==1.13.8" },
+ { name = "kagglehub", marker = "extra == 'docs'", specifier = ">=0.3.3" },
{ name = "matplotlib", marker = "extra == 'all'" },
{ name = "matplotlib", marker = "extra == 'docs'" },
{ name = "ml-collections", marker = "extra == 'docs'" },
@@ -1217,9 +1221,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", size = 26343 },
]
+[[package]]
+name = "ipywidgets"
+version = "8.1.5"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "comm" },
+ { name = "ipython" },
+ { name = "jupyterlab-widgets" },
+ { name = "traitlets" },
+ { name = "widgetsnbextension" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/c7/4c/dab2a281b07596a5fc220d49827fe6c794c66f1493d7a74f1df0640f2cc5/ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17", size = 116723 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/22/2d/9c0b76f2f9cc0ebede1b9371b6f317243028ed60b90705863d493bae622e/ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245", size = 139767 },
+]
+
[[package]]
name = "jax"
-version = "0.4.34"
+version = "0.4.35"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jaxlib" },
@@ -1228,14 +1248,14 @@ dependencies = [
{ name = "opt-einsum" },
{ name = "scipy" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/19/6a/cacfcdf77841a4562e555ef35e0dbc5f8ca79c9f1010aaa4cf3973e79c69/jax-0.4.34.tar.gz", hash = "sha256:44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db", size = 1848472 }
+sdist = { url = "https://files.pythonhosted.org/packages/e3/34/21da583b9596e72bb8e95b6197dee0a44b96b9ea2c147fccabd43ca5515b/jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e", size = 1861189 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/06/f3/c499d358dd7f267a63d7d38ef54aadad82e28d2c28bafff15360c3091946/jax-0.4.34-py3-none-any.whl", hash = "sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e", size = 2144294 },
+ { url = "https://files.pythonhosted.org/packages/62/20/6c57c50c0ccc645fea1895950f1e5cd02f961ee44b3ffe83617fa46b0c1d/jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325", size = 2158621 },
]
[[package]]
name = "jaxlib"
-version = "0.4.34"
+version = "0.4.35"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ml-dtypes" },
@@ -1243,26 +1263,26 @@ dependencies = [
{ name = "scipy" },
]
wheels = [
- { url = "https://files.pythonhosted.org/packages/24/31/2e254fe2fc23201775a7d0ccd1bcde892cfa349eb805744b81b15e0dcf74/jaxlib-0.4.34-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:b7a212a3cb5c6acc201c32ae4f4b5f5a9ac09457fbb77ba8db5ce7e7d4adc214", size = 87399257 },
- { url = "https://files.pythonhosted.org/packages/1e/67/6a344c357caad33e84b871925cd043b4218fc13a427266d1a1dedcb1c095/jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:45d719a2ce0ebf21255a277b71d756f3609b7b5be70cddc5d88fd58c35219de0", size = 67617952 },
- { url = "https://files.pythonhosted.org/packages/dd/ea/12c836126419ca80248228f2236831617eedb1e3640c34c942606f33bb08/jaxlib-0.4.34-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:3e60bc826933082e99b19b87c21818a8d26fcdb01f418d47cedff554746fd6cc", size = 69391770 },
- { url = "https://files.pythonhosted.org/packages/e4/b0/a5bd34643c070e50829beec217189eab1acdfea334df1f9ddb4e5f8bec0f/jaxlib-0.4.34-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232", size = 86094116 },
- { url = "https://files.pythonhosted.org/packages/d8/c9/35a4233fe74ddd5aabe89aac1b3992b0e463982564252d21fd263d4d9992/jaxlib-0.4.34-cp310-cp310-win_amd64.whl", hash = "sha256:b0001c8f0e2b1c7bc99e4f314b524a340d25653505c1a1484d4041a9d3617f6f", size = 55206389 },
- { url = "https://files.pythonhosted.org/packages/bf/14/00a3385532d72ab51bd8e9f8c3e19a2e257667955565e9fc10236771dd06/jaxlib-0.4.34-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:8ee3f93836e53c86556ccd9449a4ea43516ee05184d031a71dd692e81259f7d9", size = 87420889 },
- { url = "https://files.pythonhosted.org/packages/66/78/d1535ee73fe505dc6c8831c19c4846afdce7df5acefb9f8ee885aa73d700/jaxlib-0.4.34-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8", size = 67635880 },
- { url = "https://files.pythonhosted.org/packages/aa/06/3e09e794acf308e170905d732eca0d041449503c47505cc22e8ef78a989d/jaxlib-0.4.34-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:571ef03259835458111596a71a2f4a6fabf4ec34595df4cea555035362ac5bf0", size = 69421901 },
- { url = "https://files.pythonhosted.org/packages/c7/d0/6bc81c0b1d507f403e6085ce76a429e6d7f94749d742199252e299dd1424/jaxlib-0.4.34-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:3bcfa639ca3cfaf86c8ceebd5fc0d47300fd98a078014a1d0cc03133e1523d5f", size = 86114491 },
- { url = "https://files.pythonhosted.org/packages/9d/5d/7e71019af5f6fdebe6c10eab97d01f44b931d94609330da9e142cb155f8c/jaxlib-0.4.34-cp311-cp311-win_amd64.whl", hash = "sha256:133070d4fec5525ffea4dc72956398c1cf647a04dcb37f8a935ee82af78d9965", size = 55241262 },
- { url = "https://files.pythonhosted.org/packages/bc/42/5038983664494dfb50f8669a662d965d7ea62f9250e40d8cd36dcf9ac3dd/jaxlib-0.4.34-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10", size = 87473956 },
- { url = "https://files.pythonhosted.org/packages/87/2e/8a75d3107c019c370c50c01acc205da33f9d6fba830950401a772a8e9f6d/jaxlib-0.4.34-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:096f0ca309d41fa692a9d1f2f9baab1c5c8ca0749876ebb3f748e738a27c7ff4", size = 67650276 },
- { url = "https://files.pythonhosted.org/packages/af/09/cceae2d251a506b4297679d10ee9f5e905a6b992b0687d553c9470ffd1db/jaxlib-0.4.34-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:1a30771d85fa77f9ab8f18e63240f455ab3a3f87660ed7b8d5eea6ceecbe5c1e", size = 69431284 },
- { url = "https://files.pythonhosted.org/packages/e7/0d/4faf839e3c8ce2a5b615df64427be3e870899c72c0ebfb5859348150aba1/jaxlib-0.4.34-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:48272e9034ff868d4328cf0055a07882fd2be93f59dfb6283af7de491f9d1290", size = 86151183 },
- { url = "https://files.pythonhosted.org/packages/a4/bc/a38f99071fca6cc31ae949e508a23b0de5de559da594443bb625a1adb8f3/jaxlib-0.4.34-cp312-cp312-win_amd64.whl", hash = "sha256:901cb4040ed24eae40071d8114ea8d10dff436277fa74a1a5b9e7206f641151c", size = 55278745 },
- { url = "https://files.pythonhosted.org/packages/21/4e/fab0606683af7aa9284a32d2b188ff132cffb0ee3ea04d941a547eb776d1/jaxlib-0.4.34-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:72e22e99a5dc890a64443c3fc12f13f20091f578c405a76de077ba42b4c62cd7", size = 87474367 },
- { url = "https://files.pythonhosted.org/packages/3e/1b/709be16d543a3db5b471ee5e7d089c57484c386b08499923e43bd8da5d0b/jaxlib-0.4.34-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c303f5acaf6c56ce5ff133a923c9b6247bdebedde15bd2c893c24be4d8f71306", size = 67651281 },
- { url = "https://files.pythonhosted.org/packages/85/9e/f3801096cd4a2c764af7a1f6b683c769706602ea72b27ec35bacfcc4cd4f/jaxlib-0.4.34-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7be673a876ebd1aef440fb7e3ebaf99a91abeb550c9728c644b7d7c7b5d7c108", size = 69432987 },
- { url = "https://files.pythonhosted.org/packages/e6/79/61301f55b24c3a898ef9bc4e13600b66e3f838623fc6f87648ac1ccbca01/jaxlib-0.4.34-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:87f25a477cd279840e53718403f97092eba0e8a945fcab47bcf435b6f9119dda", size = 86152550 },
- { url = "https://files.pythonhosted.org/packages/16/b0/e682d02126e0062b58dec0f0851048592396f74c24b4a4412dce4ddbbadb/jaxlib-0.4.34-cp313-cp313-win_amd64.whl", hash = "sha256:6b43a974c5d91a19912d138f2658dd8dbb7d30dcdff5c961d896c673e872b611", size = 55279410 },
+ { url = "https://files.pythonhosted.org/packages/f4/67/c025520d2c548569f73cd68b885862e56e8946a10c9d43834460007c2671/jaxlib-0.4.35-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:907e548ad6ce53b242a55c5f36c2a2a4c37d38f6cd8c356fc550a2f18ab0e82f", size = 87876323 },
+ { url = "https://files.pythonhosted.org/packages/a8/e7/7962830da208ad3fa6596dc2df77824da9bc0196b549ae549ce53d1d1de1/jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f8c499644660aefd0ae2ee31039da6d4df0f26d0ee67ba9fb316183a5304288", size = 68025360 },
+ { url = "https://files.pythonhosted.org/packages/fa/91/2a1a1551845dd634bb1647fd37157f6f4ea71481e63f4100d08923c29d22/jaxlib-0.4.35-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5d2d8a5b89d334b875ede98d7fcee946bebef1a1b5abd118ff543bcef4ab09f5", size = 70588250 },
+ { url = "https://files.pythonhosted.org/packages/d7/16/6a9053d8b4b2790e330f9143030ab9d456556da5d98887b7e071bd08ffed/jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:91a283a72263feebe0d110d1136df96950744e47530f12df42c03f36888c971e", size = 87282292 },
+ { url = "https://files.pythonhosted.org/packages/6c/a9/b6bdff31e21a485190985dccbdd5ae1130fe2e4af826c83c10ae1d0d14a9/jaxlib-0.4.35-cp310-cp310-win_amd64.whl", hash = "sha256:d210bab7e1ce0b2f2e568548b3903ea6aec349019fc1398cd2a0c069e8342e62", size = 56484115 },
+ { url = "https://files.pythonhosted.org/packages/ee/01/4be899cf8d05920877b46b8acf51083dedaba206e951d88ddf7b098bed80/jaxlib-0.4.35-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:7f8bfc90f68857b223b7e38a9bdf466a4f1cb405c9a4aa11698dc9ab7b35c29b", size = 87895891 },
+ { url = "https://files.pythonhosted.org/packages/55/77/ca1e70bc3a161c1043d2e169a618263f865bf959433e5bf40ea56ec13e5e/jaxlib-0.4.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261570c94b169dc90f3af903282eeec856b52736c0944d243504ced93d19b217", size = 68045181 },
+ { url = "https://files.pythonhosted.org/packages/cd/2f/a8f4c441718558406cf27749415d1aa14bdac9dbd06fadb7bb4742c53637/jaxlib-0.4.35-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e1cee6dc291251f3fb6b0127fdd96c0439ac1ea97e01571d06910df72d6ac6e1", size = 70614621 },
+ { url = "https://files.pythonhosted.org/packages/c8/a6/1abe8d682d46cf2989f9c4928866ae80c30a54d607221a262cff8a5d9366/jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc9eafba001ff8569cfa252fe7f04ba553622702b4b473b656dd0866edf6b8d4", size = 87309681 },
+ { url = "https://files.pythonhosted.org/packages/7d/7c/73a4c4a34f2bbfce63e8baefee11753b0d58a71e0d2c33f210e00edba3cb/jaxlib-0.4.35-cp311-cp311-win_amd64.whl", hash = "sha256:0fd990354d5623d3a34493fcd7213493390dbf5039bea19b62e2aaee1049eda9", size = 56520062 },
+ { url = "https://files.pythonhosted.org/packages/ef/1c/901a59d9bc051b2a991163c46f58c50724d18ab25e71fa5556e5f68b84a4/jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f", size = 87936215 },
+ { url = "https://files.pythonhosted.org/packages/da/ff/38030bc3c96fae50f629830afe9c63a8a040aae332f6e28cd529397ba114/jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad", size = 68063993 },
+ { url = "https://files.pythonhosted.org/packages/55/27/83b6d2a1b380e20610e1449231c30c948cc4352c9a7e74a0d0d01bff8339/jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74", size = 70629159 },
+ { url = "https://files.pythonhosted.org/packages/6d/3f/5ac6dfef795f4f58645ccff0ebd65234cb77d7dbf1bdd2b6c49a677b64b0/jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a", size = 87349348 },
+ { url = "https://files.pythonhosted.org/packages/97/05/093b3c511837ba514f0b97581f7b21e1bb79768b8b9c29013a406b00d484/jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d", size = 56561679 },
+ { url = "https://files.pythonhosted.org/packages/99/40/aedef37c44797779a01bf71a392145724e3e0fc369e5f08f55c3c82de733/jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18", size = 87934299 },
+ { url = "https://files.pythonhosted.org/packages/94/42/62d4d13078886f4d22ca95ca07135f740cf9dd925f4cdb23d7b7d432403b/jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb", size = 68065641 },
+ { url = "https://files.pythonhosted.org/packages/4d/a0/87a4eae3811ce7014ce2c59b811ad930273bfbbb8252ba78079606f9ec40/jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5", size = 70629568 },
+ { url = "https://files.pythonhosted.org/packages/b3/89/59d6fe10e30ff5a48a73319bafa9a11cd999f91a47e4f08f7dc3651c899c/jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309", size = 87350315 },
+ { url = "https://files.pythonhosted.org/packages/79/d7/d7600c65fe0412a6584d84ca172816a8cf19965219ee3dd59542447ffe2f/jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43", size = 56562022 },
]
[[package]]
@@ -1412,6 +1432,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965 },
]
+[[package]]
+name = "jupyterlab-widgets"
+version = "3.0.13"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/59/73/fa26bbb747a9ea4fca6b01453aa22990d52ab62dd61384f1ac0dc9d4e7ba/jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed", size = 203556 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a9/93/858e87edc634d628e5d752ba944c2833133a28fa87bb093e6832ced36a3e/jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54", size = 214392 },
+]
+
[[package]]
name = "jupytext"
version = "1.13.8"
@@ -1428,6 +1457,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f9/e3/538509410372acd6d41f12c028dfc75ebddfbc4f7544f933bff7b5cc3e97/jupytext-1.13.8-py3-none-any.whl", hash = "sha256:625d2d2012763cc87d3f0dd60383516cec442c11894f53ad0c5ee5aa2a52caa2", size = 297592 },
]
+[[package]]
+name = "kagglehub"
+version = "0.3.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "packaging" },
+ { name = "requests" },
+ { name = "tqdm" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b0/69/3e3d9533b44535903011157102bcf08ad4124f12b5d2c294850e6fad5032/kagglehub-0.3.3.tar.gz", hash = "sha256:0777d4d1ee1e59d4125b14ba62a46b2eadedb68bc6517479f6fb02a522a262f8", size = 60620 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c4/d1/4ab25019a168f5c414202f124d156e11ac79f07845d67288929311f1b1b2/kagglehub-0.3.3-py3-none-any.whl", hash = "sha256:5370acde855d04b6d8a7bc242edff339266913fffc8b198d31859b25b7d095f7", size = 42852 },
+]
+
[[package]]
name = "keras"
version = "3.5.0"
@@ -2014,6 +2057,7 @@ version = "12.1.3.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 },
+ { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 },
]
[[package]]
@@ -2022,6 +2066,7 @@ version = "12.1.105"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 },
+ { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 },
]
[[package]]
@@ -2030,6 +2075,7 @@ version = "12.1.105"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 },
+ { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 },
]
[[package]]
@@ -2038,6 +2084,7 @@ version = "12.1.105"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 },
+ { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 },
]
[[package]]
@@ -2049,6 +2096,7 @@ dependencies = [
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
+ { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 },
]
[[package]]
@@ -2057,6 +2105,7 @@ version = "11.0.2.54"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 },
+ { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 },
]
[[package]]
@@ -2065,6 +2114,7 @@ version = "10.3.2.106"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 },
+ { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 },
]
[[package]]
@@ -2078,6 +2128,7 @@ dependencies = [
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
+ { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 },
]
[[package]]
@@ -2089,6 +2140,7 @@ dependencies = [
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
+ { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 },
]
[[package]]
@@ -2107,6 +2159,7 @@ source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/81/b3/e456a1b2d499bb84bdc6670bfbcf41ff3bac58bd2fae6880d62834641558/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb", size = 19252608 },
{ url = "https://files.pythonhosted.org/packages/59/65/7ff0569494fbaea45ad2814972cc88da843d53cc96eb8554fcd0908941d9/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79", size = 19724950 },
+ { url = "https://files.pythonhosted.org/packages/cb/ef/8f96c82e1cfcf6d5b770f7b043c3cc24841fc247b37629a7cc643dbf72a1/nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6", size = 162012830 },
]
[[package]]
@@ -2115,6 +2168,7 @@ version = "12.1.105"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 },
+ { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 },
]
[[package]]
@@ -3403,7 +3457,9 @@ dependencies = [
{ name = "tensorflow", marker = "platform_system != 'Darwin'" },
]
wheels = [
+ { url = "https://files.pythonhosted.org/packages/a2/e3/33fc5957790cf4710e0a9116cf37c0a881eda673e5f8b569bfff5654a48c/tensorflow_text-2.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8eba0b5804235519b571c827c97337c332de270107f06af6d2171cdefdc4c6a0", size = 6109587 },
{ url = "https://files.pythonhosted.org/packages/61/59/2090318555d98dc9dc868b3c585ada2e1139be538d954340726aa3d3899a/tensorflow_text-2.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f04c3f478f1885ad4c7380643a768a72a3de79e1f8f40d50b48cc1fbf73893", size = 5205819 },
+ { url = "https://files.pythonhosted.org/packages/92/65/e2d3d9300173a0927e8b7e3cf9a35f9539e9269786c1e1d9d945223fe21a/tensorflow_text-2.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a9b9f9c8a06878714a14f4e086fa8122beb2e141f82d0aa5a8f6b8f9b694db51", size = 6109684 },
{ url = "https://files.pythonhosted.org/packages/de/32/182ecf4eb1432942876d9b0b089625564084c5ed4d03c02ddf2872177e95/tensorflow_text-2.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161c09380b090774ed721cdcce973194458708250d7dfbac7cb9ea8a3e9ac762", size = 5205866 },
]
@@ -3653,6 +3709,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1b/d1/9babe2ccaecff775992753d8686970b1e2755d21c8a63be73aba7a4e7d77/wheel-0.44.0-py3-none-any.whl", hash = "sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f", size = 67059 },
]
+[[package]]
+name = "widgetsnbextension"
+version = "4.0.13"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/56/fc/238c424fd7f4ebb25f8b1da9a934a3ad7c848286732ae04263661eb0fc03/widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6", size = 1164730 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/21/02/88b65cc394961a60c43c70517066b6b679738caf78506a5da7b88ffcb643/widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71", size = 2335872 },
+]
+
[[package]]
name = "wrapt"
version = "1.16.0"