diff --git a/tests/test_model_workflow.py b/tests/test_model_workflow.py index 367cfdfb..fe07bbd0 100644 --- a/tests/test_model_workflow.py +++ b/tests/test_model_workflow.py @@ -12,11 +12,17 @@ from vllm_mlx.model_workflow import ( CONVERSION_MANIFEST_NAME, MODEL_MANIFEST_NAME, + QUALIFICATION_REQUEST_NAME, + REGISTRATION_MANIFEST_NAME, AcquisitionOptions, ConversionOptions, + QualificationOptions, + RegistrationOptions, acquire_model, convert_model, inspect_model, + qualify_model, + register_model, ) @@ -249,3 +255,179 @@ def test_inspect_gptq_model_is_not_detected_as_mlx(tmp_path): assert payload["mlx"]["looks_like_mlx_artifact"] is False assert payload["mlx"]["needs_conversion"] is True + + +def test_register_model_writes_manifest_from_artifact(tmp_path): + artifact = tmp_path / "mlx-model" + artifact.mkdir() + (artifact / "config.json").write_text( + json.dumps({"model_type": "qwen3", "quantization": {"bits": 4}}) + ) + (artifact / MODEL_MANIFEST_NAME).write_text( + json.dumps({"kind": "vllm-mlx-model-artifact", "model_id": "org/model"}) + ) + + payload = register_model( + RegistrationOptions( + artifact_path=str(artifact), + model_id="qwen-test", + served_model_name="qwen-test-served", + preset_alias="fast-qwen", + mllm=True, + tool_call_parser="qwen3_coder", + reasoning_parser="qwen3", + default_temperature=0.6, + default_top_p=0.95, + default_top_k=20, + default_min_p=0.0, + default_presence_penalty=0.0, + default_repetition_penalty=1.0, + chat_template_kwargs={"enable_thinking": True}, + feature_flags=["prefix_cache"], + ) + ) + + assert payload["kind"] == "vllm-mlx-model-registration" + assert payload["model_id"] == "qwen-test" + assert payload["served_model_name"] == "qwen-test-served" + assert payload["preset_alias"] == "fast-qwen" + assert payload["mllm"] is True + assert payload["production_ready"] is False + assert payload["qualification_required"] is True + assert payload["serving_defaults"]["top_k"] == 20 + assert payload["serving_defaults"]["chat_template_kwargs"] == { + "enable_thinking": True + } + assert payload["parser_policy"]["reasoning_parser"] == "qwen3" + assert payload["source_manifests"]["acquisition"]["payload"]["model_id"] == ( + "org/model" + ) + assert (artifact / REGISTRATION_MANIFEST_NAME).exists() + + +def test_register_model_minimal_defaults(tmp_path): + """register_model with only artifact_path derives model_id from directory name.""" + artifact = tmp_path / "my-cool-model" + artifact.mkdir() + (artifact / "config.json").write_text( + json.dumps({"model_type": "llama", "quantization": {"bits": 4}}) + ) + + payload = register_model(RegistrationOptions(artifact_path=str(artifact))) + + assert payload["model_id"] == "my-cool-model" + assert payload["served_model_name"] == "my-cool-model" + assert payload["preset_alias"] is None + assert payload["mllm"] is None + assert payload["serving_defaults"] == {} + assert payload["parser_policy"] == {} + assert payload["feature_flags"] == [] + assert payload["qualification_required"] is True + assert (artifact / REGISTRATION_MANIFEST_NAME).exists() + + +def test_register_model_requires_local_directory(tmp_path): + missing = tmp_path / "missing" + + try: + register_model(RegistrationOptions(artifact_path=str(missing))) + except FileNotFoundError: + pass + else: + raise AssertionError("expected FileNotFoundError") + + +def test_register_model_rejects_file_as_artifact(tmp_path): + """register_model raises NotADirectoryError for a file path.""" + file_path = tmp_path / "not-a-dir.safetensors" + file_path.write_bytes(b"weights") + + try: + register_model(RegistrationOptions(artifact_path=str(file_path))) + except NotADirectoryError: + pass + else: + raise AssertionError("expected NotADirectoryError") + + +def test_qualify_model_dry_run_records_bench_command(tmp_path): + output = tmp_path / QUALIFICATION_REQUEST_NAME + + payload = qualify_model( + QualificationOptions( + model_id="qwen-test", + server_url="http://127.0.0.1:8090", + workload_path="/tmp/workload.json", + output_path=str(output), + result_path="/tmp/results.json", + repetitions=3, + dry_run=True, + extra_args=["--tag", "nightly"], + ) + ) + + assert payload["status"] == "dry_run" + assert payload["production_ready"] is False + assert "--workload" in payload["command"] + assert "/tmp/workload.json" in payload["command"] + assert "--tag" in payload["command"] + assert output.exists() + + +def test_qualify_model_runs_command_and_records_success(tmp_path): + def fake_run(*args, **kwargs): + return SimpleNamespace(returncode=0, stdout="all passed", stderr="") + + with patch("vllm_mlx.model_workflow.subprocess.run", side_effect=fake_run): + payload = qualify_model( + QualificationOptions( + model_id="qwen-test", + workload_path="/tmp/workload.json", + output_path=str(tmp_path / "result.json"), + ) + ) + + assert payload["status"] == "succeeded" + assert payload["returncode"] == 0 + assert payload["stdout"] == "all passed" + assert "completed_at" in payload + assert (tmp_path / "result.json").exists() + + +def test_qualify_model_runs_command_and_records_failure(tmp_path): + def fake_run(*args, **kwargs): + return SimpleNamespace(returncode=7, stdout="", stderr="bad workload") + + with patch("vllm_mlx.model_workflow.subprocess.run", side_effect=fake_run): + payload = qualify_model( + QualificationOptions( + model_id="qwen-test", + workload_path="/tmp/workload.json", + ) + ) + + assert payload["status"] == "failed" + assert payload["returncode"] == 7 + assert payload["stderr"] == "bad workload" + + +def test_drop_none_preserves_zero_and_false_values(): + """_drop_none must keep 0, 0.0, and False -- only drop None.""" + from vllm_mlx.model_workflow import _drop_none + + result = _drop_none( + { + "temperature": 0.0, + "top_k": 0, + "presence_penalty": 0.0, + "enabled": False, + "missing": None, + } + ) + assert result == { + "temperature": 0.0, + "top_k": 0, + "presence_penalty": 0.0, + "enabled": False, + } + assert "missing" not in result diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 05da7de1..e82299d0 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -329,9 +329,13 @@ def model_command(args): from .model_workflow import ( AcquisitionOptions, ConversionOptions, + QualificationOptions, + RegistrationOptions, acquire_model, convert_model, inspect_model, + qualify_model, + register_model, ) if args.model_command == "inspect": @@ -370,7 +374,43 @@ def model_command(args): if payload.get("status") == "failed": print(json.dumps(payload, indent=2)) sys.exit(payload.get("returncode") or 1) - return # unreachable, but prevents double-print if sys.exit is caught + elif args.model_command == "register": + payload = register_model( + RegistrationOptions( + artifact_path=args.artifact, + model_id=args.model_id, + served_model_name=args.served_model_name, + preset_alias=args.preset_alias, + output_path=args.output, + mllm=args.mllm, + tool_call_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + default_temperature=args.default_temperature, + default_top_p=args.default_top_p, + default_top_k=args.default_top_k, + default_min_p=args.default_min_p, + default_presence_penalty=args.default_presence_penalty, + default_repetition_penalty=args.default_repetition_penalty, + chat_template_kwargs=args.default_chat_template_kwargs, + feature_flags=args.feature_flag, + ) + ) + elif args.model_command == "qualify": + payload = qualify_model( + QualificationOptions( + model_id=args.model_id, + server_url=args.url, + workload_path=args.workload, + output_path=args.output, + result_path=args.result_output, + repetitions=args.repetitions, + dry_run=args.dry_run, + extra_args=args.extra_arg, + ) + ) + if payload.get("status") == "failed": + print(json.dumps(payload, indent=2)) + sys.exit(payload.get("returncode") or 1) else: raise ValueError(f"Unsupported model command: {args.model_command}") @@ -1511,8 +1551,18 @@ def create_parser() -> argparse.ArgumentParser: action="store_true", help="Generate a quantized MLX model", ) - model_convert_parser.add_argument("--q-bits", type=int, default=None) - model_convert_parser.add_argument("--q-group-size", type=int, default=None) + model_convert_parser.add_argument( + "--q-bits", + type=int, + default=None, + help="Quantization bit width (e.g. 3, 4, 8)", + ) + model_convert_parser.add_argument( + "--q-group-size", + type=int, + default=None, + help="Quantization group size (default: mlx-lm default)", + ) model_convert_parser.add_argument( "--q-mode", choices=["affine", "mxfp4", "nvfp4", "mxfp8"], @@ -1541,6 +1591,164 @@ def create_parser() -> argparse.ArgumentParser: help="Print the conversion command and manifest without executing", ) + model_register_parser = model_subparsers.add_parser( + "register", + help="Write a portable registration manifest for a finalized artifact", + ) + model_register_parser.add_argument( + "artifact", + type=str, + help="Finalized local model artifact directory", + ) + model_register_parser.add_argument( + "--model-id", + type=str, + default=None, + help="Override model ID (default: directory name of artifact)", + ) + model_register_parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="Model name exposed by the API (default: model-id)", + ) + model_register_parser.add_argument( + "--preset-alias", + type=str, + default=None, + help="Optional alias for preset lookup in registry", + ) + model_register_parser.add_argument( + "--output", + type=str, + default=None, + help="Manifest path. Defaults to artifact/vllm_mlx_registration_manifest.json", + ) + mllm_group = model_register_parser.add_mutually_exclusive_group() + mllm_group.add_argument( + "--mllm", + action="store_true", + default=None, + help="Mark the artifact as an MLLM serving candidate", + ) + mllm_group.add_argument( + "--no-mllm", + action="store_false", + dest="mllm", + help="Explicitly mark the artifact as text-only", + ) + model_register_parser.add_argument( + "--tool-call-parser", + type=str, + default=None, + help="Tool call parser name for the model (e.g. qwen3_coder, mistral)", + ) + model_register_parser.add_argument( + "--reasoning-parser", + type=str, + default=None, + help="Reasoning parser name for thinking models (e.g. qwen3)", + ) + model_register_parser.add_argument( + "--default-temperature", + type=float, + default=None, + help="Default temperature for all requests", + ) + model_register_parser.add_argument( + "--default-top-p", + type=float, + default=None, + help="Default top_p for all requests", + ) + model_register_parser.add_argument( + "--default-top-k", + type=int, + default=None, + help="Default top_k for all requests", + ) + model_register_parser.add_argument( + "--default-min-p", + type=float, + default=None, + help="Default min_p for all requests", + ) + model_register_parser.add_argument( + "--default-presence-penalty", + type=float, + default=None, + help="Default presence_penalty for all requests", + ) + model_register_parser.add_argument( + "--default-repetition-penalty", + type=float, + default=None, + help="Default repetition_penalty for all requests", + ) + model_register_parser.add_argument( + "--default-chat-template-kwargs", + type=make_json_object_arg_parser("--default-chat-template-kwargs"), + default=None, + help='Default chat template kwargs as JSON, e.g. {"enable_thinking": true}', + ) + model_register_parser.add_argument( + "--feature-flag", + action="append", + default=[], + help="Feature flag to record in the registration manifest. Repeatable.", + ) + + model_qualify_parser = model_subparsers.add_parser( + "qualify", + help="Create or run a bench-serve qualification handoff", + ) + model_qualify_parser.add_argument( + "model_id", + type=str, + help="Model ID to qualify against the running server", + ) + model_qualify_parser.add_argument( + "--url", + type=str, + default="http://127.0.0.1:8080", + help="Running server URL for bench-serve", + ) + model_qualify_parser.add_argument( + "--workload", + type=str, + default=None, + help="bench-serve workload contract path", + ) + model_qualify_parser.add_argument( + "--output", + type=str, + default=None, + help="Qualification request manifest path", + ) + model_qualify_parser.add_argument( + "--result-output", + type=str, + default=None, + help="Result output path passed to bench-serve", + ) + model_qualify_parser.add_argument( + "--repetitions", + type=int, + default=None, + help="Number of repetitions per benchmark sweep configuration", + ) + model_qualify_parser.add_argument( + "--dry-run", + action="store_true", + help="Write or print the qualification command without running it", + ) + model_qualify_parser.add_argument( + "--extra-arg", + action="append", + default=[], + help="Extra argument passed through to bench-serve. Repeatable.", + ) + # Serving benchmark bench_serve_parser = subparsers.add_parser( "bench-serve", help="Benchmark a running vllm-mlx server via HTTP API" diff --git a/vllm_mlx/model_workflow.py b/vllm_mlx/model_workflow.py index 932005ec..83b56d92 100644 --- a/vllm_mlx/model_workflow.py +++ b/vllm_mlx/model_workflow.py @@ -28,6 +28,8 @@ MODEL_MANIFEST_NAME = "vllm_mlx_model_manifest.json" CONVERSION_MANIFEST_NAME = "vllm_mlx_conversion_manifest.json" +REGISTRATION_MANIFEST_NAME = "vllm_mlx_registration_manifest.json" +QUALIFICATION_REQUEST_NAME = "vllm_mlx_qualification_request.json" _MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]*/[A-Za-z0-9][A-Za-z0-9_.-]*$") @@ -60,6 +62,42 @@ class ConversionOptions: dry_run: bool = False +@dataclass(frozen=True) +class RegistrationOptions: + """Options for generating a portable model registration manifest.""" + + artifact_path: str + model_id: str | None = None + served_model_name: str | None = None + preset_alias: str | None = None + output_path: str | None = None + mllm: bool | None = None + tool_call_parser: str | None = None + reasoning_parser: str | None = None + default_temperature: float | None = None + default_top_p: float | None = None + default_top_k: int | None = None + default_min_p: float | None = None + default_presence_penalty: float | None = None + default_repetition_penalty: float | None = None + chat_template_kwargs: dict[str, Any] | None = None + feature_flags: list[str] | None = None + + +@dataclass(frozen=True) +class QualificationOptions: + """Options for creating or running a bench-serve qualification handoff.""" + + model_id: str + server_url: str = "http://127.0.0.1:8080" + workload_path: str | None = None + output_path: str | None = None + result_path: str | None = None + repetitions: int | None = None + dry_run: bool = False + extra_args: list[str] | None = None + + def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() @@ -485,3 +523,139 @@ def convert_model(options: ConversionOptions) -> dict[str, Any]: _write_json(manifest_path, result) result["manifest_path"] = str(manifest_path) return result + + +def _existing_manifests(path: Path) -> dict[str, Any]: + manifests: dict[str, Any] = {} + for name, key in ( + (MODEL_MANIFEST_NAME, "acquisition"), + (CONVERSION_MANIFEST_NAME, "conversion"), + ): + manifest_path = path / name + if manifest_path.exists(): + manifests[key] = { + "path": str(manifest_path), + "payload": _read_json(manifest_path), + } + return manifests + + +def _drop_none(payload: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in payload.items() if value is not None} + + +def register_model(options: RegistrationOptions) -> dict[str, Any]: + """Write a portable registration manifest for a finalized local artifact. + + This deliberately does not mutate a production registry. The manifest is a + handoff artifact that Ops or a deployment tool can apply after qualification. + """ + artifact = Path(options.artifact_path).expanduser() + if not artifact.exists(): + raise FileNotFoundError(f"artifact path does not exist: {artifact}") + if not artifact.is_dir(): + raise NotADirectoryError(f"artifact path must be a directory: {artifact}") + + inspection = inspect_model(str(artifact)) + model_id = options.model_id or artifact.name + serving_defaults = _drop_none( + { + "temperature": options.default_temperature, + "top_p": options.default_top_p, + "top_k": options.default_top_k, + "min_p": options.default_min_p, + "presence_penalty": options.default_presence_penalty, + "repetition_penalty": options.default_repetition_penalty, + "chat_template_kwargs": options.chat_template_kwargs, + } + ) + parser_policy = _drop_none( + { + "tool_call_parser": options.tool_call_parser, + "reasoning_parser": options.reasoning_parser, + } + ) + payload = { + "kind": "vllm-mlx-model-registration", + "schema_version": 1, + "created_at": _now_iso(), + "model_id": model_id, + "served_model_name": options.served_model_name or model_id, + "preset_alias": options.preset_alias, + "artifact_path": str(artifact), + "mllm": options.mllm, + "feature_flags": options.feature_flags or [], + "serving_defaults": serving_defaults, + "parser_policy": parser_policy, + "inspection": inspection, + "source_manifests": _existing_manifests(artifact), + "qualification_required": True, + "production_ready": False, + } + + output = ( + Path(options.output_path).expanduser() + if options.output_path + else artifact / REGISTRATION_MANIFEST_NAME + ) + _write_json(output, payload) + payload["manifest_path"] = str(output) + return payload + + +def _qualification_command(options: QualificationOptions) -> list[str]: + command = [ + sys.executable, + "-m", + "vllm_mlx.cli", + "bench-serve", + "--url", + options.server_url, + "--model", + options.model_id, + "--format", + "json", + ] + if options.workload_path: + command.extend(["--workload", options.workload_path]) + if options.repetitions is not None: + command.extend(["--repetitions", str(options.repetitions)]) + if options.result_path: + command.extend(["--output", options.result_path]) + if options.extra_args: + command.extend(options.extra_args) + return command + + +def qualify_model(options: QualificationOptions) -> dict[str, Any]: + """Create or run a bench-serve qualification handoff.""" + command = _qualification_command(options) + payload = { + "kind": "vllm-mlx-model-qualification", + "schema_version": 1, + "created_at": _now_iso(), + "model_id": options.model_id, + "server_url": options.server_url, + "workload_path": options.workload_path, + "result_path": options.result_path, + "repetitions": options.repetitions, + "dry_run": options.dry_run, + "command": command, + "production_ready": False, + } + + if not options.dry_run: + completed = subprocess.run(command, text=True, capture_output=True, check=False) + payload["returncode"] = completed.returncode + payload["stdout"] = completed.stdout + payload["stderr"] = completed.stderr + payload["completed_at"] = _now_iso() + payload["status"] = "succeeded" if completed.returncode == 0 else "failed" + else: + payload["status"] = "dry_run" + + if options.output_path: + output = Path(options.output_path).expanduser() + _write_json(output, payload) + payload["manifest_path"] = str(output) + return payload