Skip to content

Commit

Permalink
fix(provider_manager): fix custom provider
Browse files Browse the repository at this point in the history
Signed-off-by: -LAN- <[email protected]>
  • Loading branch information
laipz8200 committed Feb 25, 2025
1 parent 76bcdc2 commit 2355509
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions api/core/provider_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,11 @@ def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:

provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers:
provider_name_to_provider_records_dict[provider.provider_name].append(provider)
if provider.provider_name.startswith(DEFAULT_PLUGIN_ID):
provider_name = provider.provider_name
else:
provider_name = f"{DEFAULT_PLUGIN_ID}/{provider.provider_name}/{provider.provider_name}"
provider_name_to_provider_records_dict[provider_name].append(provider)

return provider_name_to_provider_records_dict

Expand Down Expand Up @@ -505,14 +509,12 @@ def _init_trial_provider_records(
if quota.quota_type == ProviderQuotaType.TRIAL:
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
if not provider_name.startswith(DEFAULT_PLUGIN_ID):
continue
hosting_provider_name = provider_name.split("/")[-1]
try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider(
tenant_id=tenant_id,
provider_name=hosting_provider_name,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit, # type: ignore
Expand All @@ -527,13 +529,12 @@ def _init_trial_provider_records(
db.session.query(Provider)
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == hosting_provider_name,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
.first()
)

if provider_record and not provider_record.is_valid:
provider_record.is_valid = True
db.session.commit()
Expand Down

0 comments on commit 2355509

Please sign in to comment.