From 8b858e0f68386c6b29a206858db3e19c2a60b4f1 Mon Sep 17 00:00:00 2001 From: sky0walker99 Date: Thu, 23 Oct 2025 22:29:34 +0530 Subject: [PATCH 1/3] Update llm client --- pyproject.toml | 3 +- synthetic_data_kit/models/llm_client.py | 51 ++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 639d14ef..77133791 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "bootstrap-flask>=2.2.0", "beautifulsoup4>=4.12.0", "pylance", - "PyMuPDF" + "PyMuPDF", + "python-dotenv>=1.0.1", ] # These fields appear in pip show diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index 4f964100..7752865a 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -48,6 +48,8 @@ def __init__(self, max_retries: Override max retries from config retry_delay: Override retry delay from config """ + from dotenv import load_dotenv + load_dotenv() # Load config self.config = load_config(config_path) @@ -65,7 +67,7 @@ def __init__(self, self.api_base = api_base or api_endpoint_config.get('api_base') # Check for environment variables - api_endpoint_key = os.environ.get('API_ENDPOINT_KEY') + api_endpoint_key = os.getenv('GEMINI_API_KEY') print(f"API_ENDPOINT_KEY from environment: {'Found' if api_endpoint_key else 'Not found'}") # Set API key with priority: CLI arg > env var > config @@ -167,6 +169,48 @@ def _openai_chat_completion(self, debug_mode = os.environ.get('SDK_DEBUG', 'false').lower() == 'true' if verbose: logger.info(f"Sending request to {self.provider} model {self.model}...") + + if "gemini" in self.model: + url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent" + headers = { + "x-goog-api-key": self.api_key, + "Content-Type": "application/json", + } + + # Convert messages to Gemini format + contents = [] + for msg in messages: + # Gemini API does not support system role, so we convert it to user role + role = msg["role"] if msg["role"] != "system" else "user" + contents.append({ + "role": role, + "parts": [{"text": msg["content"]}] + }) + + data = { + "contents": contents, + "generationConfig": { + "temperature": temperature, + "maxOutputTokens": max_tokens, + "topP": top_p, + } + } + for attempt in range(self.max_retries): + try: + response = requests.post(url, headers=headers, data=json.dumps(data)) + response.raise_for_status() + response_json = response.json() + if "candidates" in response_json and len(response_json["candidates"]) > 0: + candidate = response_json["candidates"][0] + if "content" in candidate and "parts" in candidate["content"] and len(candidate["content"]["parts"]) > 0: + return candidate["content"]["parts"][0]["text"] + return "" # return empty string if no content found + except Exception as e: + if verbose: + logger.error(f"Gemini API error (attempt {attempt+1}/{self.max_retries}): {str(e)}") + if attempt == self.max_retries - 1: + raise Exception(f"Failed to get Gemini completion after {self.max_retries} attempts: {str(e)}") + time.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff for attempt in range(self.max_retries): try: @@ -351,6 +395,11 @@ async def _process_message_async(self, verbose: bool, debug_mode: bool): """Process a single message set asynchronously using the OpenAI API""" + if "gemini" in self.model: + # This is a synchronous call inside an async function. + # Not ideal, but avoids adding new dependencies like httpx. + return self._openai_chat_completion(messages, temperature, max_tokens, top_p, verbose) + try: from openai import AsyncOpenAI except ImportError: From 5adfed0a9dedcc8430e6d12de71104a8d976767e Mon Sep 17 00:00:00 2001 From: sky0walker99 Date: Thu, 23 Oct 2025 22:44:13 +0530 Subject: [PATCH 2/3] Update api endpoint key --- synthetic_data_kit/models/llm_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index 7752865a..0cc60c3b 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -67,7 +67,7 @@ def __init__(self, self.api_base = api_base or api_endpoint_config.get('api_base') # Check for environment variables - api_endpoint_key = os.getenv('GEMINI_API_KEY') + api_endpoint_key = os.getenv('API_ENDPOINT_KEY') print(f"API_ENDPOINT_KEY from environment: {'Found' if api_endpoint_key else 'Not found'}") # Set API key with priority: CLI arg > env var > config From 7194c5bb25900f485937c8d21086e94a590c27ed Mon Sep 17 00:00:00 2001 From: Haroon <106879583+haroon0x@users.noreply.github.com> Date: Wed, 29 Oct 2025 19:39:24 +0530 Subject: [PATCH 3/3] Update lance_utils.py --- synthetic_data_kit/utils/lance_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synthetic_data_kit/utils/lance_utils.py b/synthetic_data_kit/utils/lance_utils.py index ec94e037..3d7b5ee2 100644 --- a/synthetic_data_kit/utils/lance_utils.py +++ b/synthetic_data_kit/utils/lance_utils.py @@ -5,6 +5,8 @@ # the root directory of this source tree. import lance +from lance.dataset import write_dataset as lance_write_dataset +from lance.dataset import LanceDataset import pyarrow as pa from typing import List, Dict, Any, Optional import os @@ -30,7 +32,7 @@ def create_lance_dataset( os.makedirs(output_dir) table = pa.Table.from_pylist(data, schema=schema) - lance.write_dataset(table, output_path, mode="overwrite") + lance_write_dataset(table, output_path, mode="overwrite") def load_lance_dataset( dataset_path: str @@ -45,4 +47,4 @@ def load_lance_dataset( """ if not os.path.exists(dataset_path): return None - return lance.dataset(dataset_path) + return LanceDataset(dataset_path)