Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ __pycache__/
# C extensions
*.so

# custom
/slurm/train_ldm.sh
/slurm/train_vae.sh
/slurm/test.sh
test.ipynb

# Distribution / packaging
.Python
build/
Expand Down
10 changes: 10 additions & 0 deletions configs/autoencoder_module/vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,13 @@ visualization:

# compile model for faster training with pytorch 2.0
compile: false

datasets:
mp20:
proportion: ${data.datamodule.datasets.mp20.proportion}

qm9:
proportion: ${data.datamodule.datasets.qm9.proportion}

qmof150:
proportion: ${data.datamodule.datasets.qmof150.proportion}
2 changes: 1 addition & 1 deletion configs/callbacks/model_checkpoint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ model_checkpoint:
monitor: null # name of the logged metric which determines when model is improving
verbose: True # verbosity mode
save_last: True # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 3 # save k best models (determined by above metric)
save_top_k: 2 # save k best models (determined by above metric)
mode: "max" # "max" means higher metric value is better, can be also "min"
auto_insert_metric_name: False # when True, the checkpoints filenames will contain the metric name
save_weights_only: False # if True, then only the model’s weights will be saved
Expand Down
10 changes: 10 additions & 0 deletions configs/data/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
datamodule:
batch_size:
train: 256
val: 256
test: 256

num_workers:
train: 16
val: 16
test: 16
13 changes: 3 additions & 10 deletions configs/data/joint.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- default

datamodule:
_target_: src.data.joint_datamodule.JointDataModule

Expand All @@ -18,13 +21,3 @@ datamodule:
_target_: src.data.components.qmof150_dataset.QMOF150
root: ${paths.data_dir}/qmof
proportion: 0.0

batch_size:
train: 256
val: 256
test: 256

num_workers:
train: 16
val: 16
test: 16
12 changes: 3 additions & 9 deletions configs/data/mp20_only.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- default

datamodule:
_target_: src.data.joint_datamodule.JointDataModule

Expand All @@ -19,12 +22,3 @@ datamodule:
root: ${paths.data_dir}/qmof
proportion: 0.0

batch_size:
train: 256
val: 256
test: 256

num_workers:
train: 16
val: 16
test: 16
12 changes: 3 additions & 9 deletions configs/data/qm9_only.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- default

datamodule:
_target_: src.data.joint_datamodule.JointDataModule

Expand All @@ -19,12 +22,3 @@ datamodule:
root: ${paths.data_dir}/qmof
proportion: 0.0

batch_size:
train: 256
val: 256
test: 256

num_workers:
train: 16
val: 16
test: 16
10 changes: 10 additions & 0 deletions configs/diffusion_module/ldm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,13 @@ scheduler_frequency: ${trainer.check_val_every_n_epoch}

# compile model for faster training with pytorch 2.0
compile: false

datasets:
mp20:
proportion: ${data.datamodule.datasets.mp20.proportion}

qm9:
proportion: ${data.datamodule.datasets.qm9.proportion}

qmof150:
proportion: ${data.datamodule.datasets.qmof150.proportion}
2 changes: 1 addition & 1 deletion configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ wandb:
project: "all-atom-diffusion-transformer"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
entity: "chaitjo" # set to name of your wandb user/team
entity: "" # set to name of your wandb user/team
group: ""
tags: "${tags}"
job_type: ""
2 changes: 1 addition & 1 deletion configs/train_autoencoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
- data: joint # joint / qm9_only / mp20_only
- data: qm9_only # joint / qm9_only / mp20_only
- encoder: transformer
- decoder: transformer
- autoencoder_module: vae
Expand Down
4 changes: 2 additions & 2 deletions configs/train_diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
- data: joint # joint / qm9_only / mp20_only
- data: qm9_only # joint / qm9_only / mp20_only
- diffusion_module: ldm
- callbacks: diffusion_default # diffusion_default / _qm9_only / _mp20_only
- callbacks: diffusion_qm9_only # diffusion_default / _qm9_only / _mp20_only
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
Expand Down
2 changes: 1 addition & 1 deletion configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ devices: 1
# precision: 16

# perform a validation loop every N training epochs
check_val_every_n_epoch: 250
check_val_every_n_epoch: 100

# log metrics every N steps
log_every_n_steps: 100
Expand Down
2 changes: 1 addition & 1 deletion src/data/components/qmof150_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,4 @@ def process_one(data_dir, filename):
if np.all(all_check):
return result_dict
else:
return None
return None
86 changes: 32 additions & 54 deletions src/data/joint_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,59 +214,37 @@ def train_dataloader(self) -> DataLoader:
)

def val_dataloader(self) -> Sequence[DataLoader]:
"""Create and return the validation dataloader.

:return: The validation dataloader.
"""
return [
DataLoader(
dataset=self.mp20_val_dataset,
batch_size=self.hparams.batch_size.val,
num_workers=self.hparams.num_workers.val,
pin_memory=False,
shuffle=False,
),
DataLoader(
dataset=self.qm9_val_dataset,
batch_size=self.hparams.batch_size.val,
num_workers=self.hparams.num_workers.val,
pin_memory=False,
shuffle=False,
),
DataLoader(
dataset=self.qmof150_val_dataset,
batch_size=self.hparams.batch_size.val,
num_workers=self.hparams.num_workers.val,
pin_memory=False,
shuffle=False,
),
]
"""Create and return the validation dataloader."""
datasets = [("mp20", self.mp20_val_dataset), ("qm9", self.qm9_val_dataset), ("qmof150", self.qmof150_val_dataset)]
self.dataloader_to_dataset = {}
dataloaders = []
for name, ds in datasets:
if len(ds) > 0:
dataloaders.append(DataLoader(
dataset=ds,
batch_size=self.hparams.batch_size.val,
num_workers=self.hparams.num_workers.val,
pin_memory=False,
shuffle=False,
))
self.dataloader_to_dataset[len(dataloaders)-1] = name

return dataloaders

def test_dataloader(self) -> Sequence[DataLoader]:
"""Create and return the test dataloader.

:return: The test dataloader.
"""
return [
DataLoader(
dataset=self.mp20_test_dataset,
batch_size=self.hparams.batch_size.test,
num_workers=self.hparams.num_workers.test,
pin_memory=False,
shuffle=False,
),
DataLoader(
dataset=self.qm9_test_dataset,
batch_size=self.hparams.batch_size.test,
num_workers=self.hparams.num_workers.test,
pin_memory=False,
shuffle=False,
),
DataLoader(
dataset=self.qmof150_test_dataset,
batch_size=self.hparams.batch_size.test,
num_workers=self.hparams.num_workers.test,
pin_memory=False,
shuffle=False,
),
]
"""Create and return the test dataloader."""
datasets = [("mp20", self.mp20_test_dataset), ("qm9", self.qm9_test_dataset), ("qmof150", self.qmof150_test_dataset)]
# self.testdataloader_to_dataset = {}
dataloaders = []
for name, ds in datasets:
if len(ds) > 0:
dataloaders.append(DataLoader(
dataset=ds,
batch_size=self.hparams.batch_size.test,
num_workers=self.hparams.num_workers.test,
pin_memory=False,
shuffle=False,
))
# self.testdataloader_to_dataset[len(dataloaders)-1] = name

return dataloaders
Loading