Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/proxy_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,67 @@ async def refresh_quota_stats(
raise HTTPException(status_code=500, detail=str(e))


@app.post("/v1/force-credential")
async def force_credential(
request: Request,
client: RotatingClient = Depends(get_rotating_client),
_=Depends(verify_api_key),
):
"""
Force the proxy to use a specific credential for all requests.

This overrides the normal rotation logic and always selects the specified
credential (if available and not on cooldown). Used by the TUI for manual
credential selection.

Request body:
{
"credential": "path/to/credential.json" | null,
"provider": "antigravity" // optional, for display/logging only
}

Set credential to null to clear the override and resume normal rotation.

Returns:
{
"status": "forced" | "cleared",
"credential": "path/to/credential.json" | null,
"message": "..."
}
"""
try:
data = await request.json()
credential = data.get("credential")
provider = data.get("provider", "unknown")

# Set or clear forced credential
await client.usage_manager.set_forced_credential(credential)

if credential:
# Extract a friendly display name from the credential path
if "/" in credential or "\\" in credential:
display_name = credential.split("/")[-1].split("\\")[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path splitting logic is a bit brittle. Consider using Path(credential).name from pathlib for a more robust cross-platform solution.

Suggested change
display_name = credential.split("/")[-1].split("\\")[-1]
display_name = Path(credential).name

else:
display_name = credential

return {
"status": "forced",
"credential": credential,
"provider": provider,
"message": f"Now forcing credential: {display_name}"
}
else:
return {
"status": "cleared",
"credential": None,
"message": "Forced credential cleared. Resuming normal rotation."
}

except Exception as e:
logging.error(f"Failed to set forced credential: {e}")
raise HTTPException(status_code=500, detail=str(e))


@app.post("/v1/token-count")
async def token_count(
request: Request,
Expand Down
81 changes: 81 additions & 0 deletions src/proxy_app/quota_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,9 @@ def show_provider_detail_screen(self, provider: str):
self.console.print(" G. Toggle view mode (current/global)")
self.console.print(" R. Reload stats (from proxy cache)")
self.console.print(" RA. Reload all stats")
self.console.print()
self.console.print(" [cyan]1-9. Force use of credential [N] (locks rotation)[/cyan]")
self.console.print(" C. Clear forced credential (resume normal rotation)")

# Force refresh options (only for providers that support it)
has_quota_groups = bool(
Expand Down Expand Up @@ -964,6 +967,84 @@ def show_provider_detail_screen(self, provider: str):
"[bold]Reloading all stats...", spinner="dots"
):
self.post_action("reload", scope="all")
elif choice.isdigit() and 1 <= int(choice) <= 9:
# Handle numeric selection (force credential)
idx = int(choice)
credentials = (
self.cached_stats.get("providers", {})
.get(provider, {})
.get("credentials", [])
if self.cached_stats
else []
)
# Sort credentials naturally to match display order
credentials = sorted(credentials, key=natural_sort_key)

if idx <= len(credentials):
cred = credentials[idx - 1]
# Use full_path for matching, fall back to identifier
cred_identifier = cred.get("full_path", cred.get("identifier", ""))
cred_email = cred.get("email", cred.get("identifier", ""))

# Call API to force this credential
url = self._build_endpoint_url("/v1/force-credential")
payload = {
"credential": cred_identifier,
"provider": provider
}

try:
with httpx.Client(timeout=10.0) as http_client:
response = http_client.post(
url,
headers=self._get_headers(),
json=payload
)
Comment on lines +997 to +1002
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a new httpx.Client for every request in a loop is slightly inefficient. For a TUI it likely doesn't matter, but using a single client instance for the QuotaViewer session might be cleaner for the long term.


if response.status_code == 200:
result = response.json()
self.console.print(
f"\n[green]✓ Forced credential:[/green] [{idx}] {cred_email}"
)
self.console.print(
"[dim]All requests will now use this credential (if available)[/dim]"
)
else:
self.console.print(
f"\n[red]Failed to force credential: HTTP {response.status_code}[/red]"
)
except Exception as e:
self.console.print(f"\n[red]Error: {e}[/red]")

Prompt.ask("Press Enter to continue", default="")
else:
self.console.print(f"\n[red]Invalid selection. Only {len(credentials)} credentials available.[/red]")
Prompt.ask("Press Enter to continue", default="")
elif choice == "C":
# Clear forced credential
url = self._build_endpoint_url("/v1/force-credential")
payload = {"credential": None}

try:
with httpx.Client(timeout=10.0) as http_client:
response = http_client.post(
url,
headers=self._get_headers(),
json=payload
)

if response.status_code == 200:
self.console.print(
"\n[green]✓ Forced credential cleared. Resuming normal rotation.[/green]"
)
else:
self.console.print(
f"\n[red]Failed to clear forced credential: HTTP {response.status_code}[/red]"
)
except Exception as e:
self.console.print(f"\n[red]Error: {e}[/red]")

Prompt.ask("Press Enter to continue", default="")
elif choice == "F" and has_quota_groups:
result = None
with self.console.status(
Expand Down
93 changes: 93 additions & 0 deletions src/rotator_library/usage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def __init__(
# Resilient writer for usage data persistence
self._state_writer = ResilientStateWriter(file_path, lib_logger)

# Forced credential for manual override (TUI control)
self._forced_credential: Optional[str] = None
self._forced_credential_lock = asyncio.Lock()

if daily_reset_time_utc:
hour, minute = map(int, daily_reset_time_utc.split(":"))
self.daily_reset_time_utc = dt_time(
Expand All @@ -182,6 +186,37 @@ def _get_rotation_mode(self, provider: str) -> str:
"""
return self.provider_rotation_modes.get(provider, "balanced")

# =========================================================================
# FORCED CREDENTIAL (TUI OVERRIDE)
# =========================================================================

async def set_forced_credential(self, credential: Optional[str]) -> None:
"""
Force the usage manager to use a specific credential for all requests.

This overrides the normal rotation logic and always selects the specified
credential, if it's available and not on cooldown.

Args:
credential: Full credential path/identifier, or None to clear the override
"""
async with self._forced_credential_lock:
self._forced_credential = credential
if credential:
lib_logger.info(f"Forced credential set to: {mask_credential(credential)}")
else:
lib_logger.info("Forced credential cleared")

async def get_forced_credential(self) -> Optional[str]:
"""
Get the currently forced credential, if any.

Returns:
The forced credential path/identifier, or None if no override is active
"""
async with self._forced_credential_lock:
return self._forced_credential

# =========================================================================
# FAIR CYCLE ROTATION HELPERS
# =========================================================================
Expand Down Expand Up @@ -2163,6 +2198,64 @@ async def acquire_key(
self._normalize_model(available_keys[0], model) if available_keys else model
)

# Check if a specific credential is forced (TUI override)
forced_cred = await self.get_forced_credential()
if forced_cred:
# Find matching credential - support both full path and filename matching
matched_cred = None
if forced_cred in available_keys:
matched_cred = forced_cred
else:
# Try matching by filename (basename)
for key in available_keys:
if key.endswith(forced_cred) or Path(key).name == forced_cred:
matched_cred = key
break
Comment on lines +2209 to +2213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching by basename (filename) could be ambiguous if multiple credentials have the same filename in different directories. While probably rare, it might be safer to prioritize exact matches (which you already do) and perhaps log a warning if multiple basename matches are found.


if matched_cred:
now = time.time()
async with self._data_lock:
key_data = self._usage_data.get(matched_cred, {})
# Check if forced credential is available (not on cooldown)
is_on_cooldown = (
(key_data.get("key_cooldown_until") or 0) > now or
(key_data.get("model_cooldowns", {}).get(normalized_model) or 0) > now
)

if not is_on_cooldown:
# Try to acquire the forced credential
state = self.key_states[matched_cred]
async with state["lock"]:
current_count = state["models_in_use"].get(model, 0)
if current_count < max_concurrent:
state["models_in_use"][model] = current_count + 1
tier_name = (
credential_tier_names.get(matched_cred, "unknown")
if credential_tier_names
else "unknown"
)
quota_display = self._get_quota_display(matched_cred, model)
lib_logger.info(
f"Acquired FORCED key {mask_credential(matched_cred)} for model {model} "
f"(tier: {tier_name}, {quota_display})"
)
return matched_cred
else:
lib_logger.warning(
f"Forced credential {mask_credential(matched_cred)} is at max concurrency "
f"({current_count}/{max_concurrent}), falling back to normal rotation"
)
else:
lib_logger.warning(
f"Forced credential {mask_credential(matched_cred)} is on cooldown, "
f"falling back to normal rotation"
)
else:
lib_logger.warning(
f"Forced credential {mask_credential(forced_cred)} not found in available credentials, "
f"falling back to normal rotation"
)
Comment on lines +2244 to +2257
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback to normal rotation when a forced credential is on cooldown or at max concurrency is a safe choice for availability. However, if a user forces a credential, they might prefer a clear error if it can't be used. Since the TUI mentions "(if available)", this behavior is at least documented, but it's worth considering if a stricter 'force' is needed.


# This loop continues as long as the global deadline has not been met.
while time.time() < deadline:
now = time.time()
Expand Down
Loading