diff --git a/jrystal/calc/calc_band_structure.py b/jrystal/calc/calc_band_structure.py index 6adbf79..0caa11d 100644 --- a/jrystal/calc/calc_band_structure.py +++ b/jrystal/calc/calc_band_structure.py @@ -118,7 +118,7 @@ def get_occupation(params): ) # Define the objective function for band structure calculation. - def hamiltonian_trace(params_pw_band, kpts): + def hamiltonian_trace(params_pw_band, kpts, g_vec=g_vec): coeff_band = pw.coeff(params_pw_band, freq_mask) energy = hamiltonian.hamiltonian_matrix_trace( coeff_band, @@ -139,8 +139,10 @@ def hamiltonian_trace(params_pw_band, kpts): # define update function @jax.jit - def update(params, opt_state, kpts): - hamil_trace, grad = jax.value_and_grad(hamiltonian_trace)(params, kpts) + def update(params, opt_state, kpts, g_vec): + hamil_trace, grad = jax.value_and_grad(hamiltonian_trace)( + params, kpts, g_vec + ) updates, opt_state = optimizer.update(grad, opt_state) params = optax.apply_updates(params, updates) @@ -163,7 +165,7 @@ def update(params, opt_state, kpts): start = time.time() for i in iters: params_pw_band, opt_state, hamil_trace = update( - params_pw_band, opt_state, k_path[0:1] + params_pw_band, opt_state, k_path[0:1], g_vec ) iters.set_description(f"Hamiltonian trace: {hamil_trace:.4E}") @@ -185,7 +187,7 @@ def update(params, opt_state, kpts): for i in iters: for _ in range(config.k_path_fine_tuning_epoch): params_pw_band, opt_state, hamil_trace = update( - params_pw_band, opt_state, k_path[i:(i+1)] + params_pw_band, opt_state, k_path[i:(i+1)], g_vec ) iters.set_description(f" Loss(the {i+1}th k point): {hamil_trace:.4E}") params_kpoint_list.append(params_pw_band) @@ -197,8 +199,12 @@ def update(params, opt_state, kpts): logging.info("===> Diagonalizing the Hamiltonian matrix...") @jax.jit - def eig_fn(param, k): - coeff_k = pw.coeff(param, freq_mask) + def eig_fn( + coeff_k, + k, + ground_state_density_grid, + g_vec, + ): hamil_matrix = hamiltonian.hamiltonian_matrix( coeff_k, crystal.positions, @@ -219,7 +225,8 @@ def eig_fn(param, k): prm = params_kpoint_list[i] k = k_path[i:(i + 1), :] iters.set_description(f"Diagonolizing the {i+1}th k points") - eig = eig_fn(prm, k) + coeff_k = pw.coeff(prm, freq_mask) + eig = eig_fn(coeff_k, k, ground_state_density_grid, g_vec) eigen_values.append(eig) # eigen_values = jnp.vstack(eigen_values) diff --git a/jrystal/calc/calc_ground_state_energy.py b/jrystal/calc/calc_ground_state_energy.py index b3690ab..bab6fbb 100644 --- a/jrystal/calc/calc_ground_state_energy.py +++ b/jrystal/calc/calc_ground_state_energy.py @@ -24,6 +24,7 @@ from .._src.crystal import Crystal from .._src import energy, entropy, occupation, pw +from .._src.grid import proper_grid_size from ..config import JrystalConfigDict from .opt_utils import ( create_crystal, @@ -82,6 +83,8 @@ def calc(config: JrystalConfigDict) -> GroundStateEnergyOutput: g_vec, r_vec, k_vec = create_grids(config) num_kpts = k_vec.shape[0] + logging.info(f"Number of G-vectors: {proper_grid_size(config.grid_sizes)}") + logging.info(f"Number of k-vectors: {proper_grid_size(config.k_grid_sizes)}") num_bands = ceil(crystal.num_electron / 2) + config.empty_bands freq_mask = create_freq_mask(config) ew = get_ewald_coulomb_repulsion(config) @@ -95,7 +98,7 @@ def get_occupation(params): params, crystal.num_electron, num_kpts, crystal.spin ) - def total_energy(params_pw, params_occ): + def total_energy(params_pw, params_occ, g_vec=g_vec): coeff = pw.coeff(params_pw, freq_mask) occ = get_occupation(params_occ) return energy.total_energy( @@ -113,8 +116,8 @@ def get_entropy(params_occ): occ = get_occupation(params_occ) return entropy.fermi_dirac(occ, eps=EPS) - def free_energy(params_pw, params_occ, temp): - total = total_energy(params_pw, params_occ) + def free_energy(params_pw, params_occ, temp, g_vec=g_vec): + total = total_energy(params_pw, params_occ, g_vec) etro = get_entropy(params_occ) free = total + temp * etro return free, (total, etro) @@ -128,8 +131,8 @@ def free_energy(params_pw, params_occ, temp): # Define update function. @jax.jit - def update(params, opt_state, temp): - loss = lambda x: free_energy(x["pw"], x["occ"], temp) + def update(params, opt_state, temp, g_vec): + loss = lambda x: free_energy(x["pw"], x["occ"], temp, g_vec) (loss_val, es), grad = jax.value_and_grad(loss, has_aux=True)(params) updates, opt_state = optimizer.update(grad, opt_state) params = optax.apply_updates(params, updates) @@ -160,7 +163,7 @@ def temperature_scheduler(i): for i in iters: temp = temperature_scheduler(i) start = time.time() - params, opt_state, loss_val, es = update(params, opt_state, temp) + params, opt_state, loss_val, es = update(params, opt_state, temp, g_vec) etot, entro = es etot = jax.block_until_ready(etot) train_time += time.time() - start diff --git a/jrystal/calc/opt_utils.py b/jrystal/calc/opt_utils.py index 56df15c..a591410 100644 --- a/jrystal/calc/opt_utils.py +++ b/jrystal/calc/opt_utils.py @@ -32,7 +32,7 @@ def set_env_params(config: JrystalConfigDict): - jax.config.update("jax_debug_nans", True) + jax.config.update("jax_debug_nans", False) if config.verbose: logging.set_verbosity(logging.INFO) diff --git a/main.py b/main.py index 8edfdb2..12c6a0e 100644 --- a/main.py +++ b/main.py @@ -1,38 +1,39 @@ import argparse import jrystal as jr -parser = argparse.ArgumentParser( - prog='Jrystal', description='Command for Jrystal package.' -) -parser.add_argument( - "-m", - "--mode", - choices=["energy", "band"], - default='energy', - help="Set the computation mode. For total enrgy minimization, please use " - "\'energy\'. For band structure calculation, please use \'band\'. " -) +def main(): + parser = argparse.ArgumentParser( + prog='Jrystal', description='Command for Jrystal package.' + ) -parser.add_argument( - "-c", - "--config", - default='config.yaml', - help="Set the configuration file path." -) + parser.add_argument( + "-m", + "--mode", + choices=["energy", "band"], + default='energy', + help="Set the computation mode. For total enrgy minimization, please use " + "\'energy\'. For band structure calculation, please use \'band\'. " + ) -args = parser.parse_args() + parser.add_argument( + "-c", + "--config", + default='config.yaml', + help="Set the configuration file path." + ) -config = jr.config.get_config("config.yaml") + args = parser.parse_args() -if args.mode == "energy": - if config.use_pseudopotential: - jr.calc.energy_normcons(config) - else: - jr.calc.energy(config) -elif args.mode == "band": - if config.use_pseudopotential: - jr.calc.band_normcons(config) - else: - jr.calc.band(config) + config = jr.config.get_config("config.yaml") + if args.mode == "energy": + if config.use_pseudopotential: + jr.calc.energy_normcons(config) + else: + jr.calc.energy(config) + elif args.mode == "band": + if config.use_pseudopotential: + jr.calc.band_normcons(config) + else: + jr.calc.band(config) diff --git a/setup.py b/setup.py index c1328d1..73d68fd 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def _read_requirements(): install_requires=_read_requirements(), entry_points={ 'console_scripts': [ - 'jrystal=main', + 'jrystal=main:main', ], }, )