diff --git a/.gitignore b/.gitignore index 64c1b658..a55c13e6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,6 @@ __pycache__ data/\ndata/\n*.pdf .venv-ci/ data/ -example_output \ No newline at end of file +example_output!/docs/Why_lang_models_hallucinate.pdf +!/docs/Why_lang_models_hallucinate.pdf +vllm/ \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 00000000..13566b81 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/copilot.data.migration.agent.xml b/.idea/copilot.data.migration.agent.xml new file mode 100644 index 00000000..4ea72a91 --- /dev/null +++ b/.idea/copilot.data.migration.agent.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/copilot.data.migration.ask.xml b/.idea/copilot.data.migration.ask.xml new file mode 100644 index 00000000..7ef04e2e --- /dev/null +++ b/.idea/copilot.data.migration.ask.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/copilot.data.migration.ask2agent.xml b/.idea/copilot.data.migration.ask2agent.xml new file mode 100644 index 00000000..1f2ea11e --- /dev/null +++ b/.idea/copilot.data.migration.ask2agent.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/copilot.data.migration.edit.xml b/.idea/copilot.data.migration.edit.xml new file mode 100644 index 00000000..8648f940 --- /dev/null +++ b/.idea/copilot.data.migration.edit.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 00000000..105ce2da --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 00000000..6b24166b --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,8 @@ + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 00000000..c4a72400 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/synthetic-data-kit.iml b/.idea/synthetic-data-kit.iml new file mode 100644 index 00000000..5e667b78 --- /dev/null +++ b/.idea/synthetic-data-kit.iml @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..03dcdbea --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 4fe198ed..824f355a 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,17 @@ mkdir -p data/{input,parsed,generated,curated,final} mkdir -p data/{pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final} ``` -- You also need a LLM backend that you will utilize for generating your dataset, if using vLLM: +- You also need a LLM backend that you will utilize for generating your dataset: +- if using ollama: +```bash +# Download from https://ollama.com/download +# get llamma model: + ollama pull llama3 +# Run Ollama + ollama serve +# server is ruuning at http://localhost:11434 +``` +- if using vLLM: ```bash # Start vLLM server diff --git a/configs/config.yaml b/configs/config.yaml index b105ac75..bc954267 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -14,8 +14,11 @@ paths: # LLM Provider configuration llm: + provider: "ollama" #Using api-endpoint for Ollama # Provider selection: "vllm" or "api-endpoint" - provider: "api-endpoint" + # provider: "api-endpoint" #Using api-endpoint for Llama API + + # VLLM server configuration vllm: @@ -35,6 +38,15 @@ api-endpoint: retry_delay: 1.0 # Initial delay between retries (seconds) sleep_time: 0.5 # Small delay in seconds between batches to avoid rate limits +# Ollama server configuration (for Ollama via OpenAI-compatible API) +ollama: + api_base: "http://localhost:11434/v1" # Ollama's OpenAI-compatible endpoint + api_key: "not-needed" # Ollama doesn't require an API key + model: "llama3:latest" # Your Ollama model + max_retries: 3 # Number of retries for API calls + retry_delay: 1.0 # Initial delay between retries (seconds) + + # Ingest configuration ingest: default_format: "txt" # Default output format for parsed files @@ -67,6 +79,39 @@ format: include_metadata: true # Include metadata in output files pretty_json: true # Use indentation in JSON output +# Provider-specific settings for different use cases +provider_configs: + # For local development and testing (using Ollama via api-endpoint) + local_dev: + provider: "api-endpoint" + api_base: "http://localhost:11434/v1" + api_key: "not-needed" + model: "llama3:latest" + generation: + temperature: 0.8 + max_tokens: 2048 + batch_size: 5 # Smaller batch for local resources + + # For production with Llama API + production: + provider: "api-endpoint" + api_base: "https://api.llama.com/v1" + api_key: "llama_api_key" + model: "Llama-4-Maverick-17B-128E-Instruct-FP8" + generation: + temperature: 0.7 + max_tokens: 4096 + batch_size: 32 + + # For high-performance local inference + local_vllm: + provider: "vllm" + model: "meta-llama/Llama-3.3-70B-Instruct" + generation: + temperature: 0.7 + max_tokens: 4096 + batch_size: 64 + # Prompts for different tasks prompts: # Summary generation prompt @@ -173,3 +218,11 @@ prompts: Original conversations: {conversations} + +# Environment variable mappings (optional) +env_vars: + API_ENDPOINT_KEY: "not-needed" # For api-endpoint provider (Ollama doesn't need a key) + OLLAMA_HOST: "http://localhost:11434" # For reference + VLLM_HOST: "http://localhost:8000" # For vllm provider +# SDK_VERBOSE: 'True' + diff --git a/docs/Why_lang_models_hallucinate.pdf b/docs/Why_lang_models_hallucinate.pdf new file mode 100644 index 00000000..d5c99e41 Binary files /dev/null and b/docs/Why_lang_models_hallucinate.pdf differ diff --git a/synthetic_data_kit/cli.py b/synthetic_data_kit/cli.py index e88dbb59..e31d045c 100644 --- a/synthetic_data_kit/cli.py +++ b/synthetic_data_kit/cli.py @@ -13,7 +13,7 @@ from rich.console import Console from rich.table import Table -from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_llm_provider, get_path_config +from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_llm_provider, get_path_config, get_ollama_config from synthetic_data_kit.core.context import AppContext from synthetic_data_kit.server.app import run_server @@ -338,6 +338,14 @@ def create( api_base = api_base or api_endpoint_config.get("api_base") model = model or api_endpoint_config.get("model") # No server check needed for API endpoint + + if provider == "ollama": + # Use Ollama config + ollama_config = get_ollama_config(ctx.config) + api_base = api_base or ollama_config.get("api_base") + model = model or ollama_config.get("model") + # No server check needed for Ollama endpoint + else: # Use vLLM config vllm_config = get_vllm_config(ctx.config) @@ -498,6 +506,14 @@ def curate( api_base = api_base or api_endpoint_config.get("api_base") model = model or api_endpoint_config.get("model") # No server check needed for API endpoint + + if provider == "ollama": + # Use Ollama config + ollama_config = get_ollama_config(ctx.config) + api_base = api_base or ollama_config.get("api_base") + model = model or ollama_config.get("model") + # No server check needed for Ollama endpoint + else: # Use vLLM config vllm_config = get_vllm_config(ctx.config) diff --git a/synthetic_data_kit/config.yaml b/synthetic_data_kit/config.yaml index 2690697d..5df1f59b 100644 --- a/synthetic_data_kit/config.yaml +++ b/synthetic_data_kit/config.yaml @@ -15,7 +15,8 @@ paths: # LLM Provider configuration llm: # Provider selection: "vllm" or "api-endpoint" - provider: "api-endpoint" + # provider: "api-endpoint" + provider: "ollama" # VLLM server configuration vllm: @@ -24,7 +25,15 @@ vllm: model: "meta-llama/Llama-3.3-70B-Instruct" # Default model to use max_retries: 3 # Number of retries for API calls retry_delay: 1.0 # Initial delay between retries (seconds) - + +#ollama server configuration +ollama: + api_base: "http://localhost:11434/v1" # Ollama's OpenAI-compatible endpoint + api_key: "not-needed" # Ollama doesn't require an API key + model: "llama3:latest" # Your Ollama model + max_retries: 3 # Number of retries for API calls + retry_delay: 1.0 # Initial delay between retries (seconds) + # API endpoint configuration api-endpoint: api_base: "https://api.llama.com/v1" # Optional base URL for API endpoint (null for default API) diff --git a/synthetic_data_kit/core/context.py b/synthetic_data_kit/core/context.py index 9642aa3a..52b2b2e3 100644 --- a/synthetic_data_kit/core/context.py +++ b/synthetic_data_kit/core/context.py @@ -21,8 +21,8 @@ def __init__(self, config_path: Optional[Path] = None): # Ensure data directories exist self._ensure_data_dirs() - # Why have separeate folders? Yes ideally you should just be able to ingest an input folder and have everything being ingested and converted BUT - # Managing context window is hard and there are more edge cases which needs to be handled carefully + # Why have separate folders? Yes, ideally, you should just be able to ingest an input folder and have everything being ingested and converted BUT + # Managing context window is hard and there are more edge cases that need to be handled carefully # it's also easier to debug in alpha if we have multiple files. def _ensure_data_dirs(self): """Ensure data directories exist based on configuration""" @@ -30,7 +30,7 @@ def _ensure_data_dirs(self): config = load_config(self.config_path) paths_config = config.get('paths', {}) - # Create input directory - handle new config format where input is a string + # Create input directory - handle a new config format where input is a string input_dir = paths_config.get('input', 'data/input') os.makedirs(input_dir, exist_ok=True) diff --git a/synthetic_data_kit/core/create.py b/synthetic_data_kit/core/create.py index 9c2dd78b..9ec734b6 100644 --- a/synthetic_data_kit/core/create.py +++ b/synthetic_data_kit/core/create.py @@ -33,7 +33,7 @@ def process_file( model: Optional[str] = None, content_type: str = "qa", num_pairs: Optional[int] = None, - verbose: bool = False, + verbose: bool = True, provider: Optional[str] = None, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, @@ -42,9 +42,14 @@ def process_file( """Process a file to generate content Args: + provider: llm provider to use + chunk_size: size of text chunks for processing + chunk_overlap: overlap between text chunks + rolling_summary: use rolling summary for context + verbose: determine if extra logging is needed file_path: Path to the text file to process output_dir: Directory to save generated content - config_path: Path to configuration file + config_path: Path to a configuration file api_base: VLLM API base URL model: Model to use content_type: Type of content to generate (qa, summary, cot) @@ -54,7 +59,7 @@ def process_file( Returns: Path to the output file """ - # Create output directory if it doesn't exist + # Create an output directory if it doesn't exist # The reason for having this directory logic for now is explained in context.py os.makedirs(output_dir, exist_ok=True) @@ -86,6 +91,7 @@ def process_file( documents = [{"text": read_json(file_path), "image": None}] if content_type == "qa": + print("Generating QA pairs...") generator = QAGenerator(client, config_path) # Get num_pairs from args or config @@ -93,7 +99,6 @@ def process_file( config = client.config generation_config = get_generation_config(config) num_pairs = generation_config.get("num_pairs", 25) - # Process document result = generator.process_documents( documents, diff --git a/synthetic_data_kit/core/curate.py b/synthetic_data_kit/core/curate.py index ec83ee7a..47ac99c1 100644 --- a/synthetic_data_kit/core/curate.py +++ b/synthetic_data_kit/core/curate.py @@ -28,6 +28,7 @@ def curate_qa_pairs( """Clean and filter QA pairs based on quality ratings Args: + provider: llm provider to use eg, vllm, ollama, etc. input_path: Path to the input file with QA pairs output_path: Path to save the cleaned output threshold: Quality threshold (1-10) @@ -117,7 +118,7 @@ def curate_qa_pairs( # This avoids conflicts with other output messages print(f"Processing {len(batches)} batches of QA pairs...") - # Only use detailed progress bar in verbose mode + # Only use a detailed progress bar in verbose mode if verbose: from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn diff --git a/synthetic_data_kit/core/ingest.py b/synthetic_data_kit/core/ingest.py index aba75c15..d50bd457 100644 --- a/synthetic_data_kit/core/ingest.py +++ b/synthetic_data_kit/core/ingest.py @@ -56,7 +56,7 @@ def determine_parser(file_path: str, config: Dict[str, Any], multimodal: bool = # Check if it's a URL if file_path.startswith(("http://", "https://")): # YouTube URL - if "youtube.com" in file_path or "youtu.be" in file_path: + if "youtube.com" in file_path or "youtube" in file_path: return YouTubeParser() # PDF URL elif _check_pdf_url(file_path): diff --git a/synthetic_data_kit/generators/qa_generator.py b/synthetic_data_kit/generators/qa_generator.py index e892cdd2..8bbb3cf6 100644 --- a/synthetic_data_kit/generators/qa_generator.py +++ b/synthetic_data_kit/generators/qa_generator.py @@ -50,7 +50,7 @@ def generate_summary(self, chunks = split_into_chunks(document_text, chunk_size=max_context_length, overlap=summary_overlap) - + # print(f"Document split into {len(chunks)} chunks") for chunk in chunks: messages = [ {"role": "system", "content": prompt}, @@ -71,12 +71,12 @@ def generate_summary(self, {"role": "system", "content": prompt}, {"role": "user", "content": document_text[0:max_context_length]} ] - + # print(self.client.config) summary = self.client.chat_completion( messages, temperature=0.1 # Use lower temperature for summaries ) - + if verbose: print(f"Summary generated ({len(summary)} chars)") return summary diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index 4f964100..100670f8 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -13,7 +13,7 @@ import asyncio from pathlib import Path -from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_llm_provider +from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_llm_provider, get_ollama_config # Set up logging logging.basicConfig(level=logging.INFO) @@ -82,6 +82,24 @@ def __init__(self, # Initialize OpenAI client self._init_openai_client() + + if self.provider=='ollama': + # Load Ollama configuration + ollama_config = get_ollama_config(self.config) + + # Set parameters, with CLI overrides taking precedence + self.api_base = api_base or ollama_config.get('api_base') + self.model = model_name or ollama_config.get('model') + self.max_retries = max_retries or ollama_config.get('max_retries') + self.retry_delay = retry_delay or ollama_config.get('retry_delay') + self.sleep_time = ollama_config.get('sleep_time',0.1) + + # No client to initialize for Ollama as we use requests directly + # Verify server is running + available, info = self._check_llm_server() + if not available: + raise ConnectionError(f"Ollama server not available at {self.api_base}: {info}") + else: # Default to vLLM # Load vLLM configuration vllm_config = get_vllm_config(self.config) @@ -95,7 +113,7 @@ def __init__(self, # No client to initialize for vLLM as we use requests directly # Verify server is running - available, info = self._check_vllm_server() + available, info = self._check_llm_server() if not available: raise ConnectionError(f"VLLM server not available at {self.api_base}: {info}") @@ -118,8 +136,28 @@ def _init_openai_client(self): self.openai_client = OpenAI(**client_kwargs) - def _check_vllm_server(self) -> tuple: - """Check if the VLLM server is running and accessible""" + # def _check_vllm_server(self) -> tuple: + # """Check if the VLLM server is running and accessible""" + # try: + # response = requests.get(f"{self.api_base}/models", timeout=5) + # if response.status_code == 200: + # return True, response.json() + # return False, f"Server returned status code: {response.status_code}" + # except requests.exceptions.RequestException as e: + # return False, f"Server connection error: {str(e)}" + # + # def _check_ollama_server(self) -> tuple: + # """Check if the Ollama server is running and accessible""" + # try: + # response = requests.get(f"{self.api_base}/models", timeout=5) + # if response.status_code == 200: + # return True, response.json() + # return False, f"Server returned status code: {response.status_code}" + # except requests.exceptions.RequestException as e: + # return False, f"Server connection error: {str(e)}" + + def _check_llm_server(self) -> tuple: + """Check if the vllm or Ollama server is running and accessible""" try: response = requests.get(f"{self.api_base}/models", timeout=5) if response.status_code == 200: diff --git a/synthetic_data_kit/utils/config.py b/synthetic_data_kit/utils/config.py index e17600f7..d817151a 100644 --- a/synthetic_data_kit/utils/config.py +++ b/synthetic_data_kit/utils/config.py @@ -20,14 +20,15 @@ os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.yaml") ) -# Use internal package path as default -DEFAULT_CONFIG_PATH = PACKAGE_CONFIG_PATH +# Use config/config.yml package path as default +DEFAULT_CONFIG_PATH = ORIGINAL_CONFIG_PATH def load_config(config_path: Optional[str] = None) -> Dict[str, Any]: """Load YAML configuration file""" if config_path is None: # Try each path in order until one exists - for path in [PACKAGE_CONFIG_PATH, ORIGINAL_CONFIG_PATH]: + # for path in [ORIGINAL_CONFIG_PATH, PACKAGE_CONFIG_PATH]: + for path in [ORIGINAL_CONFIG_PATH, PACKAGE_CONFIG_PATH]: if os.path.exists(path): config_path = path break @@ -79,7 +80,7 @@ def get_llm_provider(config: Dict[str, Any]) -> str: """Get the selected LLM provider Returns: - String with provider name: 'vllm' or 'api-endpoint' + String with provider name: 'vllm' or 'api-endpoint' or 'ollama' """ llm_config = config.get('llm', {}) provider = llm_config.get('provider', 'vllm') @@ -98,6 +99,16 @@ def get_vllm_config(config: Dict[str, Any]) -> Dict[str, Any]: 'retry_delay': 1.0 }) +def get_ollama_config(config: Dict[str, Any]) -> Dict[str, Any]: + """Get VLLM configuration""" + return config.get('ollama', { + 'api_base': "http://localhost:11434/v1", # Ollama's OpenAI-compatible endpoint + 'api_key': "not-needed", # Ollama doesn't require an API key + 'model': "llama3:latest", # Your Ollama model + 'max_retries': 3 , # Number of retries for API calls + 'retry_delay': 1.0 , # Initial delay between retries (seconds) + }) + def get_openai_config(config: Dict[str, Any]) -> Dict[str, Any]: """Get API endpoint configuration""" return config.get('api-endpoint', { diff --git a/tests/conftest.py b/tests/conftest.py index 892b78d8..e1b2dc98 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,6 +199,26 @@ def create_vllm_config(model="mock-vllm-model"): }, } + def create_ollama_config(self, model="mock-ollama-model"): + """Create a mock Ollama configuration.""" + return { + "llm": {"provider": "ollama"}, + "ollama": { + "model": model, + "max_retries": 3, + "retry_delay": 1, + }, + "generation": { + "temperature": 0.3, + "max_tokens": 4096, + "top_p": 0.95, + "batch_size": 8, + }, + "paths": { + "data_dir": "data", + "output_dir": "output", + }, + } @pytest.fixture def config_factory(): @@ -243,10 +263,13 @@ def patch_vllm_config(config_factory): mock_load_config.return_value = config_factory.create_vllm_config() yield mock_load_config +def patch_ollama_config(config_factory): + """Patch the config loader to return an Ollama configuration.""" + with patch("synthetic_data_kit.utils.config.load_config") as mock_load_config: + mock_load_config.return_value = config_factory.create_ollama_config() + yield mock_load_config # Additional utility fixtures for common test patterns - - @pytest.fixture def temp_output_dir(): """Create a temporary output directory for tests.""" diff --git a/tests/unit/test_llm_client.py b/tests/unit/test_llm_client.py index d030f652..d2220d5f 100644 --- a/tests/unit/test_llm_client.py +++ b/tests/unit/test_llm_client.py @@ -44,6 +44,23 @@ def test_llm_client_vllm_initialization(patch_config, test_env): # Check that vLLM server was checked assert mock_get.called +def test_llm_client_ollama_initialization(patch_config, test_env): + """Test LLM client initialization with Ollama provider.""" + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [{"name": "mock-model"}] + mock_get.return_value = mock_response + + # Initialize client + client = LLMClient(provider="ollama") + + # Check that the client was initialized correctly + assert client.provider == "ollama" + assert client.api_base is not None + assert client.model is not None + # Check that Ollama server was checked + assert mock_get.called @pytest.mark.unit def test_llm_client_chat_completion(patch_config, test_env): @@ -89,7 +106,6 @@ def test_llm_client_chat_completion(patch_config, test_env): # Check that OpenAI client was called assert mock_create.called - @pytest.mark.unit def test_llm_client_vllm_chat_completion(patch_config, test_env): """Test LLM client chat completion with vLLM provider.""" @@ -123,3 +139,36 @@ def test_llm_client_vllm_chat_completion(patch_config, test_env): assert response == "This is a test response" # Check that vLLM API was called assert mock_post.called + +def test_llm_client_ollama_chat_completion(patch_config, test_env): + """Test LLM client chat completion with Ollama provider.""" + with patch("requests.post") as mock_post, patch("requests.get") as mock_get: + # Mock Ollama server check + mock_check_response = MagicMock() + mock_check_response.status_code = 200 + mock_check_response.json.return_value = [{"name": "mock-model"}] + mock_get.return_value = mock_check_response + + # Mock Ollama API response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "response": "This is a test response" + } + mock_post.return_value = mock_response + + # Initialize client + client = LLMClient(provider="ollama") + + # Test chat completion + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is synthetic data?"}, + ] + + response = client.chat_completion(messages, temperature=0.7) + + # Check that the response is correct + assert response == "This is a test response" + # Check that Ollama API was called + assert mock_post.called diff --git a/vllm b/vllm new file mode 160000 index 00000000..57329a8c --- /dev/null +++ b/vllm @@ -0,0 +1 @@ +Subproject commit 57329a8c013ca9e5d575faad3f04436f2eabad15