Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #21

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 21.11b1
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/kynan/nbstripout
rev: 0.5.0
rev: 0.7.1
hooks:
- id: nbstripout
70 changes: 35 additions & 35 deletions notebooks/02-jax-idioms/04-optimized-learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "frequent-field",
"id": "0",
"metadata": {
"tags": []
},
Expand All @@ -18,7 +18,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "opened-virgin",
"id": "1",
"metadata": {
"tags": []
},
Expand All @@ -31,7 +31,7 @@
},
{
"cell_type": "markdown",
"id": "lasting-express",
"id": "2",
"metadata": {},
"source": [
"# Optimized Learning\n",
Expand All @@ -41,7 +41,7 @@
},
{
"cell_type": "markdown",
"id": "forward-process",
"id": "3",
"metadata": {},
"source": [
"## Autograd to JAX\n",
Expand All @@ -52,7 +52,7 @@
},
{
"cell_type": "markdown",
"id": "correct-cyprus",
"id": "4",
"metadata": {},
"source": [
"## Example: Transforming a function into its derivative\n",
Expand All @@ -68,7 +68,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "demanding-opportunity",
"id": "5",
"metadata": {
"tags": []
},
Expand All @@ -90,7 +90,7 @@
},
{
"cell_type": "markdown",
"id": "forty-lindsay",
"id": "6",
"metadata": {},
"source": [
"Here's another example using a polynomial function:\n",
Expand All @@ -105,7 +105,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "neutral-neighbor",
"id": "7",
"metadata": {
"tags": []
},
Expand All @@ -128,7 +128,7 @@
},
{
"cell_type": "markdown",
"id": "steady-bikini",
"id": "8",
"metadata": {},
"source": [
"## Using grad to solve minimization problems\n",
Expand All @@ -147,7 +147,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "opponent-modification",
"id": "9",
"metadata": {
"tags": []
},
Expand All @@ -163,7 +163,7 @@
},
{
"cell_type": "markdown",
"id": "beautiful-theory",
"id": "10",
"metadata": {},
"source": [
"We know from calculus that the sign of the second derivative tells us whether we have a minima or maxima at a point.\n",
Expand All @@ -178,7 +178,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "former-syracuse",
"id": "11",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -189,15 +189,15 @@
},
{
"cell_type": "markdown",
"id": "surrounded-plain",
"id": "12",
"metadata": {},
"source": [
"Grad is composable an arbitrary number of times. You can keep calling grad as many times as you like."
]
},
{
"cell_type": "markdown",
"id": "brazilian-atlas",
"id": "13",
"metadata": {},
"source": [
"## Maximum likelihood estimation\n",
Expand All @@ -216,7 +216,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "confidential-sympathy",
"id": "14",
"metadata": {
"tags": []
},
Expand All @@ -236,7 +236,7 @@
},
{
"cell_type": "markdown",
"id": "atlantic-excellence",
"id": "15",
"metadata": {},
"source": [
"Our estimation task will necessitate calculating the total joint log likelihood of our data under a Gaussian model.\n",
Expand All @@ -248,7 +248,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "known-terrain",
"id": "16",
"metadata": {
"tags": []
},
Expand All @@ -263,7 +263,7 @@
},
{
"cell_type": "markdown",
"id": "terminal-census",
"id": "17",
"metadata": {},
"source": [
"If you're wondering why we use `log_sigma` rather than `sigma`, it is a choice made for practical reasons.\n",
Expand All @@ -280,7 +280,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "dominant-delight",
"id": "18",
"metadata": {
"tags": []
},
Expand All @@ -293,7 +293,7 @@
},
{
"cell_type": "markdown",
"id": "equal-brazilian",
"id": "19",
"metadata": {},
"source": [
"Now, we can create the gradient function of our negative log likelihood.\n",
Expand All @@ -307,7 +307,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "meaning-scanning",
"id": "20",
"metadata": {
"tags": []
},
Expand All @@ -322,7 +322,7 @@
},
{
"cell_type": "markdown",
"id": "hourly-miller",
"id": "21",
"metadata": {},
"source": [
"Now, we can do the gradient descent step!"
Expand All @@ -331,7 +331,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "cosmetic-perception",
"id": "22",
"metadata": {
"tags": []
},
Expand All @@ -347,15 +347,15 @@
},
{
"cell_type": "markdown",
"id": "defensive-family",
"id": "23",
"metadata": {},
"source": [
"And voila! We have gradient descended our way to the maximum likelihood parameters :)."
]
},
{
"cell_type": "markdown",
"id": "constant-account",
"id": "24",
"metadata": {},
"source": [
"## Exercise: Where is the gold? It's at the minima!\n",
Expand All @@ -368,7 +368,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "focal-climate",
"id": "25",
"metadata": {
"tags": []
},
Expand All @@ -383,7 +383,7 @@
},
{
"cell_type": "markdown",
"id": "massive-corps",
"id": "26",
"metadata": {},
"source": [
"It should be evident from here that there are two minima in the function.\n",
Expand All @@ -398,7 +398,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "opened-beads",
"id": "27",
"metadata": {
"tags": []
},
Expand All @@ -420,7 +420,7 @@
},
{
"cell_type": "markdown",
"id": "brown-violation",
"id": "28",
"metadata": {},
"source": [
"Now, implement the optimization loop!"
Expand All @@ -429,7 +429,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "alternative-wisdom",
"id": "29",
"metadata": {
"tags": []
},
Expand All @@ -450,7 +450,7 @@
},
{
"cell_type": "markdown",
"id": "alternative-iraqi",
"id": "30",
"metadata": {},
"source": [
"## Exercise: programming a robot that only moves along one axis\n",
Expand All @@ -464,7 +464,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "operational-advantage",
"id": "31",
"metadata": {
"tags": []
},
Expand All @@ -490,7 +490,7 @@
},
{
"cell_type": "markdown",
"id": "ecological-asian",
"id": "32",
"metadata": {},
"source": [
"For your reference we have the function plotted below."
Expand All @@ -499,7 +499,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "convenient-optics",
"id": "33",
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -531,7 +531,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "loaded-labor",
"id": "34",
"metadata": {},
"outputs": [],
"source": []
Expand Down
3 changes: 2 additions & 1 deletion src/dl_workshop/answers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Answers to the main tutorial notebooks.
"""

import jax.numpy as np
import numpy.random as npr
from jax import grad
Expand Down Expand Up @@ -81,7 +82,7 @@ def logistic_loss(params, model, x, y):


def f(w):
return w ** 2 + 3 * w - 5
return w**2 + 3 * w - 5


def df(w):
Expand Down
2 changes: 1 addition & 1 deletion src/dl_workshop/jax_idioms.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def randomness_ex_3(key, num_realizations: int, grw_draw: Callable):

def goldfield(x, y):
"""All credit to https://www.analyzemath.com/calculus/multivariable/maxima_minima.html for this function."""
return (2 * x ** 2) - (4 * x * y) + (y ** 4 + 2)
return (2 * x**2) - (4 * x * y) + (y**4 + 2)


def grad_ex_1():
Expand Down
1 change: 1 addition & 0 deletions src/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Setup script."""

from setuptools import find_packages, setup

setup(name="dl_workshop", version="0.1", packages=find_packages())