diff --git a/src/dl_workshop/answers.py b/src/dl_workshop/answers.py index 378f518..20edd5f 100644 --- a/src/dl_workshop/answers.py +++ b/src/dl_workshop/answers.py @@ -81,7 +81,7 @@ def logistic_loss(params, model, x, y): def f(w): - return w ** 2 + 3 * w - 5 + return w**2 + 3 * w - 5 def df(w): diff --git a/src/dl_workshop/jax_idioms.py b/src/dl_workshop/jax_idioms.py index e3b141e..4910bf1 100644 --- a/src/dl_workshop/jax_idioms.py +++ b/src/dl_workshop/jax_idioms.py @@ -93,7 +93,7 @@ def randomness_ex_3(key, num_realizations: int, grw_draw: Callable): def goldfield(x, y): """All credit to https://www.analyzemath.com/calculus/multivariable/maxima_minima.html for this function.""" - return (2 * x ** 2) - (4 * x * y) + (y ** 4 + 2) + return (2 * x**2) - (4 * x * y) + (y**4 + 2) def grad_ex_1():