11import os
22import warnings
33from typing import Dict , Optional
4-
4+ from tqdm import tqdm
55import ase .io
66import torch
77import yaml
@@ -402,6 +402,7 @@ def __init__(self, protocol: Dict):
402402 timestep = timestep ,
403403 damping = tau ,
404404 temperature_handler = self .temp ,
405+ n_steps = self .protocol .get (RUN_KEY ),
405406 ).to (self .device )
406407
407408 else :
@@ -435,6 +436,9 @@ def __init__(self, protocol: Dict):
435436 f"Either set { VELOCITIES_KEY } to True or pass user requirements as dictionary."
436437 )
437438
439+ if isinstance (temperature , list ):
440+ temperature = temperature [0 ]
441+
438442 vel_init = init_velocity (
439443 target_temperature = temperature ,
440444 graph = self .start_graph ,
@@ -449,13 +453,30 @@ def __init__(self, protocol: Dict):
449453 # for writing to file
450454 # if no write frequency is given
451455 write_settings = self .protocol .get (WRITE_TRAJECTORY_KEY )
452- self .write_freq = self .protocol .get (RUN_KEY ) + 1
456+ self .write_freq_xyz = self .protocol .get (RUN_KEY ) + 1
457+ self .write_freq_temp = self .protocol .get (RUN_KEY ) + 1
458+
453459 if write_settings :
454- self .write_freq = write_settings .get ("every" , 1 )
455- if not isinstance (self .write_freq , int ):
456- raise TypeError ("Write frequency must be specified as integer" )
460+ write_freq = write_settings .get ("every" , 1 )
461+
462+ if isinstance (write_freq , int ):
463+ self .write_freq_xyz = write_freq
464+ self .write_freq_temp = write_freq
465+
466+ elif isinstance (write_freq , Dict ):
467+ self .write_freq_xyz = write_freq .get ("xyz" , 1 )
468+ self .write_freq_temp = write_freq .get ("temp" , self .write_freq_xyz )
469+
470+ else :
471+ raise TypeError ("Write frequency must be specified as integer or dict." )
472+ self .save_vels = write_settings .get ("save_velocities" , True )
457473 self .filename = write_settings [FILENAME_KEY ]
458474 self .fileformat = write_settings .get ("format" , "extxyz" )
475+ logdir = os .path .dirname (self .filename )
476+ self .logfile = os .path .join (
477+ logdir , os .path .basename (self .filename ).split ("." )[0 ] + ".log"
478+ )
479+ os .makedirs (logdir , exist_ok = True )
459480 # write initial frame to file
460481 self ._write_frame_to_file (
461482 frame = self .start_graph ,
@@ -480,21 +501,31 @@ def generate_trajectory(self):
480501 n_steps = self .protocol .get (RUN_KEY )
481502 # initialise frame
482503 frame = self .start_graph
504+ temp = self .temp (frame )
505+
506+ with open (self .logfile , "w" ) as f :
507+ f .write ("step,Temperature\n " )
508+ f .write (f"0,{ temp :3.3f} \n " )
483509
484510 # loop over all steps
485511 with torch .no_grad ():
486- for step in range (1 , n_steps + 1 ):
512+ for step in tqdm ( range (1 , n_steps + 1 ) ):
487513 # make a step
488514 frame = self ._make_timestep (frame , step )
515+ temp = self .temp (frame )
489516
490517 # check for writer
491- if step % self .write_freq == 0 :
518+ if step % self .write_freq_xyz == 0 :
492519 self ._write_frame_to_file (
493520 frame = frame ,
494521 step = step ,
495522 append = True ,
496523 )
497524
525+ if step % self .write_freq_temp == 0 :
526+ with open (self .logfile , "a" ) as f :
527+ f .write (f"{ step } ,{ temp :3.3f} \n " )
528+
498529 def _write_frame_to_file (
499530 self ,
500531 frame : AtomicGraph ,
@@ -504,6 +535,9 @@ def _write_frame_to_file(
504535 # get ase.Atoms object
505536 ase_atoms = frame .ASEAtomsObject
506537
538+ if not self .save_vels :
539+ ase_atoms .arrays .pop ("velocities" )
540+
507541 # add timestamp
508542 ase_atoms .info [FRAME_KEY ] = step
509543
@@ -533,7 +567,7 @@ def _make_timestep(self, frame: AtomicGraph, step: int) -> AtomicGraph:
533567
534568 # thermostatting
535569 if self .nvt :
536- frame = self .thermo (frame )
570+ frame = self .thermo (frame , step )
537571
538572 # manipulate momentum if required
539573 if self .momentum and step % self .momentum .adjust_freq == 0 :
0 commit comments