diff --git a/pyproject.toml b/pyproject.toml index 68251cc..29fc7f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ dependencies = [ "openai>=1.0.0", "google-generativeai>=0.3.0", "types-requests>=2.32", + "scikit-learn>=1.0", + "numpy>=1.22", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 12ec92a..2a5bacb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ types-requests>=2.32 anthropic>=0.7.0 openai>=1.0.0 google-generativeai>=0.3.0 +scikit-learn>=1.0 +numpy>=1.22 diff --git a/src/search/__init__.py b/src/search/__init__.py index 45ad7ca..870997e 100644 --- a/src/search/__init__.py +++ b/src/search/__init__.py @@ -1,5 +1,20 @@ """Utilities for indexing and searching project files.""" from .file_loader import load_project_files, read_file +from .semantic_indexer import ( + SemanticIndex, + build_semantic_index, + get_file_embedding, + update_file_embedding, + rank_files_by_query, +) -__all__ = ["load_project_files", "read_file"] +__all__ = [ + "load_project_files", + "read_file", + "SemanticIndex", + "build_semantic_index", + "get_file_embedding", + "update_file_embedding", + "rank_files_by_query", +] diff --git a/src/search/semantic_indexer.py b/src/search/semantic_indexer.py new file mode 100644 index 0000000..2604b53 --- /dev/null +++ b/src/search/semantic_indexer.py @@ -0,0 +1,63 @@ +"""Simple semantic indexing of project files.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Dict + +import numpy as np +from sklearn.feature_extraction.text import TfidfVectorizer + +from .file_loader import load_project_files, read_file + + +@dataclass +class SemanticIndex: + """Mapping of file paths to embedding vectors.""" + + embeddings: Dict[str, np.ndarray] + vectorizer: TfidfVectorizer + + def update_file(self, path: str | Path) -> None: + """Update ``embeddings`` for ``path`` in-place.""" + emb = get_file_embedding(path, self) + self.embeddings[str(path)] = emb + + def rank(self, query: str, top_k: int = 5) -> list[tuple[str, float]]: + """Return ``top_k`` files ranked by similarity to ``query``.""" + q_vec = self.vectorizer.transform([query]).toarray()[0] + scores = { + path: float(np.dot(vec, q_vec)) + for path, vec in self.embeddings.items() + } + return sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] + + +def build_semantic_index(base_dir: str | Path = ".") -> SemanticIndex: + """Return a :class:`SemanticIndex` for all project files under ``base_dir``.""" + files = load_project_files(base_dir) + texts = [read_file(p) for p in files] + vectorizer = TfidfVectorizer() + matrix = vectorizer.fit_transform(texts) + embeddings = { + str(path): matrix[idx].toarray()[0] + for idx, path in enumerate(files) + } + return SemanticIndex(embeddings=embeddings, vectorizer=vectorizer) + + +def get_file_embedding(path: str | Path, index: SemanticIndex) -> np.ndarray: + """Return the embedding vector for ``path`` using ``index``'s vectorizer.""" + text = read_file(path) + return index.vectorizer.transform([text]).toarray()[0] + + +def update_file_embedding(path: str | Path, index: SemanticIndex) -> None: + """Update ``index`` with the embedding for ``path``.""" + index.update_file(path) + + +def rank_files_by_query(query: str, index: SemanticIndex, top_k: int = 5) -> list[tuple[str, float]]: + """Rank indexed files by semantic similarity to ``query``.""" + return index.rank(query, top_k=top_k) diff --git a/tests/test_semantic_indexer.py b/tests/test_semantic_indexer.py new file mode 100644 index 0000000..c23f6be --- /dev/null +++ b/tests/test_semantic_indexer.py @@ -0,0 +1,43 @@ +from search.semantic_indexer import ( + build_semantic_index, + get_file_embedding, + update_file_embedding, + rank_files_by_query, +) +import numpy as np + + +def test_build_semantic_index(tmp_path): + (tmp_path / "a.txt").write_text("hello world") + (tmp_path / "b.py").write_text("print('hi')") + index = build_semantic_index(tmp_path) + assert set(index.embeddings.keys()) == { + str(tmp_path / "a.txt"), + str(tmp_path / "b.py"), + } + for emb in index.embeddings.values(): + assert isinstance(emb, np.ndarray) + assert emb.ndim == 1 + + +def test_get_file_embedding(tmp_path): + (tmp_path / "a.txt").write_text("hello") + index = build_semantic_index(tmp_path) + emb1 = get_file_embedding(tmp_path / "a.txt", index) + emb2 = get_file_embedding(tmp_path / "a.txt", index) + assert np.allclose(emb1, emb2) + + +def test_update_and_rank(tmp_path): + f1 = tmp_path / "hello.txt" + f2 = tmp_path / "bye.txt" + f1.write_text("hello world") + f2.write_text("goodbye world") + index = build_semantic_index(tmp_path) + + # modify a file and update embedding + f1.write_text("hello there") + update_file_embedding(f1, index) + + ranking = rank_files_by_query("hello", index, top_k=2) + assert ranking[0][0] == str(f1)