diff --git a/ferminet/base_config.py b/ferminet/base_config.py index d867664..c080ed3 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -56,6 +56,7 @@ def default() -> ml_collections.ConfigDict: 'objective': 'vmc', # objective type. Either 'vmc' or 'wqmc' 'iterations': 1000000, # number of iterations 'optimizer': 'kfac', # one of adam, kfac, lamb, none + 'laplacian': 'default', # of of default or folx (for forward lapl) 'lr': { 'rate': 0.05, # learning rate 'decay': 1.0, # exponent of learning rate decay diff --git a/ferminet/hamiltonian.py b/ferminet/hamiltonian.py index e0c0ac4..d732dc8 100644 --- a/ferminet/hamiltonian.py +++ b/ferminet/hamiltonian.py @@ -20,6 +20,7 @@ from ferminet import networks from ferminet import pseudopotential as pp from ferminet.utils import utils +import folx import jax from jax import lax import jax.numpy as jnp @@ -80,6 +81,7 @@ def local_kinetic_energy( f: networks.FermiNetLike, use_scan: bool = False, complex_output: bool = False, + laplacian_method: str = 'default', ) -> KineticEnergy: r"""Creates a function to for the local kinetic energy, -1/2 \nabla^2 ln|f|. @@ -88,6 +90,9 @@ def local_kinetic_energy( (sign or phase, log magnitude) tuple. use_scan: Whether to use a `lax.scan` for computing the laplacian. complex_output: If true, the output of f is complex-valued. + laplacian_method: Laplacian calculation method. One of: + 'default': take jvp(grad), looping over inputs + 'folx': use Microsoft's implementation of forward laplacian Returns: Callable which evaluates the local kinetic energy, @@ -97,51 +102,77 @@ def local_kinetic_energy( phase_f = utils.select_output(f, 0) logabs_f = utils.select_output(f, 1) - def _lapl_over_f(params, data): - n = data.positions.shape[0] - eye = jnp.eye(n) - grad_f = jax.grad(logabs_f, argnums=1) - def grad_f_closure(x): - return grad_f(params, x, data.spins, data.atoms, data.charges) - - primal, dgrad_f = jax.linearize(grad_f_closure, data.positions) - + if laplacian_method == 'default': + + def _lapl_over_f(params, data): + n = data.positions.shape[0] + eye = jnp.eye(n) + grad_f = jax.grad(logabs_f, argnums=1) + def grad_f_closure(x): + return grad_f(params, x, data.spins, data.atoms, data.charges) + + primal, dgrad_f = jax.linearize(grad_f_closure, data.positions) + + if complex_output: + grad_phase = jax.grad(phase_f, argnums=1) + def grad_phase_closure(x): + return grad_phase(params, x, data.spins, data.atoms, data.charges) + phase_primal, dgrad_phase = jax.linearize( + grad_phase_closure, data.positions) + hessian_diagonal = ( + lambda i: dgrad_f(eye[i])[i] + 1.j * dgrad_phase(eye[i])[i] + ) + else: + hessian_diagonal = lambda i: dgrad_f(eye[i])[i] + + if use_scan: + _, diagonal = lax.scan( + lambda i, _: (i + 1, hessian_diagonal(i)), 0, None, length=n) + result = -0.5 * jnp.sum(diagonal) + else: + result = -0.5 * lax.fori_loop( + 0, n, lambda i, val: val + hessian_diagonal(i), 0.0) + result -= 0.5 * jnp.sum(primal ** 2) + if complex_output: + result += 0.5 * jnp.sum(phase_primal ** 2) + result -= 1.j * jnp.sum(primal * phase_primal) + return result + + elif laplacian_method == 'folx': if complex_output: - grad_phase = jax.grad(phase_f, argnums=1) - def grad_phase_closure(x): - return grad_phase(params, x, data.spins, data.atoms, data.charges) - phase_primal, dgrad_phase = jax.linearize( - grad_phase_closure, data.positions) - hessian_diagonal = ( - lambda i: dgrad_f(eye[i])[i] + 1.j * dgrad_phase(eye[i])[i] - ) + raise NotImplementedError('Forward laplacian not yet supported for' + 'complex-valued outputs.') else: - hessian_diagonal = lambda i: dgrad_f(eye[i])[i] - - if use_scan: - _, diagonal = lax.scan( - lambda i, _: (i + 1, hessian_diagonal(i)), 0, None, length=n) - result = -0.5 * jnp.sum(diagonal) - else: - result = -0.5 * lax.fori_loop( - 0, n, lambda i, val: val + hessian_diagonal(i), 0.0) - result -= 0.5 * jnp.sum(primal ** 2) - if complex_output: - result += 0.5 * jnp.sum(phase_primal ** 2) - result -= 1.j * jnp.sum(primal * phase_primal) - return result + def _lapl_over_f(params, data): + f_closure = lambda x: logabs_f(params, + x, + data.spins, + data.atoms, + data.charges) + f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6) + output = f_wrapped(data.positions) + return - (output.laplacian + + jnp.sum(output.jacobian.dense_array ** 2)) / 2 + else: + raise NotImplementedError(f'Laplacian method {laplacian_method} ' + 'not implemented.') return _lapl_over_f -def excited_kinetic_energy_matrix(f: networks.FermiNetLike, - states: int) -> KineticEnergy: +def excited_kinetic_energy_matrix( + f: networks.FermiNetLike, + states: int, + laplacian_method: str = 'default') -> KineticEnergy: """Creates a f'n which evaluates the matrix of local kinetic energies. Args: f: A network which returns a tuple of sign(psi) and log(|psi|) arrays, where each array contains one element per excited state. states: the number of excited states + laplacian_method: Laplacian calculation method. One of: + 'default': take jvp(grad), looping over inputs + 'folx': use Microsoft's implementation of forward laplacian Returns: A function which computes the matrices (psi) and (K psi), which are the @@ -166,11 +197,24 @@ def _lapl_over_f(params, data): """Return the kinetic energy (divided by psi) summed over excited states.""" pos_ = jnp.reshape(data.positions, [states, -1]) spins_ = jnp.reshape(data.spins, [states, -1]) - vmap_f = jax.vmap(f, (None, 0, 0, None, None)) - sign_mat, log_mat = vmap_f(params, pos_, spins_, data.atoms, data.charges) - vmap_lapl = jax.vmap(_lapl_all_states, (None, 0, 0, None, None)) - lapl = vmap_lapl(params, pos_, spins_, data.atoms, - data.charges) # K psi_i(r_j) / psi_i(r_j) + + if laplacian_method == 'default': + vmap_f = jax.vmap(f, (None, 0, 0, None, None)) + sign_mat, log_mat = vmap_f(params, pos_, spins_, data.atoms, data.charges) + vmap_lapl = jax.vmap(_lapl_all_states, (None, 0, 0, None, None)) + lapl = vmap_lapl(params, pos_, spins_, data.atoms, + data.charges) # K psi_i(r_j) / psi_i(r_j) + elif laplacian_method == 'folx': + # CAUTION!! Only the first array of spins is being passed! + f_closure = lambda x: f(params, x, spins_[0], data.atoms, data.charges) + f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6) + sign_mat, log_out = folx.batched_vmap(f_wrapped, 1)(pos_) + log_mat = log_out.x + lapl = -(log_out.laplacian + + jnp.sum(log_out.jacobian.dense_array ** 2, axis=-2)) / 2 + else: + raise NotImplementedError(f'Laplacian method {laplacian_method} ' + 'not implemented with excited states.') # subtract off largest value to avoid under/overflow psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat)) # psi_i(r_j) @@ -239,6 +283,7 @@ def local_energy( nspins: Sequence[int], use_scan: bool = False, complex_output: bool = False, + laplacian_method: str = 'default', states: int = 0, pp_type: str = 'ccecp', pp_symbols: Sequence[str] | None = None, @@ -252,6 +297,9 @@ def local_energy( nspins: Number of particles of each spin. use_scan: Whether to use a `lax.scan` for computing the laplacian. complex_output: If true, the output of f is complex-valued. + laplacian_method: Laplacian calculation method. One of: + 'default': take jvp(grad), looping over inputs + 'folx': use Microsoft's implementation of forward laplacian states: Number of excited states to compute. If 0, compute ground state with default machinery. If 1, compute ground state with excited state machinery pp_type: type of pseudopotential to use. Only used if ecp_symbols is @@ -270,11 +318,12 @@ def local_energy( del nspins if states: - ke = excited_kinetic_energy_matrix(f, states) + ke = excited_kinetic_energy_matrix(f, states, laplacian_method) else: ke = local_kinetic_energy(f, use_scan=use_scan, - complex_output=complex_output) + complex_output=complex_output, + laplacian_method=laplacian_method) if not pp_symbols: effective_charges = charges diff --git a/ferminet/tests/hamiltonian_test.py b/ferminet/tests/hamiltonian_test.py index 699e4d0..c816944 100644 --- a/ferminet/tests/hamiltonian_test.py +++ b/ferminet/tests/hamiltonian_test.py @@ -14,6 +14,8 @@ """Tests for ferminet.hamiltonian.""" +import itertools + from absl.testing import absltest from absl.testing import parameterized from ferminet import base_config @@ -59,7 +61,8 @@ def kinetic_operator(params, pos, spins, atoms, charges): class HamiltonianTest(parameterized.TestCase): - def test_local_kinetic_energy(self): + @parameterized.parameters(['default', 'folx']) + def test_local_kinetic_energy(self, laplacian): dummy_params = {} xs = np.random.normal(size=(3,)) @@ -68,7 +71,8 @@ def test_local_kinetic_energy(self): charges = 2 * np.ones(shape=(1,)) expected_kinetic_energy = -(1 - 2 / np.abs(np.linalg.norm(xs))) / 2 - kinetic = hamiltonian.local_kinetic_energy(h_atom_log_psi_signed) + kinetic = hamiltonian.local_kinetic_energy(h_atom_log_psi_signed, + laplacian_method=laplacian) kinetic_energy = kinetic( dummy_params, networks.FermiNetData( @@ -152,7 +156,8 @@ def test_local_energy(self): class LaplacianTest(parameterized.TestCase): - def test_laplacian(self): + @parameterized.parameters(['default', 'folx']) + def test_laplacian(self, laplacian): xs = np.random.uniform(size=(100, 3)) spins = np.ones(shape=(1,)) @@ -163,7 +168,8 @@ def test_laplacian(self): ) dummy_params = {} t_l_fn = jax.vmap( - hamiltonian.local_kinetic_energy(h_atom_log_psi_signed), + hamiltonian.local_kinetic_energy(h_atom_log_psi_signed, + laplacian_method=laplacian), in_axes=( None, networks.FermiNetData( @@ -178,8 +184,10 @@ def test_laplacian(self): )(dummy_params, xs, spins, atoms, charges) np.testing.assert_allclose(t_l, hess_t, rtol=1E-5) - @parameterized.parameters([True, False]) - def test_fermi_net_laplacian(self, full_det): + @parameterized.parameters( + itertools.product([True, False], ['default', 'folx']) + ) + def test_fermi_net_laplacian(self, full_det, laplacian): natoms = 2 np.random.seed(12) atoms = np.random.uniform(low=-5.0, high=5.0, size=(natoms, 3)) @@ -209,7 +217,8 @@ def test_fermi_net_laplacian(self, full_det): spins = np.sign(np.random.normal(scale=1, size=(batch, sum(nspins)))) t_l_fn = jax.jit( jax.vmap( - hamiltonian.local_kinetic_energy(network.apply), + hamiltonian.local_kinetic_energy(network.apply, + laplacian_method=laplacian), in_axes=( None, networks.FermiNetData( diff --git a/ferminet/tests/train_test.py b/ferminet/tests/train_test.py index 9e321fe..41efee9 100644 --- a/ferminet/tests/train_test.py +++ b/ferminet/tests/train_test.py @@ -55,13 +55,23 @@ def _config_params(): yield {'system': system, 'optimizer': optimizer, 'complex_': complex_, - 'states': states} + 'states': states, + 'laplacian': 'default'} for optimizer in ('kfac', 'adam', 'lamb', 'none'): yield { 'system': 'H' if optimizer in ('kfac', 'adam') else 'Li', 'optimizer': optimizer, 'complex_': False, 'states': 0, + 'laplacian': 'default', + } + for states, laplacian in itertools.product((0, 2), ('default', 'folx')): + yield { + 'system': 'Li', + 'optimizer': 'kfac', + 'complex_': False, + 'states': states, + 'laplacian': laplacian } @@ -75,7 +85,7 @@ def setUp(self): pyscf.lib.param.TMPDIR = None @parameterized.parameters(_config_params()) - def test_training_step(self, system, optimizer, complex_, states): + def test_training_step(self, system, optimizer, complex_, states, laplacian): if system in ('H', 'Li'): cfg = atom.get_config() cfg.system.atom = system @@ -90,6 +100,7 @@ def test_training_step(self, system, optimizer, complex_, states): cfg.pretrain.iterations = 10 cfg.mcmc.burn_in = 10 cfg.optim.optimizer = optimizer + cfg.optim.laplacian = laplacian cfg.optim.iterations = 3 cfg.debug.check_nan = True cfg.observables.s2 = True diff --git a/ferminet/train.py b/ferminet/train.py index 412de4d..996b9b3 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -672,7 +672,11 @@ def log_network(*args, **kwargs): blocks=cfg.mcmc.blocks * num_states, ) # Construct loss and optimizer + laplacian_method = cfg.optim.get('laplacian', 'default') if cfg.system.make_local_energy_fn: + if laplacian_method != 'default': + raise NotImplementedError(f'Laplacian method {laplacian_method}' + 'not yet supported by custom local energy fns.') local_energy_module, local_energy_fn = ( cfg.system.make_local_energy_fn.rsplit('.', maxsplit=1)) local_energy_module = importlib.import_module(local_energy_module) @@ -692,6 +696,7 @@ def log_network(*args, **kwargs): nspins=nspins, use_scan=False, complex_output=cfg.network.get('complex', False), + laplacian_method=laplacian_method, states=cfg.system.get('states', 0), pp_type=cfg.system.get('pp', {'type': 'ccecp'}).get('type'), pp_symbols=pp_symbols if cfg.system.get('use_pp') else None) diff --git a/setup.py b/setup.py index 80bfff0..b9d8c1f 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ 'attrs', 'chex', 'h5py', + 'folx @ git+https://github.com/microsoft/folx', 'jax', 'jaxlib', # TODO(b/230487443) - use released version of kfac. @@ -49,7 +50,8 @@ def ferminet_test_suite(): setup( name='ferminet', version='0.2', - description='A library to train networks to represent ground state wavefunctions of fermionic systems', + description=('A library to train networks to represent ground ' + 'state wavefunctions of fermionic systems'), url='https://github.com/deepmind/ferminet', author='DeepMind', author_email='no-reply@google.com',