diff --git a/configs/uma/training_release/backbone/K2L2.yaml b/configs/uma/training_release/backbone/K2L2.yaml new file mode 100644 index 0000000000..1909b74825 --- /dev/null +++ b/configs/uma/training_release/backbone/K2L2.yaml @@ -0,0 +1,34 @@ +model: fairchem.core.models.uma.escn_moe.eSCNMDMoeBackbone +moe_dropout: 0.05 +moe_layer_type: ${moe_layer_type} +num_experts: ${num_moe_experts} +use_composition_embedding: true +use_global_embedding: false + +max_num_elements: 100 +sphere_channels: 64 +lmax: 2 +mmax: 2 + +otf_graph: ${otf_graph} +max_neighbors: ${max_neighbors} +use_pbc: True + +cutoff: ${cutoff_radius} +edge_channels: 64 +distance_function: gaussian +num_distance_basis: 64 + +regress_forces: True +regress_stress: ${regress_stress} +direct_forces: ${direct_forces} + +num_layers: 2 +hidden_channels: 64 +norm_type: rms_norm_sh +act_type: gate +ff_type: spectral + +chg_spin_emb_type: "rand_emb" +cs_emb_grad: True +dataset_list: ${dataset_list} diff --git a/configs/uma/training_release/uma_omat_direct.yaml b/configs/uma/training_release/uma_omat_direct.yaml new file mode 100644 index 0000000000..1c5a92d62a --- /dev/null +++ b/configs/uma/training_release/uma_omat_direct.yaml @@ -0,0 +1,190 @@ +defaults: + - cluster: h100 + - backbone: K2L2 + - dataset: uma + - element_refs: uma_v1_hof_lin_refs + - _self_ + +job: + device_type: ${cluster.device} + scheduler: + mode: ${cluster.mode} + ranks_per_node: ${cluster.ranks_per_node} + num_nodes: 4 + slurm: + account: ${cluster.account} + qos: ${cluster.qos} + mem_gb: ${cluster.mem_gb} + cpus_per_task: ${cluster.cpus_per_task} + debug: ${cluster.debug} + run_dir: ${cluster.run_dir} + run_name: uma_tiny_omat + logger: + _target_: fairchem.core.common.logger.WandBSingletonLogger.init_wandb + _partial_: true + entity: fairchem + project: uma + +moe_layer_type: pytorch +num_moe_experts: 32 +max_neighbors: 30 +cutoff_radius: 6 +epochs: 4 +steps: null +max_atoms: 700 +bf16: True +cpu_graph: True +otf_graph: False +normalizer_rmsd: 2.83073303546876 +direct_forces_coef: 30 +omat_energy_coef: 10 + +regress_stress: False +direct_forces: True + +omol_forces_key: forces +odac_forces_key: forces +omc_forces_key: forces +oc20_forces_key: forces +omat_forces_key: forces + +dataset_list: ["omat"] + +tasks: +- _target_: fairchem.core.units.mlip_unit.mlip_unit.Task + name: omat_energy + level: system + property: energy + loss_fn: + _target_: fairchem.core.modules.loss.DDPMTLoss + loss_fn: + _target_: fairchem.core.modules.loss.PerAtomMAELoss + coefficient: ${omat_energy_coef} + out_spec: + dim: [1] + dtype: float32 + normalizer: + _target_: fairchem.core.modules.normalization.normalizer.Normalizer + mean: 0.0 + rmsd: ${normalizer_rmsd} + element_references: + _target_: fairchem.core.modules.normalization.element_references.ElementReferences + element_references: + _target_: torch.DoubleTensor + _args_: + - ${element_refs.omat_elem_refs} + datasets: + - omat + metrics: + - mae + - per_atom_mae +- _target_: fairchem.core.units.mlip_unit.mlip_unit.Task + name: forces + level: atom + property: forces + train_on_free_atoms: True + eval_on_free_atoms: True + loss_fn: + _target_: fairchem.core.modules.loss.DDPMTLoss + loss_fn: + _target_: fairchem.core.modules.loss.L2NormLoss + reduction: per_structure + coefficient: ${direct_forces_coef} + out_spec: + dim: [3] + dtype: float32 + normalizer: + _target_: fairchem.core.modules.normalization.normalizer.Normalizer + mean: 0.0 + rmsd: ${normalizer_rmsd} + datasets: + - omat + metrics: + - mae + - cosine_similarity + - magnitude_error + + +train_dataset: + _target_: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset + dataset_configs: + omat: ${dataset.omat_train} + combined_dataset_config: { sampling: {type: temperature, temperature: 1.0} } + +val_dataset: + _target_: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset + dataset_configs: + omat: ${dataset.omat_val} + combined_dataset_config: { sampling: {type: temperature, temperature: 1.0} } + +train_dataloader: + _target_: fairchem.core.components.common.dataloader_builder.get_dataloader + dataset: ${train_dataset} + batch_sampler_fn: + _target_: fairchem.core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler + _partial_: True + max_atoms: ${max_atoms} + shuffle: True + seed: 0 + num_workers: ${cluster.dataloader_workers} + collate_fn: + _target_: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter + tasks: ${tasks} + +eval_dataloader: + _target_: fairchem.core.components.common.dataloader_builder.get_dataloader + dataset: ${val_dataset} + batch_sampler_fn: + _target_: fairchem.core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler + _partial_: True + max_atoms: ${max_atoms} + shuffle: False + seed: 0 + num_workers: ${cluster.dataloader_workers} + collate_fn: + _target_: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter + tasks: ${tasks} + +heads: + omat_energy: + module: fairchem.core.models.uma.escn_md.MLP_Energy_Head + forces: + module: fairchem.core.models.uma.escn_md.Linear_Force_Head + +runner: + _target_: fairchem.core.components.train.train_runner.TrainEvalRunner + train_dataloader: ${train_dataloader} + eval_dataloader: ${eval_dataloader} + train_eval_unit: + _target_: fairchem.core.units.mlip_unit.mlip_unit.MLIPTrainEvalUnit + job_config: ${job} + tasks: ${tasks} + model: + _target_: fairchem.core.models.base.HydraModel + backbone: ${backbone} + heads: ${heads} + optimizer_fn: + _target_: torch.optim.AdamW + _partial_: true + lr: 8e-4 + weight_decay: 1e-3 + cosine_lr_scheduler_fn: + _target_: fairchem.core.units.mlip_unit.mlip_unit._get_consine_lr_scheduler + _partial_: true + warmup_factor: 0.2 + warmup_epochs: 0.01 + lr_min_factor: 0.01 + epochs: ${epochs} + steps: ${steps} + print_every: 10 + clip_grad_norm: 100 + bf16: ${bf16} + max_epochs: ${epochs} + max_steps: ${steps} + evaluate_every_n_steps: 10000 + callbacks: + - _target_: fairchem.core.common.profiler_utils.ProfilerCallback + job_config: ${job} + - _target_: fairchem.core.components.train.train_runner.TrainCheckpointCallback + checkpoint_every_n_steps: 5000 + max_saved_checkpoints: 5