@@ -397,6 +397,10 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
397
397
else :
398
398
logger .info ("No benchmarks enabled." )
399
399
400
+ # Make a random seed that will be used for both pre- and post-finetune eval
401
+ # This is required to keep the FIM examples the same for both runs
402
+ eval_seed = random .randint (0 , 2 ** 32 - 1 )
403
+
400
404
# Load name of all eval tasks
401
405
task_metrics = []
402
406
for task_name , task_info in load_task_info ():
@@ -409,6 +413,7 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
409
413
### Run eval ###
410
414
411
415
logger .info (f"Running pre-finetune eval..." )
416
+ set_seed (eval_seed )
412
417
eval_metrics = run_task_eval (config , task_info , model , tokenizer )
413
418
eval_metrics ['task_name' ] = task_name
414
419
eval_metrics ['finetuned' ] = False
@@ -432,6 +437,7 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
432
437
### Run eval again ###
433
438
434
439
logger .info (f"Running post-finetune eval..." )
440
+ set_seed (eval_seed )
435
441
eval_metrics = run_task_eval (config , task_info , model , tokenizer )
436
442
eval_metrics ['task_name' ] = task_name
437
443
eval_metrics ['finetuned' ] = True
@@ -522,6 +528,14 @@ def configure_logging():
522
528
lite_llm .setLevel (logging .INFO )
523
529
524
530
531
+ def set_seed (seed : Optional [int ] = None ):
532
+ """Set the random seed for reproducibility."""
533
+ if seed is not None :
534
+ random .seed (seed )
535
+ np .random .seed (seed )
536
+ torch .manual_seed (seed )
537
+
538
+
525
539
@hydra .main (version_base = None , config_path = '../src/conf' , config_name = 'eval' )
526
540
def main (config : DictConfig ):
527
541
@@ -535,11 +549,8 @@ def main(config: DictConfig):
535
549
model_provider = ModelProvider .get_instance (config .model )
536
550
537
551
# Set random seed
538
- if config .get ('seed' ) is not None :
539
- random .seed (config .seed )
540
- np .random .seed (config .seed )
541
- torch .manual_seed (config .seed )
542
-
552
+ set_seed (config .get ('seed' ))
553
+
543
554
# Run eval
544
555
logger .info ("Running eval..." )
545
556
eval_results = run_eval (config , model_provider )
0 commit comments