Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
383 changes: 383 additions & 0 deletions python-bridge/agent_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,383 @@
"""
Agent Gateway - Multi-LLM Router
================================
Routes queries to multiple LLM providers, scores responses,
and allows human selection or auto-selects best response.

Providers:
- Ollama (local, free)
- Gemini (fast, cheap)
- ChatGPT (accurate, expensive)
"""

import os
import asyncio
import logging
import time
from typing import Dict, List, Optional, Callable
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import requests

from confidence_scorer import ConfidenceScorer, ScoredResponse

logger = logging.getLogger(__name__)


@dataclass
class GatewayConfig:
"""Configuration for Agent Gateway"""
mode: str = "auto" # auto, human, fallback
timeout: float = 10.0 # seconds to wait for human
parallel: bool = True # query all providers in parallel
max_response_length: int = 200 # for chat

# Provider toggles
use_ollama: bool = True
use_gemini: bool = True
use_chatgpt: bool = True

# API Keys (from env)
gemini_api_key: str = field(default_factory=lambda: os.getenv("GEMINI_API_KEY", ""))
openai_api_key: str = field(default_factory=lambda: os.getenv("OPENAI_API_KEY", ""))

# Model names
ollama_model: str = "llama3.2"
gemini_model: str = "gemini-2.0-flash"
chatgpt_model: str = "gpt-4o"


class LLMProvider(ABC):
"""Base class for LLM providers"""

@property
@abstractmethod
def name(self) -> str:
pass

@abstractmethod
async def query(self, prompt: str, max_tokens: int = 150) -> str:
pass


class OllamaProvider(LLMProvider):
"""Local Ollama provider"""

def __init__(self, model: str = "llama3.2", base_url: str = "http://localhost:11434"):
self.model = model
self.base_url = base_url

@property
def name(self) -> str:
return f"ollama:{self.model}"

async def query(self, prompt: str, max_tokens: int = 150) -> str:
try:
response = await asyncio.to_thread(
requests.post,
f"{self.base_url}/api/generate",
json={
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {"num_predict": max_tokens, "temperature": 0.7}
},
timeout=30
)
if response.status_code == 200:
return response.json().get("response", "").strip()
except Exception as e:
logger.error(f"Ollama error: {e}")
return ""


class GeminiProvider(LLMProvider):
"""Google Gemini provider"""

def __init__(self, api_key: str, model: str = "gemini-2.0-flash"):
self.api_key = api_key
self.model = model
self.base_url = "https://generativelanguage.googleapis.com/v1beta"

@property
def name(self) -> str:
return f"gemini:{self.model}"

async def query(self, prompt: str, max_tokens: int = 150) -> str:
if not self.api_key:
logger.warning("Gemini API key not set")
return ""

try:
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
response = await asyncio.to_thread(
requests.post,
url,
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"maxOutputTokens": max_tokens,
"temperature": 0.7
}
},
timeout=30
)
if response.status_code == 200:
data = response.json()
candidates = data.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
if parts:
return parts[0].get("text", "").strip()
except Exception as e:
logger.error(f"Gemini error: {e}")
return ""


class ChatGPTProvider(LLMProvider):
"""OpenAI ChatGPT provider"""

def __init__(self, api_key: str, model: str = "gpt-4o"):
self.api_key = api_key
self.model = model
self.base_url = "https://api.openai.com/v1/chat/completions"

@property
def name(self) -> str:
return f"openai:{self.model}"

async def query(self, prompt: str, max_tokens: int = 150) -> str:
if not self.api_key:
logger.warning("OpenAI API key not set")
return ""

try:
response = await asyncio.to_thread(
requests.post,
self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.7
},
timeout=30
)
if response.status_code == 200:
data = response.json()
choices = data.get("choices", [])
if choices:
return choices[0].get("message", {}).get("content", "").strip()
except Exception as e:
logger.error(f"ChatGPT error: {e}")
return ""


class AgentGateway:
"""
Multi-LLM Agent Gateway with confidence scoring.

Usage:
gateway = AgentGateway()
result = await gateway.query_best("Solve 3Sum problem", problem_context)
"""

def __init__(self, config: GatewayConfig = None):
self.config = config or GatewayConfig()
self.scorer = ConfidenceScorer()
self.providers: List[LLMProvider] = []
self._setup_providers()

# Callback for human selection UI
self.human_callback: Optional[Callable] = None

logger.info(f"🚪 Agent Gateway initialized with {len(self.providers)} providers")

def _setup_providers(self):
"""Initialize enabled providers"""
if self.config.use_ollama:
self.providers.append(OllamaProvider(self.config.ollama_model))

if self.config.use_gemini and self.config.gemini_api_key:
self.providers.append(GeminiProvider(
self.config.gemini_api_key,
self.config.gemini_model
))

if self.config.use_chatgpt and self.config.openai_api_key:
self.providers.append(ChatGPTProvider(
self.config.openai_api_key,
self.config.chatgpt_model
))

def set_problem_context(self, problem_text: str):
"""Set problem context for relevance scoring"""
self.scorer.set_problem_context(problem_text)

def set_human_callback(self, callback: Callable):
"""Set callback for human selection UI"""
self.human_callback = callback

async def query_all(self, prompt: str) -> Dict[str, str]:
"""Query all providers in parallel"""
tasks = [
self._query_provider(provider, prompt)
for provider in self.providers
]
results = await asyncio.gather(*tasks, return_exceptions=True)

responses = {}
for provider, result in zip(self.providers, results):
if isinstance(result, Exception):
logger.error(f"Provider {provider.name} failed: {result}")
responses[provider.name] = ""
else:
responses[provider.name] = result

return responses

async def _query_provider(self, provider: LLMProvider, prompt: str) -> str:
"""Query a single provider with timeout"""
try:
return await asyncio.wait_for(
provider.query(prompt, self.config.max_response_length),
timeout=30
)
except asyncio.TimeoutError:
logger.warning(f"Provider {provider.name} timed out")
return ""

async def query_best(
self,
prompt: str,
problem_context: str = ""
) -> Optional[ScoredResponse]:
"""
Query all providers and return best response.

In 'human' mode, calls human_callback for selection.
In 'auto' mode, returns highest confidence score.
"""
# Set problem context for relevance scoring
if problem_context:
self.set_problem_context(problem_context)

# Build the full prompt
full_prompt = self._build_prompt(prompt, problem_context)

# Query all providers
logger.info(f"🔍 Querying {len(self.providers)} providers...")
start = time.time()
responses = await self.query_all(full_prompt)
elapsed = time.time() - start
logger.info(f"⏱️ All responses received in {elapsed:.2f}s")

# Filter empty responses
valid_responses = {k: v for k, v in responses.items() if v}
if not valid_responses:
logger.warning("⚠️ No valid responses from any provider")
return None

# Score and rank responses
ranked = self.scorer.rank_responses(valid_responses)

# Log scores
for r in ranked:
logger.info(f" {r.provider}: {r.total_score:.2f}")

# Selection based on mode
if self.config.mode == "human" and self.human_callback:
return await self._human_selection(ranked)
else:
# Auto: return best
return ranked[0] if ranked else None

def _build_prompt(self, query: str, problem_context: str) -> str:
"""Build a structured prompt for consistent responses"""
return f"""You are a coding assistant helping with a live stream.
Give a SHORT, SPECIFIC answer suitable for YouTube chat (under 200 chars).
Focus on the algorithmic hint, not full code.

Problem Context:
{problem_context}

User Query: {query}

Your concise answer (one line, under 200 chars):"""

async def _human_selection(
self,
ranked: List[ScoredResponse]
) -> Optional[ScoredResponse]:
"""
Present choices to human and wait for selection.
Falls back to auto-select on timeout.
"""
if self.human_callback:
try:
# Call the UI to display choices
selection = await asyncio.wait_for(
asyncio.to_thread(self.human_callback, ranked),
timeout=self.config.timeout
)
if selection is not None and 0 <= selection < len(ranked):
logger.info(f"👤 Human selected: {ranked[selection].provider}")
return ranked[selection]
except asyncio.TimeoutError:
logger.info("⏰ Human selection timeout, using best score")

return ranked[0] if ranked else None

async def query_fallback(self, prompt: str, problem_context: str = "") -> str:
"""
Fallback chain: Ollama → Gemini → ChatGPT
Returns first successful response.
"""
if problem_context:
self.set_problem_context(problem_context)

full_prompt = self._build_prompt(prompt, problem_context)

for provider in self.providers:
response = await self._query_provider(provider, full_prompt)
if response:
score = self.scorer.score_response(response, provider.name)
# Accept if score is above threshold
if score.total_score >= 0.5:
logger.info(f"✅ Using {provider.name} (score: {score.total_score:.2f})")
return response
logger.info(f"⚠️ {provider.name} response too weak ({score.total_score:.2f})")

return ""


# CLI test
if __name__ == "__main__":
import asyncio
logging.basicConfig(level=logging.INFO)

async def test():
config = GatewayConfig()
gateway = AgentGateway(config)

problem = """
3 Sum (Closest): Given an array A of N integers,
find three integers in A such that the sum is closest to B.
"""

result = await gateway.query_best(
"What's the optimal approach?",
problem
)

if result:
print(f"\n🏆 Best Response ({result.provider}):")
print(f" Score: {result.total_score:.2f}")
print(f" Response: {result.response}")

asyncio.run(test())
Loading