-
Notifications
You must be signed in to change notification settings - Fork 108
Add a wikirace task #471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
phhusson
wants to merge
2
commits into
open-thought:main
Choose a base branch
from
phhusson:dev/phh/wikirace
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add a wikirace task #471
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,238 @@ | ||
| import re | ||
| from collections import defaultdict, deque | ||
| from dataclasses import dataclass | ||
| from functools import lru_cache | ||
| from random import Random | ||
| from typing import Any, Optional | ||
|
|
||
| from ..coaching import BaseCurriculum, RangeAttributeDefinition | ||
| from ..factory import ProceduralDataset, register_dataset | ||
|
|
||
| try: | ||
| from datasets import load_dataset | ||
| except: | ||
| raise Exception("wikirace requires datasets library. Run `pip install datasets`") | ||
|
|
||
| QUESTION_FORMAT_TEMPLATE = """ | ||
| You are playing WikiRace, trying to navigate from one Wikipedia article to another using only links. | ||
|
|
||
| Answer with just the link number. | ||
|
|
||
| Current article: {current} | ||
| Target article: {target} | ||
| Available links (numbered): | ||
| {formatted_links} | ||
|
|
||
| Your path so far: {formatted_path} | ||
|
|
||
| Think about which link is most likely to lead you toward the target article. | ||
| First, analyze each link briefly and how it connects to your goal, then select the most promising one. | ||
| """ | ||
|
|
||
|
|
||
| DATASET_NAME = "wikirace" | ||
|
|
||
|
|
||
| @dataclass | ||
| class WikiraceConfig: | ||
| """Configuration for WikiRace task generation""" | ||
|
|
||
| min_distance: int = 3 | ||
| max_distance: int = 6 | ||
| max_tries: int = 100 | ||
| seed: Optional[int] = None | ||
| size: int = 500 | ||
|
|
||
| def validate(self) -> None: | ||
| """Validate configuration parameters""" | ||
| assert self.min_distance > 1, "min_distance must be greater than 1" | ||
| assert self.max_distance >= self.min_distance, "max_distance must be >= min_distance" | ||
|
|
||
| assert self.max_tries >= 1, "max_tries must be greater than 1" | ||
|
|
||
|
|
||
| def load_wiki_graph(): | ||
| dataset = load_dataset("HuggingFaceTB/simplewiki-pruned-350k") | ||
|
|
||
| graph = defaultdict(set) | ||
| titles = set() | ||
|
|
||
| # Build the graph | ||
| for example in dataset["train"]: | ||
| title = example["article"] | ||
| links = example["links"] | ||
|
|
||
| titles.add(title) | ||
|
|
||
| for link in links: | ||
| graph[title].add(link) | ||
|
|
||
| # Note: Since titles was a set, and hash are naturally unstable | ||
| # We want to sort it, so that prng.choice() is stable | ||
| return graph, sorted(list(titles)) | ||
|
|
||
|
|
||
| class WikiraceDataset(ProceduralDataset): | ||
| """Generates Wikirace Game tasks""" | ||
|
|
||
| def __init__(self, config: WikiraceConfig): | ||
| self.wikigraph, self.wikititles = load_wiki_graph() | ||
| super().__init__(config=config, seed=config.seed, size=config.size) | ||
|
|
||
| # We'll be computing a lot of shortest_path of very similar paths | ||
| # So cache it | ||
| @lru_cache(maxsize=128 * 1024) | ||
| def shortest_path(self, source, target): | ||
| if source not in self.wikigraph or target not in self.wikigraph: | ||
| return None | ||
|
|
||
| if source == target: | ||
| return 1, [source] | ||
|
|
||
| visited = {source} | ||
| queue = deque([(source, [source], 0)]) | ||
|
|
||
| while queue: | ||
| current_node, path, l = queue.popleft() | ||
| for neighbor in self.wikigraph[current_node]: | ||
| if neighbor == target: | ||
| return 1 + l, (path + [neighbor]) | ||
| if neighbor not in visited: | ||
| visited.add(neighbor) | ||
| queue.append((neighbor, path + [neighbor], 1 + l)) | ||
|
|
||
| return None # No path found | ||
|
|
||
| def __getitem__(self, idx: int) -> dict: | ||
| """Generate a single Wikirace Game task | ||
|
|
||
| Returns: | ||
| dict with keys: | ||
| - question: str, the task description with a source article, target article, and current chosen path | ||
| - answer: str, one possible article on the shortest path | ||
| - metadata: dict with generation parameters | ||
| """ | ||
| rng = Random(self.seed + idx) | ||
|
|
||
| # Find a task that suits our min_distance/max_distance | ||
| # Since some pages might be dead-ends, we might need to try multiple times | ||
| for _ in range(self.config.max_tries): | ||
| source = rng.choice(self.wikititles) | ||
| target = source | ||
| chosen_distance = rng.randint(self.config.min_distance, self.config.max_distance) | ||
| path = [source] | ||
| length = 0 | ||
| while self.shortest_path(source, target)[0] != chosen_distance: | ||
| possibilities = self.wikigraph[target] - set(path) | ||
| if not possibilities: | ||
| break | ||
| # Since hash() is random, we need to sort the set into a list | ||
| # for prng stability | ||
| possibilities = sorted(list(possibilities)) | ||
| target = rng.choice(possibilities) | ||
| length += 1 | ||
| # Are we lost? Are we looping? Aborting | ||
| if length > 12: | ||
| break | ||
| path.append(target) | ||
| if self.shortest_path(source, target)[0] == chosen_distance: | ||
| break | ||
| # We got lost in a loop or a dead end, try again | ||
|
|
||
| if self.shortest_path(source, target)[0] != chosen_distance: | ||
| raise Exception(f"After {self.config.max_tries}, we failed to find a suitable wikipedia articles pair") | ||
|
|
||
| _, path = self.shortest_path(source, target) | ||
| # This is the length of the current path (let's call it state) | ||
| # 0 mean that we are still at the source of the path we're searching for | ||
| path_len = rng.randint(0, min(self.config.min_distance, len(path)) - 2) | ||
| given_path = path[:path_len] | ||
| given_path = " => ".join(given_path) | ||
| current = path[path_len] | ||
| # Stable links | ||
| links = sorted(list(self.wikigraph[current])) | ||
| links = list(enumerate(links)) | ||
| question = QUESTION_FORMAT_TEMPLATE.format( | ||
| current=current, | ||
| target=target, | ||
| formatted_links=[f"{x[0]} - {x[1]}\n" for x in links], | ||
| formatted_path=given_path, | ||
| ) | ||
| answer = [x[0] for x in links if x[1] == path[path_len + 1]][0] | ||
|
|
||
| return { | ||
| "question": question, | ||
| "answer": str(answer), | ||
| "metadata": { | ||
| "source_dataset": DATASET_NAME, | ||
| "source_index": idx, | ||
| "source": source, | ||
| "current": current, | ||
| "target": target, | ||
| "distance": chosen_distance, | ||
| "path": given_path, | ||
| "remaining_path": path[path_len:], | ||
| "links": links, | ||
| "difficulty": { | ||
| "distance": (self.config.min_distance, self.config.max_distance), | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: | ||
| """Determine if the solution provided solves the problem""" | ||
| reward = 0.01 # Default reward | ||
| source = entry["metadata"]["source"] | ||
| target = entry["metadata"]["target"] | ||
| current = entry["metadata"]["current"] | ||
| links = entry["metadata"]["links"] | ||
|
|
||
| if answer is None or not answer.strip(): | ||
| return reward | ||
|
|
||
| try: | ||
| answer = answer.strip() | ||
| answer = int(answer) | ||
| if answer < 0: | ||
| return 0.01 | ||
| link = links[answer][1] | ||
| new_distance = self.shortest_path(link, target)[0] | ||
| old_distance = self.shortest_path(current, target)[0] | ||
| if new_distance < old_distance: | ||
| # Path is shortet than before, it is following (a) shortest path! | ||
| return 1.0 | ||
| elif new_distance == old_distance: | ||
| # Path isn't shorter, but not longer either, that's still something | ||
| return 0.5 | ||
| else: | ||
| # At least answer is valid... | ||
| return 0.1 | ||
|
|
||
| except Exception: | ||
| return 0.01 | ||
|
|
||
|
|
||
| class WikiraceCurriculum(BaseCurriculum): | ||
| def __init__(self): | ||
| super().__init__(WikiraceCurriculum.__name__, WikiraceConfig) | ||
|
|
||
| # Define attributes | ||
| self._define_attributes( | ||
| RangeAttributeDefinition( | ||
| name="distance", | ||
| levels=[3, 6, 9, 12, 15], | ||
| description="Number of source numbers", | ||
| lower_field_name="min_distance", | ||
| upper_field_name="max_distance", | ||
| ensure_interval=True, | ||
| ), | ||
| ScalarAttributeDefinition( | ||
| name="max_tries", | ||
| description="Max number of tries to find test cases", | ||
| field_name="max_tries", | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| # Register the dataset | ||
| register_dataset(DATASET_NAME, WikiraceDataset, WikiraceConfig, WikiraceCurriculum) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| datasets>=3.6.0 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| import pytest | ||
|
|
||
| from reasoning_gym.games.wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset | ||
|
|
||
|
|
||
| def test_wikirace_game_config_validation(): | ||
| """Test that invalid configs raise appropriate errors""" | ||
| with pytest.raises(AssertionError): | ||
| config = WikiraceConfig(min_distance=0) | ||
| config.validate() | ||
|
|
||
| with pytest.raises(AssertionError): | ||
| config = WikiraceConfig(min_distance=3, max_distance=2) | ||
| config.validate() | ||
|
|
||
| with pytest.raises(AssertionError): | ||
| config = WikiraceConfig(max_tries=-2) | ||
| config.validate() | ||
|
|
||
|
|
||
| def test_wikirace_game_deterministic(): | ||
| """Test that dataset generates same items with same seed""" | ||
| config1 = WikiraceConfig(seed=42, size=2) | ||
| dataset1 = WikiraceDataset(config1) | ||
| config2 = WikiraceConfig(seed=42, size=2) | ||
| dataset2 = WikiraceDataset(config2) | ||
|
|
||
| for i in range(len(dataset1)): | ||
| assert dataset1[i] == dataset2[i] | ||
|
|
||
|
|
||
| def test_wikirace_game_items(): | ||
| """Test basic properties of generated items""" | ||
| config = WikiraceConfig( | ||
| seed=42, | ||
| size=2, | ||
| ) | ||
| dataset = WikiraceDataset(config) | ||
|
|
||
| for item in dataset: | ||
| assert isinstance(item, dict) | ||
| assert "question" in item | ||
| assert "answer" in item | ||
| assert "metadata" in item | ||
|
|
||
| # Check metadata contains required fields | ||
| assert "source" in item["metadata"] | ||
| assert "links" in item["metadata"] | ||
| assert "target" in item["metadata"] | ||
| assert "current" in item["metadata"] | ||
| assert "distance" in item["metadata"] | ||
|
|
||
| # Verify number of source numbers is within config range | ||
| assert config.min_distance <= item["metadata"]["distance"] <= config.max_distance | ||
|
|
||
| # A non-int answer fails | ||
| assert dataset.score_answer(answer="nope", entry=item) == 0.01 | ||
|
|
||
| # A negative answer fails | ||
| assert dataset.score_answer(answer="-1", entry=item) == 0.01 | ||
|
|
||
| # An out of bond answer fails | ||
| assert dataset.score_answer(answer=str(len(item["metadata"]["links"])), entry=item) == 0.01 | ||
|
|
||
| # A parsable answer gives at least 0.1 | ||
| assert dataset.score_answer(answer="0", entry=item) >= 0.1 | ||
|
|
||
| # The expected answer gives 1.0 | ||
| assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 | ||
|
|
||
|
|
||
| def test_wikirace_game_single(): | ||
| """Test a known item""" | ||
| config = WikiraceConfig( | ||
| seed=42, | ||
| size=1, | ||
| ) | ||
| dataset = WikiraceDataset(config) | ||
| item = dataset[0] | ||
|
|
||
| # If those asserts fails, it probably just means you changed the generation algorithm, which is fine | ||
| # you'll have have to update this test | ||
| assert item["metadata"]["source"] == "Vadim Bakatin" | ||
| assert item["metadata"]["target"] == "Azerbaijan Technological University" | ||
| assert item["metadata"]["distance"] == 3 | ||
| assert len(item["metadata"]["path"]) == 0 | ||
|
|
||
| # If those asserts fails, it is most likely an actual error | ||
|
|
||
| # Only valid answer is 4 - Moscow | ||
| assert dataset.score_answer(answer="4", entry=item) == 1.0 | ||
| # Selecting 8 - Russians makes you go further away from the target | ||
| assert dataset.score_answer(answer="2", entry=item) == 0.1 | ||
| # Selecting 0 - Commmunist Party of the Soviet Union doesn't get you further away, but it doesn't get you closer either | ||
| assert dataset.score_answer(answer="2", entry=item) == 0.1 | ||
|
|
||
| # Use this to check the results if you need to update this test | ||
| # (with pytest -s) | ||
| # for (i,_) in item['metadata']['links']: | ||
| # print(i, dataset.score_answer(answer=str(i), entry=item)) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than using a
requirements-optional.txtwe can add a section inpyproject.tomlcontaining this dep, as we already do for some other optional requirements