Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nanoeval] sync #50

Merged
merged 3 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions project/nanoeval/nanoeval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import inspect
import logging
import os
from pprint import pformat
from typing import (
Any,
Expand All @@ -23,14 +24,14 @@
from chz.factories import function
from nanoeval._multiprocessing_utils import check_multiprocess_safe
from nanoeval.asyncio_utils import HasAsyncContextManager
from nanoeval.recorder import dummy_recorder
from nanoeval.recorder_protocol import RecorderConfig


class Task(BaseModel):
"""
All nanoeval Tasks must inherit from this class.
"""

question_id: str
attempt_id: int = 1
retry_idx: int = 0
Expand Down Expand Up @@ -129,10 +130,10 @@ async def get_summary(self, results: list[tuple[TTask, TResult]]) -> dict[str, A
is to compute accuracy (this will underweight instances with fewer rollouts). Instead, you should compute accuracy
by instance, and then average over instances.

Notably, this function is called on an interval before the eval is completed, so it should be able to handle partial
results and partial results are very likely to be ragged.
Notably, this function is called on an interval, so it should be able to handle partial results and partial
results are very likely to be ragged.

This function is used by the default implementation of `eval.get_full_summary()`.
This function is called by `eval.get_full_summary()`.
"""
raise NotImplementedError

Expand Down Expand Up @@ -198,7 +199,10 @@ async def self_test(self) -> None:
@chz.chz
class RunnerArgs:
# Runner options.
concurrency: int = 4096
concurrency: int | None = chz.field(
default=4096,
doc="Per-eval concurrency. If None, concurrency is not limited.",
)
# If enabled, use multiprocessing. This can be useful for CPU-bound tasks, and uses
# multiprocessing as the outer loop, and asyncio concurrency as the inner loop.
# We split tasks into groups of size `concurrency`. A subprocess processes one group
Expand All @@ -215,10 +219,10 @@ class RunnerArgs:
doc="Limit the number of tasks run. The limit is the first N tasks selected before shuffling.",
)
run_set_id: str | None = None
recorder: RecorderConfig = chz.field(
meta_factory=function(default_module="nanoeval.recorders"),
default_factory=dummy_recorder,
doc="Recorder configuration used to create a recorder for the eval.",
recorder: RecorderConfig | None = chz.field(
meta_factory=function(),
default=None,
doc="Recorder configuration used to create a recorder for the eval. If None, default recorder is used (as determined by `library_config().get_default_recorder()`.",
)
enable_slackbot: bool = True
slack_name: str | None = None
Expand All @@ -227,7 +231,7 @@ class RunnerArgs:
summary_interval: float | None = None
use_monitor: bool = chz.field(
default=False,
doc="If enabled, starts a streamlit server on port 8501 to monitor the eval. You can also run it manually by running `streamlit run project/nanoeval/monitor.py`.",
doc="If enabled, starts a streamlit server on port 8501 to monitor the eval. You can also run it manually by running `python3 -m nanoeval.bin.mon`.",
)
max_retries: int = chz.field(
default=16,
Expand Down Expand Up @@ -255,6 +259,16 @@ def _validate_multiprocessing_options(self) -> None:
def _numerical_limits(self) -> None:
assert self.n_tasks is None or self.n_tasks > 0

@chz.validate
def _validate_concurrency(self) -> None:
if self.concurrency is not None and self.concurrency <= 0:
if self.concurrency == 0 and os.environ.get("NANOEVAL_ALLOW_ZERO_CONCURRENCY"):
pass
else:
raise ValueError(
"concurrency must be > 0 or None unless NANOEVAL_ALLOW_ZERO_CONCURRENCY is set."
)


@chz.chz
class EvalSpec:
Expand Down
5 changes: 3 additions & 2 deletions project/nanoeval/nanoeval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ async def run(spec: EvalSpec) -> dict[str, Any]: # type: ignore
)

# Build the recorder!
recorder = await spec.runner.recorder.factory(spec, len(tasks))
recorder_config = spec.runner.recorder or get_library_config().get_default_recorder()
recorder = await recorder_config.factory(spec, len(tasks))

# Load all tasks into the database
with db.conn() as conn:
Expand All @@ -358,7 +359,7 @@ async def run(spec: EvalSpec) -> dict[str, Any]: # type: ignore
datetime.now(),
dill.dumps(spec),
dill.dumps(recorder),
spec.runner.concurrency,
spec.runner.concurrency if spec.runner.concurrency is not None else 999999,
),
)

Expand Down
91 changes: 57 additions & 34 deletions project/nanoeval/nanoeval/evaluation_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import asyncio
import os
from typing import Any
from unittest.mock import patch

import numpy as np
import pytest
Expand All @@ -24,45 +26,51 @@ async def test_concurrency_blocking_edge_case() -> None:

If this test fails, it will hang forever.
"""
async with global_exit_stack, db.open_run_set_db(backup=False), asyncio.TaskGroup() as tg:
# This eval will never finish and get queued in the executor forever.
# Note concurrency=0.
background = tg.create_task(
nanoeval.validate(
EvalSpec(
eval=GPQAEval(solver=MockSolver()),
runner=RunnerArgs(n_tasks=1, concurrency=0),
)
)
)

# Wait for the task to get registered
while True:
with db.conn() as c:
(count,) = c.execute("SELECT COUNT(*) FROM eval").fetchone()
if count > 0:
break
await asyncio.sleep(0.5)

print("made it to the point where the first one got picked up by the executor")

for f in asyncio.as_completed(
[
# Will never finish - and that's ok!
background,
# We expect this one to actually finish.
# In general setting concurrency to 0 is not allowed as it would just make the eval hang,
# however we can bypass this for tests in order to create an eval that intentionally hangs.
with patch.dict(os.environ, {"NANOEVAL_ALLOW_ZERO_CONCURRENCY": "1"}, clear=False):
async with global_exit_stack, db.open_run_set_db(backup=False), asyncio.TaskGroup() as tg:
# This eval will never finish and get queued in the executor forever.
background = tg.create_task(
nanoeval.validate(
EvalSpec(
eval=GPQAEval(solver=MockSolver()),
runner=RunnerArgs(n_tasks=1, concurrency=1),
runner=RunnerArgs(
n_tasks=1,
concurrency=0, # Forces the eval to hang.
),
)
),
]
):
await f
print("We did it! one of the evals finished!")
await cancel_task(background)
break
)
)

# Wait for the task to get registered
while True:
with db.conn() as c:
(count,) = c.execute("SELECT COUNT(*) FROM eval").fetchone()
if count > 0:
break
await asyncio.sleep(0.5)

print("made it to the point where the first one got picked up by the executor")

for f in asyncio.as_completed(
[
# Will never finish - and that's ok!
background,
# We expect this one to actually finish.
nanoeval.validate(
EvalSpec(
eval=GPQAEval(solver=MockSolver()),
runner=RunnerArgs(n_tasks=1, concurrency=1),
)
),
]
):
await f
print("We did it! one of the evals finished!")
await cancel_task(background)
break


@pytest.mark.asyncio
Expand Down Expand Up @@ -346,3 +354,18 @@ async def solve(self, task: MCQTask) -> Answer:
)

assert first_eval_succeeded


@pytest.mark.asyncio
async def test_concurrency_none() -> None:
async with global_exit_stack, db.open_run_set_db(backup=False):
report = await nanoeval.validate(
EvalSpec(
eval=GPQAEval(solver=MockSolver()),
runner=RunnerArgs(
n_tasks=1,
concurrency=None,
),
)
)
assert "accuracy" in report
82 changes: 81 additions & 1 deletion project/nanoeval/nanoeval/library_config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import functools
import logging
import os
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator, Literal, Self

import pandas as pd
import structlog
from structlog.typing import EventDict

import chz
from nanoeval.recorder_protocol import BasicRunSpec, RecorderConfig, RecorderProtocol
Expand All @@ -17,6 +21,9 @@
from nanoeval.eval import EvalSpec


ENV_NANOEVAL_LOG_ALL = "NANOEVAL_LOG_ALL"


@functools.cache
def root_dir() -> Path:
return Path(tempfile.gettempdir()) / "nanoeval"
Expand Down Expand Up @@ -91,6 +98,38 @@ async def factory(self, spec: EvalSpec, num_tasks: int) -> RecorderProtocol:
return _DefaultDummyRecorder(run_spec=self._make_default_run_spec(spec))


def _rename_field(
old: str, new: str, logger: logging.Logger, name: str, event_dict: EventDict
) -> EventDict:
del logger, name

if value := event_dict.get(old):
event_dict[new] = value
del event_dict[old]
return event_dict


def _remove_all_fields_except(
to_keep: list[str], logger: logging.Logger, name: str, event_dict: EventDict
) -> EventDict:
del logger, name

for key in list(event_dict.keys()):
if key not in to_keep:
del event_dict[key]
return event_dict


class PrintOrWarningFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
if record.levelno > logging.INFO or (
os.environ.get(ENV_NANOEVAL_LOG_ALL) and record.levelno == logging.INFO
):
return True

return isinstance(record.msg, dict) and record.msg.get("_print", False)


class LibraryConfig:
"""
Hooks to configure parts of the nanoeval library. Shared across all runs in the process.
Expand All @@ -112,7 +151,7 @@ async def send_user_notification(self, message: str, extra: str | None = None) -
extra=extra,
)

def on_logging_setup(self) -> None:
def setup_logging(self) -> None:
# Set up structlog according to https://www.structlog.org/en/stable/standard-library.html
# Basically, we convert structlogs to logging-style record and then process them using
# structlog formatters into json for humio, and console for stdout
Expand All @@ -124,6 +163,42 @@ def on_logging_setup(self) -> None:
logger_factory=structlog.stdlib.LoggerFactory(),
)

# Remove all StreamHandlers from the root logger
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.StreamHandler):
logging.getLogger().removeHandler(handler)

handler = logging.StreamHandler()
# Use OUR `ProcessorFormatter` to format all `logging` entries to stdout.
handler.setFormatter(
structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.MaybeTimeStamper(fmt="iso"),
functools.partial(
_remove_all_fields_except,
["timestamp", "level", "event", "component", "exc_info"],
),
structlog.dev.ConsoleRenderer(),
],
# logger -> structlog transforms
foreign_pre_chain=[
structlog.stdlib.add_logger_name,
partial(_rename_field, "logger", "component"),
partial(_rename_field, "logger_name", "component"),
partial(_rename_field, "log", "event"),
structlog.stdlib.ExtraAdder(),
],
)
)

handler.addFilter(PrintOrWarningFilter())
root_logger = logging.getLogger()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)

@contextmanager
def set_recorder_context(
self, rec: RecorderProtocol, sample_id: str, group_id: str
Expand All @@ -133,6 +208,11 @@ def set_recorder_context(
def get_dummy_recorder(self, log: bool) -> RecorderConfig:
return _DummyRecorderConfig()

def get_default_recorder(self) -> RecorderConfig:
from nanoeval.json_recorder import json_recorder

return json_recorder()

def writable_root_dir(self) -> str:
return str(root_dir())

Expand Down
11 changes: 7 additions & 4 deletions project/nanoeval/nanoeval/metrics/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,28 @@ def compute_default_metrics_on_correctness_without_answer_groups(
samples_df["answer_group_id"] = samples_df["is_correct"].astype(int)
del samples_df["is_correct"]

# Create two answer groups: 0 is wrong, 1 is right.
# Create two answer groups: 0 is wrong, 1 is right. Ensure no instance-level duplicates.
answer_group_correctness_df = pd.DataFrame(
flatten(
[
{
"instance": sample.instance,
"instance": instance,
"answer_group_id": 0,
"is_correct": False,
},
{
"instance": sample.instance,
"instance": instance,
"answer_group_id": 1,
"is_correct": True,
},
]
for sample in samples_df.itertuples()
for instance in samples_df["instance"].unique()
)
)

# Should have 1 answer group for False and 1 for True
assert len(answer_group_correctness_df) == 2 * len(samples_df["instance"].unique())

return compute_default_metrics(samples_df, answer_group_correctness_df)


Expand Down
6 changes: 6 additions & 0 deletions project/nanoeval/nanoeval/recorders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""This module exports all bundled recorders for nanoeval."""

from nanoeval.json_recorder import json_recorder
from nanoeval.recorder import dummy_recorder

__all__ = ["json_recorder", "dummy_recorder"]
Loading