diff --git a/src/generate_model_answers.py b/src/generate_model_answers.py index a998b2a..474364b 100644 --- a/src/generate_model_answers.py +++ b/src/generate_model_answers.py @@ -22,7 +22,8 @@ def parse_args(): choices=LIST_OF_MODELS, required=True) parser.add_argument("--dataset", - choices=LIST_OF_DATASETS) + required=True, + help="Name of the dataset (e.g., triviaqa, triviaqa_test, etc.)") parser.add_argument("--verbose", action='store_true', help='print more information') parser.add_argument("--n_samples", type=int, help='number of examples to use', default=None)