Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [
"bootstrap-flask>=2.2.0",
"beautifulsoup4>=4.12.0",
"pylance",
"PyMuPDF"
"PyMuPDF",
"python-dotenv>=1.0.1",
]

# These fields appear in pip show
Expand Down
51 changes: 50 additions & 1 deletion synthetic_data_kit/models/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self,
max_retries: Override max retries from config
retry_delay: Override retry delay from config
"""
from dotenv import load_dotenv
load_dotenv()
# Load config
self.config = load_config(config_path)

Expand All @@ -65,7 +67,7 @@ def __init__(self,
self.api_base = api_base or api_endpoint_config.get('api_base')

# Check for environment variables
api_endpoint_key = os.environ.get('API_ENDPOINT_KEY')
api_endpoint_key = os.getenv('API_ENDPOINT_KEY')
print(f"API_ENDPOINT_KEY from environment: {'Found' if api_endpoint_key else 'Not found'}")

# Set API key with priority: CLI arg > env var > config
Expand Down Expand Up @@ -167,6 +169,48 @@ def _openai_chat_completion(self,
debug_mode = os.environ.get('SDK_DEBUG', 'false').lower() == 'true'
if verbose:
logger.info(f"Sending request to {self.provider} model {self.model}...")

if "gemini" in self.model:
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent"
headers = {
"x-goog-api-key": self.api_key,
"Content-Type": "application/json",
}

# Convert messages to Gemini format
contents = []
for msg in messages:
# Gemini API does not support system role, so we convert it to user role
role = msg["role"] if msg["role"] != "system" else "user"
contents.append({
"role": role,
"parts": [{"text": msg["content"]}]
})

data = {
"contents": contents,
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens,
"topP": top_p,
}
}
for attempt in range(self.max_retries):
try:
response = requests.post(url, headers=headers, data=json.dumps(data))
response.raise_for_status()
response_json = response.json()
if "candidates" in response_json and len(response_json["candidates"]) > 0:
candidate = response_json["candidates"][0]
if "content" in candidate and "parts" in candidate["content"] and len(candidate["content"]["parts"]) > 0:
return candidate["content"]["parts"][0]["text"]
return "" # return empty string if no content found
except Exception as e:
if verbose:
logger.error(f"Gemini API error (attempt {attempt+1}/{self.max_retries}): {str(e)}")
if attempt == self.max_retries - 1:
raise Exception(f"Failed to get Gemini completion after {self.max_retries} attempts: {str(e)}")
time.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff

for attempt in range(self.max_retries):
try:
Expand Down Expand Up @@ -351,6 +395,11 @@ async def _process_message_async(self,
verbose: bool,
debug_mode: bool):
"""Process a single message set asynchronously using the OpenAI API"""
if "gemini" in self.model:
# This is a synchronous call inside an async function.
# Not ideal, but avoids adding new dependencies like httpx.
return self._openai_chat_completion(messages, temperature, max_tokens, top_p, verbose)

try:
from openai import AsyncOpenAI
except ImportError:
Expand Down
6 changes: 4 additions & 2 deletions synthetic_data_kit/utils/lance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# the root directory of this source tree.

import lance
from lance.dataset import write_dataset as lance_write_dataset
from lance.dataset import LanceDataset
import pyarrow as pa
from typing import List, Dict, Any, Optional
import os
Expand All @@ -30,7 +32,7 @@ def create_lance_dataset(
os.makedirs(output_dir)

table = pa.Table.from_pylist(data, schema=schema)
lance.write_dataset(table, output_path, mode="overwrite")
lance_write_dataset(table, output_path, mode="overwrite")

def load_lance_dataset(
dataset_path: str
Expand All @@ -45,4 +47,4 @@ def load_lance_dataset(
"""
if not os.path.exists(dataset_path):
return None
return lance.dataset(dataset_path)
return LanceDataset(dataset_path)