Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InfiniteBench for long context benchmarking #2421

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions benchmark/infinitebench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
## Download dataset
```
bash ./download_infinitebench.sh ~/data
```

## Run benchmark

### SGLang
#### Set up environment
```bash
conda create -n sglang python=3.10
conda activate sglang
pip install --upgrade pip
pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/
```

#### Launch server and run eval
```
python -m sglang.launch_server --model-path gradientai/Llama-3-8B-Instruct-Gradient-1048k --port 30000 --max-total-tokens 131072
```

In a separate terminal, run eval on the first 10 samples as follows
```
python eval_long_context.py --task passkey --end-idx 10 --data-dir ~/data
```

### TensorRT
The following evaluation with TensorRT has been tested with 1xH100 (80 GB SXM5).

#### Set up enviroment
```bash
conda create -n tensorrt python=3.10
conda activate tensorrt
conda install -c conda-forge mpi4py openmpi
sudo apt-get -y install libopenmpi-dev && pip install tensorrt_llm==0.15.0
```

```bash
git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM
pip install --upgrade -r requirements-dev.txt
cd examples/llama
```

#### Prepare checkpoint
```bash
sudo apt install git-lfs
git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k/

python convert_checkpoint.py \
--model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \
--output_dir /tmp/llama-3-8B-1048k/trt_ckpts \
--dtype float16

```

#### Build engine and run eval
```bash
python -m tensorrt_llm.commands.build \
--checkpoint_dir /tmp/llama-3-8B-1048k/trt_ckpts \
--output_dir /tmp/llama-3-8B-1048k/trt_engines \
--gemm_plugin float16 \
--max_num_tokens 4096 \
--max_input_len 131072 \
--max_seq_len 131082 \
--use_paged_context_fmha enable

python ../eval_long_context.py \
--task passkey \
--engine_dir /tmp/llama-3-8B-1048k/trt_engines \
--tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \
--stop_idx 10 \
--max_input_length 131072 \
--enable_chunked_context \
--max_tokens_in_paged_kv_cache 131136 \
--data_dir ~/data \
--output_dir ./
```
199 changes: 199 additions & 0 deletions benchmark/infinitebench/compute_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# MIT License

# Copyright (c) 2023 OpenBMB

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# reference: https://github.com/OpenBMB/InfiniteBench/blob/main/src/compute_scores.py

import json
import re
import string
from collections import Counter
from pathlib import Path

from tqdm import tqdm


def normalize_answer(s: str) -> str:
"""Lower text and remove punctuation, articles and extra whitespace."""

def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth) -> tuple[float, float, float]:
common = Counter(prediction) & Counter(ground_truth)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
precision = 1.0 * num_same / len(prediction)
recall = 1.0 * num_same / len(ground_truth)
f1 = (2 * precision * recall) / (precision + recall)
return f1, precision, recall


def load_json(fname):
return json.load(open(fname))


def iter_jsonl(fname, cnt=None):
i = 0
with open(fname, "r", encoding="utf8") as fin:
for line in fin:
if line.strip() == "": # Skip empty lines
continue
if i == cnt:
break
if line.strip() == "": # Skip empty lines
continue
yield json.loads(line)
i += 1


def first_int_match(prediction):
pred_list = re.split("[^0-9]", prediction)
pred_value = ""
for item in pred_list:
if item != "":
pred_value = item
break
return pred_value


def split_retrieval_answer(pred: str):
for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]:
pred = pred.replace(c, " ")
words = pred.split()
return words


def get_score_one_kv_retrieval(pred, label) -> bool:
for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]:
pred = pred.replace(c, " ")
words = pred.split()
return label in words


def get_score_one_passkey(pred, label) -> bool:
if isinstance(label, list):
label = label[0]
return label == first_int_match(pred)


def get_score_one(pred: str, label: str, task_name: str) -> float:
"""
Computes the score for one prediction.
Returns one float (zero and one for boolean values).
"""
NAME_TO_SCORE_GETTER = {
# Retrieve
"kv_retrieval": get_score_one_kv_retrieval,
"kv_retrieval_prefix": get_score_one_kv_retrieval,
"kv_retrieval_both": get_score_one_kv_retrieval,
"passkey": get_score_one_passkey,
}
assert task_name in NAME_TO_SCORE_GETTER, f"Invalid task name: {task_name}"
score = NAME_TO_SCORE_GETTER[task_name](pred, label)
return float(score)


def get_labels(preds: list) -> list[str]:
possible_label_keys = ["ground_truth", "label"]
for label_key in possible_label_keys:
if label_key in preds[0]:
return [x.get(label_key, "XXXXXXXXXX") for x in preds]
raise ValueError(f"Cannot find label in {preds[0]}")


def get_preds(preds: list, data_name: str) -> list[str]:
pred_strings = []
possible_pred_keys = ["prediction", "pred"]
for pred in preds:
this_pred = "NO PREDICTION"
for pred_key in possible_pred_keys:
if pred_key in pred:
this_pred = pred[pred_key]
break
else:
raise ValueError(f"Cannot find prediction in {pred}")
pred_strings.append(this_pred)
return pred_strings


def get_score(labels: list, preds: list, data_name: str) -> float:
"""
Computes the average score for a task.
"""
assert len(labels) == len(preds)
scores = []
for label, pred in tqdm(zip(labels, preds)):
score = get_score_one(pred, label, data_name)
scores.append(score)
return sum(scores) / len(scores)


def load_json(preds_path):
assert preds_path.exists(), f"Predictions not found in: {preds_path}"
print("Loading prediction results from", preds_path)
return list(iter_jsonl(preds_path))


def compute_scores(preds, data_name: str):
labels = get_labels(preds)
preds = get_preds(preds, data_name)

acc = get_score(labels, preds, data_name)
return acc


ALL_TASKS = [
"passkey",
"kv_retrieval",
]

if __name__ == "__main__":
from args import parse_args

arguments = parse_args()
tasks = [arguments.task]
for task in tasks:
preds_path = Path(arguments.preds_file)
preds = load_json(preds_path)
acc = compute_scores(preds, task)
print(acc)
8 changes: 8 additions & 0 deletions benchmark/infinitebench/download_infinitebench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# reference: https://github.com/OpenBMB/InfiniteBench/blob/51d9b37b0f1790ead936df2243abbf7f0420e439/scripts/download_dataset.sh

save_dir=$1
mkdir -p ${save_dir}

for file in code_debug code_run kv_retrieval longbook_choice_eng longbook_qa_chn longbook_qa_eng longbook_sum_eng longdialogue_qa_eng math_calc math_find number_string passkey; do
wget -c https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/${file}.jsonl?download=true -O ${save_dir}/${file}.jsonl
done
111 changes: 111 additions & 0 deletions benchmark/infinitebench/eval_long_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# reference: https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/eval_long_context.py
import argparse
import json
import os
import time

from compute_scores import compute_scores
from eval_utils import DATA_NAME_TO_MAX_NEW_TOKENS, create_prompt, get_answer, load_data

import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)


def validate_args(args):
assert args.task in ["passkey", "kv_retrieval"], f"Invalid task: {args.task}"


def write_answers(filename, model_id, results):
with open(os.path.expanduser(filename), "w") as fout:
for i in range(len(results)):
ans_json = {
"question_id": results[i]["question_id"],
"model_id": model_id,
"prediction": results[i]["prediction"],
"ground_truth": results[i]["ground_truth"],
}
fout.write(json.dumps(ans_json) + "\n")


@sgl.function
def infinitebench(s, question, max_tokens):
s += question
s += sgl.gen(
"answer",
max_tokens=max_tokens,
)


def main(args):
validate_args(args)

# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)

# Data
data_name = args.task
max_tokens = DATA_NAME_TO_MAX_NEW_TOKENS[data_name] # max output length
lines = load_data(data_name, data_dir=args.data_dir)
if args.end_idx is None:
args.end_idx = len(lines)

# Construct prompts
questions = []
labels = []
for i in range(len(lines[args.start_idx : args.end_idx])):
questions.append(create_prompt(lines[i], data_name, args.data_dir))
labels.append(get_answer(lines[i], data_name))
arguments = [{"question": q, "max_tokens": max_tokens} for q in questions]

# Run requests
tic = time.time()
results = infinitebench.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic

# Compute scores
results = [
{"ground_truth": label, "prediction": s["answer"], "question_id": line["id"]}
for line, label, s in zip(lines, labels, results)
]
acc = compute_scores(results, args.task)
print(f"#questions: {len(questions)}, Latency: {latency:.2f}, Accuracy: {acc:.3f}")

# Write results to file
model_id = backend.model_info["model_path"]
answer_file = f"tmp_output_{data_name}_{args.backend}.txt"
write_answers(answer_file, model_id, results)

with open(args.result_file, "a") as fout:
value = {
"task": args.task,
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(questions),
"other": {
"num_questions": len(questions),
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--task", type=str, choices=["passkey", "kv_retrieval"], required=True
)
parser.add_argument("--data-dir", type=str, default="./data")
parser.add_argument("--start-idx", type=int, default=0)
parser.add_argument("--end-idx", type=int, default=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add more descriptions about the "--start-idx" and "--end-idx" ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed these arguments, which were borrowed from tensorrt eval script, and added num-samples with description.

args = add_common_sglang_args_and_parse(parser)
main(args)
Loading
Loading