11import logging
22import os
33import sys
4+ import torch
45from transformers import AutoTokenizer
56from transformers import (
67 HfArgumentParser ,
78 set_seed ,
89)
10+ from torch .nn .parallel import DistributedDataParallel as DDP
11+ import torch .distributed as dist
912from tevatron .reranker .arguments import ModelArguments , DataArguments , TevatronTrainingArguments
1013from tevatron .reranker .modeling import RerankerModel
1114from tevatron .reranker .dataset import RerankerTrainDataset
1215from tevatron .reranker .collator import RerankerTrainCollator
13- from tevatron .reranker .trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
16+ from tevatron .reranker .trainer import RerankerTrainer
1417
1518logger = logging .getLogger (__name__ )
1619
1720
21+ def setup_ddp ():
22+ if 'RANK' in os .environ and 'WORLD_SIZE' in os .environ :
23+ # We're running in a distributed environment
24+ import torch .distributed as dist
25+ rank = int (os .environ ['RANK' ])
26+ world_size = int (os .environ ['WORLD_SIZE' ])
27+ dist .init_process_group (backend = "nccl" )
28+ return rank
29+ else :
30+ # We're not running in a distributed environment
31+ return - 1
32+
33+
1834def main ():
1935 parser = HfArgumentParser ((ModelArguments , DataArguments , TevatronTrainingArguments ))
2036
@@ -23,29 +39,22 @@ def main():
2339 else :
2440 model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
2541
26- if (
27- os .path .exists (training_args .output_dir )
28- and os .listdir (training_args .output_dir )
29- and training_args .do_train
30- and not training_args .overwrite_output_dir
31- ):
32- raise ValueError (
33- f"Output directory ({ training_args .output_dir } ) already exists and is not empty. Use --overwrite_output_dir to overcome."
34- )
42+ local_rank = setup_ddp ()
43+ training_args .local_rank = local_rank
3544
3645 # Setup logging
3746 logging .basicConfig (
3847 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
3948 datefmt = "%m/%d/%Y %H:%M:%S" ,
40- level = logging .INFO if training_args . local_rank in [- 1 , 0 ] else logging .WARN ,
49+ level = logging .INFO if local_rank in [- 1 , 0 ] else logging .WARN ,
4150 )
4251 logger .warning (
4352 "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" ,
44- training_args . local_rank ,
53+ local_rank ,
4554 training_args .device ,
4655 training_args .n_gpu ,
47- bool (training_args . local_rank != - 1 ),
48- training_args .fp16 ,
56+ bool (local_rank != - 1 ),
57+ training_args .fp16 or training_args . bf16 ,
4958 )
5059 logger .info ("Training/evaluation parameters %s" , training_args )
5160 logger .info ("MODEL parameters %s" , model_args )
@@ -67,11 +76,16 @@ def main():
6776 cache_dir = model_args .cache_dir ,
6877 )
6978
79+ # Move model to GPU
80+ if local_rank != - 1 :
81+ model = model .to (local_rank )
82+ model = DDP (model , device_ids = [local_rank ], output_device = local_rank )
83+
7084 train_dataset = RerankerTrainDataset (data_args )
7185 train_collator = RerankerTrainCollator (data_args , tokenizer )
7286
73- # Add GradCache-specific arguments to training_args
7487 training_args .gc_chunk_size = getattr (training_args , 'gc_chunk_size' , 2 )
88+ training_args .grad_cache = getattr (training_args , 'grad_cache' , False )
7589
7690 trainer = RerankerTrainer (
7791 model = model ,
@@ -81,11 +95,11 @@ def main():
8195 )
8296 train_dataset .trainer = trainer
8397
84- trainer .train () # TODO: resume training
98+ trainer .train ()
8599 trainer .save_model ()
86100 if trainer .is_world_process_zero ():
87101 tokenizer .save_pretrained (training_args .output_dir )
88102
89103
90104if __name__ == "__main__" :
91- main ()
105+ main ()
0 commit comments