Skip to content
Open
155 changes: 155 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ class RegisterModelRequest(BaseModel):
persist: bool


class AddModelRequest(BaseModel):
model_type: str
model_json: Dict[str, Any]


class UpdateModelRequest(BaseModel):
model_type: str


class BuildGradioInterfaceRequest(BaseModel):
model_type: str
model_name: str
Expand Down Expand Up @@ -900,6 +909,26 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/models/add",
self.add_model,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:add"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/models/update_type",
self.update_model_type,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:add"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/cache/models",
self.list_cached_models,
Expand Down Expand Up @@ -3123,25 +3152,151 @@ async def unregister_model(self, model_type: str, model_name: str) -> JSONRespon
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content=None)

async def add_model(self, request: Request) -> JSONResponse:
try:
# Debug: Log incoming request
logger.info(f"[DEBUG] Add model API called")
logger.info(f"[DEBUG] Request headers: {dict(request.headers)}")

# Parse request
raw_json = await request.json()
logger.info(f"[DEBUG] Raw request JSON: {raw_json}")

body = AddModelRequest.parse_obj(raw_json)
model_type = body.model_type
model_json = body.model_json

logger.info(f"[DEBUG] Parsed model_type: {model_type}")
logger.info(
f"[DEBUG] Parsed model_json keys: {list(model_json.keys()) if isinstance(model_json, dict) else 'Not a dict'}"
)
if isinstance(model_json, dict):
logger.info(f"[DEBUG] Model JSON content: {model_json}")

# Debug: Check supervisor reference
logger.info(f"[DEBUG] Getting supervisor reference...")
supervisor_ref = await self._get_supervisor_ref()
logger.info(f"[DEBUG] Supervisor reference obtained: {supervisor_ref}")

# Call supervisor
logger.info(
f"[DEBUG] Calling supervisor.add_model with model_type: {model_type}"
)
await supervisor_ref.add_model(model_type, model_json)
logger.info(f"[DEBUG] Supervisor.add_model completed successfully")

except ValueError as re:
logger.error(f"[DEBUG] ValueError in add_model API: {re}", exc_info=True)
logger.error(f"[DEBUG] ValueError details: {type(re).__name__}: {re}")
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(
f"[DEBUG] Unexpected error in add_model API: {e}", exc_info=True
)
logger.error(f"[DEBUG] Error details: {type(e).__name__}: {e}")
import traceback

logger.error(f"[DEBUG] Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))

logger.info(
f"[DEBUG] Add model API completed successfully for model_type: {model_type}"
)
return JSONResponse(
content={"message": f"Model added successfully for type: {model_type}"}
)

async def update_model_type(self, request: Request) -> JSONResponse:
try:
# Parse request
raw_json = await request.json()
logger.info(f"[DEBUG] Update model type API called with: {raw_json}")

body = UpdateModelRequest.parse_obj(raw_json)
model_type = body.model_type

logger.info(f"[DEBUG] Parsed model_type for update: {model_type}")

# Get supervisor reference
supervisor_ref = await self._get_supervisor_ref()

# Call supervisor to update model type
logger.info(
f"[DEBUG] Calling supervisor.update_model_type with model_type: {model_type}"
)
await supervisor_ref.update_model_type(model_type)
logger.info(f"[DEBUG] Supervisor.update_model_type completed successfully")

except ValueError as re:
logger.error(
f"[DEBUG] ValueError in update_model_type API: {re}", exc_info=True
)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(
f"[DEBUG] Unexpected error in update_model_type API: {e}", exc_info=True
)
raise HTTPException(status_code=500, detail=str(e))

logger.info(
f"[DEBUG] Update model type API completed successfully for model_type: {model_type}"
)
return JSONResponse(
content={
"message": f"Model configurations updated successfully for type: {model_type}"
}
)

async def list_model_registrations(
self, model_type: str, detailed: bool = Query(False)
) -> JSONResponse:
try:
logger.info(
f"[DEBUG API] list_model_registrations called with model_type: {model_type}, detailed: {detailed}"
)

data = await (await self._get_supervisor_ref()).list_model_registrations(
model_type, detailed=detailed
)

logger.info(f"[DEBUG API] Raw data from supervisor: {len(data)} items")
for i, item in enumerate(data):
logger.info(
f"[DEBUG API] Item {i}: {item.get('model_name', 'Unknown')} (builtin: {item.get('is_builtin', 'Unknown')})"
)

# Remove duplicate model names.
model_names = set()
final_data = []
for item in data:
if item["model_name"] not in model_names:
model_names.add(item["model_name"])
final_data.append(item)

logger.info(f"[DEBUG API] After deduplication: {len(final_data)} items")
builtin_count = sum(
1 for item in final_data if item.get("is_builtin", False)
)
custom_count = sum(
1 for item in final_data if not item.get("is_builtin", False)
)
logger.info(
f"[DEBUG API] Built-in models: {builtin_count}, Custom models: {custom_count}"
)

return JSONResponse(content=final_data)
except ValueError as re:
logger.error(
f"[DEBUG API] ValueError in list_model_registrations: {re}",
exc_info=True,
)
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(
f"[DEBUG API] Unexpected error in list_model_registrations: {e}",
exc_info=True,
)
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

Expand Down
Loading
Loading