Skip to content

Commit 79bf926

Browse files
committed
fix: Consistent FIM examples between evals
1 parent 50ff01c commit 79bf926

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

eval/run_eval.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,10 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
397397
else:
398398
logger.info("No benchmarks enabled.")
399399

400+
# Make a random seed that will be used for both pre- and post-finetune eval
401+
# This is required to keep the FIM examples the same for both runs
402+
eval_seed = random.randint(0, 2**32 - 1)
403+
400404
# Load name of all eval tasks
401405
task_metrics = []
402406
for task_name, task_info in load_task_info():
@@ -409,6 +413,7 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
409413
### Run eval ###
410414

411415
logger.info(f"Running pre-finetune eval...")
416+
set_seed(eval_seed)
412417
eval_metrics = run_task_eval(config, task_info, model, tokenizer)
413418
eval_metrics['task_name'] = task_name
414419
eval_metrics['finetuned'] = False
@@ -432,6 +437,7 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
432437
### Run eval again ###
433438

434439
logger.info(f"Running post-finetune eval...")
440+
set_seed(eval_seed)
435441
eval_metrics = run_task_eval(config, task_info, model, tokenizer)
436442
eval_metrics['task_name'] = task_name
437443
eval_metrics['finetuned'] = True
@@ -522,6 +528,14 @@ def configure_logging():
522528
lite_llm.setLevel(logging.INFO)
523529

524530

531+
def set_seed(seed: Optional[int] = None):
532+
"""Set the random seed for reproducibility."""
533+
if seed is not None:
534+
random.seed(seed)
535+
np.random.seed(seed)
536+
torch.manual_seed(seed)
537+
538+
525539
@hydra.main(version_base=None, config_path='../src/conf', config_name='eval')
526540
def main(config: DictConfig):
527541

@@ -535,11 +549,8 @@ def main(config: DictConfig):
535549
model_provider = ModelProvider.get_instance(config.model)
536550

537551
# Set random seed
538-
if config.get('seed') is not None:
539-
random.seed(config.seed)
540-
np.random.seed(config.seed)
541-
torch.manual_seed(config.seed)
542-
552+
set_seed(config.get('seed'))
553+
543554
# Run eval
544555
logger.info("Running eval...")
545556
eval_results = run_eval(config, model_provider)

0 commit comments

Comments
 (0)