Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Jan 28, 2025
1 parent e5b9923 commit 8e2356f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
7 changes: 5 additions & 2 deletions open_instruct/dpo_tune_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
get_scheduler,
)

from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_preference
from open_instruct.dataset_transformation import (
CHAT_TEMPLATES,
TokenizerConfig,
get_cached_dataset_tulu_preference,
)
from open_instruct.dpo_utils import (
DataCollatorForSeq2SeqDPO,
concatenated_forward,
Expand Down Expand Up @@ -421,7 +425,6 @@ def __post_init__(self):
raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.")



def get_cache_ref_logprobs(
model: torch.nn.Module,
active_dataloader: torch.utils.data.DataLoader,
Expand Down
12 changes: 9 additions & 3 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@
get_scheduler,
)

from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_sft
from open_instruct.dataset_transformation import (
CHAT_TEMPLATES,
TokenizerConfig,
get_cached_dataset_tulu_sft,
)
from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate
from open_instruct.utils import (
ArgumentParserPlus,
Expand Down Expand Up @@ -473,7 +477,7 @@ def main(args: FlatArguments):
train_dataset.set_format(type="pt")
if args.cache_dataset_only:
return

# Load pretrained model and tokenizer
if args.config_name:
config = AutoConfig.from_pretrained(
Expand Down Expand Up @@ -816,7 +820,9 @@ def main(args: FlatArguments):
os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w"
) as f:
f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker.
if accelerator.is_local_main_process: # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem.
if (
accelerator.is_local_main_process
): # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem.
clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints)
accelerator.wait_for_everyone()

Expand Down
18 changes: 13 additions & 5 deletions open_instruct/olmo_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@

import torch
from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.models.olmo2.modeling_olmo2 import Olmo2Config, Olmo2Model, Olmo2PreTrainedModel
from transformers.models.olmoe.modeling_olmoe import OlmoeConfig, OlmoeModel, OlmoePreTrainedModel
from transformers.models.olmo2.modeling_olmo2 import (
Olmo2Config,
Olmo2Model,
Olmo2PreTrainedModel,
)
from transformers.models.olmoe.modeling_olmoe import (
OlmoeConfig,
OlmoeModel,
OlmoePreTrainedModel,
)


class Olmo2ForSequenceClassification(Olmo2PreTrainedModel):
Expand Down Expand Up @@ -36,7 +44,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
) -> Union[Tuple, "SequenceClassifierOutputWithPast"]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
Expand Down Expand Up @@ -116,7 +124,7 @@ def forward(


class OlmoeForSequenceClassification(OlmoePreTrainedModel):
def __init__(self, config: Olmo2Config):
def __init__(self, config: OlmoeConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = OlmoeModel(config)
Expand Down

0 comments on commit 8e2356f

Please sign in to comment.