Skip to content

Add consistent synthetic data flag #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,15 @@ def benchmark():
type=int,
help="The random seed to use for benchmarking to ensure reproducibility.",
)
@click.option(
"--consistent-synthetic-data",
is_flag=True,
default=GenerativeTextScenario.get_default("consistent_synthetic_data"),
help=(
"Ensure synthetic datasets generate the same prompts across different "
"concurrency levels for fair comparison. Only applies to synthetic data."
),
)
def run(
scenario,
target,
Expand All @@ -268,6 +277,7 @@ def run(
output_extras,
output_sampling,
random_seed,
consistent_synthetic_data,
):
click_ctx = click.get_current_context()

Expand All @@ -290,6 +300,7 @@ def run(
cooldown_percent=cooldown_percent,
output_sampling=output_sampling,
random_seed=random_seed,
consistent_synthetic_data=consistent_synthetic_data,
)

try:
Expand Down
2 changes: 2 additions & 0 deletions src/guidellm/benchmark/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def benchmark_generative_text(
output_extras: Optional[dict[str, Any]],
output_sampling: Optional[int],
random_seed: int,
consistent_synthetic_data: bool = False,
show_progress: bool = True,
show_progress_scheduler_stats: bool = False,
output_console: bool = True,
Expand Down Expand Up @@ -89,6 +90,7 @@ async def benchmark_generative_text(
else "infinite" # default to infinite so we don't run out of data
),
random_seed=random_seed,
consistent_synthetic_data=consistent_synthetic_data,
)
unique_requests = request_loader.num_unique_items(raise_err=False)
console.print_line(
Expand Down
1 change: 1 addition & 0 deletions src/guidellm/benchmark/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,4 @@ class Config:
cooldown_percent: Annotated[Optional[float], Field(gt=0, le=1)] = None
output_sampling: Optional[NonNegativeInt] = None
random_seed: int = 42
consistent_synthetic_data: bool = False
24 changes: 22 additions & 2 deletions src/guidellm/request/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from guidellm.config import settings
from guidellm.dataset import ColumnInputTypes, load_dataset
from guidellm.dataset.synthetic import SyntheticDatasetCreator
from guidellm.objects import StandardBaseModel
from guidellm.request.request import GenerationRequest

Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
shuffle: bool = True,
iter_type: Literal["finite", "infinite"] = "finite",
random_seed: int = 42,
consistent_synthetic_data: bool = False,
):
self.data = data
self.data_args = data_args
Expand All @@ -100,6 +102,7 @@ def __init__(
self.shuffle = shuffle
self.iter_type = iter_type
self.random_seed = random_seed
self.consistent_synthetic_data = consistent_synthetic_data

self.column_mappings = self._create_column_mappings(args_column_mappings)
self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests
Expand Down Expand Up @@ -244,8 +247,20 @@ def _get_dataset_iter(
if scope_create_count > 0 and self.iter_type != "infinite":
return None

# For infinite iter_type, we need to handle synthetic datasets specially
# to ensure each concurrency rate gets the same prompts when the
# consistent_synthetic_data flag is enabled
if self.preserve_iter_state and self._preserved_iter is not None:
return self._preserved_iter
if self.consistent_synthetic_data and SyntheticDatasetCreator.is_supported(
self.data, self.data_args
):
# reset the iterator for each concurrency rate to ensure
# consistent prompts across different concurrency levels
pass # Continue to create a new iterator below
else:
# For non-synthetic datasets or when flag is disabled, preserve
# the iterator state as before
return self._preserved_iter

dataset = (
self.dataset
Expand All @@ -255,7 +270,12 @@ def _get_dataset_iter(

dataset_iter = iter(dataset)

if self.preserve_iter_state:
# We preserve the iter state for non-synthetic datasets or when flag
# is disabled
if self.preserve_iter_state and not (
self.consistent_synthetic_data
and SyntheticDatasetCreator.is_supported(self.data, self.data_args)
):
self._preserved_iter = dataset_iter

return dataset_iter
Expand Down
Loading