diff --git a/pyproject.toml b/pyproject.toml index 639d14ef..77133791 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "bootstrap-flask>=2.2.0", "beautifulsoup4>=4.12.0", "pylance", - "PyMuPDF" + "PyMuPDF", + "python-dotenv>=1.0.1", ] # These fields appear in pip show diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index 4f964100..0cc60c3b 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -48,6 +48,8 @@ def __init__(self, max_retries: Override max retries from config retry_delay: Override retry delay from config """ + from dotenv import load_dotenv + load_dotenv() # Load config self.config = load_config(config_path) @@ -65,7 +67,7 @@ def __init__(self, self.api_base = api_base or api_endpoint_config.get('api_base') # Check for environment variables - api_endpoint_key = os.environ.get('API_ENDPOINT_KEY') + api_endpoint_key = os.getenv('API_ENDPOINT_KEY') print(f"API_ENDPOINT_KEY from environment: {'Found' if api_endpoint_key else 'Not found'}") # Set API key with priority: CLI arg > env var > config @@ -167,6 +169,48 @@ def _openai_chat_completion(self, debug_mode = os.environ.get('SDK_DEBUG', 'false').lower() == 'true' if verbose: logger.info(f"Sending request to {self.provider} model {self.model}...") + + if "gemini" in self.model: + url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent" + headers = { + "x-goog-api-key": self.api_key, + "Content-Type": "application/json", + } + + # Convert messages to Gemini format + contents = [] + for msg in messages: + # Gemini API does not support system role, so we convert it to user role + role = msg["role"] if msg["role"] != "system" else "user" + contents.append({ + "role": role, + "parts": [{"text": msg["content"]}] + }) + + data = { + "contents": contents, + "generationConfig": { + "temperature": temperature, + "maxOutputTokens": max_tokens, + "topP": top_p, + } + } + for attempt in range(self.max_retries): + try: + response = requests.post(url, headers=headers, data=json.dumps(data)) + response.raise_for_status() + response_json = response.json() + if "candidates" in response_json and len(response_json["candidates"]) > 0: + candidate = response_json["candidates"][0] + if "content" in candidate and "parts" in candidate["content"] and len(candidate["content"]["parts"]) > 0: + return candidate["content"]["parts"][0]["text"] + return "" # return empty string if no content found + except Exception as e: + if verbose: + logger.error(f"Gemini API error (attempt {attempt+1}/{self.max_retries}): {str(e)}") + if attempt == self.max_retries - 1: + raise Exception(f"Failed to get Gemini completion after {self.max_retries} attempts: {str(e)}") + time.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff for attempt in range(self.max_retries): try: @@ -351,6 +395,11 @@ async def _process_message_async(self, verbose: bool, debug_mode: bool): """Process a single message set asynchronously using the OpenAI API""" + if "gemini" in self.model: + # This is a synchronous call inside an async function. + # Not ideal, but avoids adding new dependencies like httpx. + return self._openai_chat_completion(messages, temperature, max_tokens, top_p, verbose) + try: from openai import AsyncOpenAI except ImportError: diff --git a/synthetic_data_kit/utils/lance_utils.py b/synthetic_data_kit/utils/lance_utils.py index ec94e037..3d7b5ee2 100644 --- a/synthetic_data_kit/utils/lance_utils.py +++ b/synthetic_data_kit/utils/lance_utils.py @@ -5,6 +5,8 @@ # the root directory of this source tree. import lance +from lance.dataset import write_dataset as lance_write_dataset +from lance.dataset import LanceDataset import pyarrow as pa from typing import List, Dict, Any, Optional import os @@ -30,7 +32,7 @@ def create_lance_dataset( os.makedirs(output_dir) table = pa.Table.from_pylist(data, schema=schema) - lance.write_dataset(table, output_path, mode="overwrite") + lance_write_dataset(table, output_path, mode="overwrite") def load_lance_dataset( dataset_path: str @@ -45,4 +47,4 @@ def load_lance_dataset( """ if not os.path.exists(dataset_path): return None - return lance.dataset(dataset_path) + return LanceDataset(dataset_path)