Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ voicemode config get VOICEMODE_TTS_VOICE
voicemode config set VOICEMODE_TTS_VOICE nova
```

### Proxy Service Configuration

VoiceMode supports custom HTTP headers for proxy services like Portkey:

```bash
# Set comma-separated headers for TTS requests
voicemode config set VOICEMODE_TTS_EXTRA_HEADERS "X-Portkey-API-Key=pk_xxx,X-Portkey-Provider=@openai"

# Set comma-separated headers for STT requests
voicemode config set VOICEMODE_STT_EXTRA_HEADERS "X-Portkey-API-Key=pk_xxx,X-Portkey-Provider=@openai"

# View current header configuration
voicemode config get VOICEMODE_TTS_EXTRA_HEADERS
```

### Building & Publishing
```bash
# Build Python package
Expand Down Expand Up @@ -96,6 +111,7 @@ make docs-check
- Environment-based configuration with sensible defaults
- Support for voice preference files (project/user level)
- Audio format configuration (PCM, MP3, WAV, FLAC, AAC, Opus)
- Custom HTTP headers for proxy services via comma-separated format

5. **Resources (`voice_mode/resources/`)**
- MCP resources exposed for client access
Expand Down
220 changes: 220 additions & 0 deletions tests/test_extra_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Tests for custom HTTP header support."""

import os
import pytest
from unittest.mock import patch, MagicMock
from voice_mode.config import parse_extra_headers


class TestHeaderParsing:
"""Test comma-separated header parsing functionality."""

def test_parse_valid_headers(self):
"""Test parsing valid comma-separated headers."""
with patch.dict(os.environ, {"TEST_VAR": "X-Custom=value"}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {"X-Custom": "value"}

def test_parse_multiple_headers(self):
"""Test parsing multiple headers."""
with patch.dict(os.environ, {
"TEST_VAR": "X-API-Key=key123,X-Provider=test"
}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {"X-API-Key": "key123", "X-Provider": "test"}

def test_parse_empty_headers(self):
"""Test parsing empty header string."""
headers = parse_extra_headers("NONEXISTENT_VAR", "")
assert headers == {}

def test_parse_empty_string(self):
"""Test parsing empty string returns empty dict."""
with patch.dict(os.environ, {"TEST_VAR": ""}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {}

def test_parse_missing_equals(self):
"""Test handling of pairs without equals sign."""
with patch.dict(os.environ, {"TEST_VAR": "InvalidHeader"}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {}

def test_parse_empty_key(self):
"""Test handling of empty key."""
with patch.dict(os.environ, {"TEST_VAR": "=value"}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {}

def test_parse_value_with_equals(self):
"""Test parsing value that contains equals sign."""
with patch.dict(os.environ, {"TEST_VAR": "X-Header=value=with=equals"}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {"X-Header": "value=with=equals"}

def test_parse_with_whitespace(self):
"""Test parsing with extra whitespace."""
with patch.dict(os.environ, {
"TEST_VAR": " X-Custom = value , X-Other = test "
}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {"X-Custom": "value", "X-Other": "test"}

def test_parse_empty_pairs(self):
"""Test handling of empty pairs (consecutive commas)."""
with patch.dict(os.environ, {"TEST_VAR": "X-Key=val1,,X-Key2=val2"}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {"X-Key": "val1", "X-Key2": "val2"}

def test_parse_portkey_example(self):
"""Test parsing real-world Portkey header example."""
with patch.dict(os.environ, {
"TEST_VAR": "X-Portkey-API-Key=pk_xxx,X-Portkey-Provider=@openai"
}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {
"X-Portkey-API-Key": "pk_xxx",
"X-Portkey-Provider": "@openai"
}

def test_parse_from_env_var(self):
"""Test reading from actual environment variable."""
with patch.dict(os.environ, {"VOICEMODE_TEST_HEADERS": "X-Test=value"}):
headers = parse_extra_headers("VOICEMODE_TEST_HEADERS")
assert headers == {"X-Test": "value"}

def test_fallback_value(self):
"""Test fallback parameter when env var not set."""
headers = parse_extra_headers("NONEXISTENT", "X-Default=value")
assert headers == {"X-Default": "value"}

def test_empty_value(self):
"""Test parsing header with empty value."""
with patch.dict(os.environ, {"TEST_VAR": "X-Header="}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {"X-Header": ""}

def test_special_characters_in_value(self):
"""Test parsing values with special characters."""
with patch.dict(os.environ, {"TEST_VAR": "X-Header=@value-with_special.chars"}):
headers = parse_extra_headers("TEST_VAR")
assert headers == {"X-Header": "@value-with_special.chars"}

def test_multiple_commas_in_value(self):
"""Test that commas in values are not supported (limitation of format)."""
# This is a known limitation - values cannot contain commas
with patch.dict(os.environ, {"TEST_VAR": "X-Header=val1,val2"}):
headers = parse_extra_headers("TEST_VAR")
# This will parse as two headers, second one invalid
# Only X-Header=val1 is valid
assert "X-Header" in headers


class TestConfigVariables:
"""Test that config variables are properly initialized."""

def test_tts_extra_headers_default(self):
"""Test TTS_EXTRA_HEADERS defaults to empty dict."""
from voice_mode import config
# Default should be empty dict when env var not set
if "VOICEMODE_TTS_EXTRA_HEADERS" not in os.environ:
assert config.TTS_EXTRA_HEADERS == {}

def test_stt_extra_headers_default(self):
"""Test STT_EXTRA_HEADERS defaults to empty dict."""
from voice_mode import config
# Default should be empty dict when env var not set
if "VOICEMODE_STT_EXTRA_HEADERS" not in os.environ:
assert config.STT_EXTRA_HEADERS == {}

def test_headers_from_environment(self):
"""Test that headers are loaded from environment variables."""
import importlib
with patch.dict(os.environ, {
"VOICEMODE_TTS_EXTRA_HEADERS": "X-TTS=test",
"VOICEMODE_STT_EXTRA_HEADERS": "X-STT=test"
}):
# Reload config to pick up new env vars
from voice_mode import config
importlib.reload(config)

assert config.TTS_EXTRA_HEADERS == {"X-TTS": "test"}
assert config.STT_EXTRA_HEADERS == {"X-STT": "test"}


class TestClientInstantiation:
"""Test that headers are passed to AsyncOpenAI clients."""

@pytest.mark.asyncio
async def test_core_get_openai_clients_with_headers(self):
"""Test get_openai_clients passes headers correctly."""
from voice_mode.core import get_openai_clients
from unittest.mock import AsyncMock

with patch.dict(os.environ, {
"VOICEMODE_TTS_EXTRA_HEADERS": "X-TTS=test",
"VOICEMODE_STT_EXTRA_HEADERS": "X-STT=test"
}):
# Reload config to pick up new env vars
import importlib
from voice_mode import config
importlib.reload(config)

with patch('voice_mode.core.AsyncOpenAI') as mock_openai_class:
mock_client = AsyncMock()
mock_openai_class.return_value = mock_client

clients = get_openai_clients(
api_key="test-key",
stt_base_url="http://test-stt",
tts_base_url="http://test-tts"
)

# Verify AsyncOpenAI was called twice
assert mock_openai_class.call_count == 2

# Get all calls
calls = mock_openai_class.call_args_list

# Find STT and TTS calls
stt_call = None
tts_call = None
for call in calls:
if call.kwargs.get('base_url') == 'http://test-stt':
stt_call = call
elif call.kwargs.get('base_url') == 'http://test-tts':
tts_call = call

# Verify headers were passed
assert stt_call is not None
assert tts_call is not None
assert stt_call.kwargs.get('default_headers') == {"X-STT": "test"}
assert tts_call.kwargs.get('default_headers') == {"X-TTS": "test"}

def test_headers_none_when_empty(self):
"""Test that None is passed when headers are empty."""
from voice_mode.core import get_openai_clients
from unittest.mock import AsyncMock

with patch.dict(os.environ, {}, clear=True):
# Reload config to clear headers
import importlib
from voice_mode import config
importlib.reload(config)

with patch('voice_mode.core.AsyncOpenAI') as mock_openai_class:
mock_client = AsyncMock()
mock_openai_class.return_value = mock_client

clients = get_openai_clients(
api_key="test-key",
stt_base_url="http://test-stt",
tts_base_url="http://test-tts"
)

# Verify AsyncOpenAI was called with None or empty dict for default_headers
calls = mock_openai_class.call_args_list
for call in calls:
headers = call.kwargs.get('default_headers')
# Should be None or empty dict
assert headers is None or headers == {}
52 changes: 52 additions & 0 deletions voice_mode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ def load_voicemode_env():
# OpenAI API key for cloud TTS/STT
# OPENAI_API_KEY=your-key-here

# Custom HTTP headers for proxy services (comma-separated key=value pairs)
# Useful for services like Portkey that require custom headers
# VOICEMODE_TTS_EXTRA_HEADERS=X-Portkey-API-Key=pk_xxx,X-Portkey-Provider=@openai
# VOICEMODE_STT_EXTRA_HEADERS=X-Portkey-API-Key=pk_xxx,X-Portkey-Provider=@openai

# LiveKit server URL
# LIVEKIT_URL=ws://127.0.0.1:7880

Expand Down Expand Up @@ -483,12 +488,59 @@ def parse_comma_list(env_var: str, fallback: str) -> list:
value = os.getenv(env_var, fallback)
return [item.strip() for item in value.split(",") if item.strip()]

def parse_extra_headers(env_var: str, fallback: str = "") -> dict:
"""Parse comma-separated header pairs from environment variable.

Format: "Key1=value1,Key2=value2"
Example: "X-Portkey-API-Key=pk_xxx,X-Portkey-Provider=@openai"

Args:
env_var: Environment variable name
fallback: Fallback string if env var not set

Returns:
Dict of headers, empty dict if parsing fails
"""
value = os.getenv(env_var, fallback)
if not value or value.strip() == "":
return {}

headers = {}
pairs = value.split(",")

for pair in pairs:
pair = pair.strip()
if not pair:
continue

if "=" not in pair:
logger.warning(f"Invalid header pair in {env_var}: '{pair}' (missing '=')")
continue

# Split on first '=' only, in case value contains '='
key, sep, val = pair.partition("=")
key = key.strip()
val = val.strip()

if not key:
logger.warning(f"Invalid header pair in {env_var}: '{pair}' (empty key)")
continue

headers[key] = val

return headers

# New provider endpoint lists configuration
TTS_BASE_URLS = parse_comma_list("VOICEMODE_TTS_BASE_URLS", "http://127.0.0.1:8880/v1,https://api.openai.com/v1")
STT_BASE_URLS = parse_comma_list("VOICEMODE_STT_BASE_URLS", "http://127.0.0.1:2022/v1,https://api.openai.com/v1")
TTS_VOICES = parse_comma_list("VOICEMODE_VOICES", "af_sky,alloy")
TTS_MODELS = parse_comma_list("VOICEMODE_TTS_MODELS", "tts-1,tts-1-hd,gpt-4o-mini-tts")

# Custom HTTP headers for TTS/STT providers (comma-separated key=value format)
# Useful for proxy services like Portkey
TTS_EXTRA_HEADERS = parse_extra_headers("VOICEMODE_TTS_EXTRA_HEADERS", "")
STT_EXTRA_HEADERS = parse_extra_headers("VOICEMODE_STT_EXTRA_HEADERS", "")

# Voice preferences cache
_cached_voice_preferences: Optional[list] = None
_voice_preferences_loaded = False
Expand Down
8 changes: 6 additions & 2 deletions voice_mode/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def save_debug_file(data: bytes, prefix: str, extension: str, debug_dir: Path, d

def get_openai_clients(api_key: str, stt_base_url: Optional[str] = None, tts_base_url: Optional[str] = None) -> dict:
"""Initialize OpenAI clients for STT and TTS with connection pooling"""
from .config import STT_EXTRA_HEADERS, TTS_EXTRA_HEADERS

# Configure timeouts and connection pooling
http_client_config = {
'timeout': httpx.Timeout(30.0, connect=5.0),
Expand All @@ -150,13 +152,15 @@ def get_openai_clients(api_key: str, stt_base_url: Optional[str] = None, tts_bas
api_key=api_key,
base_url=stt_base_url,
http_client=stt_http_client,
max_retries=stt_max_retries
max_retries=stt_max_retries,
default_headers=STT_EXTRA_HEADERS or None
),
'tts': AsyncOpenAI(
api_key=api_key,
base_url=tts_base_url,
http_client=tts_http_client,
max_retries=tts_max_retries
max_retries=tts_max_retries,
default_headers=TTS_EXTRA_HEADERS or None
)
}

Expand Down
8 changes: 6 additions & 2 deletions voice_mode/provider_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from openai import AsyncOpenAI

from . import config
from .config import TTS_BASE_URLS, STT_BASE_URLS, OPENAI_API_KEY
from .config import TTS_BASE_URLS, STT_BASE_URLS, OPENAI_API_KEY, TTS_EXTRA_HEADERS, STT_EXTRA_HEADERS

logger = logging.getLogger("voicemode")

Expand Down Expand Up @@ -141,10 +141,14 @@ async def _discover_endpoint(self, service_type: str, base_url: str) -> None:

try:
# Create OpenAI client for the endpoint
# Determine which headers to use based on service_type
extra_headers = TTS_EXTRA_HEADERS if service_type == "tts" else STT_EXTRA_HEADERS

client = AsyncOpenAI(
api_key=OPENAI_API_KEY or "dummy-key-for-local",
base_url=base_url,
timeout=10.0
timeout=10.0,
default_headers=extra_headers or None
)

# Try to list models
Expand Down
Loading