From c86bffca1245899fa374575654ffe8c354feca2d Mon Sep 17 00:00:00 2001 From: Ankur Kumar Date: Mon, 9 Dec 2024 01:17:59 -0800 Subject: [PATCH 1/4] add infinitebench eval --- benchmark/infinitebench/README.md | 78 +++++++ benchmark/infinitebench/compute_scores.py | 199 ++++++++++++++++++ .../infinitebench/download_infinitebench.sh | 8 + benchmark/infinitebench/eval_long_context.py | 111 ++++++++++ benchmark/infinitebench/eval_utils.py | 122 +++++++++++ 5 files changed, 518 insertions(+) create mode 100644 benchmark/infinitebench/README.md create mode 100644 benchmark/infinitebench/compute_scores.py create mode 100644 benchmark/infinitebench/download_infinitebench.sh create mode 100644 benchmark/infinitebench/eval_long_context.py create mode 100644 benchmark/infinitebench/eval_utils.py diff --git a/benchmark/infinitebench/README.md b/benchmark/infinitebench/README.md new file mode 100644 index 0000000000..cf698e325d --- /dev/null +++ b/benchmark/infinitebench/README.md @@ -0,0 +1,78 @@ +## Download dataset +``` +bash ./download_infinitebench.sh ~/data +``` + +## Run benchmark + +### SGLang +#### Set up environment +```bash +conda create -n sglang python=3.10 +conda activate sglang +pip install --upgrade pip +pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ +``` + +#### Launch server and run eval +``` +python -m sglang.launch_server --model-path gradientai/Llama-3-8B-Instruct-Gradient-1048k --port 30000 --max-total-tokens 131072 +``` + +In a separate terminal, run eval on the first 10 samples as follows +``` +python eval_long_context.py --task passkey --end-idx 10 --data-dir ~/data +``` + +### TensorRT +The following evaluation with TensorRT has been tested with 1xH100 (80 GB SXM5). + +#### Set up enviroment +```bash +conda create -n tensorrt python=3.10 +conda activate tensorrt +conda install -c conda-forge mpi4py openmpi +sudo apt-get -y install libopenmpi-dev && pip install tensorrt_llm==0.15.0 +``` + +```bash +git clone https://github.com/NVIDIA/TensorRT-LLM.git +cd TensorRT-LLM +pip install --upgrade -r requirements-dev.txt +cd examples/llama +``` + +#### Prepare checkpoint +```bash +sudo apt install git-lfs +git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k/ + +python convert_checkpoint.py \ + --model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ + --output_dir /tmp/llama-3-8B-1048k/trt_ckpts \ + --dtype float16 + +``` + +#### Build engine and run eval +```bash +python -m tensorrt_llm.commands.build \ + --checkpoint_dir /tmp/llama-3-8B-1048k/trt_ckpts \ + --output_dir /tmp/llama-3-8B-1048k/trt_engines \ + --gemm_plugin float16 \ + --max_num_tokens 4096 \ + --max_input_len 131072 \ + --max_seq_len 131082 \ + --use_paged_context_fmha enable + +python ../eval_long_context.py \ + --task passkey \ + --engine_dir /tmp/llama-3-8B-1048k/trt_engines \ + --tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ + --stop_idx 10 \ + --max_input_length 131072 \ + --enable_chunked_context \ + --max_tokens_in_paged_kv_cache 131136 \ + --data_dir ~/data \ + --output_dir ./ +``` diff --git a/benchmark/infinitebench/compute_scores.py b/benchmark/infinitebench/compute_scores.py new file mode 100644 index 0000000000..810c3b0a4d --- /dev/null +++ b/benchmark/infinitebench/compute_scores.py @@ -0,0 +1,199 @@ +# MIT License + +# Copyright (c) 2023 OpenBMB + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# reference: https://github.com/OpenBMB/InfiniteBench/blob/main/src/compute_scores.py + +import json +import re +import string +from collections import Counter +from pathlib import Path + +from tqdm import tqdm + + +def normalize_answer(s: str) -> str: + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth) -> tuple[float, float, float]: + common = Counter(prediction) & Counter(ground_truth) + num_same = sum(common.values()) + if num_same == 0: + return 0, 0, 0 + precision = 1.0 * num_same / len(prediction) + recall = 1.0 * num_same / len(ground_truth) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def load_json(fname): + return json.load(open(fname)) + + +def iter_jsonl(fname, cnt=None): + i = 0 + with open(fname, "r", encoding="utf8") as fin: + for line in fin: + if line.strip() == "": # Skip empty lines + continue + if i == cnt: + break + if line.strip() == "": # Skip empty lines + continue + yield json.loads(line) + i += 1 + + +def first_int_match(prediction): + pred_list = re.split("[^0-9]", prediction) + pred_value = "" + for item in pred_list: + if item != "": + pred_value = item + break + return pred_value + + +def split_retrieval_answer(pred: str): + for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]: + pred = pred.replace(c, " ") + words = pred.split() + return words + + +def get_score_one_kv_retrieval(pred, label) -> bool: + for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]: + pred = pred.replace(c, " ") + words = pred.split() + return label in words + + +def get_score_one_passkey(pred, label) -> bool: + if isinstance(label, list): + label = label[0] + return label == first_int_match(pred) + + +def get_score_one(pred: str, label: str, task_name: str) -> float: + """ + Computes the score for one prediction. + Returns one float (zero and one for boolean values). + """ + NAME_TO_SCORE_GETTER = { + # Retrieve + "kv_retrieval": get_score_one_kv_retrieval, + "kv_retrieval_prefix": get_score_one_kv_retrieval, + "kv_retrieval_both": get_score_one_kv_retrieval, + "passkey": get_score_one_passkey, + } + assert task_name in NAME_TO_SCORE_GETTER, f"Invalid task name: {task_name}" + score = NAME_TO_SCORE_GETTER[task_name](pred, label) + return float(score) + + +def get_labels(preds: list) -> list[str]: + possible_label_keys = ["ground_truth", "label"] + for label_key in possible_label_keys: + if label_key in preds[0]: + return [x.get(label_key, "XXXXXXXXXX") for x in preds] + raise ValueError(f"Cannot find label in {preds[0]}") + + +def get_preds(preds: list, data_name: str) -> list[str]: + pred_strings = [] + possible_pred_keys = ["prediction", "pred"] + for pred in preds: + this_pred = "NO PREDICTION" + for pred_key in possible_pred_keys: + if pred_key in pred: + this_pred = pred[pred_key] + break + else: + raise ValueError(f"Cannot find prediction in {pred}") + pred_strings.append(this_pred) + return pred_strings + + +def get_score(labels: list, preds: list, data_name: str) -> float: + """ + Computes the average score for a task. + """ + assert len(labels) == len(preds) + scores = [] + for label, pred in tqdm(zip(labels, preds)): + score = get_score_one(pred, label, data_name) + scores.append(score) + return sum(scores) / len(scores) + + +def load_json(preds_path): + assert preds_path.exists(), f"Predictions not found in: {preds_path}" + print("Loading prediction results from", preds_path) + return list(iter_jsonl(preds_path)) + + +def compute_scores(preds, data_name: str): + labels = get_labels(preds) + preds = get_preds(preds, data_name) + + acc = get_score(labels, preds, data_name) + return acc + + +ALL_TASKS = [ + "passkey", + "kv_retrieval", +] + +if __name__ == "__main__": + from args import parse_args + + arguments = parse_args() + tasks = [arguments.task] + for task in tasks: + preds_path = Path(arguments.preds_file) + preds = load_json(preds_path) + acc = compute_scores(preds, task) + print(acc) diff --git a/benchmark/infinitebench/download_infinitebench.sh b/benchmark/infinitebench/download_infinitebench.sh new file mode 100644 index 0000000000..f56349da15 --- /dev/null +++ b/benchmark/infinitebench/download_infinitebench.sh @@ -0,0 +1,8 @@ +# reference: https://github.com/OpenBMB/InfiniteBench/blob/51d9b37b0f1790ead936df2243abbf7f0420e439/scripts/download_dataset.sh + +save_dir=$1 +mkdir -p ${save_dir} + +for file in code_debug code_run kv_retrieval longbook_choice_eng longbook_qa_chn longbook_qa_eng longbook_sum_eng longdialogue_qa_eng math_calc math_find number_string passkey; do + wget -c https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/${file}.jsonl?download=true -O ./${save_dir}/${file}.jsonl +done diff --git a/benchmark/infinitebench/eval_long_context.py b/benchmark/infinitebench/eval_long_context.py new file mode 100644 index 0000000000..605a8951eb --- /dev/null +++ b/benchmark/infinitebench/eval_long_context.py @@ -0,0 +1,111 @@ +# reference: https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/eval_long_context.py +import argparse +import json +import os +import time + +from compute_scores import compute_scores +from eval_utils import DATA_NAME_TO_MAX_NEW_TOKENS, create_prompt, get_answer, load_data + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) + + +def validate_args(args): + assert args.task in ["passkey", "kv_retrieval"], f"Invalid task: {args.task}" + + +def write_answers(filename, model_id, results): + with open(os.path.expanduser(filename), "w") as fout: + for i in range(len(results)): + ans_json = { + "question_id": results[i]["question_id"], + "model_id": model_id, + "prediction": results[i]["prediction"], + "ground_truth": results[i]["ground_truth"], + } + fout.write(json.dumps(ans_json) + "\n") + + +@sgl.function +def infinitebench(s, question, max_tokens): + s += question + s += sgl.gen( + "answer", + max_tokens=max_tokens, + ) + + +def main(args): + validate_args(args) + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Data + data_name = args.task + max_tokens = DATA_NAME_TO_MAX_NEW_TOKENS[data_name] # max output length + lines = load_data(data_name, data_dir=args.data_dir) + if args.end_idx is None: + args.end_idx = len(lines) + + # Construct prompts + questions = [] + labels = [] + for i in range(len(lines[args.start_idx : args.end_idx])): + questions.append(create_prompt(lines[i], data_name, args.data_dir)) + labels.append(get_answer(lines[i], data_name)) + arguments = [{"question": q, "max_tokens": max_tokens} for q in questions] + + # Run requests + tic = time.time() + results = infinitebench.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=True, + ) + latency = time.time() - tic + + # Compute scores + results = [ + {"ground_truth": label, "prediction": s["answer"], "question_id": line["id"]} + for line, label, s in zip(lines, labels, results) + ] + acc = compute_scores(results, args.task) + print(f"#questions: {len(questions)}, Latency: {latency:.2f}, Accuracy: {acc:.3f}") + + # Write results to file + model_id = backend.model_info["model_path"] + answer_file = f"tmp_output_{data_name}_{args.backend}.txt" + write_answers(answer_file, model_id, results) + + with open(args.result_file, "a") as fout: + value = { + "task": args.task, + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": len(questions), + "other": { + "num_questions": len(questions), + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--task", type=str, choices=["passkey", "kv_retrieval"], required=True + ) + parser.add_argument("--data-dir", type=str, default="./data") + parser.add_argument("--start-idx", type=int, default=0) + parser.add_argument("--end-idx", type=int, default=None) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/infinitebench/eval_utils.py b/benchmark/infinitebench/eval_utils.py new file mode 100644 index 0000000000..66d30b6f32 --- /dev/null +++ b/benchmark/infinitebench/eval_utils.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# reference: https://github.com/OpenBMB/InfiniteBench/blob/main/src/eval_utils.py + +import os +import sys + +sys.path.append(os.path.dirname(__file__)) + +import json +from pathlib import Path + +DATA_NAME_TO_PATH = { + # Retrieval tasks + "passkey": "passkey.jsonl", + "kv_retrieval": "kv_retrieval.jsonl", +} + +DATA_NAME_TO_MAX_NEW_TOKENS = { + "passkey": 6, + "kv_retrieval": 50, +} + + +def iter_jsonl(fname, cnt=None): + i = 0 + with open(fname, "r") as fin: + for line in fin: + if i == cnt: + break + yield json.loads(line) + i += 1 + + +def load_json(fname): + return json.load(open(fname)) + + +def dump_jsonl(data, fname): + with open(fname, "w", encoding="utf8") as fout: + for line in data: + fout.write(json.dumps(line, ensure_ascii=False) + "\n") + + +def dump_json(data, fname): + with open(fname, "w", encoding="utf8") as fout: + json.dump(data, fout, indent=2, ensure_ascii=False) + + +def load_data(data_name: str, data_dir: str = "../data/InfiniteBench/"): + path = DATA_NAME_TO_PATH[data_name] + fname = Path(data_dir, path) + return list(iter_jsonl(fname)) + + +def create_prompt(eg: dict, data_name: str, data_dir) -> str: + """ + Create prompt for a given example. + + Args: + eg: example dict + data_name: name of the dataset/task + """ + data_dir = Path(data_dir) + + templates = { + "passkey": "There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information.\n\n{context}\n\n{input}\n\nThe pass key is", # noqa + "kv_retrieval": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}\n\n{input}", # noqa + } + + if "content" in eg: + content = eg["content"] + del eg["content"] + eg["context"] = content + + format_dict = { + "context": eg["context"], + "input": eg["input"], + } + prompt = templates[data_name].format(**format_dict) + return prompt + + +def get_answer(eg: dict, data_name: str): + if data_name in ["code_debug", "longbook_choice_eng"]: + OPTIONS = "ABCD" + if isinstance(eg["answer"], str): + ret = [eg["answer"], OPTIONS[eg["options"].index(eg["answer"])]] + elif isinstance(eg["answer"], list): + if len(eg["answer"]) == 1: + ret = [eg["answer"][0], OPTIONS[eg["options"].index(eg["answer"][0])]] + elif len(eg["answer"]) == 2 and eg["answer"][1] in ["A", "B", "C", "D"]: + ret = eg["answer"] + else: + raise ValueError + else: + raise ValueError + return ret + + return eg["answer"] + + +def truncate_input(input, max_length, manner="middle"): + if len(input) <= max_length: + return input + if manner == "middle": + return input[0 : max_length // 2] + input[-max_length // 2 :] + else: + return None From 46df6d8792728e7f2bed5ccabd4ca2f03382e4b7 Mon Sep 17 00:00:00 2001 From: Ankur Kumar Date: Mon, 9 Dec 2024 01:50:13 -0800 Subject: [PATCH 2/4] fix download path --- benchmark/infinitebench/download_infinitebench.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/infinitebench/download_infinitebench.sh b/benchmark/infinitebench/download_infinitebench.sh index f56349da15..45f8aa12e6 100644 --- a/benchmark/infinitebench/download_infinitebench.sh +++ b/benchmark/infinitebench/download_infinitebench.sh @@ -4,5 +4,5 @@ save_dir=$1 mkdir -p ${save_dir} for file in code_debug code_run kv_retrieval longbook_choice_eng longbook_qa_chn longbook_qa_eng longbook_sum_eng longdialogue_qa_eng math_calc math_find number_string passkey; do - wget -c https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/${file}.jsonl?download=true -O ./${save_dir}/${file}.jsonl + wget -c https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/${file}.jsonl?download=true -O ${save_dir}/${file}.jsonl done From 275a64ef71c76670fc8f7a619c0248baa80032ac Mon Sep 17 00:00:00 2001 From: Ankur Kumar Date: Tue, 10 Dec 2024 23:56:09 -0800 Subject: [PATCH 3/4] move dataset download to main script --- benchmark/infinitebench/README.md | 17 +++++---- .../{eval_long_context.py => bench_sglang.py} | 35 ++++++++++++------- .../infinitebench/download_infinitebench.sh | 8 ----- 3 files changed, 33 insertions(+), 27 deletions(-) rename benchmark/infinitebench/{eval_long_context.py => bench_sglang.py} (75%) delete mode 100644 benchmark/infinitebench/download_infinitebench.sh diff --git a/benchmark/infinitebench/README.md b/benchmark/infinitebench/README.md index cf698e325d..5bcf39d947 100644 --- a/benchmark/infinitebench/README.md +++ b/benchmark/infinitebench/README.md @@ -1,8 +1,3 @@ -## Download dataset -``` -bash ./download_infinitebench.sh ~/data -``` - ## Run benchmark ### SGLang @@ -21,7 +16,7 @@ python -m sglang.launch_server --model-path gradientai/Llama-3-8B-Instruct-Gradi In a separate terminal, run eval on the first 10 samples as follows ``` -python eval_long_context.py --task passkey --end-idx 10 --data-dir ~/data +python bench_sglang.py --task passkey --num-samples 10 ``` ### TensorRT @@ -42,6 +37,14 @@ pip install --upgrade -r requirements-dev.txt cd examples/llama ``` +#### Download dataset +The following commands will download the dataset files in `./data` directory, which is hardcoded in the script. +```bash +URL="https://raw.githubusercontent.com/OpenBMB/InfiniteBench/51d9b37b0f1790ead936df2243abbf7f0420e439/scripts/download_dataset.sh" +wget $URL -O download_infinitebench.sh +bash download_infinitebench.sh +``` + #### Prepare checkpoint ```bash sudo apt install git-lfs @@ -73,6 +76,6 @@ python ../eval_long_context.py \ --max_input_length 131072 \ --enable_chunked_context \ --max_tokens_in_paged_kv_cache 131136 \ - --data_dir ~/data \ + --data_dir ./data \ --output_dir ./ ``` diff --git a/benchmark/infinitebench/eval_long_context.py b/benchmark/infinitebench/bench_sglang.py similarity index 75% rename from benchmark/infinitebench/eval_long_context.py rename to benchmark/infinitebench/bench_sglang.py index 605a8951eb..269ff3164e 100644 --- a/benchmark/infinitebench/eval_long_context.py +++ b/benchmark/infinitebench/bench_sglang.py @@ -5,13 +5,14 @@ import time from compute_scores import compute_scores -from eval_utils import DATA_NAME_TO_MAX_NEW_TOKENS, create_prompt, get_answer, load_data +from eval_utils import DATA_NAME_TO_MAX_NEW_TOKENS, create_prompt, get_answer import sglang as sgl from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) +from sglang.utils import download_and_cache_file, read_jsonl def validate_args(args): @@ -46,18 +47,21 @@ def main(args): backend = select_sglang_backend(args) sgl.set_default_backend(backend) - # Data + # Download and load data data_name = args.task + data_url = "https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/${data_name}.jsonl" max_tokens = DATA_NAME_TO_MAX_NEW_TOKENS[data_name] # max output length - lines = load_data(data_name, data_dir=args.data_dir) - if args.end_idx is None: - args.end_idx = len(lines) + + filename = download_and_cache_file(data_url) + lines = list(read_jsonl(filename)) + if args.num_samples is None: + args.num_samples = len(lines) # Construct prompts questions = [] labels = [] - for i in range(len(lines[args.start_idx : args.end_idx])): - questions.append(create_prompt(lines[i], data_name, args.data_dir)) + for i in range(len(lines[: args.num_samples])): + questions.append(create_prompt(lines[i], data_name, os.path.dirname(filename))) labels.append(get_answer(lines[i], data_name)) arguments = [{"question": q, "max_tokens": max_tokens} for q in questions] @@ -73,8 +77,12 @@ def main(args): # Compute scores results = [ - {"ground_truth": label, "prediction": s["answer"], "question_id": line["id"]} - for line, label, s in zip(lines, labels, results) + { + "ground_truth": label, + "prediction": result["answer"], + "question_id": line["id"], + } + for line, label, result in zip(lines, labels, results) ] acc = compute_scores(results, args.task) print(f"#questions: {len(questions)}, Latency: {latency:.2f}, Accuracy: {acc:.3f}") @@ -104,8 +112,11 @@ def main(args): parser.add_argument( "--task", type=str, choices=["passkey", "kv_retrieval"], required=True ) - parser.add_argument("--data-dir", type=str, default="./data") - parser.add_argument("--start-idx", type=int, default=0) - parser.add_argument("--end-idx", type=int, default=None) + parser.add_argument( + "--num-samples", + type=int, + default=None, + help="Number of samples from the beginning of dataset to use for eval.", + ) args = add_common_sglang_args_and_parse(parser) main(args) diff --git a/benchmark/infinitebench/download_infinitebench.sh b/benchmark/infinitebench/download_infinitebench.sh deleted file mode 100644 index 45f8aa12e6..0000000000 --- a/benchmark/infinitebench/download_infinitebench.sh +++ /dev/null @@ -1,8 +0,0 @@ -# reference: https://github.com/OpenBMB/InfiniteBench/blob/51d9b37b0f1790ead936df2243abbf7f0420e439/scripts/download_dataset.sh - -save_dir=$1 -mkdir -p ${save_dir} - -for file in code_debug code_run kv_retrieval longbook_choice_eng longbook_qa_chn longbook_qa_eng longbook_sum_eng longdialogue_qa_eng math_calc math_find number_string passkey; do - wget -c https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/${file}.jsonl?download=true -O ${save_dir}/${file}.jsonl -done From d0645b834581de80c8ef24647bc52fc9fdf1eaac Mon Sep 17 00:00:00 2001 From: Ankur Kumar Date: Wed, 11 Dec 2024 00:41:20 -0800 Subject: [PATCH 4/4] fix url --- benchmark/infinitebench/bench_sglang.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/infinitebench/bench_sglang.py b/benchmark/infinitebench/bench_sglang.py index 269ff3164e..ae021ff5ab 100644 --- a/benchmark/infinitebench/bench_sglang.py +++ b/benchmark/infinitebench/bench_sglang.py @@ -49,7 +49,7 @@ def main(args): # Download and load data data_name = args.task - data_url = "https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/${data_name}.jsonl" + data_url = f"https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/{data_name}.jsonl" max_tokens = DATA_NAME_TO_MAX_NEW_TOKENS[data_name] # max output length filename = download_and_cache_file(data_url)