Skip to content

Commit

Permalink
Uses jnp.square instead of power. (jax-ml#3036)
Browse files Browse the repository at this point in the history
* Uses multiplication instead of power.

* Uses jnp.square instead of mul and adds check if jnp.square is implemented by mul.
  • Loading branch information
Yusuke Oda authored May 12, 2020
1 parent 28bc4b7 commit ccb8d45
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def body_fun(i, k):
def error_ratio(error_estimate, rtol, atol, y0, y1):
err_tol = atol + rtol * jnp.maximum(jnp.abs(y0), jnp.abs(y1))
err_ratio = error_estimate / err_tol
return jnp.mean(err_ratio ** 2)
return jnp.mean(jnp.square(err_ratio))

def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
dfactor=0.2, order=5.0):
Expand Down
10 changes: 5 additions & 5 deletions jax/experimental/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def init(x0):

def update(i, g, state):
x, g_sq, m = state
g_sq += g**2
g_sq += jnp.square(g)
g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
x = x - step_size(i) * m
Expand Down Expand Up @@ -304,7 +304,7 @@ def init(x0):
return x0, avg_sq_grad
def update(i, g, state):
x, avg_sq_grad = state
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
return x, avg_sq_grad
def get_params(state):
Expand Down Expand Up @@ -337,7 +337,7 @@ def init(x0):
return x0, avg_sq_grad, mom
def update(i, g, state):
x, avg_sq_grad, mom = state
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
x = x - mom
return x, avg_sq_grad, mom
Expand Down Expand Up @@ -372,7 +372,7 @@ def init(x0):
def update(i, g, state):
x, m, v = state
m = (1 - b1) * g + b1 * m # First moment estimate.
v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate.
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
mhat = m / (1 - b1 ** (i + 1)) # Bias correction.
vhat = v / (1 - b2 ** (i + 1))
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
Expand Down Expand Up @@ -450,7 +450,7 @@ def init(x0):
def update(i, g, state):
x, m, vs = state
vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
accum = functools.reduce(jnp.minimum, vs) + g ** 2
accum = functools.reduce(jnp.minimum, vs) + jnp.square(g)
accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
x = x - step_size(i) * m
Expand Down
5 changes: 3 additions & 2 deletions jax/experimental/optix.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def update_fn(updates, state):


def global_norm(updates: Updates) -> Updates:
return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(updates)]))
return jnp.sqrt(
jnp.sum([jnp.sum(jnp.square(x)) for x in tree_leaves(updates)]))


class ClipByGlobalNormState(OptState):
Expand Down Expand Up @@ -222,7 +223,7 @@ def update_fn(updates, state):
mu = _update_moment(updates, state.mu, decay, 1)
nu = _update_moment(updates, state.nu, decay, 2)
updates = tree_multimap(
lambda g, m, n: g / jnp.sqrt(n - m**2 + eps), updates, mu, nu)
lambda g, m, n: g / jnp.sqrt(n - jnp.square(m) + eps), updates, mu, nu)
return updates, ScaleByRStdDevState(mu=mu, nu=nu)

return InitUpdate(init_fn, update_fn)
Expand Down
7 changes: 5 additions & 2 deletions jax/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def gelu(x):
<https://arxiv.org/abs/1606.08415>`_, section 2.
"""
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * x**3)))
# Does not use the power operator here.
# See https://github.com/google/jax/pull/3036
x_cubed = x * x * x
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * x_cubed)))
return x * cdf

def glu(x, axis=-1):
Expand Down Expand Up @@ -237,7 +240,7 @@ def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = jnp.mean(x**2, axis, keepdims=True) - mean**2
variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean)
return (x - mean) * lax.rsqrt(variance + epsilon)

def one_hot(x, num_classes, *, dtype=jnp.float64):
Expand Down
2 changes: 1 addition & 1 deletion jax/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):

# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
bs = _vdot_tree(b, b)
atol2 = jnp.maximum(tol ** 2 * bs, atol ** 2)
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))

# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method

Expand Down
3 changes: 2 additions & 1 deletion jax/scipy/stats/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
def logpdf(x, mean, cov):
x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
if not mean.shape:
return -1/2 * (x - mean) ** 2 / cov - 1/2 * (np.log(2*np.pi) + jnp.log(cov))
return (-1/2 * jnp.square(x - mean) / cov
- 1/2 * (np.log(2*np.pi) + jnp.log(cov)))
else:
n = mean.shape[-1]
if not np.shape(cov):
Expand Down
7 changes: 7 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3302,6 +3302,13 @@ def testTraceMethod(self):
self.assertAllClose(x.trace(), api.jit(lambda y: y.trace())(x),
check_dtypes=True)

def testSquareOfIntegers(self):
# See https://github.com/google/jax/pull/3036
# Checks if the squares of float32 integers have no numerical errors.
# It should be satisfied with all integers less than sqrt(2**24).
x = jnp.arange(2**12, dtype=jnp.int32)
onp.testing.assert_array_equal(jnp.square(x.astype(jnp.float32)), x * x)

# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.

Expand Down

0 comments on commit ccb8d45

Please sign in to comment.