-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrules_base.py
More file actions
97 lines (74 loc) · 3.57 KB
/
Copy pathrules_base.py
File metadata and controls
97 lines (74 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Internal RAG over a small library of math rules, backed by ChromaDB.
The vector store is built lazily on first query from ``MATH_RULES`` (the source
of truth in ``math_rules.py``), so importing this module performs no I/O. If the
persistent store under ``chroma_db_reguli/`` is missing or empty it is rebuilt
automatically on first use.
The active collection is ``reguli_matematice_en``: ``math_rules.py`` is written
in English to match the language of the evaluation benchmarks (GSM8K, MATH500,
AIME, SVAMP). The legacy Romanian ``reguli_matematice`` collection is left
untouched on disk for reproducibility of earlier runs.
"""
import chromadb
_DB_PATH = "./chroma_db_reguli"
_COLLECTION_NAME = "reguli_matematice_en"
_collection = None
def _get_collection():
"""Return the ChromaDB collection, building/populating it on first use."""
global _collection
if _collection is not None:
return _collection
client = chromadb.PersistentClient(path=_DB_PATH)
collection = client.get_or_create_collection(name=_COLLECTION_NAME)
if collection.count() == 0:
from math_rules import MATH_RULES
print("[ChromaDB] Empty store — generating rule embeddings...")
collection.add(
documents=[rule["description"] for rule in MATH_RULES],
metadatas=[{"hint": rule["hint"], "rule_id": rule["id"]} for rule in MATH_RULES],
ids=[rule["id"] for rule in MATH_RULES],
)
print(f"[ChromaDB] Indexed {collection.count()} math rules.")
_collection = collection
return _collection
def _safe_print(msg: str) -> None:
"""Print that survives legacy Windows console encodings (cp1252)."""
try:
print(msg)
except UnicodeEncodeError:
print(msg.encode("ascii", "backslashreplace").decode("ascii"))
def find_hints(query: str, n_results: int = 2, max_distance: float = 1.2) -> str:
"""Retrieve the top-N most relevant hints by embedding similarity.
Implements Eq. 1 of the paper: ``H = argmax cos(E(P), E(r_i))`` — the query
can be the raw problem text itself, no LLM classification needed. Hints
whose semantic distance exceeds ``max_distance`` are discarded; the kept
matches are joined into a single hint block.
"""
if not query or not query.strip():
return ""
try:
results = _get_collection().query(query_texts=[query], n_results=n_results)
hints = []
for distance, metadata in zip(
results["distances"][0], results["metadatas"][0], strict=True
):
# 'id_regula' is the metadata key used by older persisted DBs.
rule_id = metadata.get("rule_id") or metadata.get("id_regula", "UNKNOWN_ID")
hint = metadata.get("hint", "")
if distance < max_distance:
_safe_print(f" [Chroma Match] Rule {rule_id} (distance {distance:.2f})")
hints.append(hint)
else:
_safe_print(
f" [Chroma Miss] Closest rule {rule_id} rejected "
f"(distance {distance:.2f} >= {max_distance})."
)
return "\n\n".join(hints)
except Exception as exc: # noqa: BLE001 - never let RAG break the pipeline
print(f"[ChromaDB Error] Query failed: {exc}")
return ""
def find_hint(problem_type: str, max_distance: float = 1.2) -> str:
"""Backwards-compatible single-hint lookup.
``max_distance`` is the semantic-distance threshold above which the closest
match is rejected as irrelevant.
"""
return find_hints(problem_type, n_results=1, max_distance=max_distance)