diff --git a/pyproject.toml b/pyproject.toml index 8e26f57..1c45d8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "uvicorn[standard]", "httpx", "tiktoken", + "requests", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 7d25c9d..cb79ce4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ httpx fastapi uvicorn[standard] tiktoken +requests diff --git a/skillclaw/config.py b/skillclaw/config.py index de70d0f..40e9276 100644 --- a/skillclaw/config.py +++ b/skillclaw/config.py @@ -32,7 +32,13 @@ class SkillClawConfig: skills_dir: str = "memory_data/skills" skills_public_root: str = "" retrieval_mode: str = "template" + # Embedding configuration: "local" for SentenceTransformer, "api" for OpenAI-compatible APIs + embedding_type: str = "local" embedding_model_path: str = "Qwen/Qwen3-Embedding-0.6B" + # OpenAI-compatible embedding API configuration + embedding_api_url: str = "" # e.g., "https://api.openai.com/v1" or "https://api.jina.ai/v1" + embedding_api_model: str = "" # e.g., "text-embedding-3-small" or "jina-embeddings-v5-text-small" + embedding_api_key: str = "" skill_top_k: int = 6 max_skills_prompt_chars: int = 30000 diff --git a/skillclaw/embedding_api_client.py b/skillclaw/embedding_api_client.py new file mode 100644 index 0000000..5f82dc3 --- /dev/null +++ b/skillclaw/embedding_api_client.py @@ -0,0 +1,143 @@ +""" +Embedding API client supporting OpenAI-compatible APIs. + +Supports any embedding service with OpenAI API format, including: +- OpenAI (https://api.openai.com/v1/embeddings) +- Jina (https://api.jina.ai/v1/embeddings) +- Azure OpenAI +- LocalAI +- Ollama (with OpenAI-compatible server) +""" + +import logging +import numpy as np +import requests +from typing import List, Optional + +logger = logging.getLogger(__name__) + + +class EmbeddingAPIClient: + """Client for OpenAI-compatible embedding APIs.""" + + def __init__( + self, + api_url: str, + model: str, + api_key: Optional[str] = None, + timeout: int = 30, + ): + """Initialize embedding API client. + + Args: + api_url: Base URL of the embedding API (e.g., "https://api.openai.com/v1") + model: Model name to use for embeddings + api_key: API key for authentication (optional for local services) + timeout: Request timeout in seconds + """ + self.api_url = api_url.rstrip("/") + self.model = model + self.api_key = api_key + self.timeout = timeout + self._session = None + + @property + def session(self): + """Lazy-load requests session.""" + if self._session is None: + self._session = requests.Session() + return self._session + + def encode( + self, + texts: List[str], + normalize_embeddings: bool = True, + show_progress_bar: bool = False, + convert_to_numpy: bool = True, + ) -> np.ndarray: + """Encode texts into embeddings using the API. + + Args: + texts: List of text strings to encode + normalize_embeddings: Whether to normalize embeddings (L2) + show_progress_bar: Whether to show progress bar (ignored for API) + convert_to_numpy: Whether to return numpy array + + Returns: + numpy array of shape (len(texts), embedding_dim) + """ + if show_progress_bar: + logger.warning( + "show_progress_bar parameter is not supported for embedding API client and will be ignored" + ) + + if not texts: + return np.zeros((0, 0), dtype=np.float32) + + embeddings = self._call_api(texts) + + if normalize_embeddings: + # L2 normalization + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + norms[norms == 0] = 1 # Avoid division by zero + embeddings = embeddings / norms + + if convert_to_numpy: + return embeddings.astype(np.float32) + return embeddings + + def _call_api(self, texts: List[str]) -> np.ndarray: + """Call the embedding API and return embeddings. + + Args: + texts: List of text strings to encode + + Returns: + numpy array of embeddings + """ + headers = { + "Content-Type": "application/json", + } + + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + payload = { + "model": self.model, + "input": texts, + } + + try: + response = self.session.post( + f"{self.api_url}/embeddings", + json=payload, + headers=headers, + timeout=self.timeout, + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + logger.error(f"Embedding API request failed: {e}") + raise + + data = response.json() + + # Extract embeddings from response + # OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]} + if "data" not in data: + raise ValueError(f"Unexpected API response format: {data}") + + embeddings_list = sorted(data["data"], key=lambda x: x.get("index", 0)) + embeddings = np.array( + [item["embedding"] for item in embeddings_list], + dtype=np.float32, + ) + + logger.debug( + f"Retrieved {len(embeddings)} embeddings with dimension {embeddings.shape[1]}" + ) + return embeddings + + def __del__(self): + """Close session when client is destroyed.""" + if self._session is not None: + self._session.close() diff --git a/skillclaw/launcher.py b/skillclaw/launcher.py index 958731b..4bc94c7 100644 --- a/skillclaw/launcher.py +++ b/skillclaw/launcher.py @@ -75,6 +75,10 @@ async def _run(self, cfg): public_skill_root=cfg.skills_public_root, retrieval_mode=cfg.retrieval_mode, embedding_model_path=cfg.embedding_model_path, + embedding_type=cfg.embedding_type, + embedding_api_url=cfg.embedding_api_url, + embedding_api_model=cfg.embedding_api_model, + embedding_api_key=cfg.embedding_api_key, ) logger.info("[Launcher] SkillManager loaded: %s skills", skill_manager.get_skill_count()) diff --git a/skillclaw/skill_manager.py b/skillclaw/skill_manager.py index 55547a2..2acbfd9 100644 --- a/skillclaw/skill_manager.py +++ b/skillclaw/skill_manager.py @@ -174,6 +174,10 @@ def __init__( public_skill_root: str = "", retrieval_mode: str = "template", embedding_model_path: Optional[str] = None, + embedding_type: str = "local", + embedding_api_url: Optional[str] = None, + embedding_api_model: Optional[str] = None, + embedding_api_key: Optional[str] = None, ): if retrieval_mode not in ("template", "embedding"): raise ValueError(f"retrieval_mode must be 'template' or 'embedding', got '{retrieval_mode}'") @@ -183,7 +187,11 @@ def __init__( self._skills_dir = skills_dir self._public_skill_root = public_skill_root.strip() self.retrieval_mode = retrieval_mode + self.embedding_type = embedding_type self.embedding_model_path = embedding_model_path or "Qwen/Qwen3-Embedding-0.6B" + self.embedding_api_url = embedding_api_url + self.embedding_api_model = embedding_api_model + self.embedding_api_key = embedding_api_key self._embedding_model = None self._skill_embeddings_cache: Optional[Dict] = None @@ -199,10 +207,11 @@ def __init__( counts = self._category_counts() logger.info( - "[SkillManager] loaded %d skills from %s | mode=%s | categories=%s", + "[SkillManager] loaded %d skills from %s | mode=%s | embedding_type=%s | categories=%s", len(self.skills.get("all_skills", [])), skills_dir, retrieval_mode, + embedding_type, dict(counts), ) @@ -372,16 +381,37 @@ def refresh_if_changed(self) -> bool: # ------------------------------------------------------------------ # def _get_embedding_model(self): + """Get embedding model - either local SentenceTransformer or API client.""" if self._embedding_model is None: - try: - from sentence_transformers import SentenceTransformer - except ImportError: - raise ImportError( - "sentence-transformers is required for embedding retrieval. " - "Install with: pip install sentence-transformers" + if self.embedding_type == "api": + # Use OpenAI-compatible API + if not self.embedding_api_url or not self.embedding_api_model: + raise ValueError( + "embedding_api_url and embedding_api_model must be set when embedding_type='api'" + ) + from .embedding_api_client import EmbeddingAPIClient + + logger.info( + "[SkillManager] using embedding API: %s (model: %s)", + self.embedding_api_url, + self.embedding_api_model, + ) + self._embedding_model = EmbeddingAPIClient( + api_url=self.embedding_api_url, + model=self.embedding_api_model, + api_key=self.embedding_api_key, ) - logger.info("[SkillManager] loading embedding model: %s", self.embedding_model_path) - self._embedding_model = SentenceTransformer(self.embedding_model_path) + else: + # Use local SentenceTransformer model + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError( + "sentence-transformers is required for local embedding. " + "Install with: pip install sentence-transformers" + ) + logger.info("[SkillManager] loading embedding model: %s", self.embedding_model_path) + self._embedding_model = SentenceTransformer(self.embedding_model_path) return self._embedding_model @staticmethod diff --git a/tests/test_embedding_api.py b/tests/test_embedding_api.py new file mode 100644 index 0000000..1e64368 --- /dev/null +++ b/tests/test_embedding_api.py @@ -0,0 +1,272 @@ +""" +Unit tests for embedding API client. + +Run with: pytest tests/test_embedding_api.py +""" + +import json +import numpy as np +import pytest +import responses +from unittest.mock import MagicMock, patch + +from skillclaw.embedding_api_client import EmbeddingAPIClient + + +class TestEmbeddingAPIClient: + """Test cases for EmbeddingAPIClient.""" + + @responses.activate + def test_encode_basic(self): + """Test basic encoding with mock API.""" + # Mock API response in OpenAI format + responses.add( + responses.POST, + "https://api.example.com/embeddings", + json={ + "data": [ + {"embedding": [0.1, 0.2, 0.3], "index": 0}, + {"embedding": [0.4, 0.5, 0.6], "index": 1}, + ] + }, + status=200, + ) + + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + api_key="test-key", + ) + + texts = ["Hello world", "Test text"] + embeddings = client.encode(texts, normalize_embeddings=False) + + assert embeddings.shape == (2, 3) + assert np.allclose(embeddings[0], [0.1, 0.2, 0.3]) + assert np.allclose(embeddings[1], [0.4, 0.5, 0.6]) + + @responses.activate + def test_encode_with_normalization(self): + """Test L2 normalization of embeddings.""" + responses.add( + responses.POST, + "https://api.example.com/embeddings", + json={ + "data": [ + {"embedding": [3.0, 4.0], "index": 0}, # magnitude = 5 + ] + }, + status=200, + ) + + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + ) + + embeddings = client.encode(["test"], normalize_embeddings=True) + + # After normalization: [3/5, 4/5] = [0.6, 0.8] + assert np.allclose(embeddings[0], [0.6, 0.8], atol=1e-6) + # Check L2 norm is 1 + assert np.allclose(np.linalg.norm(embeddings[0]), 1.0) + + @responses.activate + def test_encode_empty_input(self): + """Test encoding with empty input.""" + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + ) + + embeddings = client.encode([]) + + assert embeddings.shape == (0, 0) + + @responses.activate + def test_encode_with_authorization(self): + """Test that API key is correctly included in request.""" + responses.add( + responses.POST, + "https://api.example.com/embeddings", + json={"data": [{"embedding": [0.1, 0.2], "index": 0}]}, + status=200, + ) + + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + api_key="secret-key-123", + ) + + embeddings = client.encode(["test"]) + + # Check that request includes Authorization header + assert len(responses.calls) == 1 + assert responses.calls[0].request.headers["Authorization"] == "Bearer secret-key-123" + + @responses.activate + def test_encode_without_api_key(self): + """Test encoding without API key (for local services).""" + responses.add( + responses.POST, + "https://api.example.com/embeddings", + json={"data": [{"embedding": [0.1, 0.2], "index": 0}]}, + status=200, + ) + + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + api_key=None, + ) + + embeddings = client.encode(["test"]) + + # Check that request does not include Authorization header + assert len(responses.calls) == 1 + assert "Authorization" not in responses.calls[0].request.headers + + @responses.activate + def test_encode_out_of_order_responses(self): + """Test that embeddings are correctly sorted by index.""" + responses.add( + responses.POST, + "https://api.example.com/embeddings", + json={ + "data": [ + {"embedding": [0.3, 0.3], "index": 2}, + {"embedding": [0.1, 0.1], "index": 0}, + {"embedding": [0.2, 0.2], "index": 1}, + ] + }, + status=200, + ) + + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + ) + + embeddings = client.encode(["a", "b", "c"], normalize_embeddings=False) + + # Should be sorted by index + assert np.allclose(embeddings[0], [0.1, 0.1]) + assert np.allclose(embeddings[1], [0.2, 0.2]) + assert np.allclose(embeddings[2], [0.3, 0.3]) + + @responses.activate + def test_api_error_handling(self): + """Test error handling for API failures.""" + responses.add( + responses.POST, + "https://api.example.com/embeddings", + json={"error": "Invalid API key"}, + status=401, + ) + + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + api_key="invalid-key", + ) + + with pytest.raises(Exception): # Should raise RequestException + client.encode(["test"]) + + @responses.activate + def test_invalid_response_format(self): + """Test error handling for invalid response format.""" + responses.add( + responses.POST, + "https://api.example.com/embeddings", + json={"invalid": "response"}, + status=200, + ) + + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + ) + + with pytest.raises(ValueError, match="Unexpected API response format"): + client.encode(["test"]) + + @responses.activate + def test_large_batch(self): + """Test encoding a large batch of texts.""" + num_texts = 100 + embedding_dim = 384 + + # Generate mock embeddings + embeddings_data = [ + {"embedding": list(np.random.rand(embedding_dim).astype(float)), "index": i} + for i in range(num_texts) + ] + + responses.add( + responses.POST, + "https://api.example.com/embeddings", + json={"data": embeddings_data}, + status=200, + ) + + client = EmbeddingAPIClient( + api_url="https://api.example.com", + model="test-model", + ) + + texts = [f"text_{i}" for i in range(num_texts)] + embeddings = client.encode(texts) + + assert embeddings.shape == (num_texts, embedding_dim) + assert embeddings.dtype == np.float32 + + +class TestEmbeddingAPIIntegration: + """Integration tests with SkillManager.""" + + @pytest.mark.skipif(True, reason="Requires actual API key") + def test_skill_manager_with_api(self): + """Test SkillManager integration with embedding API. + + This test is skipped by default as it requires a real API key. + Set embedding API credentials and remove @skipif to run. + """ + from skillclaw.skill_manager import SkillManager + from pathlib import Path + import tempfile + + # Create temporary skill directory + with tempfile.TemporaryDirectory() as tmpdir: + skills_dir = Path(tmpdir) + + # Create sample skills + for skill_name in ["test-skill-1", "test-skill-2"]: + skill_path = skills_dir / skill_name + skill_path.mkdir() + (skill_path / "SKILL.md").write_text( + f"""--- +name: {skill_name} +description: Test skill for {skill_name} +--- + +# {skill_name} + +Test content +""" + ) + + # Initialize with API + skill_manager = SkillManager( + skills_dir=str(skills_dir), + retrieval_mode="embedding", + embedding_type="api", + embedding_api_url="https://api.jina.ai/v1", + embedding_api_model="jina-embeddings-v5-text-small", + embedding_api_key="your-api-key", + ) + + # Test retrieval + results = skill_manager.retrieve("test query", top_k=2) + assert len(results) <= 2