Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0 # Use the ref you want to point at
rev: v6.0.0 # Use the ref you want to point at
hooks:
- id: trailing-whitespace
- id: check-ast
Expand All @@ -17,7 +17,7 @@ repos:
- id: check-toml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.9.10'
rev: 'v0.12.11'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -26,7 +26,7 @@ repos:
types_or: [python, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
rev: v1.17.1
hooks:
- id: mypy
entry: python3 -m mypy --config-file pyproject.toml
Expand All @@ -35,7 +35,7 @@ repos:
exclude: "tests"

- repo: https://github.com/crate-ci/typos
rev: v1
rev: v1.35.6
hooks:
- id: typos
args: [--force-exclude]
Expand Down
2 changes: 1 addition & 1 deletion atomgen/models/configuration_atomformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
cls_token_id: int = 122,
**kwargs: Any,
) -> None:
super().__init__(**kwargs) # type: ignore[no-untyped-call]
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.dim = dim
self.num_heads = num_heads
Expand Down
26 changes: 13 additions & 13 deletions atomgen/models/modeling_atomformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2547,10 +2547,10 @@ def forward(
return input_embeds, pos_embeds


class AtomformerPreTrainedModel(PreTrainedModel):
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
"""Base class for all transformer models."""

config_class = AtomformerConfig # type: ignore[assignment]
config_class = AtomformerConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ParallelBlock"]
Expand All @@ -2562,7 +2562,7 @@ def _set_gradient_checkpointing( # type: ignore[override]
module.gradient_checkpointing = value


class AtomformerModel(AtomformerPreTrainedModel):
class AtomformerModel(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer model for atom modeling."""

def __init__(self, config: AtomformerConfig):
Expand All @@ -2581,7 +2581,7 @@ def forward(
return output


class AtomformerForMaskedAM(AtomformerPreTrainedModel):
class AtomformerForMaskedAM(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with an atom modeling head on top for masked atom modeling."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2611,7 +2611,7 @@ def forward(
return loss, logits


class AtomformerForCoordinateAM(AtomformerPreTrainedModel):
class AtomformerForCoordinateAM(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with an atom coordinate head on top for coordinate denoising."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2641,7 +2641,7 @@ def forward(
return loss, coords_pred


class InitialStructure2RelaxedStructure(AtomformerPreTrainedModel):
class InitialStructure2RelaxedStructure(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with an coordinate head on top for relaxed structure prediction."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2674,7 +2674,7 @@ def forward(
return loss, coords_pred


class InitialStructure2RelaxedEnergy(AtomformerPreTrainedModel):
class InitialStructure2RelaxedEnergy(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with an energy head on top for relaxed energy prediction."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2704,7 +2704,7 @@ def forward(
return loss, energy


class InitialStructure2RelaxedStructureAndEnergy(AtomformerPreTrainedModel):
class InitialStructure2RelaxedStructureAndEnergy(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with an coordinate and energy head."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2757,7 +2757,7 @@ def forward(
return loss, (formation_energy_pred, coords_pred)


class Structure2Energy(AtomformerPreTrainedModel):
class Structure2Energy(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with an atom modeling head on top for masked atom modeling."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2799,7 +2799,7 @@ def forward(
)


class Structure2Forces(AtomformerPreTrainedModel):
class Structure2Forces(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with a forces head on top for forces prediction."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2841,7 +2841,7 @@ def forward(
)


class Structure2EnergyAndForces(AtomformerPreTrainedModel):
class Structure2EnergyAndForces(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with an energy and forces head for energy and forces prediction."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2892,7 +2892,7 @@ def forward(
return loss, (formation_energy_pred, forces_pred, attention_mask)


class Structure2TotalEnergyAndForces(AtomformerPreTrainedModel):
class Structure2TotalEnergyAndForces(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with an energy and forces head for energy and forces prediction."""

def __init__(self, config: AtomformerConfig):
Expand Down Expand Up @@ -2949,7 +2949,7 @@ def forward(
return loss, (total_energy_pred, forces_pred, attention_mask)


class AtomFormerForSystemClassification(AtomformerPreTrainedModel):
class AtomFormerForSystemClassification(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
"""Atomformer with a classification head for system classification."""

def __init__(self, config: AtomformerConfig):
Expand Down
8 changes: 4 additions & 4 deletions atomgen/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
cls_token_id: int = 122,
**kwargs: Any,
):
super().__init__(**kwargs) # type: ignore[no-untyped-call]
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_channels = hidden_channels
self.num_filters = num_filters
Expand All @@ -126,20 +126,20 @@ def __init__(
self.cls_token_id = cls_token_id


class SchNetPreTrainedModel(PreTrainedModel):
class SchNetPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
"""
A base class for all SchNet models.

An abstract class to handle weights initialization and a
simple interface for loading and exporting models.
"""

config_class = SchNetConfig # type: ignore[assignment]
config_class = SchNetConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False


class SchNetModel(SchNetPreTrainedModel):
class SchNetModel(SchNetPreTrainedModel): # type: ignore[no-untyped-call]
"""
SchNet model for energy prediction.

Expand Down
6 changes: 3 additions & 3 deletions atomgen/models/tokengt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2351,7 +2351,7 @@ def __init__(
gradient_checkpointing: bool = False,
**kwargs: Any,
):
super().__init__(**kwargs) # type: ignore[no-untyped-call]
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.dim = dim
self.num_heads = num_heads
Expand Down Expand Up @@ -2507,15 +2507,15 @@ def custom_forward(*inputs: Any) -> Any:
return input_embeds


class TransformerPreTrainedModel(PreTrainedModel):
class TransformerPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
"""Base class for all transformer models."""

config_class = TransformerConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ParallelBlock"]

def _set_gradient_checkpointing(
def _set_gradient_checkpointing( # type: ignore[override]
self, module: nn.Module, value: bool = False
) -> None:
if isinstance(module, (TransformerEncoder)):
Expand Down
8 changes: 4 additions & 4 deletions scripts/training/pretrain_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def train(args: argparse.Namespace) -> None:
config.gradient_checkpointing = (
args.gradient_checkpointing if args.gradient_checkpointing else False
)
model = Structure2EnergyAndForces(config) # type: ignore[arg-type]
model = Structure2EnergyAndForces(config)

tokenizer = AtomTokenizer(vocab_file=args.tokenizer_json)
data_collator = DataCollatorForAtomModeling(
Expand All @@ -173,7 +173,7 @@ def train(args: argparse.Namespace) -> None:
return_edge_indices=False,
)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if local_rank == 0:
wandb.login(key=os.environ["WANDB_API_KEY"])
wandb.init(project=args.project, config=vars(args), name=args.name)
Expand Down Expand Up @@ -207,14 +207,14 @@ def train(args: argparse.Namespace) -> None:
weight_decay=args.weight_decay,
)

trainer = Trainer( # type: ignore[no-untyped-call]
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
)

trainer.train(resume_from_checkpoint=args.checkpoint_exists) # type: ignore[attr-defined]
trainer.train(resume_from_checkpoint=args.checkpoint_exists)

model.save_pretrained(args.output_dir)

Expand Down
12 changes: 6 additions & 6 deletions scripts/training/run_atom3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def run_atom3d(args: argparse.Namespace) -> None:
else False,
problem_type=task_config["problem_type"],
)
model = AtomFormerForSystemClassification(config) # type: ignore[arg-type]
model = AtomFormerForSystemClassification(config)
else:
config = AtomformerConfig.from_pretrained(
args.model,
Expand Down Expand Up @@ -157,7 +157,7 @@ def run_atom3d(args: argparse.Namespace) -> None:
return_edge_indices=False,
)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if local_rank == 0:
wandb.login(key=os.environ.get("WANDB_API_KEY"))
wandb.init(project=args.project, config=vars(args), name=args.name)
Expand All @@ -182,7 +182,7 @@ def run_atom3d(args: argparse.Namespace) -> None:
)

# Initialize trainer
trainer = Trainer( # type: ignore[no-untyped-call]
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
Expand All @@ -192,12 +192,12 @@ def run_atom3d(args: argparse.Namespace) -> None:
)

# Train the model
trainer.train() # type: ignore[attr-defined]
trainer.train()

trainer.evaluate(dataset["test"]) # type: ignore[attr-defined]
trainer.evaluate(dataset["test"])

# Save the model
trainer.save_model(args.output_dir) # type: ignore[attr-defined]
trainer.save_model(args.output_dir)


if __name__ == "__main__":
Expand Down
Loading