77 HfArgumentParser ,
88 set_seed ,
99)
10- from transformers import TrainingArguments
1110
12- from tevatron .reranker .arguments import ModelArguments , DataArguments , \
13- TevatronTrainingArguments
11+ from tevatron .reranker .arguments import ModelArguments , DataArguments , TevatronTrainingArguments
1412from tevatron .reranker .modeling import RerankerModel
1513from tevatron .reranker .dataset import RerankerTrainDataset
1614from tevatron .reranker .collator import RerankerTrainCollator
1715from tevatron .reranker .trainer import RerankerTrainer
16+ from tevatron .reranker .gc_trainer import GradCacheTrainer
1817
1918logger = logging .getLogger (__name__ )
2019
2120def main ():
2221 parser = HfArgumentParser ((ModelArguments , DataArguments , TevatronTrainingArguments ))
2322
2423 if len (sys .argv ) == 2 and sys .argv [1 ].endswith (".json" ):
25- model_args , data_args , training_args , tevatron_args = parser .parse_json_file (json_file = os .path .abspath (sys .argv [1 ]))
24+ model_args , data_args , training_args = parser .parse_json_file (json_file = os .path .abspath (sys .argv [1 ]))
2625 else :
27- model_args , data_args , training_args , tevatron_args = parser .parse_args_into_dataclasses ()
28- model_args : ModelArguments
29- data_args : DataArguments
30- training_args : TrainingArguments
31- tevatron_args : TevatronTrainingArguments
32-
33- # Combine TrainingArguments and TevatronTrainingArguments
34- for key , value in vars (tevatron_args ).items ():
35- setattr (training_args , key , value )
26+ model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
3627
3728 if (
3829 os .path .exists (training_args .output_dir )
@@ -60,7 +51,6 @@ def main():
6051 )
6152 logger .info ("Training/evaluation parameters %s" , training_args )
6253 logger .info ("MODEL parameters %s" , model_args )
63- logger .info ("Tevatron parameters %s" , tevatron_args )
6454
6555 set_seed (training_args .seed )
6656
@@ -81,7 +71,9 @@ def main():
8171 train_dataset = RerankerTrainDataset (data_args )
8272 train_collator = RerankerTrainCollator (data_args , tokenizer )
8373
84- trainer = RerankerTrainer (
74+ # Choose the appropriate trainer based on the grad_cache flag
75+ trainer_cls = GradCacheTrainer if training_args .grad_cache else RerankerTrainer
76+ trainer = trainer_cls (
8577 model = model ,
8678 args = training_args ,
8779 train_dataset = train_dataset ,
0 commit comments