diff --git a/open_instruct/dpo_tune_cache.py b/open_instruct/dpo_tune_cache.py
index ee3a37868..6f2975915 100644
--- a/open_instruct/dpo_tune_cache.py
+++ b/open_instruct/dpo_tune_cache.py
@@ -57,7 +57,11 @@
     get_scheduler,
 )
 
-from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_preference
+from open_instruct.dataset_transformation import (
+    CHAT_TEMPLATES,
+    TokenizerConfig,
+    get_cached_dataset_tulu_preference,
+)
 from open_instruct.dpo_utils import (
     DataCollatorForSeq2SeqDPO,
     concatenated_forward,
@@ -421,7 +425,6 @@ def __post_init__(self):
             raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.")
 
 
-
 def get_cache_ref_logprobs(
     model: torch.nn.Module,
     active_dataloader: torch.utils.data.DataLoader,
diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py
index 235037290..db4df123a 100644
--- a/open_instruct/finetune.py
+++ b/open_instruct/finetune.py
@@ -53,7 +53,11 @@
     get_scheduler,
 )
 
-from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_sft
+from open_instruct.dataset_transformation import (
+    CHAT_TEMPLATES,
+    TokenizerConfig,
+    get_cached_dataset_tulu_sft,
+)
 from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate
 from open_instruct.utils import (
     ArgumentParserPlus,
@@ -473,7 +477,7 @@ def main(args: FlatArguments):
         train_dataset.set_format(type="pt")
     if args.cache_dataset_only:
         return
-    
+
     # Load pretrained model and tokenizer
     if args.config_name:
         config = AutoConfig.from_pretrained(
@@ -816,7 +820,9 @@ def main(args: FlatArguments):
                             os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w"
                         ) as f:
                             f.write("COMPLETED")  # annoyingly, empty files arent uploaded by beaker.
-                        if accelerator.is_local_main_process: # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem.
+                        if (
+                            accelerator.is_local_main_process
+                        ):  # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem.
                             clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints)
                         accelerator.wait_for_everyone()
 
diff --git a/open_instruct/olmo_adapter/__init__.py b/open_instruct/olmo_adapter/__init__.py
index f39796b11..71776201d 100644
--- a/open_instruct/olmo_adapter/__init__.py
+++ b/open_instruct/olmo_adapter/__init__.py
@@ -2,10 +2,18 @@
 
 import torch
 from torch import nn
-from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 from transformers.modeling_outputs import SequenceClassifierOutputWithPast
-from transformers.models.olmo2.modeling_olmo2 import Olmo2Config, Olmo2Model, Olmo2PreTrainedModel
-from transformers.models.olmoe.modeling_olmoe import OlmoeConfig, OlmoeModel, OlmoePreTrainedModel
+from transformers.models.olmo2.modeling_olmo2 import (
+    Olmo2Config,
+    Olmo2Model,
+    Olmo2PreTrainedModel,
+)
+from transformers.models.olmoe.modeling_olmoe import (
+    OlmoeConfig,
+    OlmoeModel,
+    OlmoePreTrainedModel,
+)
 
 
 class Olmo2ForSequenceClassification(Olmo2PreTrainedModel):
@@ -36,7 +44,7 @@ def forward(
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
-    ) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
+    ) -> Union[Tuple, "SequenceClassifierOutputWithPast"]:
         r"""
         labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
             Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -116,7 +124,7 @@ def forward(
 
 
 class OlmoeForSequenceClassification(OlmoePreTrainedModel):
-    def __init__(self, config: Olmo2Config):
+    def __init__(self, config: OlmoeConfig):
         super().__init__(config)
         self.num_labels = config.num_labels
         self.model = OlmoeModel(config)