diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5c2d0a2..8948154 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,15 +1,15 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 21.11b1 + rev: 24.2.0 hooks: - id: black - repo: https://github.com/kynan/nbstripout - rev: 0.5.0 + rev: 0.7.1 hooks: - id: nbstripout diff --git a/notebooks/02-jax-idioms/04-optimized-learning.ipynb b/notebooks/02-jax-idioms/04-optimized-learning.ipynb index 48e19e3..79a4eed 100644 --- a/notebooks/02-jax-idioms/04-optimized-learning.ipynb +++ b/notebooks/02-jax-idioms/04-optimized-learning.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "frequent-field", + "id": "0", "metadata": { "tags": [] }, @@ -18,7 +18,7 @@ { "cell_type": "code", "execution_count": null, - "id": "opened-virgin", + "id": "1", "metadata": { "tags": [] }, @@ -31,7 +31,7 @@ }, { "cell_type": "markdown", - "id": "lasting-express", + "id": "2", "metadata": {}, "source": [ "# Optimized Learning\n", @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "forward-process", + "id": "3", "metadata": {}, "source": [ "## Autograd to JAX\n", @@ -52,7 +52,7 @@ }, { "cell_type": "markdown", - "id": "correct-cyprus", + "id": "4", "metadata": {}, "source": [ "## Example: Transforming a function into its derivative\n", @@ -68,7 +68,7 @@ { "cell_type": "code", "execution_count": null, - "id": "demanding-opportunity", + "id": "5", "metadata": { "tags": [] }, @@ -90,7 +90,7 @@ }, { "cell_type": "markdown", - "id": "forty-lindsay", + "id": "6", "metadata": {}, "source": [ "Here's another example using a polynomial function:\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "neutral-neighbor", + "id": "7", "metadata": { "tags": [] }, @@ -128,7 +128,7 @@ }, { "cell_type": "markdown", - "id": "steady-bikini", + "id": "8", "metadata": {}, "source": [ "## Using grad to solve minimization problems\n", @@ -147,7 +147,7 @@ { "cell_type": "code", "execution_count": null, - "id": "opponent-modification", + "id": "9", "metadata": { "tags": [] }, @@ -163,7 +163,7 @@ }, { "cell_type": "markdown", - "id": "beautiful-theory", + "id": "10", "metadata": {}, "source": [ "We know from calculus that the sign of the second derivative tells us whether we have a minima or maxima at a point.\n", @@ -178,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "former-syracuse", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "markdown", - "id": "surrounded-plain", + "id": "12", "metadata": {}, "source": [ "Grad is composable an arbitrary number of times. You can keep calling grad as many times as you like." @@ -197,7 +197,7 @@ }, { "cell_type": "markdown", - "id": "brazilian-atlas", + "id": "13", "metadata": {}, "source": [ "## Maximum likelihood estimation\n", @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "confidential-sympathy", + "id": "14", "metadata": { "tags": [] }, @@ -236,7 +236,7 @@ }, { "cell_type": "markdown", - "id": "atlantic-excellence", + "id": "15", "metadata": {}, "source": [ "Our estimation task will necessitate calculating the total joint log likelihood of our data under a Gaussian model.\n", @@ -248,7 +248,7 @@ { "cell_type": "code", "execution_count": null, - "id": "known-terrain", + "id": "16", "metadata": { "tags": [] }, @@ -263,7 +263,7 @@ }, { "cell_type": "markdown", - "id": "terminal-census", + "id": "17", "metadata": {}, "source": [ "If you're wondering why we use `log_sigma` rather than `sigma`, it is a choice made for practical reasons.\n", @@ -280,7 +280,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dominant-delight", + "id": "18", "metadata": { "tags": [] }, @@ -293,7 +293,7 @@ }, { "cell_type": "markdown", - "id": "equal-brazilian", + "id": "19", "metadata": {}, "source": [ "Now, we can create the gradient function of our negative log likelihood.\n", @@ -307,7 +307,7 @@ { "cell_type": "code", "execution_count": null, - "id": "meaning-scanning", + "id": "20", "metadata": { "tags": [] }, @@ -322,7 +322,7 @@ }, { "cell_type": "markdown", - "id": "hourly-miller", + "id": "21", "metadata": {}, "source": [ "Now, we can do the gradient descent step!" @@ -331,7 +331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cosmetic-perception", + "id": "22", "metadata": { "tags": [] }, @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "defensive-family", + "id": "23", "metadata": {}, "source": [ "And voila! We have gradient descended our way to the maximum likelihood parameters :)." @@ -355,7 +355,7 @@ }, { "cell_type": "markdown", - "id": "constant-account", + "id": "24", "metadata": {}, "source": [ "## Exercise: Where is the gold? It's at the minima!\n", @@ -368,7 +368,7 @@ { "cell_type": "code", "execution_count": null, - "id": "focal-climate", + "id": "25", "metadata": { "tags": [] }, @@ -383,7 +383,7 @@ }, { "cell_type": "markdown", - "id": "massive-corps", + "id": "26", "metadata": {}, "source": [ "It should be evident from here that there are two minima in the function.\n", @@ -398,7 +398,7 @@ { "cell_type": "code", "execution_count": null, - "id": "opened-beads", + "id": "27", "metadata": { "tags": [] }, @@ -420,7 +420,7 @@ }, { "cell_type": "markdown", - "id": "brown-violation", + "id": "28", "metadata": {}, "source": [ "Now, implement the optimization loop!" @@ -429,7 +429,7 @@ { "cell_type": "code", "execution_count": null, - "id": "alternative-wisdom", + "id": "29", "metadata": { "tags": [] }, @@ -450,7 +450,7 @@ }, { "cell_type": "markdown", - "id": "alternative-iraqi", + "id": "30", "metadata": {}, "source": [ "## Exercise: programming a robot that only moves along one axis\n", @@ -464,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "operational-advantage", + "id": "31", "metadata": { "tags": [] }, @@ -490,7 +490,7 @@ }, { "cell_type": "markdown", - "id": "ecological-asian", + "id": "32", "metadata": {}, "source": [ "For your reference we have the function plotted below." @@ -499,7 +499,7 @@ { "cell_type": "code", "execution_count": null, - "id": "convenient-optics", + "id": "33", "metadata": { "tags": [] }, @@ -531,7 +531,7 @@ { "cell_type": "code", "execution_count": null, - "id": "loaded-labor", + "id": "34", "metadata": {}, "outputs": [], "source": [] diff --git a/src/dl_workshop/answers.py b/src/dl_workshop/answers.py index 378f518..a35e8a5 100644 --- a/src/dl_workshop/answers.py +++ b/src/dl_workshop/answers.py @@ -1,6 +1,7 @@ """ Answers to the main tutorial notebooks. """ + import jax.numpy as np import numpy.random as npr from jax import grad @@ -81,7 +82,7 @@ def logistic_loss(params, model, x, y): def f(w): - return w ** 2 + 3 * w - 5 + return w**2 + 3 * w - 5 def df(w): diff --git a/src/dl_workshop/jax_idioms.py b/src/dl_workshop/jax_idioms.py index e3b141e..4910bf1 100644 --- a/src/dl_workshop/jax_idioms.py +++ b/src/dl_workshop/jax_idioms.py @@ -93,7 +93,7 @@ def randomness_ex_3(key, num_realizations: int, grw_draw: Callable): def goldfield(x, y): """All credit to https://www.analyzemath.com/calculus/multivariable/maxima_minima.html for this function.""" - return (2 * x ** 2) - (4 * x * y) + (y ** 4 + 2) + return (2 * x**2) - (4 * x * y) + (y**4 + 2) def grad_ex_1(): diff --git a/src/setup.py b/src/setup.py index 16cd073..bac1b43 100644 --- a/src/setup.py +++ b/src/setup.py @@ -1,4 +1,5 @@ """Setup script.""" + from setuptools import find_packages, setup setup(name="dl_workshop", version="0.1", packages=find_packages())