Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions reasoning_gym/games/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .survo import SurvoConfig, SurvoCurriculum, SurvoDataset
from .tower_of_hanoi import HanoiConfig, HanoiCurriculum, HanoiDataset
from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset
from .wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset

__all__ = [
"BoxnetConfig",
Expand Down Expand Up @@ -76,4 +77,7 @@
"MahjongPuzzleConfig",
"MahjongPuzzleDataset",
"MahjongPuzzleCurriculum",
"WikiraceConfig",
"WikiraceDataset",
"WikiraceCurriculum",
]
238 changes: 238 additions & 0 deletions reasoning_gym/games/wikirace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import re
from collections import defaultdict, deque
from dataclasses import dataclass
from functools import lru_cache
from random import Random
from typing import Any, Optional

from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset

try:
from datasets import load_dataset
except:
raise Exception("wikirace requires datasets library. Run `pip install datasets`")

QUESTION_FORMAT_TEMPLATE = """
You are playing WikiRace, trying to navigate from one Wikipedia article to another using only links.

Answer with just the link number.

Current article: {current}
Target article: {target}
Available links (numbered):
{formatted_links}

Your path so far: {formatted_path}

Think about which link is most likely to lead you toward the target article.
First, analyze each link briefly and how it connects to your goal, then select the most promising one.
"""


DATASET_NAME = "wikirace"


@dataclass
class WikiraceConfig:
"""Configuration for WikiRace task generation"""

min_distance: int = 3
max_distance: int = 6
max_tries: int = 100
seed: Optional[int] = None
size: int = 500

def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_distance > 1, "min_distance must be greater than 1"
assert self.max_distance >= self.min_distance, "max_distance must be >= min_distance"

assert self.max_tries >= 1, "max_tries must be greater than 1"


def load_wiki_graph():
dataset = load_dataset("HuggingFaceTB/simplewiki-pruned-350k")

graph = defaultdict(set)
titles = set()

# Build the graph
for example in dataset["train"]:
title = example["article"]
links = example["links"]

titles.add(title)

for link in links:
graph[title].add(link)

# Note: Since titles was a set, and hash are naturally unstable
# We want to sort it, so that prng.choice() is stable
return graph, sorted(list(titles))


class WikiraceDataset(ProceduralDataset):
"""Generates Wikirace Game tasks"""

def __init__(self, config: WikiraceConfig):
self.wikigraph, self.wikititles = load_wiki_graph()
super().__init__(config=config, seed=config.seed, size=config.size)

# We'll be computing a lot of shortest_path of very similar paths
# So cache it
@lru_cache(maxsize=128 * 1024)
def shortest_path(self, source, target):
if source not in self.wikigraph or target not in self.wikigraph:
return None

if source == target:
return 1, [source]

visited = {source}
queue = deque([(source, [source], 0)])

while queue:
current_node, path, l = queue.popleft()
for neighbor in self.wikigraph[current_node]:
if neighbor == target:
return 1 + l, (path + [neighbor])
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [neighbor], 1 + l))

return None # No path found

def __getitem__(self, idx: int) -> dict:
"""Generate a single Wikirace Game task

Returns:
dict with keys:
- question: str, the task description with a source article, target article, and current chosen path
- answer: str, one possible article on the shortest path
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)

# Find a task that suits our min_distance/max_distance
# Since some pages might be dead-ends, we might need to try multiple times
for _ in range(self.config.max_tries):
source = rng.choice(self.wikititles)
target = source
chosen_distance = rng.randint(self.config.min_distance, self.config.max_distance)
path = [source]
length = 0
while self.shortest_path(source, target)[0] != chosen_distance:
possibilities = self.wikigraph[target] - set(path)
if not possibilities:
break
# Since hash() is random, we need to sort the set into a list
# for prng stability
possibilities = sorted(list(possibilities))
target = rng.choice(possibilities)
length += 1
# Are we lost? Are we looping? Aborting
if length > 12:
break
path.append(target)
if self.shortest_path(source, target)[0] == chosen_distance:
break
# We got lost in a loop or a dead end, try again

if self.shortest_path(source, target)[0] != chosen_distance:
raise Exception(f"After {self.config.max_tries}, we failed to find a suitable wikipedia articles pair")

_, path = self.shortest_path(source, target)
# This is the length of the current path (let's call it state)
# 0 mean that we are still at the source of the path we're searching for
path_len = rng.randint(0, min(self.config.min_distance, len(path)) - 2)
given_path = path[:path_len]
given_path = " => ".join(given_path)
current = path[path_len]
# Stable links
links = sorted(list(self.wikigraph[current]))
links = list(enumerate(links))
question = QUESTION_FORMAT_TEMPLATE.format(
current=current,
target=target,
formatted_links=[f"{x[0]} - {x[1]}\n" for x in links],
formatted_path=given_path,
)
answer = [x[0] for x in links if x[1] == path[path_len + 1]][0]

return {
"question": question,
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"source": source,
"current": current,
"target": target,
"distance": chosen_distance,
"path": given_path,
"remaining_path": path[path_len:],
"links": links,
"difficulty": {
"distance": (self.config.min_distance, self.config.max_distance),
},
},
}

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the problem"""
reward = 0.01 # Default reward
source = entry["metadata"]["source"]
target = entry["metadata"]["target"]
current = entry["metadata"]["current"]
links = entry["metadata"]["links"]

if answer is None or not answer.strip():
return reward

try:
answer = answer.strip()
answer = int(answer)
if answer < 0:
return 0.01
link = links[answer][1]
new_distance = self.shortest_path(link, target)[0]
old_distance = self.shortest_path(current, target)[0]
if new_distance < old_distance:
# Path is shortet than before, it is following (a) shortest path!
return 1.0
elif new_distance == old_distance:
# Path isn't shorter, but not longer either, that's still something
return 0.5
else:
# At least answer is valid...
return 0.1

except Exception:
return 0.01


class WikiraceCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(WikiraceCurriculum.__name__, WikiraceConfig)

# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="distance",
levels=[3, 6, 9, 12, 15],
description="Number of source numbers",
lower_field_name="min_distance",
upper_field_name="max_distance",
ensure_interval=True,
),
ScalarAttributeDefinition(
name="max_tries",
description="Max number of tries to find test cases",
field_name="max_tries",
),
)


# Register the dataset
register_dataset(DATASET_NAME, WikiraceDataset, WikiraceConfig, WikiraceCurriculum)
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
datasets>=3.6.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than using a requirements-optional.txt we can add a section in pyproject.toml containing this dep, as we already do for some other optional requirements

100 changes: 100 additions & 0 deletions tests/test_wikirace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest

from reasoning_gym.games.wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset


def test_wikirace_game_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = WikiraceConfig(min_distance=0)
config.validate()

with pytest.raises(AssertionError):
config = WikiraceConfig(min_distance=3, max_distance=2)
config.validate()

with pytest.raises(AssertionError):
config = WikiraceConfig(max_tries=-2)
config.validate()


def test_wikirace_game_deterministic():
"""Test that dataset generates same items with same seed"""
config1 = WikiraceConfig(seed=42, size=2)
dataset1 = WikiraceDataset(config1)
config2 = WikiraceConfig(seed=42, size=2)
dataset2 = WikiraceDataset(config2)

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]


def test_wikirace_game_items():
"""Test basic properties of generated items"""
config = WikiraceConfig(
seed=42,
size=2,
)
dataset = WikiraceDataset(config)

for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item

# Check metadata contains required fields
assert "source" in item["metadata"]
assert "links" in item["metadata"]
assert "target" in item["metadata"]
assert "current" in item["metadata"]
assert "distance" in item["metadata"]

# Verify number of source numbers is within config range
assert config.min_distance <= item["metadata"]["distance"] <= config.max_distance

# A non-int answer fails
assert dataset.score_answer(answer="nope", entry=item) == 0.01

# A negative answer fails
assert dataset.score_answer(answer="-1", entry=item) == 0.01

# An out of bond answer fails
assert dataset.score_answer(answer=str(len(item["metadata"]["links"])), entry=item) == 0.01

# A parsable answer gives at least 0.1
assert dataset.score_answer(answer="0", entry=item) >= 0.1

# The expected answer gives 1.0
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0


def test_wikirace_game_single():
"""Test a known item"""
config = WikiraceConfig(
seed=42,
size=1,
)
dataset = WikiraceDataset(config)
item = dataset[0]

# If those asserts fails, it probably just means you changed the generation algorithm, which is fine
# you'll have have to update this test
assert item["metadata"]["source"] == "Vadim Bakatin"
assert item["metadata"]["target"] == "Azerbaijan Technological University"
assert item["metadata"]["distance"] == 3
assert len(item["metadata"]["path"]) == 0

# If those asserts fails, it is most likely an actual error

# Only valid answer is 4 - Moscow
assert dataset.score_answer(answer="4", entry=item) == 1.0
# Selecting 8 - Russians makes you go further away from the target
assert dataset.score_answer(answer="2", entry=item) == 0.1
# Selecting 0 - Commmunist Party of the Soviet Union doesn't get you further away, but it doesn't get you closer either
assert dataset.score_answer(answer="2", entry=item) == 0.1

# Use this to check the results if you need to update this test
# (with pytest -s)
# for (i,_) in item['metadata']['links']:
# print(i, dataset.score_answer(answer=str(i), entry=item))