Skip to content

Commit

Permalink
Integrate folx for forward Laplacian computation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592595762
Change-Id: I2209521359899fe2c2fda232c9e0a2c68359331d
  • Loading branch information
dpfau authored and jsspencer committed Dec 20, 2023
1 parent bb979a0 commit e94435b
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 50 deletions.
1 change: 1 addition & 0 deletions ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def default() -> ml_collections.ConfigDict:
'objective': 'vmc', # objective type. Either 'vmc' or 'wqmc'
'iterations': 1000000, # number of iterations
'optimizer': 'kfac', # one of adam, kfac, lamb, none
'laplacian': 'default', # of of default or folx (for forward lapl)
'lr': {
'rate': 0.05, # learning rate
'decay': 1.0, # exponent of learning rate decay
Expand Down
129 changes: 89 additions & 40 deletions ferminet/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ferminet import networks
from ferminet import pseudopotential as pp
from ferminet.utils import utils
import folx
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -80,6 +81,7 @@ def local_kinetic_energy(
f: networks.FermiNetLike,
use_scan: bool = False,
complex_output: bool = False,
laplacian_method: str = 'default',
) -> KineticEnergy:
r"""Creates a function to for the local kinetic energy, -1/2 \nabla^2 ln|f|.
Expand All @@ -88,6 +90,9 @@ def local_kinetic_energy(
(sign or phase, log magnitude) tuple.
use_scan: Whether to use a `lax.scan` for computing the laplacian.
complex_output: If true, the output of f is complex-valued.
laplacian_method: Laplacian calculation method. One of:
'default': take jvp(grad), looping over inputs
'folx': use Microsoft's implementation of forward laplacian
Returns:
Callable which evaluates the local kinetic energy,
Expand All @@ -97,51 +102,77 @@ def local_kinetic_energy(
phase_f = utils.select_output(f, 0)
logabs_f = utils.select_output(f, 1)

def _lapl_over_f(params, data):
n = data.positions.shape[0]
eye = jnp.eye(n)
grad_f = jax.grad(logabs_f, argnums=1)
def grad_f_closure(x):
return grad_f(params, x, data.spins, data.atoms, data.charges)

primal, dgrad_f = jax.linearize(grad_f_closure, data.positions)

if laplacian_method == 'default':

def _lapl_over_f(params, data):
n = data.positions.shape[0]
eye = jnp.eye(n)
grad_f = jax.grad(logabs_f, argnums=1)
def grad_f_closure(x):
return grad_f(params, x, data.spins, data.atoms, data.charges)

primal, dgrad_f = jax.linearize(grad_f_closure, data.positions)

if complex_output:
grad_phase = jax.grad(phase_f, argnums=1)
def grad_phase_closure(x):
return grad_phase(params, x, data.spins, data.atoms, data.charges)
phase_primal, dgrad_phase = jax.linearize(
grad_phase_closure, data.positions)
hessian_diagonal = (
lambda i: dgrad_f(eye[i])[i] + 1.j * dgrad_phase(eye[i])[i]
)
else:
hessian_diagonal = lambda i: dgrad_f(eye[i])[i]

if use_scan:
_, diagonal = lax.scan(
lambda i, _: (i + 1, hessian_diagonal(i)), 0, None, length=n)
result = -0.5 * jnp.sum(diagonal)
else:
result = -0.5 * lax.fori_loop(
0, n, lambda i, val: val + hessian_diagonal(i), 0.0)
result -= 0.5 * jnp.sum(primal ** 2)
if complex_output:
result += 0.5 * jnp.sum(phase_primal ** 2)
result -= 1.j * jnp.sum(primal * phase_primal)
return result

elif laplacian_method == 'folx':
if complex_output:
grad_phase = jax.grad(phase_f, argnums=1)
def grad_phase_closure(x):
return grad_phase(params, x, data.spins, data.atoms, data.charges)
phase_primal, dgrad_phase = jax.linearize(
grad_phase_closure, data.positions)
hessian_diagonal = (
lambda i: dgrad_f(eye[i])[i] + 1.j * dgrad_phase(eye[i])[i]
)
raise NotImplementedError('Forward laplacian not yet supported for'
'complex-valued outputs.')
else:
hessian_diagonal = lambda i: dgrad_f(eye[i])[i]

if use_scan:
_, diagonal = lax.scan(
lambda i, _: (i + 1, hessian_diagonal(i)), 0, None, length=n)
result = -0.5 * jnp.sum(diagonal)
else:
result = -0.5 * lax.fori_loop(
0, n, lambda i, val: val + hessian_diagonal(i), 0.0)
result -= 0.5 * jnp.sum(primal ** 2)
if complex_output:
result += 0.5 * jnp.sum(phase_primal ** 2)
result -= 1.j * jnp.sum(primal * phase_primal)
return result
def _lapl_over_f(params, data):
f_closure = lambda x: logabs_f(params,
x,
data.spins,
data.atoms,
data.charges)
f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6)
output = f_wrapped(data.positions)
return - (output.laplacian +
jnp.sum(output.jacobian.dense_array ** 2)) / 2
else:
raise NotImplementedError(f'Laplacian method {laplacian_method} '
'not implemented.')

return _lapl_over_f


def excited_kinetic_energy_matrix(f: networks.FermiNetLike,
states: int) -> KineticEnergy:
def excited_kinetic_energy_matrix(
f: networks.FermiNetLike,
states: int,
laplacian_method: str = 'default') -> KineticEnergy:
"""Creates a f'n which evaluates the matrix of local kinetic energies.
Args:
f: A network which returns a tuple of sign(psi) and log(|psi|) arrays, where
each array contains one element per excited state.
states: the number of excited states
laplacian_method: Laplacian calculation method. One of:
'default': take jvp(grad), looping over inputs
'folx': use Microsoft's implementation of forward laplacian
Returns:
A function which computes the matrices (psi) and (K psi), which are the
Expand All @@ -166,11 +197,24 @@ def _lapl_over_f(params, data):
"""Return the kinetic energy (divided by psi) summed over excited states."""
pos_ = jnp.reshape(data.positions, [states, -1])
spins_ = jnp.reshape(data.spins, [states, -1])
vmap_f = jax.vmap(f, (None, 0, 0, None, None))
sign_mat, log_mat = vmap_f(params, pos_, spins_, data.atoms, data.charges)
vmap_lapl = jax.vmap(_lapl_all_states, (None, 0, 0, None, None))
lapl = vmap_lapl(params, pos_, spins_, data.atoms,
data.charges) # K psi_i(r_j) / psi_i(r_j)

if laplacian_method == 'default':
vmap_f = jax.vmap(f, (None, 0, 0, None, None))
sign_mat, log_mat = vmap_f(params, pos_, spins_, data.atoms, data.charges)
vmap_lapl = jax.vmap(_lapl_all_states, (None, 0, 0, None, None))
lapl = vmap_lapl(params, pos_, spins_, data.atoms,
data.charges) # K psi_i(r_j) / psi_i(r_j)
elif laplacian_method == 'folx':
# CAUTION!! Only the first array of spins is being passed!
f_closure = lambda x: f(params, x, spins_[0], data.atoms, data.charges)
f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6)
sign_mat, log_out = folx.batched_vmap(f_wrapped, 1)(pos_)
log_mat = log_out.x
lapl = -(log_out.laplacian +
jnp.sum(log_out.jacobian.dense_array ** 2, axis=-2)) / 2
else:
raise NotImplementedError(f'Laplacian method {laplacian_method} '
'not implemented with excited states.')

# subtract off largest value to avoid under/overflow
psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat)) # psi_i(r_j)
Expand Down Expand Up @@ -239,6 +283,7 @@ def local_energy(
nspins: Sequence[int],
use_scan: bool = False,
complex_output: bool = False,
laplacian_method: str = 'default',
states: int = 0,
pp_type: str = 'ccecp',
pp_symbols: Sequence[str] | None = None,
Expand All @@ -252,6 +297,9 @@ def local_energy(
nspins: Number of particles of each spin.
use_scan: Whether to use a `lax.scan` for computing the laplacian.
complex_output: If true, the output of f is complex-valued.
laplacian_method: Laplacian calculation method. One of:
'default': take jvp(grad), looping over inputs
'folx': use Microsoft's implementation of forward laplacian
states: Number of excited states to compute. If 0, compute ground state with
default machinery. If 1, compute ground state with excited state machinery
pp_type: type of pseudopotential to use. Only used if ecp_symbols is
Expand All @@ -270,11 +318,12 @@ def local_energy(
del nspins

if states:
ke = excited_kinetic_energy_matrix(f, states)
ke = excited_kinetic_energy_matrix(f, states, laplacian_method)
else:
ke = local_kinetic_energy(f,
use_scan=use_scan,
complex_output=complex_output)
complex_output=complex_output,
laplacian_method=laplacian_method)

if not pp_symbols:
effective_charges = charges
Expand Down
23 changes: 16 additions & 7 deletions ferminet/tests/hamiltonian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Tests for ferminet.hamiltonian."""

import itertools

from absl.testing import absltest
from absl.testing import parameterized
from ferminet import base_config
Expand Down Expand Up @@ -59,7 +61,8 @@ def kinetic_operator(params, pos, spins, atoms, charges):

class HamiltonianTest(parameterized.TestCase):

def test_local_kinetic_energy(self):
@parameterized.parameters(['default', 'folx'])
def test_local_kinetic_energy(self, laplacian):

dummy_params = {}
xs = np.random.normal(size=(3,))
Expand All @@ -68,7 +71,8 @@ def test_local_kinetic_energy(self):
charges = 2 * np.ones(shape=(1,))
expected_kinetic_energy = -(1 - 2 / np.abs(np.linalg.norm(xs))) / 2

kinetic = hamiltonian.local_kinetic_energy(h_atom_log_psi_signed)
kinetic = hamiltonian.local_kinetic_energy(h_atom_log_psi_signed,
laplacian_method=laplacian)
kinetic_energy = kinetic(
dummy_params,
networks.FermiNetData(
Expand Down Expand Up @@ -152,7 +156,8 @@ def test_local_energy(self):

class LaplacianTest(parameterized.TestCase):

def test_laplacian(self):
@parameterized.parameters(['default', 'folx'])
def test_laplacian(self, laplacian):

xs = np.random.uniform(size=(100, 3))
spins = np.ones(shape=(1,))
Expand All @@ -163,7 +168,8 @@ def test_laplacian(self):
)
dummy_params = {}
t_l_fn = jax.vmap(
hamiltonian.local_kinetic_energy(h_atom_log_psi_signed),
hamiltonian.local_kinetic_energy(h_atom_log_psi_signed,
laplacian_method=laplacian),
in_axes=(
None,
networks.FermiNetData(
Expand All @@ -178,8 +184,10 @@ def test_laplacian(self):
)(dummy_params, xs, spins, atoms, charges)
np.testing.assert_allclose(t_l, hess_t, rtol=1E-5)

@parameterized.parameters([True, False])
def test_fermi_net_laplacian(self, full_det):
@parameterized.parameters(
itertools.product([True, False], ['default', 'folx'])
)
def test_fermi_net_laplacian(self, full_det, laplacian):
natoms = 2
np.random.seed(12)
atoms = np.random.uniform(low=-5.0, high=5.0, size=(natoms, 3))
Expand Down Expand Up @@ -209,7 +217,8 @@ def test_fermi_net_laplacian(self, full_det):
spins = np.sign(np.random.normal(scale=1, size=(batch, sum(nspins))))
t_l_fn = jax.jit(
jax.vmap(
hamiltonian.local_kinetic_energy(network.apply),
hamiltonian.local_kinetic_energy(network.apply,
laplacian_method=laplacian),
in_axes=(
None,
networks.FermiNetData(
Expand Down
15 changes: 13 additions & 2 deletions ferminet/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,23 @@ def _config_params():
yield {'system': system,
'optimizer': optimizer,
'complex_': complex_,
'states': states}
'states': states,
'laplacian': 'default'}
for optimizer in ('kfac', 'adam', 'lamb', 'none'):
yield {
'system': 'H' if optimizer in ('kfac', 'adam') else 'Li',
'optimizer': optimizer,
'complex_': False,
'states': 0,
'laplacian': 'default',
}
for states, laplacian in itertools.product((0, 2), ('default', 'folx')):
yield {
'system': 'Li',
'optimizer': 'kfac',
'complex_': False,
'states': states,
'laplacian': laplacian
}


Expand All @@ -75,7 +85,7 @@ def setUp(self):
pyscf.lib.param.TMPDIR = None

@parameterized.parameters(_config_params())
def test_training_step(self, system, optimizer, complex_, states):
def test_training_step(self, system, optimizer, complex_, states, laplacian):
if system in ('H', 'Li'):
cfg = atom.get_config()
cfg.system.atom = system
Expand All @@ -90,6 +100,7 @@ def test_training_step(self, system, optimizer, complex_, states):
cfg.pretrain.iterations = 10
cfg.mcmc.burn_in = 10
cfg.optim.optimizer = optimizer
cfg.optim.laplacian = laplacian
cfg.optim.iterations = 3
cfg.debug.check_nan = True
cfg.observables.s2 = True
Expand Down
5 changes: 5 additions & 0 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,11 @@ def log_network(*args, **kwargs):
blocks=cfg.mcmc.blocks * num_states,
)
# Construct loss and optimizer
laplacian_method = cfg.optim.get('laplacian', 'default')
if cfg.system.make_local_energy_fn:
if laplacian_method != 'default':
raise NotImplementedError(f'Laplacian method {laplacian_method}'
'not yet supported by custom local energy fns.')
local_energy_module, local_energy_fn = (
cfg.system.make_local_energy_fn.rsplit('.', maxsplit=1))
local_energy_module = importlib.import_module(local_energy_module)
Expand All @@ -692,6 +696,7 @@ def log_network(*args, **kwargs):
nspins=nspins,
use_scan=False,
complex_output=cfg.network.get('complex', False),
laplacian_method=laplacian_method,
states=cfg.system.get('states', 0),
pp_type=cfg.system.get('pp', {'type': 'ccecp'}).get('type'),
pp_symbols=pp_symbols if cfg.system.get('use_pp') else None)
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'attrs',
'chex',
'h5py',
'folx @ git+https://github.com/microsoft/folx',
'jax',
'jaxlib',
# TODO(b/230487443) - use released version of kfac.
Expand All @@ -49,7 +50,8 @@ def ferminet_test_suite():
setup(
name='ferminet',
version='0.2',
description='A library to train networks to represent ground state wavefunctions of fermionic systems',
description=('A library to train networks to represent ground '
'state wavefunctions of fermionic systems'),
url='https://github.com/deepmind/ferminet',
author='DeepMind',
author_email='[email protected]',
Expand Down

0 comments on commit e94435b

Please sign in to comment.