Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update calculators #31

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions jrystal/calc/calc_band_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_occupation(params):
)

# Define the objective function for band structure calculation.
def hamiltonian_trace(params_pw_band, kpts):
def hamiltonian_trace(params_pw_band, kpts, g_vec=g_vec):
coeff_band = pw.coeff(params_pw_band, freq_mask)
energy = hamiltonian.hamiltonian_matrix_trace(
coeff_band,
Expand All @@ -139,8 +139,10 @@ def hamiltonian_trace(params_pw_band, kpts):

# define update function
@jax.jit
def update(params, opt_state, kpts):
hamil_trace, grad = jax.value_and_grad(hamiltonian_trace)(params, kpts)
def update(params, opt_state, kpts, g_vec):
hamil_trace, grad = jax.value_and_grad(hamiltonian_trace)(
params, kpts, g_vec
)

updates, opt_state = optimizer.update(grad, opt_state)
params = optax.apply_updates(params, updates)
Expand All @@ -163,7 +165,7 @@ def update(params, opt_state, kpts):
start = time.time()
for i in iters:
params_pw_band, opt_state, hamil_trace = update(
params_pw_band, opt_state, k_path[0:1]
params_pw_band, opt_state, k_path[0:1], g_vec
)
iters.set_description(f"Hamiltonian trace: {hamil_trace:.4E}")

Expand All @@ -185,7 +187,7 @@ def update(params, opt_state, kpts):
for i in iters:
for _ in range(config.k_path_fine_tuning_epoch):
params_pw_band, opt_state, hamil_trace = update(
params_pw_band, opt_state, k_path[i:(i+1)]
params_pw_band, opt_state, k_path[i:(i+1)], g_vec
)
iters.set_description(f" Loss(the {i+1}th k point): {hamil_trace:.4E}")
params_kpoint_list.append(params_pw_band)
Expand All @@ -197,8 +199,12 @@ def update(params, opt_state, kpts):
logging.info("===> Diagonalizing the Hamiltonian matrix...")

@jax.jit
def eig_fn(param, k):
coeff_k = pw.coeff(param, freq_mask)
def eig_fn(
coeff_k,
k,
ground_state_density_grid,
g_vec,
):
hamil_matrix = hamiltonian.hamiltonian_matrix(
coeff_k,
crystal.positions,
Expand All @@ -219,7 +225,8 @@ def eig_fn(param, k):
prm = params_kpoint_list[i]
k = k_path[i:(i + 1), :]
iters.set_description(f"Diagonolizing the {i+1}th k points")
eig = eig_fn(prm, k)
coeff_k = pw.coeff(prm, freq_mask)
eig = eig_fn(coeff_k, k, ground_state_density_grid, g_vec)
eigen_values.append(eig)

# eigen_values = jnp.vstack(eigen_values)
Expand Down
15 changes: 9 additions & 6 deletions jrystal/calc/calc_ground_state_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .._src.crystal import Crystal
from .._src import energy, entropy, occupation, pw
from .._src.grid import proper_grid_size
from ..config import JrystalConfigDict
from .opt_utils import (
create_crystal,
Expand Down Expand Up @@ -82,6 +83,8 @@ def calc(config: JrystalConfigDict) -> GroundStateEnergyOutput:

g_vec, r_vec, k_vec = create_grids(config)
num_kpts = k_vec.shape[0]
logging.info(f"Number of G-vectors: {proper_grid_size(config.grid_sizes)}")
logging.info(f"Number of k-vectors: {proper_grid_size(config.k_grid_sizes)}")
num_bands = ceil(crystal.num_electron / 2) + config.empty_bands
freq_mask = create_freq_mask(config)
ew = get_ewald_coulomb_repulsion(config)
Expand All @@ -95,7 +98,7 @@ def get_occupation(params):
params, crystal.num_electron, num_kpts, crystal.spin
)

def total_energy(params_pw, params_occ):
def total_energy(params_pw, params_occ, g_vec=g_vec):
coeff = pw.coeff(params_pw, freq_mask)
occ = get_occupation(params_occ)
return energy.total_energy(
Expand All @@ -113,8 +116,8 @@ def get_entropy(params_occ):
occ = get_occupation(params_occ)
return entropy.fermi_dirac(occ, eps=EPS)

def free_energy(params_pw, params_occ, temp):
total = total_energy(params_pw, params_occ)
def free_energy(params_pw, params_occ, temp, g_vec=g_vec):
total = total_energy(params_pw, params_occ, g_vec)
etro = get_entropy(params_occ)
free = total + temp * etro
return free, (total, etro)
Expand All @@ -128,8 +131,8 @@ def free_energy(params_pw, params_occ, temp):

# Define update function.
@jax.jit
def update(params, opt_state, temp):
loss = lambda x: free_energy(x["pw"], x["occ"], temp)
def update(params, opt_state, temp, g_vec):
loss = lambda x: free_energy(x["pw"], x["occ"], temp, g_vec)
(loss_val, es), grad = jax.value_and_grad(loss, has_aux=True)(params)
updates, opt_state = optimizer.update(grad, opt_state)
params = optax.apply_updates(params, updates)
Expand Down Expand Up @@ -160,7 +163,7 @@ def temperature_scheduler(i):
for i in iters:
temp = temperature_scheduler(i)
start = time.time()
params, opt_state, loss_val, es = update(params, opt_state, temp)
params, opt_state, loss_val, es = update(params, opt_state, temp, g_vec)
etot, entro = es
etot = jax.block_until_ready(etot)
train_time += time.time() - start
Expand Down
2 changes: 1 addition & 1 deletion jrystal/calc/opt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def set_env_params(config: JrystalConfigDict):
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_nans", False)

if config.verbose:
logging.set_verbosity(logging.INFO)
Expand Down
59 changes: 30 additions & 29 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,39 @@
import argparse
import jrystal as jr

parser = argparse.ArgumentParser(
prog='Jrystal', description='Command for Jrystal package.'
)

parser.add_argument(
"-m",
"--mode",
choices=["energy", "band"],
default='energy',
help="Set the computation mode. For total enrgy minimization, please use "
"\'energy\'. For band structure calculation, please use \'band\'. "
)
def main():
parser = argparse.ArgumentParser(
prog='Jrystal', description='Command for Jrystal package.'
)

parser.add_argument(
"-c",
"--config",
default='config.yaml',
help="Set the configuration file path."
)
parser.add_argument(
"-m",
"--mode",
choices=["energy", "band"],
default='energy',
help="Set the computation mode. For total enrgy minimization, please use "
"\'energy\'. For band structure calculation, please use \'band\'. "
)

args = parser.parse_args()
parser.add_argument(
"-c",
"--config",
default='config.yaml',
help="Set the configuration file path."
)

config = jr.config.get_config("config.yaml")
args = parser.parse_args()

if args.mode == "energy":
if config.use_pseudopotential:
jr.calc.energy_normcons(config)
else:
jr.calc.energy(config)
elif args.mode == "band":
if config.use_pseudopotential:
jr.calc.band_normcons(config)
else:
jr.calc.band(config)
config = jr.config.get_config("config.yaml")

if args.mode == "energy":
if config.use_pseudopotential:
jr.calc.energy_normcons(config)
else:
jr.calc.energy(config)
elif args.mode == "band":
if config.use_pseudopotential:
jr.calc.band_normcons(config)
else:
jr.calc.band(config)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _read_requirements():
install_requires=_read_requirements(),
entry_points={
'console_scripts': [
'jrystal=main',
'jrystal=main:main',
],
},
)