Skip to content

Commit c52f107

Browse files
authored
Updating thermostat with logging and capability for temperature ramp (#9)
* Updating thermostat with logging and capability for temperature ramp Signed-off-by: Fabian Thiemann <fabian.thiemann@ibm.com> * Adjusting github action tests.yaml Signed-off-by: Fabian Thiemann <fabian.thiemann@ibm.com> * Updating unittests and minor cosmetics Signed-off-by: Fabian Thiemann <fabian.thiemann@ibm.com> * Updating temperature in forecast for logging Signed-off-by: Fabian Thiemann <fabian.thiemann@ibm.com> * Fixing cuda problem for notebook Signed-off-by: Fabian Thiemann <fabian.thiemann@ibm.com> * Add PR template Signed-off-by: Fabian Thiemann <fabian.thiemann@ibm.com> * Update changelog Signed-off-by: Fabian Thiemann <fabian.thiemann@ibm.com> --------- Signed-off-by: Fabian Thiemann <fabian.thiemann@ibm.com>
1 parent 5c88b0f commit c52f107

10 files changed

Lines changed: 240 additions & 64 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
**Summary**
2+
<!-- Briefly describe the purpose of this PR and what it changes. -->
3+
4+
**Key Features**
5+
<!-- List the main changes or additions introduced by this PR. -->
6+
7+
**Dependencies**
8+
<!-- List any new dependencies added or changes to existing requirements. -->
9+
10+
**Tests**
11+
<!-- Describe testing performed and any tests added. -->
12+
13+
**Related Issues / Branches**
14+
<!-- Link any related issues or pull requests, e.g., “Closes #123”. -->

.github/workflows/tests.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ jobs:
6161
source venv/bin/activate
6262
cd $GITHUB_WORKSPACE
6363
coverage run --source=trajcast -m unittest discover -s tests -p "test_*.py"
64+
exit $?
6465
coverage xml
6566
6667
- name: Generate coverage report

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## [Unreleased]
9+
10+
## [1.1.0] - 2025-09-23
11+
912
### Added
13+
- Added support for temperature ramps in the thermostat to enable controlled cooling/heating in simulations.
14+
- Improved logging for forecast runs to provide more informative output.
1015
- Added .github/workflows/tests.yaml to handle unittesting on github.
1116
- Added files for unittests which were missing in [1.0.0].
1217
- Added this changelog.
1318

1419
### Fixed
1520
- Fixed small issue in trajcast/data/wrappers/_lammps.py to not save the timestep which results in unittests failing with ase versions >= 3.25.0.
21+
- Fixed a CUDA error in generating inertia tensor reported by a user.
1622

1723
### Removed
1824
- Removed travis.yml as unittest will be handled by .github/workflows/tests.yaml.

examples/inference/forecasting.ipynb

Lines changed: 119 additions & 23 deletions
Large diffs are not rendered by default.

tests/unit/model/test_forecast.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -463,24 +463,23 @@ def test_returns_thermostat_is_initialised_correctly(self):
463463
timestep=5.0,
464464
damping=1000.0,
465465
temperature_handler=temperature,
466+
n_steps=50,
466467
)
467468

468-
self.assertIsInstance(csvr.e_kin_target, torch.Tensor)
469-
self.assertEqual(csvr.e_kin_target.shape, torch.Size([1]))
470-
471469
self.assertIsInstance(csvr.temp.conv_fac, torch.Tensor)
472470
self.assertEqual(csvr.temp.conv_fac.shape, torch.Size([1]))
473471

474472
self.assertTrue(csvr.temp._n_dofs == 90)
475473

476474
def test_returns_sum_noises_are_sampled_correclty_when_ndof_odd(self):
477-
475+
n_steps = 500000
478476
temperature = Temperature(units="real", n_atoms=162, n_extra_dofs=3)
479477
csvr = CSVRThermostat(
480478
target_temp=300.0,
481479
timestep=20.0,
482480
damping=2000,
483481
temperature_handler=temperature,
482+
n_steps=n_steps,
484483
)
485484

486485
gamm = []
@@ -499,18 +498,20 @@ def test_returns_sum_noises_are_sampled_correclty_when_ndof_odd(self):
499498
self.assertTrue(torch.isclose(gamm.std(), gauss.std(), atol=1e-1))
500499

501500
def test_returns_sum_noises_are_sampled_correclty_when_ndof_even(self):
501+
n_steps = 500000
502502
temperature = Temperature(units="real", n_atoms=22, n_extra_dofs=6)
503503
csvr = CSVRThermostat(
504504
target_temp=150.0,
505505
timestep=5.0,
506506
damping=500,
507507
temperature_handler=temperature,
508+
n_steps=n_steps,
508509
)
509510

510511
gamm = []
511512
gauss = []
512513
n_dofs = temperature._n_dofs
513-
for i in range(500000):
514+
for i in range(n_steps):
514515
# gaussian direct
515516
rns = torch.randn(n_dofs - 1)
516517
gauss.append(rns.pow(2).sum())
@@ -540,18 +541,18 @@ def test_returns_forward_produces_rescaled_velocities(self):
540541
)
541542

542543
graph = f.start_graph
543-
544+
n_steps = 500000
544545
temperature = Temperature(units="real", n_atoms=graph.num_nodes, n_extra_dofs=0)
545546
csvr = CSVRThermostat(
546547
target_temp=150.0,
547548
timestep=5.0,
548549
damping=50.0,
549550
temperature_handler=temperature,
551+
n_steps=n_steps,
550552
)
551-
552553
temps = []
553-
for _ in range(500000):
554-
graph = csvr(graph)
554+
for step in range(n_steps):
555+
graph = csvr(graph, step)
555556
T = temperature(graph)
556557
temps.append(T)
557558

@@ -588,18 +589,22 @@ def test_returns_thermostatting_is_reproducible_via_seed(self):
588589
},
589590
}
590591
)
591-
592+
n_steps = 500
592593
graph = f.start_graph
593594
csvr = CSVRThermostat(
594-
target_temp=250.0, timestep=5.0, damping=500.0, temperature_handler=f.temp
595+
target_temp=250.0,
596+
timestep=5.0,
597+
damping=500.0,
598+
temperature_handler=f.temp,
599+
n_steps=n_steps,
595600
)
596601

597602
temps_s1 = []
598603
vel_s1 = []
599-
for steps in range(500):
604+
for steps in range(n_steps):
600605
T = f.temp(graph)
601606
vel_s1.append(graph[VELOCITIES_KEY][0])
602-
graph = csvr(graph)
607+
graph = csvr(graph, steps)
603608

604609
temps_s1.append(T)
605610

@@ -645,7 +650,7 @@ def test_returns_thermostatting_is_reproducible_via_seed(self):
645650
for steps in range(500):
646651
T = f.temp(graph)
647652
vel_s2.append(graph[VELOCITIES_KEY][0])
648-
graph = csvr(graph)
653+
graph = csvr(graph, steps)
649654

650655
temps_s2.append(T)
651656

trajcast/model/forecast.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import warnings
33
from typing import Dict, Optional
4-
4+
from tqdm import tqdm
55
import ase.io
66
import torch
77
import yaml
@@ -402,6 +402,7 @@ def __init__(self, protocol: Dict):
402402
timestep=timestep,
403403
damping=tau,
404404
temperature_handler=self.temp,
405+
n_steps=self.protocol.get(RUN_KEY),
405406
).to(self.device)
406407

407408
else:
@@ -435,6 +436,9 @@ def __init__(self, protocol: Dict):
435436
f"Either set {VELOCITIES_KEY} to True or pass user requirements as dictionary."
436437
)
437438

439+
if isinstance(temperature, list):
440+
temperature = temperature[0]
441+
438442
vel_init = init_velocity(
439443
target_temperature=temperature,
440444
graph=self.start_graph,
@@ -449,13 +453,30 @@ def __init__(self, protocol: Dict):
449453
# for writing to file
450454
# if no write frequency is given
451455
write_settings = self.protocol.get(WRITE_TRAJECTORY_KEY)
452-
self.write_freq = self.protocol.get(RUN_KEY) + 1
456+
self.write_freq_xyz = self.protocol.get(RUN_KEY) + 1
457+
self.write_freq_temp = self.protocol.get(RUN_KEY) + 1
458+
453459
if write_settings:
454-
self.write_freq = write_settings.get("every", 1)
455-
if not isinstance(self.write_freq, int):
456-
raise TypeError("Write frequency must be specified as integer")
460+
write_freq = write_settings.get("every", 1)
461+
462+
if isinstance(write_freq, int):
463+
self.write_freq_xyz = write_freq
464+
self.write_freq_temp = write_freq
465+
466+
elif isinstance(write_freq, Dict):
467+
self.write_freq_xyz = write_freq.get("xyz", 1)
468+
self.write_freq_temp = write_freq.get("temp", self.write_freq_xyz)
469+
470+
else:
471+
raise TypeError("Write frequency must be specified as integer or dict.")
472+
self.save_vels = write_settings.get("save_velocities", True)
457473
self.filename = write_settings[FILENAME_KEY]
458474
self.fileformat = write_settings.get("format", "extxyz")
475+
logdir = os.path.dirname(self.filename)
476+
self.logfile = os.path.join(
477+
logdir, os.path.basename(self.filename).split(".")[0] + ".log"
478+
)
479+
os.makedirs(logdir, exist_ok=True)
459480
# write initial frame to file
460481
self._write_frame_to_file(
461482
frame=self.start_graph,
@@ -480,21 +501,31 @@ def generate_trajectory(self):
480501
n_steps = self.protocol.get(RUN_KEY)
481502
# initialise frame
482503
frame = self.start_graph
504+
temp = self.temp(frame)
505+
506+
with open(self.logfile, "w") as f:
507+
f.write("step,Temperature\n")
508+
f.write(f"0,{temp:3.3f}\n")
483509

484510
# loop over all steps
485511
with torch.no_grad():
486-
for step in range(1, n_steps + 1):
512+
for step in tqdm(range(1, n_steps + 1)):
487513
# make a step
488514
frame = self._make_timestep(frame, step)
515+
temp = self.temp(frame)
489516

490517
# check for writer
491-
if step % self.write_freq == 0:
518+
if step % self.write_freq_xyz == 0:
492519
self._write_frame_to_file(
493520
frame=frame,
494521
step=step,
495522
append=True,
496523
)
497524

525+
if step % self.write_freq_temp == 0:
526+
with open(self.logfile, "a") as f:
527+
f.write(f"{step},{temp:3.3f}\n")
528+
498529
def _write_frame_to_file(
499530
self,
500531
frame: AtomicGraph,
@@ -504,6 +535,9 @@ def _write_frame_to_file(
504535
# get ase.Atoms object
505536
ase_atoms = frame.ASEAtomsObject
506537

538+
if not self.save_vels:
539+
ase_atoms.arrays.pop("velocities")
540+
507541
# add timestamp
508542
ase_atoms.info[FRAME_KEY] = step
509543

@@ -533,7 +567,7 @@ def _make_timestep(self, frame: AtomicGraph, step: int) -> AtomicGraph:
533567

534568
# thermostatting
535569
if self.nvt:
536-
frame = self.thermo(frame)
570+
frame = self.thermo(frame, step)
537571

538572
# manipulate momentum if required
539573
if self.momentum and step % self.momentum.adjust_freq == 0:

trajcast/model/forecast_tools/_temperature.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ def __init__(self, units: str, n_atoms: int, n_extra_dofs: int) -> None:
4545

4646
self.register_buffer("kB", torch.tensor(kB))
4747

48-
def to_kinetic_energy(self, temperature: float) -> torch.Tensor:
48+
def to_kinetic_energy(self, temperature: torch.Tensor) -> torch.Tensor:
4949
return torch.tensor(
5050
[0.5 * self._n_dofs * temperature * self.kB],
5151
dtype=torch.get_default_dtype(),
52+
device=temperature.device,
5253
)
5354

5455
def from_kinetic_energy(self, kinetic_energy: torch.Tensor) -> float:

trajcast/model/forecast_tools/_thermostat.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
2+
from typing import Union, List
33
from trajcast.data._keys import ATOMIC_MASSES_KEY, VELOCITIES_KEY
44
from trajcast.data.atomic_graph import AtomicGraph
55
from trajcast.model.forecast_tools import Temperature
@@ -15,10 +15,11 @@ class CSVRThermostat(torch.nn.Module):
1515

1616
def __init__(
1717
self,
18-
target_temp: float,
18+
target_temp: Union[float, int, List[float], List[int]],
1919
timestep: float,
2020
damping: float,
2121
temperature_handler: Temperature,
22+
n_steps: int,
2223
):
2324
super().__init__()
2425

@@ -34,9 +35,23 @@ def __init__(
3435
torch.tensor([-timestep / damping], dtype=torch.get_default_dtype())
3536
),
3637
)
38+
if isinstance(target_temp, float) or isinstance(target_temp, int):
39+
self.register_buffer("start_temp", torch.tensor(float(target_temp)))
40+
self.temp_style = "CONST"
3741

38-
# target kinetic energy, this is in eV
39-
self.register_buffer("e_kin_target", self.temp.to_kinetic_energy(target_temp))
42+
else:
43+
if len(target_temp) != 2:
44+
raise TypeError(
45+
"Temperature should be float or list with two floats (start and stop temperature)."
46+
)
47+
48+
start_temp = float(target_temp[0])
49+
stop_temp = float(target_temp[1])
50+
self.temp_style = "RAMP" if start_temp != stop_temp else "CONST"
51+
self.register_buffer(
52+
"rate", torch.tensor((stop_temp - start_temp) / n_steps)
53+
)
54+
self.register_buffer("start_temp", torch.tensor(start_temp))
4055

4156
# for sampling noise we initialise a gamma distribution
4257
if (self.n_dofs - 1) % 2 == 0:
@@ -48,7 +63,7 @@ def __init__(
4863
concentration=(self.n_dofs - 2) / 2, rate=1.0
4964
)
5065

51-
def forward(self, data: AtomicGraph) -> AtomicGraph:
66+
def forward(self, data: AtomicGraph, step: int) -> AtomicGraph:
5267
masses = data[ATOMIC_MASSES_KEY]
5368
velocities = data[VELOCITIES_KEY]
5469

@@ -58,19 +73,28 @@ def forward(self, data: AtomicGraph) -> AtomicGraph:
5873
* self.temp.conv_fac
5974
)
6075
# compute rescale_factor
61-
alpha = self._get_rescale_factor(e_kin)
76+
alpha = self._get_rescale_factor(e_kin, step)
6277

6378
# rescale velocities accordingly
6479
data[VELOCITIES_KEY] *= alpha
6580
return data
6681

67-
def _get_rescale_factor(self, e_kin_current: torch.Tensor) -> torch.Tensor:
82+
def _get_rescale_factor(
83+
self, e_kin_current: torch.Tensor, step: int
84+
) -> torch.Tensor:
85+
86+
if self.temp_style == "CONST":
87+
temp_target = self.start_temp
88+
89+
else:
90+
temp_target = self.start_temp + self.rate * step
91+
92+
# target kinetic energy, this is in eV
93+
e_kin_target = self.temp.to_kinetic_energy(temp_target)
94+
6895
# compute constant c2 with c1 and kinetic energies
6996
c2 = (
70-
(torch.tensor(1.0) - self.c1)
71-
* self.e_kin_target
72-
/ e_kin_current
73-
/ self.n_dofs
97+
(torch.tensor(1.0) - self.c1) * e_kin_target / e_kin_current / self.n_dofs
7498
).to(self.device)
7599
# draw random number from Gaussian distribution with unitary variance (R1)
76100
r1 = torch.randn(1, device=self.device)

0 commit comments

Comments
 (0)