From 668502fd94c45ee4df5ac358af64741b823325c8 Mon Sep 17 00:00:00 2001
From: OliverBryant <2713999266@qq.com>
Date: Wed, 22 Oct 2025 15:45:53 +0800
Subject: [PATCH 05/25] =?UTF-8?q?FEAT=EF=BC=9Aadd=20model=20backend?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
xinference/api/restful_api.py | 46 ++++++
xinference/core/supervisor.py | 273 ++++++++++++++++++++++++++++++++++
2 files changed, 319 insertions(+)
diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py
index 84c7b18d80..8dea5ab6c8 100644
--- a/xinference/api/restful_api.py
+++ b/xinference/api/restful_api.py
@@ -198,6 +198,11 @@ class RegisterModelRequest(BaseModel):
persist: bool
+class AddModelRequest(BaseModel):
+ model_type: str
+ model_json: Dict[str, Any]
+
+
class BuildGradioInterfaceRequest(BaseModel):
model_type: str
model_name: str
@@ -900,6 +905,16 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
+ self._router.add_api_route(
+ "/v1/models/add",
+ self.add_model,
+ methods=["POST"],
+ dependencies=(
+ [Security(self._auth_service, scopes=["models:add"])]
+ if self.is_authenticated()
+ else None
+ ),
+ )
self._router.add_api_route(
"/v1/cache/models",
self.list_cached_models,
@@ -3123,6 +3138,37 @@ async def unregister_model(self, model_type: str, model_name: str) -> JSONRespon
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content=None)
+ async def add_model(self, request: Request) -> JSONResponse:
+ try:
+ # Parse request
+ raw_json = await request.json()
+ logger.info(f"[DEBUG] add_model API received raw JSON: {json.dumps(raw_json, indent=2)}")
+
+ body = AddModelRequest.parse_obj(raw_json)
+ model_type = body.model_type
+ model_json = body.model_json
+
+ logger.info(f"[DEBUG] Parsed request - model_type: {model_type}")
+ logger.info(f"[DEBUG] Parsed request - model_json keys: {list(model_json.keys())}")
+ logger.info(f"[DEBUG] model_name from JSON: {model_json.get('model_name', 'NOT_FOUND')}")
+
+ # Call supervisor
+ supervisor_ref = await self._get_supervisor_ref()
+ logger.info(f"[DEBUG] Got supervisor ref: {supervisor_ref}")
+
+ await supervisor_ref.add_model(model_type, model_json)
+
+ logger.info(f"[DEBUG] Supervisor add_model completed successfully")
+
+ except ValueError as re:
+ logger.error(f"[DEBUG] ValueError in add_model API: {re}", exc_info=True)
+ raise HTTPException(status_code=400, detail=str(re))
+ except Exception as e:
+ logger.error(f"[DEBUG] Unexpected error in add_model API: {e}", exc_info=True)
+ raise HTTPException(status_code=500, detail=str(e))
+
+ return JSONResponse(content={"message": f"Model added successfully for type: {model_type}"})
+
async def list_model_registrations(
self, model_type: str, detailed: bool = Query(False)
) -> JSONResponse:
diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py
index 1ed96cd703..ade4830035 100644
--- a/xinference/core/supervisor.py
+++ b/xinference/core/supervisor.py
@@ -14,6 +14,7 @@
import asyncio
import itertools
+import json
import os
import signal
import time
@@ -932,6 +933,278 @@ async def register_model(
else:
raise ValueError(f"Unsupported model type: {model_type}")
+ @log_async(logger=logger)
+ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
+ """
+ Add a new model by parsing the provided JSON and registering it.
+
+ Args:
+ model_type: Type of model (LLM, embedding, image, etc.)
+ model_json: JSON configuration for the model
+ """
+ logger.info(f"[DEBUG] Supervisor add_model called with model_type: {model_type}")
+ logger.info(f"[DEBUG] Supervisor add_model received JSON with keys: {list(model_json.keys())}")
+ logger.info(f"[DEBUG] JSON content: {json.dumps(model_json, indent=2)}")
+
+ # Validate model type
+ supported_types = list(self._custom_register_type_to_cls.keys())
+ logger.info(f"[DEBUG] Supported model types: {supported_types}")
+
+ if model_type not in self._custom_register_type_to_cls:
+ logger.error(f"[DEBUG] Unsupported model type: {model_type}")
+ raise ValueError(
+ f"Unsupported model type '{model_type}'. "
+ f"Supported types are: {', '.join(supported_types)}"
+ )
+
+ logger.info(f"[DEBUG] Model type validation passed for: {model_type}")
+
+ # Get the appropriate model class and register function
+ (
+ model_spec_cls,
+ register_fn,
+ unregister_fn,
+ generate_fn,
+ ) = self._custom_register_type_to_cls[model_type]
+
+ logger.info(f"[DEBUG] Got model spec class: {model_spec_cls}")
+ logger.info(f"[DEBUG] Got register function: {register_fn}")
+ logger.info(f"[DEBUG] Got unregister function: {unregister_fn}")
+ logger.info(f"[DEBUG] Got generate function: {generate_fn}")
+
+ # Validate required fields
+ required_fields = ["model_name", "model_specs"]
+ logger.info(f"[DEBUG] Checking required fields: {required_fields}")
+
+ for field in required_fields:
+ if field not in model_json:
+ logger.error(f"[DEBUG] Missing required field: {field}")
+ raise ValueError(f"Missing required field: {field}")
+ logger.info(f"[DEBUG] Field {field} found: {type(model_json[field])}")
+
+ # Validate model name format
+ from ..model.utils import is_valid_model_name
+ model_name = model_json["model_name"]
+ logger.info(f"[DEBUG] Validating model name: {model_name}")
+
+ if not is_valid_model_name(model_name):
+ logger.error(f"[DEBUG] Invalid model name format: {model_name}")
+ raise ValueError(f"Invalid model name format: {model_name}")
+
+ logger.info(f"[DEBUG] Model name validation passed")
+
+ # Convert model hub JSON format to Xinference expected format
+ try:
+ logger.info(f"[DEBUG] Converting model JSON format if needed...")
+ converted_model_json = self._convert_model_json_format(model_json)
+ logger.info(f"[DEBUG] JSON conversion completed successfully")
+ except Exception as e:
+ logger.error(f"[DEBUG] JSON conversion failed: {e}", exc_info=True)
+ raise ValueError(f"Failed to convert model JSON format: {str(e)}")
+
+ # Parse the JSON into the appropriate model spec
+ try:
+ logger.info(f"[DEBUG] Attempting to parse converted JSON with {model_spec_cls}")
+ model_spec = model_spec_cls.parse_obj(converted_model_json)
+ logger.info(f"[DEBUG] JSON parsing successful, model_spec created: {model_spec}")
+ except Exception as e:
+ logger.error(f"[DEBUG] JSON parsing failed: {e}", exc_info=True)
+ raise ValueError(f"Invalid model JSON format: {str(e)}")
+
+ # Check if model already exists
+ try:
+ logger.info(f"[DEBUG] Checking if model '{model_spec.model_name}' already exists")
+ existing_model = await self.get_model_registration(
+ model_type, model_spec.model_name
+ )
+ logger.info(f"[DEBUG] Existing model check result: {existing_model}")
+
+ if existing_model is not None:
+ logger.error(f"[DEBUG] Model already exists: {model_spec.model_name}")
+ raise ValueError(
+ f"Model '{model_spec.model_name}' already exists for type '{model_type}'. "
+ f"Please choose a different model name or remove the existing model first."
+ )
+
+ logger.info(f"[DEBUG] Model does not exist, can proceed with registration")
+
+ except ValueError as e:
+ if "not found" in str(e):
+ # Model doesn't exist, we can proceed
+ logger.info(f"[DEBUG] Model not found (expected): {e}")
+ pass
+ else:
+ # Re-raise validation errors
+ logger.error(f"[DEBUG] ValueError during model existence check: {e}")
+ raise e
+ except Exception as ex:
+ logger.error(f"[DEBUG] Unexpected error checking model registration for '{model_spec.model_name}': {ex}", exc_info=True)
+ raise ValueError(f"Failed to validate model registration: {str(ex)}")
+
+ # Register the model (persist=True for adding models)
+ try:
+ logger.info(f"[DEBUG] Starting model registration process")
+ logger.info(f"[DEBUG] Calling register_fn with persist=True")
+
+ register_fn(model_spec, persist=True)
+ logger.info(f"[DEBUG] register_fn completed successfully")
+
+ # Record model version
+ logger.info(f"[DEBUG] Generating version info")
+ version_info = generate_fn(model_spec)
+ logger.info(f"[DEBUG] Version info generated: {version_info}")
+
+ logger.info(f"[DEBUG] Recording model version to cache tracker")
+ await self._cache_tracker_ref.record_model_version(
+ version_info, self.address
+ )
+ logger.info(f"[DEBUG] Model version recorded successfully")
+
+ # Sync to workers if not local deployment
+ is_local = self.is_local_deployment()
+ logger.info(f"[DEBUG] Is local deployment: {is_local}")
+
+ if not is_local:
+ logger.info(f"[DEBUG] Syncing model to workers")
+ await self._sync_register_model(
+ model_type, converted_model_json, True, model_spec.model_name
+ )
+ logger.info(f"[DEBUG] Model synced to workers successfully")
+
+ logger.info(f"Successfully added model '{model_spec.model_name}' (type: {model_type})")
+ logger.info(f"[DEBUG] add_model process completed successfully")
+
+ except ValueError as e:
+ # Validation errors - don't need cleanup as model wasn't registered
+ logger.error(f"[DEBUG] Validation error during model registration: {e}")
+ raise e
+ except Exception as e:
+ # Unexpected errors - attempt cleanup
+ logger.error(f"[DEBUG] Unexpected error during model registration: {e}", exc_info=True)
+ try:
+ logger.info(f"[DEBUG] Attempting cleanup of failed registration")
+ unregister_fn(model_spec.model_name, raise_error=False)
+ logger.info(f"[DEBUG] Cleanup completed successfully")
+ except Exception as cleanup_error:
+ logger.warning(f"[DEBUG] Cleanup failed: {cleanup_error}")
+ raise ValueError(f"Failed to register model '{model_spec.model_name}': {str(e)}")
+
+ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Convert model hub JSON format to Xinference expected format.
+
+ The input format uses nested 'model_src' structure, but Xinference expects
+ flattened fields at the spec level.
+ """
+ logger.info(f"[DEBUG] Starting JSON format conversion")
+
+ # Check if conversion is needed (detect model_src structure)
+ needs_conversion = False
+ for spec in model_json["model_specs"]:
+ if "model_src" in spec:
+ needs_conversion = True
+ break
+
+ if not needs_conversion:
+ logger.info(f"[DEBUG] No conversion needed, JSON is already in expected format")
+ return model_json
+
+ logger.info(f"[DEBUG] Converting model_src nested structure to flattened format")
+
+ converted = model_json.copy()
+ converted_specs = []
+
+ for spec in model_json["model_specs"]:
+ model_format = spec["model_format"]
+ model_size = spec["model_size_in_billions"]
+
+ logger.info(f"[DEBUG] Processing spec: {model_format} - {model_size}B")
+
+ if "model_src" not in spec:
+ # No model_src, keep spec as is but ensure required fields
+ converted_spec = spec.copy()
+ if "quantization" not in converted_spec:
+ converted_spec["quantization"] = "none" # Default
+ converted_specs.append(converted_spec)
+ continue
+
+ model_src = spec["model_src"]
+
+ # Handle different model sources
+ if "huggingface" in model_src:
+ hf_info = model_src["huggingface"]
+ quantizations = hf_info.get("quantizations", ["none"])
+
+ logger.info(f"[DEBUG] Found {len(quantizations)} quantizations for {model_format}")
+
+ # Create separate specs for each quantization
+ for quant in quantizations:
+ converted_spec = {
+ "model_format": model_format,
+ "model_size_in_billions": model_size,
+ "quantization": quant,
+ "model_hub": "huggingface",
+ }
+
+ # Add common fields
+ if "model_id" in hf_info:
+ converted_spec["model_id"] = hf_info["model_id"]
+ if "model_revision" in hf_info:
+ converted_spec["model_revision"] = hf_info["model_revision"]
+
+ # Format-specific fields
+ if model_format == "ggufv2":
+ if "model_id" in hf_info:
+ converted_spec["model_id"] = hf_info["model_id"]
+ if "model_file_name_template" in hf_info:
+ converted_spec["model_file_name_template"] = hf_info["model_file_name_template"]
+ else:
+ # Default template
+ model_name = model_json["model_name"]
+ converted_spec["model_file_name_template"] = f"{model_name}-{{quantization}}.gguf"
+ elif model_format in ["pytorch", "mlx"]:
+ if "model_id" in hf_info:
+ converted_spec["model_id"] = hf_info["model_id"]
+ if "model_revision" in hf_info:
+ converted_spec["model_revision"] = hf_info["model_revision"]
+
+ converted_specs.append(converted_spec)
+ logger.debug(f"[DEBUG] Created spec: {model_format} - {quant}")
+
+ elif "modelscope" in model_src:
+ # Handle ModelScope similarly
+ ms_info = model_src["modelscope"]
+ quantizations = ms_info.get("quantizations", ["none"])
+
+ for quant in quantizations:
+ converted_spec = {
+ "model_format": model_format,
+ "model_size_in_billions": model_size,
+ "quantization": quant,
+ "model_hub": "modelscope",
+ }
+
+ if "model_id" in ms_info:
+ converted_spec["model_id"] = ms_info["model_id"]
+ if "model_revision" in ms_info:
+ converted_spec["model_revision"] = ms_info["model_revision"]
+
+ converted_specs.append(converted_spec)
+
+ else:
+ # Unknown model source, skip or handle as error
+ logger.warning(f"[DEBUG] Unknown model source in spec: {list(model_src.keys())}")
+ # Keep original spec but add required fields
+ converted_spec = spec.copy()
+ if "quantization" not in converted_spec:
+ converted_spec["quantization"] = "none"
+ converted_specs.append(converted_spec)
+
+ converted["model_specs"] = converted_specs
+ logger.info(f"[DEBUG] Conversion completed: {len(model_json['model_specs'])} -> {len(converted_specs)} specs")
+
+ return converted
+
async def _sync_register_model(
self, model_type: str, model: str, persist: bool, model_name: str
):
From 3cc4aa1a67b59fe5c84406d99fd54cb15a0872aa Mon Sep 17 00:00:00 2001
From: OliverBryant <2713999266@qq.com>
Date: Wed, 22 Oct 2025 15:52:13 +0800
Subject: [PATCH 06/25] =?UTF-8?q?FEAT=EF=BC=9Aadd=20model=20backend?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
xinference/api/restful_api.py | 20 +++++++---
xinference/core/supervisor.py | 71 ++++++++++++++++++++++++++---------
2 files changed, 69 insertions(+), 22 deletions(-)
diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py
index 8dea5ab6c8..57a623ed5d 100644
--- a/xinference/api/restful_api.py
+++ b/xinference/api/restful_api.py
@@ -3142,15 +3142,21 @@ async def add_model(self, request: Request) -> JSONResponse:
try:
# Parse request
raw_json = await request.json()
- logger.info(f"[DEBUG] add_model API received raw JSON: {json.dumps(raw_json, indent=2)}")
+ logger.info(
+ f"[DEBUG] add_model API received raw JSON: {json.dumps(raw_json, indent=2)}"
+ )
body = AddModelRequest.parse_obj(raw_json)
model_type = body.model_type
model_json = body.model_json
logger.info(f"[DEBUG] Parsed request - model_type: {model_type}")
- logger.info(f"[DEBUG] Parsed request - model_json keys: {list(model_json.keys())}")
- logger.info(f"[DEBUG] model_name from JSON: {model_json.get('model_name', 'NOT_FOUND')}")
+ logger.info(
+ f"[DEBUG] Parsed request - model_json keys: {list(model_json.keys())}"
+ )
+ logger.info(
+ f"[DEBUG] model_name from JSON: {model_json.get('model_name', 'NOT_FOUND')}"
+ )
# Call supervisor
supervisor_ref = await self._get_supervisor_ref()
@@ -3164,10 +3170,14 @@ async def add_model(self, request: Request) -> JSONResponse:
logger.error(f"[DEBUG] ValueError in add_model API: {re}", exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
- logger.error(f"[DEBUG] Unexpected error in add_model API: {e}", exc_info=True)
+ logger.error(
+ f"[DEBUG] Unexpected error in add_model API: {e}", exc_info=True
+ )
raise HTTPException(status_code=500, detail=str(e))
- return JSONResponse(content={"message": f"Model added successfully for type: {model_type}"})
+ return JSONResponse(
+ content={"message": f"Model added successfully for type: {model_type}"}
+ )
async def list_model_registrations(
self, model_type: str, detailed: bool = Query(False)
diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py
index ade4830035..33b728653c 100644
--- a/xinference/core/supervisor.py
+++ b/xinference/core/supervisor.py
@@ -942,8 +942,12 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
model_type: Type of model (LLM, embedding, image, etc.)
model_json: JSON configuration for the model
"""
- logger.info(f"[DEBUG] Supervisor add_model called with model_type: {model_type}")
- logger.info(f"[DEBUG] Supervisor add_model received JSON with keys: {list(model_json.keys())}")
+ logger.info(
+ f"[DEBUG] Supervisor add_model called with model_type: {model_type}"
+ )
+ logger.info(
+ f"[DEBUG] Supervisor add_model received JSON with keys: {list(model_json.keys())}"
+ )
logger.info(f"[DEBUG] JSON content: {json.dumps(model_json, indent=2)}")
# Validate model type
@@ -984,6 +988,7 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
# Validate model name format
from ..model.utils import is_valid_model_name
+
model_name = model_json["model_name"]
logger.info(f"[DEBUG] Validating model name: {model_name}")
@@ -1004,16 +1009,22 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
# Parse the JSON into the appropriate model spec
try:
- logger.info(f"[DEBUG] Attempting to parse converted JSON with {model_spec_cls}")
+ logger.info(
+ f"[DEBUG] Attempting to parse converted JSON with {model_spec_cls}"
+ )
model_spec = model_spec_cls.parse_obj(converted_model_json)
- logger.info(f"[DEBUG] JSON parsing successful, model_spec created: {model_spec}")
+ logger.info(
+ f"[DEBUG] JSON parsing successful, model_spec created: {model_spec}"
+ )
except Exception as e:
logger.error(f"[DEBUG] JSON parsing failed: {e}", exc_info=True)
raise ValueError(f"Invalid model JSON format: {str(e)}")
# Check if model already exists
try:
- logger.info(f"[DEBUG] Checking if model '{model_spec.model_name}' already exists")
+ logger.info(
+ f"[DEBUG] Checking if model '{model_spec.model_name}' already exists"
+ )
existing_model = await self.get_model_registration(
model_type, model_spec.model_name
)
@@ -1038,7 +1049,10 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
logger.error(f"[DEBUG] ValueError during model existence check: {e}")
raise e
except Exception as ex:
- logger.error(f"[DEBUG] Unexpected error checking model registration for '{model_spec.model_name}': {ex}", exc_info=True)
+ logger.error(
+ f"[DEBUG] Unexpected error checking model registration for '{model_spec.model_name}': {ex}",
+ exc_info=True,
+ )
raise ValueError(f"Failed to validate model registration: {str(ex)}")
# Register the model (persist=True for adding models)
@@ -1066,12 +1080,16 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
if not is_local:
logger.info(f"[DEBUG] Syncing model to workers")
+ # Convert back to JSON string for sync compatibility
+ model_json_str = json.dumps(converted_model_json)
await self._sync_register_model(
- model_type, converted_model_json, True, model_spec.model_name
+ model_type, model_json_str, True, model_spec.model_name
)
logger.info(f"[DEBUG] Model synced to workers successfully")
- logger.info(f"Successfully added model '{model_spec.model_name}' (type: {model_type})")
+ logger.info(
+ f"Successfully added model '{model_spec.model_name}' (type: {model_type})"
+ )
logger.info(f"[DEBUG] add_model process completed successfully")
except ValueError as e:
@@ -1080,14 +1098,19 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
raise e
except Exception as e:
# Unexpected errors - attempt cleanup
- logger.error(f"[DEBUG] Unexpected error during model registration: {e}", exc_info=True)
+ logger.error(
+ f"[DEBUG] Unexpected error during model registration: {e}",
+ exc_info=True,
+ )
try:
logger.info(f"[DEBUG] Attempting cleanup of failed registration")
unregister_fn(model_spec.model_name, raise_error=False)
logger.info(f"[DEBUG] Cleanup completed successfully")
except Exception as cleanup_error:
logger.warning(f"[DEBUG] Cleanup failed: {cleanup_error}")
- raise ValueError(f"Failed to register model '{model_spec.model_name}': {str(e)}")
+ raise ValueError(
+ f"Failed to register model '{model_spec.model_name}': {str(e)}"
+ )
def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, Any]:
"""
@@ -1106,10 +1129,14 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
break
if not needs_conversion:
- logger.info(f"[DEBUG] No conversion needed, JSON is already in expected format")
+ logger.info(
+ f"[DEBUG] No conversion needed, JSON is already in expected format"
+ )
return model_json
- logger.info(f"[DEBUG] Converting model_src nested structure to flattened format")
+ logger.info(
+ f"[DEBUG] Converting model_src nested structure to flattened format"
+ )
converted = model_json.copy()
converted_specs = []
@@ -1135,7 +1162,9 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
hf_info = model_src["huggingface"]
quantizations = hf_info.get("quantizations", ["none"])
- logger.info(f"[DEBUG] Found {len(quantizations)} quantizations for {model_format}")
+ logger.info(
+ f"[DEBUG] Found {len(quantizations)} quantizations for {model_format}"
+ )
# Create separate specs for each quantization
for quant in quantizations:
@@ -1157,11 +1186,15 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
if "model_id" in hf_info:
converted_spec["model_id"] = hf_info["model_id"]
if "model_file_name_template" in hf_info:
- converted_spec["model_file_name_template"] = hf_info["model_file_name_template"]
+ converted_spec["model_file_name_template"] = hf_info[
+ "model_file_name_template"
+ ]
else:
# Default template
model_name = model_json["model_name"]
- converted_spec["model_file_name_template"] = f"{model_name}-{{quantization}}.gguf"
+ converted_spec["model_file_name_template"] = (
+ f"{model_name}-{{quantization}}.gguf"
+ )
elif model_format in ["pytorch", "mlx"]:
if "model_id" in hf_info:
converted_spec["model_id"] = hf_info["model_id"]
@@ -1193,7 +1226,9 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
else:
# Unknown model source, skip or handle as error
- logger.warning(f"[DEBUG] Unknown model source in spec: {list(model_src.keys())}")
+ logger.warning(
+ f"[DEBUG] Unknown model source in spec: {list(model_src.keys())}"
+ )
# Keep original spec but add required fields
converted_spec = spec.copy()
if "quantization" not in converted_spec:
@@ -1201,7 +1236,9 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
converted_specs.append(converted_spec)
converted["model_specs"] = converted_specs
- logger.info(f"[DEBUG] Conversion completed: {len(model_json['model_specs'])} -> {len(converted_specs)} specs")
+ logger.info(
+ f"[DEBUG] Conversion completed: {len(model_json['model_specs'])} -> {len(converted_specs)} specs"
+ )
return converted
From 79ad0d02e54136e9aac898f22056878607296d12 Mon Sep 17 00:00:00 2001
From: OliverBryant <2713999266@qq.com>
Date: Wed, 22 Oct 2025 18:23:06 +0800
Subject: [PATCH 07/25] remove model_specs verify
---
xinference/api/restful_api.py | 21 +------
xinference/core/supervisor.py | 111 ++++++----------------------------
2 files changed, 21 insertions(+), 111 deletions(-)
diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py
index 57a623ed5d..97534aaaa9 100644
--- a/xinference/api/restful_api.py
+++ b/xinference/api/restful_api.py
@@ -3142,37 +3142,20 @@ async def add_model(self, request: Request) -> JSONResponse:
try:
# Parse request
raw_json = await request.json()
- logger.info(
- f"[DEBUG] add_model API received raw JSON: {json.dumps(raw_json, indent=2)}"
- )
body = AddModelRequest.parse_obj(raw_json)
model_type = body.model_type
model_json = body.model_json
- logger.info(f"[DEBUG] Parsed request - model_type: {model_type}")
- logger.info(
- f"[DEBUG] Parsed request - model_json keys: {list(model_json.keys())}"
- )
- logger.info(
- f"[DEBUG] model_name from JSON: {model_json.get('model_name', 'NOT_FOUND')}"
- )
-
# Call supervisor
supervisor_ref = await self._get_supervisor_ref()
- logger.info(f"[DEBUG] Got supervisor ref: {supervisor_ref}")
-
await supervisor_ref.add_model(model_type, model_json)
- logger.info(f"[DEBUG] Supervisor add_model completed successfully")
-
except ValueError as re:
- logger.error(f"[DEBUG] ValueError in add_model API: {re}", exc_info=True)
+ logger.error(f"ValueError in add_model API: {re}", exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
- logger.error(
- f"[DEBUG] Unexpected error in add_model API: {e}", exc_info=True
- )
+ logger.error(f"Unexpected error in add_model API: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(
diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py
index 33b728653c..098c29a385 100644
--- a/xinference/core/supervisor.py
+++ b/xinference/core/supervisor.py
@@ -942,27 +942,15 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
model_type: Type of model (LLM, embedding, image, etc.)
model_json: JSON configuration for the model
"""
- logger.info(
- f"[DEBUG] Supervisor add_model called with model_type: {model_type}"
- )
- logger.info(
- f"[DEBUG] Supervisor add_model received JSON with keys: {list(model_json.keys())}"
- )
- logger.info(f"[DEBUG] JSON content: {json.dumps(model_json, indent=2)}")
-
# Validate model type
supported_types = list(self._custom_register_type_to_cls.keys())
- logger.info(f"[DEBUG] Supported model types: {supported_types}")
if model_type not in self._custom_register_type_to_cls:
- logger.error(f"[DEBUG] Unsupported model type: {model_type}")
raise ValueError(
f"Unsupported model type '{model_type}'. "
f"Supported types are: {', '.join(supported_types)}"
)
- logger.info(f"[DEBUG] Model type validation passed for: {model_type}")
-
# Get the appropriate model class and register function
(
model_spec_cls,
@@ -971,143 +959,85 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
generate_fn,
) = self._custom_register_type_to_cls[model_type]
- logger.info(f"[DEBUG] Got model spec class: {model_spec_cls}")
- logger.info(f"[DEBUG] Got register function: {register_fn}")
- logger.info(f"[DEBUG] Got unregister function: {unregister_fn}")
- logger.info(f"[DEBUG] Got generate function: {generate_fn}")
-
- # Validate required fields
- required_fields = ["model_name", "model_specs"]
- logger.info(f"[DEBUG] Checking required fields: {required_fields}")
-
+ # Validate required fields (only model_name is required)
+ required_fields = ["model_name"]
for field in required_fields:
if field not in model_json:
- logger.error(f"[DEBUG] Missing required field: {field}")
raise ValueError(f"Missing required field: {field}")
- logger.info(f"[DEBUG] Field {field} found: {type(model_json[field])}")
-
# Validate model name format
from ..model.utils import is_valid_model_name
model_name = model_json["model_name"]
- logger.info(f"[DEBUG] Validating model name: {model_name}")
if not is_valid_model_name(model_name):
- logger.error(f"[DEBUG] Invalid model name format: {model_name}")
raise ValueError(f"Invalid model name format: {model_name}")
- logger.info(f"[DEBUG] Model name validation passed")
-
# Convert model hub JSON format to Xinference expected format
try:
- logger.info(f"[DEBUG] Converting model JSON format if needed...")
converted_model_json = self._convert_model_json_format(model_json)
- logger.info(f"[DEBUG] JSON conversion completed successfully")
except Exception as e:
- logger.error(f"[DEBUG] JSON conversion failed: {e}", exc_info=True)
raise ValueError(f"Failed to convert model JSON format: {str(e)}")
# Parse the JSON into the appropriate model spec
try:
- logger.info(
- f"[DEBUG] Attempting to parse converted JSON with {model_spec_cls}"
- )
model_spec = model_spec_cls.parse_obj(converted_model_json)
- logger.info(
- f"[DEBUG] JSON parsing successful, model_spec created: {model_spec}"
- )
except Exception as e:
- logger.error(f"[DEBUG] JSON parsing failed: {e}", exc_info=True)
raise ValueError(f"Invalid model JSON format: {str(e)}")
# Check if model already exists
try:
- logger.info(
- f"[DEBUG] Checking if model '{model_spec.model_name}' already exists"
- )
existing_model = await self.get_model_registration(
model_type, model_spec.model_name
)
- logger.info(f"[DEBUG] Existing model check result: {existing_model}")
if existing_model is not None:
- logger.error(f"[DEBUG] Model already exists: {model_spec.model_name}")
raise ValueError(
f"Model '{model_spec.model_name}' already exists for type '{model_type}'. "
f"Please choose a different model name or remove the existing model first."
)
- logger.info(f"[DEBUG] Model does not exist, can proceed with registration")
-
except ValueError as e:
if "not found" in str(e):
# Model doesn't exist, we can proceed
- logger.info(f"[DEBUG] Model not found (expected): {e}")
pass
else:
# Re-raise validation errors
- logger.error(f"[DEBUG] ValueError during model existence check: {e}")
raise e
except Exception as ex:
- logger.error(
- f"[DEBUG] Unexpected error checking model registration for '{model_spec.model_name}': {ex}",
- exc_info=True,
- )
raise ValueError(f"Failed to validate model registration: {str(ex)}")
# Register the model (persist=True for adding models)
try:
- logger.info(f"[DEBUG] Starting model registration process")
- logger.info(f"[DEBUG] Calling register_fn with persist=True")
-
register_fn(model_spec, persist=True)
- logger.info(f"[DEBUG] register_fn completed successfully")
# Record model version
- logger.info(f"[DEBUG] Generating version info")
version_info = generate_fn(model_spec)
- logger.info(f"[DEBUG] Version info generated: {version_info}")
-
- logger.info(f"[DEBUG] Recording model version to cache tracker")
await self._cache_tracker_ref.record_model_version(
version_info, self.address
)
- logger.info(f"[DEBUG] Model version recorded successfully")
# Sync to workers if not local deployment
is_local = self.is_local_deployment()
- logger.info(f"[DEBUG] Is local deployment: {is_local}")
-
if not is_local:
- logger.info(f"[DEBUG] Syncing model to workers")
# Convert back to JSON string for sync compatibility
model_json_str = json.dumps(converted_model_json)
await self._sync_register_model(
model_type, model_json_str, True, model_spec.model_name
)
- logger.info(f"[DEBUG] Model synced to workers successfully")
logger.info(
f"Successfully added model '{model_spec.model_name}' (type: {model_type})"
)
- logger.info(f"[DEBUG] add_model process completed successfully")
except ValueError as e:
# Validation errors - don't need cleanup as model wasn't registered
- logger.error(f"[DEBUG] Validation error during model registration: {e}")
raise e
except Exception as e:
# Unexpected errors - attempt cleanup
- logger.error(
- f"[DEBUG] Unexpected error during model registration: {e}",
- exc_info=True,
- )
try:
- logger.info(f"[DEBUG] Attempting cleanup of failed registration")
unregister_fn(model_spec.model_name, raise_error=False)
- logger.info(f"[DEBUG] Cleanup completed successfully")
except Exception as cleanup_error:
- logger.warning(f"[DEBUG] Cleanup failed: {cleanup_error}")
+ logger.warning(f"Cleanup failed: {cleanup_error}")
raise ValueError(
f"Failed to register model '{model_spec.model_name}': {str(e)}"
)
@@ -1118,8 +1048,22 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
The input format uses nested 'model_src' structure, but Xinference expects
flattened fields at the spec level.
+
+ Also handles cases where model_specs field is missing by providing a default.
"""
- logger.info(f"[DEBUG] Starting JSON format conversion")
+ # If model_specs is missing, provide a default minimal spec
+ if "model_specs" not in model_json or not model_json["model_specs"]:
+ # Create a minimal default spec
+ return {
+ **model_json,
+ "model_specs": [
+ {
+ "model_format": "pytorch",
+ "model_size_in_billions": None,
+ "quantization": "none",
+ }
+ ],
+ }
# Check if conversion is needed (detect model_src structure)
needs_conversion = False
@@ -1129,15 +1073,8 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
break
if not needs_conversion:
- logger.info(
- f"[DEBUG] No conversion needed, JSON is already in expected format"
- )
return model_json
- logger.info(
- f"[DEBUG] Converting model_src nested structure to flattened format"
- )
-
converted = model_json.copy()
converted_specs = []
@@ -1145,8 +1082,6 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
model_format = spec["model_format"]
model_size = spec["model_size_in_billions"]
- logger.info(f"[DEBUG] Processing spec: {model_format} - {model_size}B")
-
if "model_src" not in spec:
# No model_src, keep spec as is but ensure required fields
converted_spec = spec.copy()
@@ -1162,10 +1097,6 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
hf_info = model_src["huggingface"]
quantizations = hf_info.get("quantizations", ["none"])
- logger.info(
- f"[DEBUG] Found {len(quantizations)} quantizations for {model_format}"
- )
-
# Create separate specs for each quantization
for quant in quantizations:
converted_spec = {
@@ -1202,7 +1133,6 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
converted_spec["model_revision"] = hf_info["model_revision"]
converted_specs.append(converted_spec)
- logger.debug(f"[DEBUG] Created spec: {model_format} - {quant}")
elif "modelscope" in model_src:
# Handle ModelScope similarly
@@ -1227,7 +1157,7 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
else:
# Unknown model source, skip or handle as error
logger.warning(
- f"[DEBUG] Unknown model source in spec: {list(model_src.keys())}"
+ f"Unknown model source in spec: {list(model_src.keys())}"
)
# Keep original spec but add required fields
converted_spec = spec.copy()
@@ -1236,9 +1166,6 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
converted_specs.append(converted_spec)
converted["model_specs"] = converted_specs
- logger.info(
- f"[DEBUG] Conversion completed: {len(model_json['model_specs'])} -> {len(converted_specs)} specs"
- )
return converted
From 424ec5ef85aff4ae889728d793157613f63adfca Mon Sep 17 00:00:00 2001
From: OliverBryant <2713999266@qq.com>
Date: Thu, 23 Oct 2025 09:47:40 +0800
Subject: [PATCH 08/25] model_size_in_billions
---
xinference/core/supervisor.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py
index 098c29a385..6e5a10368f 100644
--- a/xinference/core/supervisor.py
+++ b/xinference/core/supervisor.py
@@ -1079,14 +1079,16 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
converted_specs = []
for spec in model_json["model_specs"]:
- model_format = spec["model_format"]
- model_size = spec["model_size_in_billions"]
+ model_format = spec.get("model_format", "pytorch")
+ model_size = spec.get("model_size_in_billions")
if "model_src" not in spec:
# No model_src, keep spec as is but ensure required fields
converted_spec = spec.copy()
if "quantization" not in converted_spec:
- converted_spec["quantization"] = "none" # Default
+ converted_spec["quantization"] = "none"
+ if "model_format" not in converted_spec:
+ converted_spec["model_format"] = "pytorch"
converted_specs.append(converted_spec)
continue
@@ -1163,6 +1165,8 @@ def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, An
converted_spec = spec.copy()
if "quantization" not in converted_spec:
converted_spec["quantization"] = "none"
+ if "model_format" not in converted_spec:
+ converted_spec["model_format"] = "pytorch"
converted_specs.append(converted_spec)
converted["model_specs"] = converted_specs
From 05a7b06a5f91e0d09f2bbcea5eade74c9a94418a Mon Sep 17 00:00:00 2001
From: yiboyasss <3359595624@qq.com>
Date: Thu, 23 Oct 2025 14:52:05 +0800
Subject: [PATCH 09/25] fix: frontend
---
xinference/ui/web/ui/src/locales/en.json | 27 +-
xinference/ui/web/ui/src/locales/ja.json | 29 +-
xinference/ui/web/ui/src/locales/ko.json | 25 +-
xinference/ui/web/ui/src/locales/zh.json | 25 +-
.../launch_model/components/addModelDialog.js | 300 +++++++-----------
.../web/ui/src/scenes/launch_model/index.js | 63 +++-
6 files changed, 199 insertions(+), 270 deletions(-)
diff --git a/xinference/ui/web/ui/src/locales/en.json b/xinference/ui/web/ui/src/locales/en.json
index b39f1e2939..437fb45a5c 100644
--- a/xinference/ui/web/ui/src/locales/en.json
+++ b/xinference/ui/web/ui/src/locales/en.json
@@ -127,26 +127,19 @@
"mustBeUnique": "{{key}} must be unique",
"addModel": "Add Model",
"addModelDialog": {
- "introPrefix": "To add a model, please use",
- "platformLinkText": "Model Management Platform",
- "introSuffix": " and paste the model's URL",
- "example": "Example: The URL for {{modelName}} on the platform is {{modelUrl}}",
- "urlLabel": "URL"
- },
- "loginDialog": {
- "title": "No permission to download this model. Please log in and try again.",
- "usernameOrEmail": "Username or Email",
- "password": "Password",
- "login": "Login"
+ "introPrefix": "To add a model, please go to the",
+ "platformLinkText": "Xinference Model Hub",
+ "introSuffix": "and fill in the corresponding model name.",
+ "modelName": "Model Name",
+ "modelName.tip": "Please enter the model name",
+ "placeholder": "e.g. qwen3 (case-sensitive)"
},
+ "update": "Update",
"error": {
- "cannotExtractModelId": "Unable to extract model_id from URL. Please check your input.",
- "downloadFailed": "Download failed: {{status}} {{text}}",
+ "name_not_matched": "No exact model name match found (case-sensitive)",
+ "downloadFailed": "Download failed",
"requestFailed": "Request failed",
- "loginFailedText": "Login failed: {{status}} {{text}}",
- "noTokenAfterLogin": "Login succeeded but no token was returned",
- "modelPrivate": "This model is private and requires download permission.",
- "noPermissionAfterLogin": "The logged-in account does not have permission to download this model. Please contact the administrator or use a different account."
+ "json_parse_error": "Failed to parse JSON"
}
},
diff --git a/xinference/ui/web/ui/src/locales/ja.json b/xinference/ui/web/ui/src/locales/ja.json
index 2dd70bc1ab..e4075f9e1d 100644
--- a/xinference/ui/web/ui/src/locales/ja.json
+++ b/xinference/ui/web/ui/src/locales/ja.json
@@ -127,26 +127,19 @@
"mustBeUnique": "{{key}} は一意でなければなりません",
"addModel": "モデルを追加",
"addModelDialog": {
- "introPrefix": "モデルを追加するには",
- "platformLinkText": "モデル管理プラットフォーム",
- "introSuffix": "に基づき、対応するURLを入力してください",
- "example": "例:{{modelName}} のモデル管理プラットフォーム上のURLは {{modelUrl}} です",
- "urlLabel": "URL"
- },
- "loginDialog": {
- "title": "このモデルをダウンロードする権限がありません。ログイン後に再度お試しください",
- "usernameOrEmail": "ユーザー名またはメールアドレス",
- "password": "パスワード",
- "login": "ログイン"
+ "introPrefix": "モデルを追加するには、",
+ "platformLinkText": "Xinference モデルセンター",
+ "introSuffix": "で対応するモデル名を入力してください。",
+ "modelName": "モデル名",
+ "modelName.tip": "モデル名を入力してください",
+ "placeholder": "例:qwen3(大文字と小文字を区別します)"
},
+ "update": "更新",
"error": {
- "cannotExtractModelId": "URLから model_id を抽出できません。入力内容を確認してください",
- "downloadFailed": "ダウンロード失敗: {{status}} {{text}}",
- "requestFailed": "リクエスト失敗",
- "loginFailedText": "ログイン失敗: {{status}} {{text}}",
- "noTokenAfterLogin": "ログインは成功しましたが、トークンを取得できませんでした",
- "modelPrivate": "このモデルは非公開であり、ダウンロード権限が必要です。",
- "noPermissionAfterLogin": "このアカウントにはモデルをダウンロードする権限がありません。管理者に連絡するか、別のアカウントを使用してください。"
+ "name_not_matched": "完全に一致するモデル名が見つかりません(大文字と小文字を区別します)",
+ "downloadFailed": "ダウンロードに失敗しました",
+ "requestFailed": "リクエストに失敗しました",
+ "json_parse_error": "JSON の解析に失敗しました"
}
},
diff --git a/xinference/ui/web/ui/src/locales/ko.json b/xinference/ui/web/ui/src/locales/ko.json
index f6eeb9b51d..36fd0cd0c2 100644
--- a/xinference/ui/web/ui/src/locales/ko.json
+++ b/xinference/ui/web/ui/src/locales/ko.json
@@ -128,25 +128,18 @@
"addModel": "모델 추가",
"addModelDialog": {
"introPrefix": "모델을 추가하려면",
- "platformLinkText": "모델 관리 플랫폼",
- "introSuffix": "을(를) 기반으로 해당 URL을 입력하세요",
- "example": "예: {{modelName}}의 모델 관리 플랫폼 URL은 {{modelUrl}} 입니다",
- "urlLabel": "URL"
- },
- "loginDialog": {
- "title": "이 모델을 다운로드할 권한이 없습니다. 로그인 후 다시 시도하세요",
- "usernameOrEmail": "사용자 이름 또는 이메일",
- "password": "비밀번호",
- "login": "로그인"
+ "platformLinkText": "Xinference 모델 센터",
+ "introSuffix": "에서 해당 모델 이름을 입력하세요.",
+ "modelName": "모델 이름",
+ "modelName.tip": "모델 이름을 입력하세요",
+ "placeholder": "예: qwen3 (대소문자를 구분합니다)"
},
+ "update": "업데이트",
"error": {
- "cannotExtractModelId": "URL에서 model_id를 추출할 수 없습니다. 입력을 확인하세요",
- "downloadFailed": "다운로드 실패: {{status}} {{text}}",
+ "name_not_matched": "완전히 일치하는 모델 이름을 찾을 수 없습니다(대소문자 구분)",
+ "downloadFailed": "다운로드 실패",
"requestFailed": "요청 실패",
- "loginFailedText": "로그인 실패: {{status}} {{text}}",
- "noTokenAfterLogin": "로그인은 성공했지만 토큰을 가져오지 못했습니다",
- "modelPrivate": "이 모델은 비공개이며 다운로드 권한이 필요합니다.",
- "noPermissionAfterLogin": "이 계정에는 해당 모델을 다운로드할 권한이 없습니다. 관리자에게 문의하거나 다른 계정을 사용하세요."
+ "json_parse_error": "JSON 구문 분석에 실패했습니다"
}
},
diff --git a/xinference/ui/web/ui/src/locales/zh.json b/xinference/ui/web/ui/src/locales/zh.json
index 066781855a..3a0a1d7a19 100644
--- a/xinference/ui/web/ui/src/locales/zh.json
+++ b/xinference/ui/web/ui/src/locales/zh.json
@@ -128,25 +128,18 @@
"addModel": "添加模型",
"addModelDialog": {
"introPrefix": "添加模型需基于",
- "platformLinkText": "模型管理平台",
- "introSuffix": ",填写模型对应的 URL",
- "example": "例:{{modelName}}在模型管理平台上对应的 URL 如下 {{modelUrl}}",
- "urlLabel": "URL"
- },
- "loginDialog": {
- "title": "暂无权限下载该模型,登录后重新尝试下载",
- "usernameOrEmail": "用户名或邮箱",
- "password": "密码",
- "login": "登录"
+ "platformLinkText": "Xinference 模型中心",
+ "introSuffix": ",填写模型对应的名称",
+ "modelName": "模型名称",
+ "modelName.tip": "请输入模型名称",
+ "placeholder": "例如:qwen3(需大小写完全匹配)"
},
+ "update": "更新",
"error": {
- "cannotExtractModelId": "无法从 URL 中提取 model_id,请检查输入",
- "downloadFailed": "下载失败: {{status}} {{text}}",
+ "name_not_matched": "未找到完全匹配的模型名称(需大小写一致)",
+ "downloadFailed": "下载失败",
"requestFailed": "请求失败",
- "loginFailedText": "登录失败: {{status}} {{text}}",
- "noTokenAfterLogin": "登录成功但未获取到 token",
- "modelPrivate": "该模型为私有,需要具有下载权限。",
- "noPermissionAfterLogin": "该登录账户暂无权限下载该模型,请联系管理员或更换账户。"
+ "json_parse_error": "JSON 解析失败"
}
},
diff --git a/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js b/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js
index 5791d7e364..af38bcffba 100644
--- a/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js
+++ b/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js
@@ -6,181 +6,134 @@ import {
DialogTitle,
TextField,
} from '@mui/material'
-import React, { useEffect, useRef, useState } from 'react'
+import React, { useContext, useState } from 'react'
import { useTranslation } from 'react-i18next'
+import { ApiContext } from '../../../components/apiContext'
+
const API_BASE_URL = 'https://model.xinference.io'
-const AddModelDialog = ({ open, onClose }) => {
+function AddModelDialog({ open, onClose }) {
const { t } = useTranslation()
- const [url, setUrl] = useState('')
- const [loginOpen, setLoginOpen] = useState(false)
- const [pendingModelId, setPendingModelId] = useState(null)
+ const [modelName, setModelName] = useState('')
const [loading, setLoading] = useState(false)
- const [errorMsg, setErrorMsg] = useState('')
- const loginIframeRef = useRef(null)
-
- const handleClose = (type) => {
- setErrorMsg('')
+ const { endPoint, setErrorMsg } = useContext(ApiContext)
- const actions = {
- add: onClose,
- login: () => setLoginOpen(false),
+ const searchModelByName = async (name) => {
+ try {
+ const url = `${API_BASE_URL}/api/models?order=featured&query=${encodeURIComponent(
+ name
+ )}&page=1&pageSize=5`
+ const res = await fetch(url, { method: 'GET' })
+ const rawText = await res.text().catch(() => '')
+ if (!res.ok) {
+ setErrorMsg(rawText || `HTTP ${res.status}`)
+ return null
+ }
+ try {
+ const data = JSON.parse(rawText)
+ const items = data?.data || []
+ const exact = items.find((it) => it?.model_name === name)
+ if (!exact) {
+ setErrorMsg(t('launchModel.error.name_not_matched'))
+ return null
+ }
+ const id = exact?.id
+ const modelType = exact?.model_type
+ if (!id || !modelType) {
+ setErrorMsg(t('launchModel.error.downloadFailed'))
+ return null
+ }
+ return { id, modelType }
+ } catch {
+ setErrorMsg(rawText || t('launchModel.error.json_parse_error'))
+ return null
+ }
+ } catch (err) {
+ console.error(err)
+ setErrorMsg(err.message || t('launchModel.error.requestFailed'))
+ return null
}
-
- actions[type]?.()
}
- const extractModelId = (input) => {
+ const fetchModelJson = async (modelId) => {
try {
- const u = new URL(input)
- const m1 = u.pathname.match(/\/(\d+)(?:\/?$)/)
- if (m1 && m1[1]) return m1[1]
- const qp = u.searchParams.get('model_id')
- if (qp) return qp
- } catch (e) {
- const m2 = String(input).match(/(\d+)(?:\/?$)/)
- if (m2 && m2[1]) return m2[1]
+ const res = await fetch(
+ `${API_BASE_URL}/api/models/download?model_id=${encodeURIComponent(
+ modelId
+ )}`,
+ { method: 'GET' }
+ )
+ const rawText = await res.text().catch(() => '')
+ if (!res.ok) {
+ setErrorMsg(rawText || `HTTP ${res.status}`)
+ return null
+ }
+ try {
+ const data = JSON.parse(rawText)
+ return data
+ } catch {
+ setErrorMsg(rawText || t('launchModel.error.json_parse_error'))
+ return null
+ }
+ } catch (err) {
+ console.error(err)
+ setErrorMsg(err.message || t('launchModel.error.requestFailed'))
+ return null
}
- return null
}
- // 修改:download 默认从 sessionStorage 读取 token(若传参提供则优先)
- // performDownload:收到 token 后直连接口,获取 JSON
- const performDownload = async (
- modelId,
- tokenFromParam,
- fromLogin = false
- ) => {
- const endpoint = `${API_BASE_URL}/api/models/download?model_id=${encodeURIComponent(
- modelId
- )}`
- const effectiveToken =
- tokenFromParam ||
- sessionStorage.getItem('model_hub_token') ||
- localStorage.getItem('io_login_success')
- const headers = effectiveToken
- ? { Authorization: `Bearer ${effectiveToken}` }
- : {}
- setLoading(true)
- setErrorMsg('')
+ const addToLocal = async (modelType, modelJson) => {
try {
- const res = await fetch(endpoint, {
- method: 'GET',
- headers,
+ const res = await fetch(endPoint + '/v1/models/add', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({ model_type: modelType, model_json: modelJson}),
})
-
- if (res.status === 401) {
- const refreshToken = sessionStorage.getItem('model_hub_refresh_token')
- if (!refreshToken) {
- sessionStorage.removeItem('model_hub_token')
- setPendingModelId(modelId)
- setLoginOpen(true)
- return
- }
- try {
- const refreshRes = await fetch(`${API_BASE_URL}/api/users/refresh`, {
- method: 'POST',
- headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify({ token: refreshToken }),
- })
- if (!refreshRes.ok) {
- throw new Error(`refresh failed: ${refreshRes.status}`)
- }
- const refreshData = await refreshRes.json().catch(() => ({}))
- const newToken = refreshData?.data?.accessToken
- if (newToken) {
- sessionStorage.setItem('model_hub_token', newToken)
- await performDownload(modelId, newToken, false)
- return
- } else {
- sessionStorage.removeItem('model_hub_token')
- setPendingModelId(modelId)
- setLoginOpen(true)
- return
- }
- } catch (e) {
- sessionStorage.removeItem('model_hub_token')
- setPendingModelId(modelId)
- setLoginOpen(true)
- return
- }
+ const rawText = await res.text().catch(() => '')
+ if (!res.ok) {
+ setErrorMsg(rawText || `HTTP ${res.status}`)
+ return
}
-
- if (res.status === 403) {
- let detailMsg = ''
- try {
- const body = await res.json()
- if (body?.error_code === 'MODEL_PRIVATE') {
- detailMsg = t('launchModel.error.modelPrivate')
- } else if (body?.message) {
- detailMsg = body.message
- }
- } catch {
- console.log('')
- }
- if (fromLogin) {
- setErrorMsg(
- detailMsg || t('launchModel.error.noPermissionAfterLogin')
- )
- return
- } else {
- setPendingModelId(modelId)
- setLoginOpen(true)
- return
- }
+ try {
+ const data = JSON.parse(rawText)
+ console.log('本地 /v1/models/add 响应:', data)
+ } catch {
+ console.log('本地 /v1/models/add 原始响应:', rawText)
}
-
- if (!res.ok) {
- const text = await res.text().catch(() => '')
- throw new Error(
- t('launchModel.error.downloadFailed', { status: res.status, text })
- )
+ } catch (error) {
+ console.error('Error:', error)
+ if (error?.response?.status !== 403) {
+ setErrorMsg(error.message)
}
- const data = await res.json()
- console.log('models/download 响应:', data)
- handleClose('add')
- } catch (err) {
- console.error(err)
- setErrorMsg(err.message || t('launchModel.error.requestFailed'))
- } finally {
- setLoading(false)
}
}
const handleFormSubmit = async (e) => {
e.preventDefault()
- const modelId = extractModelId(url?.trim())
- if (!modelId) {
- setErrorMsg(t('launchModel.error.cannotExtractModelId'))
+ const name = modelName?.trim()
+ if (!name) {
+ setErrorMsg(t('launchModel.addModelDialog.modelName.tip'))
return
}
- await performDownload(modelId)
- }
-
- useEffect(() => {
- const listener = (event) => {
- if (event.origin !== API_BASE_URL) return
- const { type, token, refresh_token } = event.data || {}
+ setLoading(true)
+ setErrorMsg('')
+ try {
+ const found = await searchModelByName(name)
+ if (!found) return
+ const { id, modelType } = found
- if (type === 'io_login_success' && token && refresh_token) {
- handleClose('login')
- sessionStorage.setItem('model_hub_token', token)
- sessionStorage.setItem('model_hub_refresh_token', refresh_token)
- if (pendingModelId) {
- void performDownload(pendingModelId, token, true)
- }
- }
- }
+ const modelJson = await fetchModelJson(id)
+ if (!modelJson) return
- window.addEventListener('message', listener)
- return () => {
- window.removeEventListener('message', listener)
+ await addToLocal(modelType, modelJson)
+ } finally {
+ setLoading(false)
}
- }, [pendingModelId])
+ }
return (
-