Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions configs/uma/training_release/backbone/K2L2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
model: fairchem.core.models.uma.escn_moe.eSCNMDMoeBackbone
moe_dropout: 0.05
moe_layer_type: ${moe_layer_type}
num_experts: ${num_moe_experts}
use_composition_embedding: true
use_global_embedding: false

max_num_elements: 100
sphere_channels: 64
lmax: 2
mmax: 2

otf_graph: ${otf_graph}
max_neighbors: ${max_neighbors}
use_pbc: True

cutoff: ${cutoff_radius}
edge_channels: 64
distance_function: gaussian
num_distance_basis: 64

regress_forces: True
regress_stress: ${regress_stress}
direct_forces: ${direct_forces}

num_layers: 2
hidden_channels: 64
norm_type: rms_norm_sh
act_type: gate
ff_type: spectral

chg_spin_emb_type: "rand_emb"
cs_emb_grad: True
dataset_list: ${dataset_list}
190 changes: 190 additions & 0 deletions configs/uma/training_release/uma_omat_direct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
defaults:
- cluster: h100
- backbone: K2L2
- dataset: uma
- element_refs: uma_v1_hof_lin_refs
- _self_

job:
device_type: ${cluster.device}
scheduler:
mode: ${cluster.mode}
ranks_per_node: ${cluster.ranks_per_node}
num_nodes: 4
slurm:
account: ${cluster.account}
qos: ${cluster.qos}
mem_gb: ${cluster.mem_gb}
cpus_per_task: ${cluster.cpus_per_task}
debug: ${cluster.debug}
run_dir: ${cluster.run_dir}
run_name: uma_tiny_omat
logger:
_target_: fairchem.core.common.logger.WandBSingletonLogger.init_wandb
_partial_: true
entity: fairchem
project: uma

moe_layer_type: pytorch
num_moe_experts: 32
max_neighbors: 30
cutoff_radius: 6
epochs: 4
steps: null
max_atoms: 700
bf16: True
cpu_graph: True
otf_graph: False
normalizer_rmsd: 2.83073303546876
direct_forces_coef: 30
omat_energy_coef: 10

regress_stress: False
direct_forces: True

omol_forces_key: forces
odac_forces_key: forces
omc_forces_key: forces
oc20_forces_key: forces
omat_forces_key: forces

dataset_list: ["omat"]

tasks:
- _target_: fairchem.core.units.mlip_unit.mlip_unit.Task
name: omat_energy
level: system
property: energy
loss_fn:
_target_: fairchem.core.modules.loss.DDPMTLoss
loss_fn:
_target_: fairchem.core.modules.loss.PerAtomMAELoss
coefficient: ${omat_energy_coef}
out_spec:
dim: [1]
dtype: float32
normalizer:
_target_: fairchem.core.modules.normalization.normalizer.Normalizer
mean: 0.0
rmsd: ${normalizer_rmsd}
element_references:
_target_: fairchem.core.modules.normalization.element_references.ElementReferences
element_references:
_target_: torch.DoubleTensor
_args_:
- ${element_refs.omat_elem_refs}
datasets:
- omat
metrics:
- mae
- per_atom_mae
- _target_: fairchem.core.units.mlip_unit.mlip_unit.Task
name: forces
level: atom
property: forces
train_on_free_atoms: True
eval_on_free_atoms: True
loss_fn:
_target_: fairchem.core.modules.loss.DDPMTLoss
loss_fn:
_target_: fairchem.core.modules.loss.L2NormLoss
reduction: per_structure
coefficient: ${direct_forces_coef}
out_spec:
dim: [3]
dtype: float32
normalizer:
_target_: fairchem.core.modules.normalization.normalizer.Normalizer
mean: 0.0
rmsd: ${normalizer_rmsd}
datasets:
- omat
metrics:
- mae
- cosine_similarity
- magnitude_error


train_dataset:
_target_: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset
dataset_configs:
omat: ${dataset.omat_train}
combined_dataset_config: { sampling: {type: temperature, temperature: 1.0} }

val_dataset:
_target_: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset
dataset_configs:
omat: ${dataset.omat_val}
combined_dataset_config: { sampling: {type: temperature, temperature: 1.0} }

train_dataloader:
_target_: fairchem.core.components.common.dataloader_builder.get_dataloader
dataset: ${train_dataset}
batch_sampler_fn:
_target_: fairchem.core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler
_partial_: True
max_atoms: ${max_atoms}
shuffle: True
seed: 0
num_workers: ${cluster.dataloader_workers}
collate_fn:
_target_: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter
tasks: ${tasks}

eval_dataloader:
_target_: fairchem.core.components.common.dataloader_builder.get_dataloader
dataset: ${val_dataset}
batch_sampler_fn:
_target_: fairchem.core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler
_partial_: True
max_atoms: ${max_atoms}
shuffle: False
seed: 0
num_workers: ${cluster.dataloader_workers}
collate_fn:
_target_: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter
tasks: ${tasks}

heads:
omat_energy:
module: fairchem.core.models.uma.escn_md.MLP_Energy_Head
forces:
module: fairchem.core.models.uma.escn_md.Linear_Force_Head

runner:
_target_: fairchem.core.components.train.train_runner.TrainEvalRunner
train_dataloader: ${train_dataloader}
eval_dataloader: ${eval_dataloader}
train_eval_unit:
_target_: fairchem.core.units.mlip_unit.mlip_unit.MLIPTrainEvalUnit
job_config: ${job}
tasks: ${tasks}
model:
_target_: fairchem.core.models.base.HydraModel
backbone: ${backbone}
heads: ${heads}
optimizer_fn:
_target_: torch.optim.AdamW
_partial_: true
lr: 8e-4
weight_decay: 1e-3
cosine_lr_scheduler_fn:
_target_: fairchem.core.units.mlip_unit.mlip_unit._get_consine_lr_scheduler
_partial_: true
warmup_factor: 0.2
warmup_epochs: 0.01
lr_min_factor: 0.01
epochs: ${epochs}
steps: ${steps}
print_every: 10
clip_grad_norm: 100
bf16: ${bf16}
max_epochs: ${epochs}
max_steps: ${steps}
evaluate_every_n_steps: 10000
callbacks:
- _target_: fairchem.core.common.profiler_utils.ProfilerCallback
job_config: ${job}
- _target_: fairchem.core.components.train.train_runner.TrainCheckpointCallback
checkpoint_every_n_steps: 5000
max_saved_checkpoints: 5
Loading