diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..47a43e58 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,112 @@ +name: CI + +on: + push: + branches: + - master + - main + pull_request: + branches: + - master + - main + +jobs: + lint: + name: Lint (Ruff) + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install Ruff + run: pip install ruff + + - name: Run Ruff check + run: ruff check . --output-format=github + + - name: Run Ruff format check + run: ruff format --check . + + type-check: + name: Type Check (mypy) + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: pip install -e ".[dev]" + + - name: Run mypy + run: mypy vastai/ + + test: + name: Test (${{ matrix.os }} / Python ${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install dependencies + run: pip install -e ".[dev]" + + - name: Run pytest with coverage + run: pytest --cov=vastai --cov-report=xml + + - name: Upload coverage to Codecov + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' + uses: codecov/codecov-action@v4 + with: + files: ./coverage.xml + fail_ci_if_error: false + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + + smoke-standalone: + name: Smoke Test Standalone (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install minimal dependencies (standalone mode) + run: pip install requests python-dateutil + + - name: Test vast.py --help + run: python vast.py --help + + - name: Test vast.py search offers --help + run: python vast.py search offers --help + + - name: Test vast.py show instances --help + run: python vast.py show instances --help diff --git a/.gitignore b/.gitignore index a3df4740..28390020 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,17 @@ passed_machines.txt failed_machines.txt Pass_testresults.log dist/ +build/ __pycache__/ +*.egg-info/ +*.egg +.eggs/ +*.pyc +*.pyo + +# Test artifacts +.coverage +.pytest_cache/ + +# MkDocs build output +site/ diff --git a/__init__.py b/__init__.py index 426f99d8..f2428881 100644 --- a/__init__.py +++ b/__init__.py @@ -1 +1,4 @@ -from .vastai_sdk import VastAI \ No newline at end of file +try: + from .vastai_sdk import VastAI +except ImportError: + pass diff --git a/benchmarks/docker-compose.yml b/benchmarks/docker-compose.yml new file mode 100644 index 00000000..fce6c817 --- /dev/null +++ b/benchmarks/docker-compose.yml @@ -0,0 +1,31 @@ +version: "3.8" + +services: + prometheus: + image: prom/prometheus:latest + container_name: prometheus + restart: unless-stopped + ports: + - "9090:9090" + volumes: + - ./prometheus/prometheus.yml:/etc/prometheus/prometheus.yml:ro + command: + - "--config.file=/etc/prometheus/prometheus.yml" + + grafana: + image: grafana/grafana:latest + container_name: grafana + restart: unless-stopped + ports: + - "3000:3000" + environment: + - GF_SECURITY_ADMIN_USER=admin + - GF_SECURITY_ADMIN_PASSWORD=admin + - GF_USERS_ALLOW_SIGN_UP=false + depends_on: + - prometheus + volumes: + - grafana-storage:/var/lib/grafana + +volumes: + grafana-storage: diff --git a/benchmarks/llm_prom_metrics.py b/benchmarks/llm_prom_metrics.py new file mode 100644 index 00000000..6ada242a --- /dev/null +++ b/benchmarks/llm_prom_metrics.py @@ -0,0 +1,260 @@ +import asyncio +from vastai import Serverless +import os +import random +import collections +import time +import uuid +from prometheus_client import start_http_server, Gauge, Histogram +from typing import Dict, Any +import nltk + +API_KEY = os.environ.get("VAST_API_KEY") + + +nltk.download("words") +WORD_LIST = nltk.corpus.words.words() +# Generate unique run ID for this session +RUN_ID = time.strftime("vLLM_%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8] + +# Prometheus metrics +REQUEST_STATUS = Gauge( + "vast_request_status_current", + "Current number of requests by status", + ["status", "run_id"] +) +LATENCY = Histogram( + "vast_request_latency_seconds", + "Latency of responses (seconds)", + ["run_id"] +) + +LATENCY_AVG = Gauge( + "vast_request_latency_avg_seconds", + "Rolling average latency (seconds) over recent responses", + ["run_id"] +) + +CURRENT_LOAD = Gauge( + "vast_current_load", + "Current request load (tokens per second or arbitrary units)", + ["run_id"] +) + +WORKER_STATUS = Gauge( + "vast_workers_status_current", + "Current number of workers by status", + ["status", "run_id"] +) + +WORKERS_TOTAL = Gauge( + "vast_workers_total", + "Total number of workers returned by get_endpoint_workers", + ["run_id"] +) + +WORKER_CUR_LOAD_TOTAL = Gauge( + "vast_workers_cur_load_total", + "Sum of current load across workers", + ["run_id"] +) + +WORKER_NEW_LOAD_TOTAL = Gauge( + "vast_workers_new_load_total", + "Sum of new load across workers", + ["run_id"] +) + +WORKER_REQS_WORKING_TOTAL = Gauge( + "vast_workers_requests_working_total", + "Sum of requests_working across workers", + ["run_id"] +) + +# Averages +WORKERS_AVG_CUR_PERF = Gauge("vast_workers_avg_cur_perf", "Average cur_perf", ["run_id"]) +WORKERS_AVG_PERF = Gauge("vast_workers_avg_perf", "Average perf", ["run_id"]) +WORKERS_AVG_MEASURED_PERF = Gauge("vast_workers_avg_measured_perf", "Average measured_perf", ["run_id"]) +WORKERS_AVG_DLPERF = Gauge("vast_workers_avg_dlperf", "Average dlperf", ["run_id"]) +WORKERS_AVG_RELIABILITY = Gauge("vast_workers_avg_reliability", "Average reliability", ["run_id"]) +WORKERS_AVG_CUR_LOAD_ROLLING = Gauge("vast_workers_avg_cur_load_rolling", "Average cur_load_rolling_avg", ["run_id"]) +WORKERS_AVG_DISK_USAGE = Gauge("vast_workers_avg_disk_usage", "Average disk_usage", ["run_id"]) + +# Per-worker gauges (also labeled with worker_id) +WORKER_MEASURED_PERF = Gauge( + "vast_worker_measured_perf", + "Measured perf per worker", + ["worker_id", "run_id"] +) +WORKER_CUR_LOAD = Gauge( + "vast_worker_cur_load", + "Current load per worker", + ["worker_id", "run_id"] +) +WORKER_REQS_WORKING = Gauge( + "vast_worker_requests_working", + "Current requests_working per worker", + ["worker_id", "run_id"] +) +WORKER_STATUS_LIVE = Gauge( + "vast_worker_status_current", + "Status (1 for present) per worker; useful for per-worker dashboards", + ["worker_id", "status", "run_id"] +) + + +latencies = collections.deque(maxlen=50) + +def export_to_prom(agg: dict, run_id: str) -> None: + """ + Push aggregate + per-worker metrics with run_id label. + Call this after aggregate_workers(workers). + """ + + # --- status counts: clear existing for this run_id, then set --- + # (Same internal reset pattern you used.) + for (label_status, label_run_id) in list(WORKER_STATUS._metrics.keys()): + if label_run_id == run_id: + WORKER_STATUS.labels(status=label_status, run_id=run_id).set(0) + + for status, count in agg["status_counts"].items(): + WORKER_STATUS.labels(status=status, run_id=run_id).set(count) + + # --- totals / averages --- + WORKERS_TOTAL.labels(run_id=run_id).set(agg["total_workers"]) + WORKER_CUR_LOAD_TOTAL.labels(run_id=run_id).set(agg["total_cur_load"]) + WORKER_NEW_LOAD_TOTAL.labels(run_id=run_id).set(agg["total_new_load"]) + WORKER_REQS_WORKING_TOTAL.labels(run_id=run_id).set(agg["total_reqs_working"]) + + WORKERS_AVG_CUR_PERF.labels(run_id=run_id).set(agg["avg_cur_perf"]) + WORKERS_AVG_PERF.labels(run_id=run_id).set(agg["avg_perf"]) + WORKERS_AVG_MEASURED_PERF.labels(run_id=run_id).set(agg["avg_measured_perf"]) + WORKERS_AVG_DLPERF.labels(run_id=run_id).set(agg["avg_dlperf"]) + WORKERS_AVG_RELIABILITY.labels(run_id=run_id).set(agg["avg_reliability"]) + WORKERS_AVG_CUR_LOAD_ROLLING.labels(run_id=run_id).set(agg["avg_cur_load_rolling"]) + WORKERS_AVG_DISK_USAGE.labels(run_id=run_id).set(agg["avg_disk_usage"]) + + # --- per-worker: set gauges keyed by worker_id + run_id --- + # First, clear any old per-worker status rows for this run_id + for (wid, status, label_run_id) in list(WORKER_STATUS_LIVE._metrics.keys()): + if label_run_id == run_id: + WORKER_STATUS_LIVE.labels(worker_id=wid, status=status, run_id=run_id).set(0) + + for row in agg["per_worker"]: + wid = str(row["id"]) + WORKER_MEASURED_PERF.labels(worker_id=wid, run_id=run_id).set(row["measured_perf"]) + WORKER_CUR_LOAD.labels(worker_id=wid, run_id=run_id).set(row["cur_load"]) + WORKER_REQS_WORKING.labels(worker_id=wid, run_id=run_id).set(row["reqs_working"]) + # Mark this worker present in its current status (value=1) + WORKER_STATUS_LIVE.labels(worker_id=wid, status=row["status"], run_id=run_id).set(1) + +async def status_reporter(client, endpoint, responses, latencies, window_size=50): + """Continuously updates gauges to reflect live status counts.""" + while True: + status_counts = collections.Counter() + + for r in responses: + try: + status = getattr(r, "status", None) + if callable(status): + status = status() + if status is None: + status = "unknown" + status_counts[status] += 1 + except Exception: + status_counts["error"] += 1 + + # Reset gauges before updating + for label_tuple in list(REQUEST_STATUS._metrics.keys()): + label_status, label_run_id = label_tuple + if label_run_id == RUN_ID: + REQUEST_STATUS.labels(status=label_status, run_id=RUN_ID).set(0) + + # Update gauges to reflect current counts + for status, count in status_counts.items(): + REQUEST_STATUS.labels(status=status, run_id=RUN_ID).set(count) + + # Compute rolling latency average + avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + + # Record it in Prometheus + LATENCY_AVG.labels(run_id=RUN_ID).set(avg_latency) + + #workers = await endpoint.get_workers() + #aggregate = client.aggregate_workers(workers) + #export_to_prom(aggregate) + + # Console display + print("\n=== Live Response Status ===") + for status, count in status_counts.items(): + print(f"{status:>12}: {count}") + print(f"Rolling Avg Latency (last {len(latencies)}): {avg_latency:.2f} s") + print("=============================\n") + + await asyncio.sleep(1) + + +def latency_callback(response): + latency = response.get("latency") + if latency is not None: + latencies.append(latency) + LATENCY.labels(run_id=RUN_ID).observe(latency) + + +def cur_load(start_time, ramp_time, min_load, max_load, backwards=False): + # progress clamped to [0,1] + t = max(0.0, min((time.time() - start_time) / ramp_time, 1.0)) + e = t * t # quadratic ease-in; keep as-is for reverse too + a, b = (max_load, min_load) if backwards else (min_load, max_load) + return a + (b - a) * e + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="big") + + system_prompt = """ + You are Qwen. + You are to only speak in English. + You are to use the token when you are done generating tokens. Do *not* use the token unless you intend to stop responding. + Your task is to return Yes or No to the following question: + Does the document contain the word "establish"? + """ + + + user_prompt = """ + Who is your favorite character from late 2000's Disney/Nickelodeon? + """ + responses = [] + + asyncio.create_task(status_reporter(client, endpoint, responses, latencies)) + + LOAD_PER_REQUEST = 10 + MAX_LOAD = 10000 + MIN_LOAD = 500 + RAMP_TIME = 15 * 60 # seconds + BACKWARDS = False + start_time = time.time() + + while True: + user_prompt = " ".join(random.choices(WORD_LIST, k=int(random.randint(300,1500)))) + payload = { + "input" : { + "model": "Qwen/Qwen2-72B", + "prompt" : f"Random_Hash: {random.randint(1, 10000)}\nSystem: {system_prompt}\n\n{user_prompt}\nAssistant: ", + "max_tokens" : 500, + "temperature" : 0.7, + "stop" : [""] + + } + } + CUR_LOAD = cur_load(start_time, RAMP_TIME, MIN_LOAD, MAX_LOAD, backwards=BACKWARDS) + CUR_LOAD = 1000 # Hardcode a set load + CURRENT_LOAD.labels(run_id=RUN_ID).set(CUR_LOAD) + request = endpoint.request("/v1/completions", payload, cost=LOAD_PER_REQUEST).then(latency_callback) + responses.append(request) + await asyncio.sleep(LOAD_PER_REQUEST / CUR_LOAD) + + +if __name__ == "__main__": + start_http_server(8000) + asyncio.run(main()) diff --git a/benchmarks/prometheus/prometheus.yml b/benchmarks/prometheus/prometheus.yml new file mode 100644 index 00000000..9eb43719 --- /dev/null +++ b/benchmarks/prometheus/prometheus.yml @@ -0,0 +1,8 @@ +global: + scrape_interval: 2s + +scrape_configs: + - job_name: "vast_metrics" + static_configs: + - targets: ["172.17.0.1:8000"] + diff --git a/benchmarks/prometheus_metrics.py b/benchmarks/prometheus_metrics.py new file mode 100644 index 00000000..2b6d072e --- /dev/null +++ b/benchmarks/prometheus_metrics.py @@ -0,0 +1,134 @@ +import asyncio +from vastai import Serverless +import os +import random +import collections +import time +import uuid +from prometheus_client import start_http_server, Gauge, Histogram + +API_KEY = os.environ.get("VAST_API_KEY") + +# Generate unique run ID for this session +RUN_ID = time.strftime("ComfyUI_Image_%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8] + +# Prometheus metrics +REQUEST_STATUS = Gauge( + "vast_request_status_current", + "Current number of requests by status", + ["status", "run_id"] +) +LATENCY = Histogram( + "vast_request_latency_seconds", + "Latency of responses (seconds)", + ["run_id"] +) + +LATENCY_AVG = Gauge( + "vast_request_latency_avg_seconds", + "Rolling average latency (seconds) over recent responses", + ["run_id"] +) + + +latencies = collections.deque(maxlen=50) + + +async def status_reporter(responses, latencies, window_size=50): + """Continuously updates gauges to reflect live status counts.""" + while True: + status_counts = collections.Counter() + + for r in responses: + try: + status = getattr(r, "status", None) + if callable(status): + status = status() + if status is None: + status = "unknown" + status_counts[status] += 1 + except Exception: + status_counts["error"] += 1 + + # Reset gauges before updating + for label_tuple in list(REQUEST_STATUS._metrics.keys()): + label_status, label_run_id = label_tuple + if label_run_id == RUN_ID: + REQUEST_STATUS.labels(status=label_status, run_id=RUN_ID).set(0) + + # Update gauges to reflect current counts + for status, count in status_counts.items(): + REQUEST_STATUS.labels(status=status, run_id=RUN_ID).set(count) + + # Compute rolling latency average + avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + + # Record it in Prometheus + LATENCY_AVG.labels(run_id=RUN_ID).set(avg_latency) + + # Console display + print("\n=== Live Response Status ===") + for status, count in status_counts.items(): + print(f"{status:>12}: {count}") + print(f"Rolling Avg Latency (last {len(latencies)}): {avg_latency:.2f} s") + print("=============================\n") + + await asyncio.sleep(1) + + +def latency_callback(response): + latency = response.get("latency") + if latency is not None: + latencies.append(latency) + LATENCY.labels(run_id=RUN_ID).observe(latency) + + +def cur_load(start_time, ramp_time, min_load, max_load, backwards=False): + # progress clamped to [0,1] + t = max(0.0, min((time.time() - start_time) / ramp_time, 1.0)) + e = t * t # quadratic ease-in; keep as-is for reverse too + a, b = (max_load, min_load) if backwards else (min_load, max_load) + return a + (b - a) * e + +async def main(): + client = Serverless() + endpoint = await client.get_endpoint(name="comfy") + + prompts = [ + "a page from a peanuts comic strip", + ] + + responses = [] + + asyncio.create_task(status_reporter(responses, latencies)) + + LOAD_PER_REQUEST = 100 + MAX_LOAD = 1000 + MIN_LOAD = 50 + RAMP_TIME = 15 * 60 # seconds + BACKWARDS = False + start_time = time.time() + + while True: + payload = { + "input": { + "modifier": "Text2Image", + "modifications": { + "prompt": random.choice(prompts), + "width": 512, + "height": 512, + "steps": 10, + "seed": random.randint(1, 10000) + } + } + } + CUR_LOAD = cur_load(start_time, RAMP_TIME, MIN_LOAD, MAX_LOAD, backwards=BACKWARDS) + CUR_LOAD = 1000 # Hardcode a set load + request = endpoint.request("/generate/sync", payload, cost=LOAD_PER_REQUEST).then(latency_callback) + responses.append(request) + await asyncio.sleep(LOAD_PER_REQUEST / CUR_LOAD) + + +if __name__ == "__main__": + start_http_server(8000) + asyncio.run(main()) diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..ecd3474f --- /dev/null +++ b/conftest.py @@ -0,0 +1 @@ +collect_ignore = ["__init__.py", "vast.py", "vast_pdf.py", "vast_config.py"] diff --git a/examples/client/ace_example.py b/examples/client/ace_example.py new file mode 100644 index 00000000..4f6f5777 --- /dev/null +++ b/examples/client/ace_example.py @@ -0,0 +1,149 @@ +from vastai import Serverless +import asyncio + + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-ace-endpoint") + + # ComfyUI API compatible json workflow for ACE Step + workflow = { + "14": { + "inputs": { + "tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic", + "lyrics": "[verse]\nNeon lights they flicker bright\nCity hums in dead of night\nRhythms pulse through concrete veins\nLost in echoes of refrains\n\n[verse]\nBassline groovin in my chest\nHeartbeats match the citys zest\nElectric whispers fill the air\nSynthesized dreams everywhere\n\n[chorus]\nTurn it up and let it flow\nFeel the fire let it grow\nIn this rhythm we belong\nHear the night sing out our song", + "lyrics_strength": 0.99, + "clip": ["40", 1] + }, + "class_type": "TextEncodeAceStepAudio", + "_meta": { + "title": "TextEncodeAceStepAudio" + } + }, + "17": { + "inputs": { + "seconds": 180, + "batch_size": 1 + }, + "class_type": "EmptyAceStepLatentAudio", + "_meta": { + "title": "EmptyAceStepLatentAudio" + } + }, + "18": { + "inputs": { + "samples": ["52", 0], + "vae": ["40", 2] + }, + "class_type": "VAEDecodeAudio", + "_meta": { + "title": "VAE Decode Audio" + } + }, + "40": { + "inputs": { + "ckpt_name": "ace_step_v1_3.5b.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "44": { + "inputs": { + "conditioning": ["14", 0] + }, + "class_type": "ConditioningZeroOut", + "_meta": { + "title": "ConditioningZeroOut" + } + }, + "49": { + "inputs": { + "model": ["51", 0], + "operation": ["50", 0] + }, + "class_type": "LatentApplyOperationCFG", + "_meta": { + "title": "LatentApplyOperationCFG" + } + }, + "50": { + "inputs": { + "multiplier": 1.15 + }, + "class_type": "LatentOperationTonemapReinhard", + "_meta": { + "title": "LatentOperationTonemapReinhard" + } + }, + "51": { + "inputs": { + "shift": 6, + "model": ["40", 0] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "52": { + "inputs": { + "seed": "__RANDOM_INT__", + "steps": 65, + "cfg": 4, + "sampler_name": "er_sde", + "scheduler": "linear_quadratic", + "denoise": 1, + "model": ["49", 0], + "positive": ["14", 0], + "negative": ["44", 0], + "latent_image": ["17", 0] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "59": { + "inputs": { + "filename_prefix": "audio/ComfyUI", + "quality": "V0", + "audioUI": "", + "audio": ["18", 0] + }, + "class_type": "SaveAudioMP3", + "_meta": { + "title": "Save Audio (MP3)" + } + } + } + + payload = { + "input": { + "request_id": "", + "workflow_json": workflow, + "s3": { + "access_key_id": "", + "secret_access_key": "", + "endpoint_url": "", + "bucket_name": "", + "region": "" + }, + "webhook": { + "url": "", + "extra_params": { + "user_id": "12345", + "project_id": "abc-def" + } + } + } + } + + response = await endpoint.request("/generate/sync", payload) + + # Response contains status, output, and any errors + print(response["response"]) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/client/callback_example.py b/examples/client/callback_example.py new file mode 100644 index 00000000..b9069e2c --- /dev/null +++ b/examples/client/callback_example.py @@ -0,0 +1,32 @@ +import asyncio +from vastai import Serverless, ServerlessRequest + +MAX_TOKENS = 128 + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-vllm-endpoint") + + payload = { + "input" : { + "model": "Qwen/Qwen3-8B", + "prompt" : "Who are you?", + "max_tokens" : MAX_TOKENS, + "temperature" : 0.7 + } + } + + # Create a ServerlessRequest object to attach callbacks before submitting the request + req = ServerlessRequest() + + # Attach a callback to run when the machine finished work on the request + def work_finished_callback(response): + print(f"Request finished. Got response of length {len(response["response"]["choices"][0]["text"])}") + + req.then(work_finished_callback) + + response = await endpoint.request(route="/v1/completions", payload=payload, serverless_request=req, cost=MAX_TOKENS) + print(response) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/client/comfy_example.py b/examples/client/comfy_example.py new file mode 100644 index 00000000..73147a5d --- /dev/null +++ b/examples/client/comfy_example.py @@ -0,0 +1,29 @@ +import asyncio +from vastai import Serverless +import random + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-comfy-endpoint") + + payload = { + "input": { + "modifier": "Text2Image", + "modifications": { + "prompt": "Generate a page from a peanuts comic strip.", + "width": 512, + "height": 512, + "steps": 10, + "seed": random.randint(1, 1000) + } + } + } + + response = await endpoint.request("/generate/sync", payload) + + # Get the file from the path on the local machine using SCP or SFTP + # or configure S3 to upload to cloud storage. + print(response["response"]["output"][0]["local_path"]) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/client/comfy_load_example.py b/examples/client/comfy_load_example.py new file mode 100644 index 00000000..597bbdb8 --- /dev/null +++ b/examples/client/comfy_load_example.py @@ -0,0 +1,38 @@ +import asyncio +from vastai import Serverless, ServerlessRequest +import random + +COST_PER_REQUEST = 100 + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-comfy-endpoint") + + payload = { + "input": { + "modifier": "Text2Image", + "modifications": { + "prompt": "Generate a page from a peanuts comic strip.", + "width": 512, + "height": 512, + "steps": 10, + "seed": random.randint(1, 1000) + } + } + } + + responses = [] + + CUR_LOAD = 300 + while True: + # Create a ServerlessRequest object to attach callbacks before submitting the request + req = ServerlessRequest() + # Attach a callback to run when the machine finished work on the request + def work_finished_callback(response): + print(f"{len([x for x in responses if x.status != "Complete"])} in flight") + req.then(work_finished_callback) + responses.append(endpoint.request(route="/generate/sync", payload=payload, serverless_request=req, cost=COST_PER_REQUEST)) + await asyncio.sleep(COST_PER_REQUEST / CUR_LOAD) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/client/tgi_simple_example.py b/examples/client/tgi_simple_example.py new file mode 100644 index 00000000..45618baf --- /dev/null +++ b/examples/client/tgi_simple_example.py @@ -0,0 +1,26 @@ +import asyncio +from vastai import Serverless + +MAX_TOKENS = 128 + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-tgi-endpoint") + + prompt = "Who are you?" + + payload = { + "inputs": prompt, + "parameters": { + "max_new_tokens": MAX_TOKENS, + "temperature": 0.7, + "return_full_text": False + } + } + + resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS) + + print(resp["response"]["generated_text"]) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/client/tgi_stream_example.py b/examples/client/tgi_stream_example.py new file mode 100644 index 00000000..5021485d --- /dev/null +++ b/examples/client/tgi_stream_example.py @@ -0,0 +1,55 @@ +import asyncio +from vastai import Serverless + +MAX_TOKENS = 1024 + +def build_prompt(system_prompt: str, user_prompt: str) -> str: + return ( + f"<>\n{system_prompt.strip()}\n<>\n\n" + f"User: {user_prompt.strip()}\n" + f"Assistant:" + ) + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-tgi-endpoint") + + system_prompt = ( + "You are Qwen.\n" + "You are to only speak in English.\n" + ) + user_prompt = """ + Critically analyze the extent to which hotdogs are sandwiches. + """ + + prompt = build_prompt(system_prompt, user_prompt) + + payload = { + "inputs": prompt, + "parameters": { + "max_new_tokens": MAX_TOKENS, + "temperature": 0.7, + "do_sample": True, + "return_full_text": False, + } + } + + resp = await endpoint.request( + "/generate_stream", + payload, + cost=MAX_TOKENS, + stream=True, + ) + stream = resp["response"] + + printed_answer = False + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + if not printed_answer: + printed_answer = True + print("Answer:\n", end="", flush=True) + print(tok, end="", flush=True) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/client/vllm_load_example.py b/examples/client/vllm_load_example.py new file mode 100644 index 00000000..4ff103ce --- /dev/null +++ b/examples/client/vllm_load_example.py @@ -0,0 +1,31 @@ +import asyncio +from vastai import Serverless, ServerlessRequest + +MAX_TOKENS = 500 + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-vllm-endpoint") + + payload = { + "model": "Qwen/Qwen3-8B", + "prompt" : "Who are you?", + "max_tokens" : MAX_TOKENS, + "temperature" : 0.7 + } + + responses = [] + + CUR_LOAD = 16000 + while True: + # Create a ServerlessRequest object to attach callbacks before submitting the request + req = ServerlessRequest() + # Attach a callback to run when the machine finished work on the request + def work_finished_callback(response): + print(response["response"]["choices"][0]["text"]) + req.then(work_finished_callback) + responses.append(endpoint.request(route="/v1/completions", payload=payload, serverless_request=req, cost=MAX_TOKENS)) + await asyncio.sleep(MAX_TOKENS / CUR_LOAD) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/client/vllm_simple_example.py b/examples/client/vllm_simple_example.py new file mode 100644 index 00000000..9f527b30 --- /dev/null +++ b/examples/client/vllm_simple_example.py @@ -0,0 +1,21 @@ +import asyncio +from vastai import Serverless + +MAX_TOKENS = 128 + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-vllm-endpoint") + + payload = { + "model": "Qwen/Qwen3-8B", + "prompt" : "Who are you?", + "max_tokens" : MAX_TOKENS, + "temperature" : 0.7 + } + + response = await endpoint.request("/v1/completions", payload, cost=MAX_TOKENS) + print(response["response"]["choices"][0]["text"]) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/client/vllm_stream_convo_example.py b/examples/client/vllm_stream_convo_example.py new file mode 100644 index 00000000..753219d9 --- /dev/null +++ b/examples/client/vllm_stream_convo_example.py @@ -0,0 +1,58 @@ +import asyncio +from vastai import Serverless + +MAX_TOKENS = 1024 + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-vllm-endpoint") + + system_prompt = ( + "You are Qwen.\n" + "You are to only speak in English.\n" + ) + + user_prompt = "What is the integral of 2x^2 from 0 to 5?" + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + payload = { + "model": "Qwen/Qwen3-8B", + "messages": messages, + "stream": True, + "max_tokens": MAX_TOKENS, + "temperature": 0.7, + } + + response = await endpoint.request("/v1/chat/completions", payload, cost=MAX_TOKENS, stream=True) + stream = response["response"] + + printed_reasoning = False + printed_answer = False + + async for chunk in stream: + delta = chunk["choices"][0].get("delta", {}) + + rc = delta.get("reasoning_content", None) + if rc: + if not printed_reasoning: + printed_reasoning = True + print("Reasoning:\n", end="", flush=True) + print(rc, end="", flush=True) + + content = delta.get("content", None) + if content: + if not printed_answer: + printed_answer = True + if printed_reasoning: + print("\n\nAnswer:\n", end="", flush=True) + else: + print("Answer:\n", end="", flush=True) + print(content, end="", flush=True) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/client/vllm_streaming_example.py b/examples/client/vllm_streaming_example.py new file mode 100644 index 00000000..b9f49a0e --- /dev/null +++ b/examples/client/vllm_streaming_example.py @@ -0,0 +1,37 @@ +import asyncio +from vastai import Serverless + +MAX_TOKENS = 1024 + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-vllm-endpoint") + + system_prompt = ( + "You are Qwen, a helpful AI assistant.\n" + "You are to only speak in English.\n" + "Please answer the users response.\n" + "When you are done, use the token.\n" + ) + + + user_prompt = """ + What is the 118th element in the periodic table? + """ + + payload = { + "model": "Qwen/Qwen3-8B", + "prompt" : f"{system_prompt}\n{user_prompt}\n", + "max_tokens" : MAX_TOKENS, + "temperature" : 0.8, + "stop" : [""], + "stream" : True, + } + + response = await endpoint.request("/v1/completions", payload, cost=MAX_TOKENS, stream=True) + stream = response["response"] + async for event in stream: + print(event["choices"][0]["text"], end="", flush=True) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/client/wan_example.py b/examples/client/wan_example.py new file mode 100644 index 00000000..cfb708dc --- /dev/null +++ b/examples/client/wan_example.py @@ -0,0 +1,205 @@ +from vastai import Serverless +import asyncio + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-wan-endpoint") + + # ComfyUI API compatible json workflow for Wan 2.2 T2V + workflow = { + "90": { + "inputs": { + "clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors", + "type": "wan", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + }, + "91": { + "inputs": { + "text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW", + "clip": ["90", 0] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "92": { + "inputs": { + "vae_name": "wan_2.1_vae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "93": { + "inputs": { + "shift": 8.000000000000002, + "model": ["101", 0] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "94": { + "inputs": { + "shift": 8, + "model": ["102", 0] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "95": { + "inputs": { + "add_noise": "disable", + "noise_seed": 0, + "steps": 20, + "cfg": 3.5, + "sampler_name": "euler", + "scheduler": "simple", + "start_at_step": 10, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": ["94", 0], + "positive": ["99", 0], + "negative": ["91", 0], + "latent_image": ["96", 0] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "96": { + "inputs": { + "add_noise": "enable", + "noise_seed": "__RANDOM_INT__", + "steps": 20, + "cfg": 3.5, + "sampler_name": "euler", + "scheduler": "simple", + "start_at_step": 0, + "end_at_step": 10, + "return_with_leftover_noise": "enable", + "model": ["93", 0], + "positive": ["99", 0], + "negative": ["91", 0], + "latent_image": ["104", 0] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "97": { + "inputs": { + "samples": ["95", 0], + "vae": ["92", 0] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "98": { + "inputs": { + "filename_prefix": "video/ComfyUI", + "format": "auto", + "codec": "auto", + "video": ["100", 0] + }, + "class_type": "SaveVideo", + "_meta": { + "title": "Save Video" + } + }, + "99": { + "inputs": { + "text": "Beautiful young European woman with honey blonde hair gracefully turning her head back over shoulder, gentle smile, bright eyes looking at camera. Hair flowing in slow motion as she turns. Soft natural lighting, clean background, cinematic portrait.", + "clip": ["90", 0] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Positive Prompt)" + } + }, + "100": { + "inputs": { + "fps": 16, + "images": ["97", 0] + }, + "class_type": "CreateVideo", + "_meta": { + "title": "Create Video" + } + }, + "101": { + "inputs": { + "unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "102": { + "inputs": { + "unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "104": { + "inputs": { + "width": 640, + "height": 640, + "length": 81, + "batch_size": 1 + }, + "class_type": "EmptyHunyuanLatentVideo", + "_meta": { + "title": "EmptyHunyuanLatentVideo" + } + } + } + + payload = { + "input": { + "request_id": "", + "workflow_json": workflow, + "s3": { + "access_key_id": "", + "secret_access_key": "", + "endpoint_url": "", + "bucket_name": "", + "region": "" + }, + "webhook": { + "url": "", + "extra_params": { + "user_id": "12345", + "project_id": "abc-def" + } + } + } + } + + response = await endpoint.request("/generate/sync", payload) + + # Response contains status, output, and any errors + print(response["response"]) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/server/ace_worker.py b/examples/server/ace_worker.py new file mode 100644 index 00000000..11524a62 --- /dev/null +++ b/examples/server/ace_worker.py @@ -0,0 +1,184 @@ +import random +import sys + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# ComyUI model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18288 +MODEL_LOG_FILE = '/var/log/portal/comfyui.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# ComyUI-specific log messages +MODEL_LOAD_LOG_MSG = [ + "To see the GUI go to: " +] + +MODEL_ERROR_LOG_MSGS = [ + "MetadataIncompleteBuffer", + "Value not in list: ", + "[ERROR] Provisioning Script failed" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Downloading' +] + +benchmark_lyrics = [ + "[verse]\nGuardian cloaked in twilight hue\nShadows melt where he breaks through\nEchoes swirl in mystic flight\nHooded hero owns the night\n\n[verse]\nThrough the chaos shapes arise\nFeral whispers, glowing eyes\nOrcs and creatures side by side\nMarch within the inky tide\n\n[chorus]\nRise above the fear and gloom\nLet your courage fully bloom\nIn the darkness stand your ground\nHear the night proclaim your sound", + "[verse]\nMorning sun on fields of gold\nGentle stories unfold\nEvery breeze a quiet song\nWhere the peaceful hearts belong\n\n[verse]\nLanterns glow at stable doors\nRustling leaves on orchard floors\nSimple joys in every hand\nLife grows soft in fertile land\n\n[chorus]\nLet the day drift slow and free\nRoot your soul where you can be\nIn this haven warm and bright\nFeel the earth breathe pure delight", + "[verse]\nLittle feet on dusty ground\nChasing dreams without a sound\nSoccer ball in morning light\nHopes take wing in youthful flight\n\n[verse]\nChrome reflections paint the day\nSwagger in the steps that play\nCopper tones in shining air\nChildhood gleaming everywhere\n\n[chorus]\nKick the world with boundless cheer\nHold the magic close and near\nIn each moment bold and true\nLet the sky belong to you", + "[verse]\nSunset bleeds across the street\nGilded calm in summer heat\nLow-rise towers rimmed with fire\nDreams ignite as lights climb higher\n\n[verse]\nFootsteps scatter through the haze\nFutures shimmer in the blaze\nEvery window tells a tale\nFloating through a tangerine veil\n\n[chorus]\nLet the neon softly glow\nLet your restless heartbeat slow\nIn this city forged in light\nCarry hope into the night", + "[verse]\nOcean breathes in rolling arcs\nSprays of diamond, glowing sparks\nWaves unfold a perfect line\nNature’s rhythm feels divine\n\n[verse]\nSun above in golden sweep\nPaints the rise of every deep\nShimmer drifting through the blue\nWorld reborn in every view\n\n[chorus]\nLet the tide pull you along\nHear the water’s ancient song\nIn the cresting waves you’ll find\nQuiet peace for heart and mind", + "[verse]\nGlass aglow with swirling light\nFruits and mints in colors bright\nIcy whispers clink and chime\nFlowing forms suspend in time\n\n[verse]\nCreamy spirals drift within\nGentle currents slowly spin\nWarm reflections lingering sweet\nMixing flavors at your feet\n\n[chorus]\nSip the glow and let it rise\nTaste the sunset in disguise\nIn this moment clear and true\nLet the warmth flow into you", + "[verse]\nEngines rumble down the lane\nCopper clouds of steam and rain\nOilpunk dreams in metal shine\nRider drifting down the line\n\n[verse]\nLeather jacket, steady glare\nStories sparking in the air\nMagazine lights frame his face\nKing of roads in timeless grace\n\n[chorus]\nThrottle up beyond the bend\nFeel the force of steel ascend\nRide the night and hold on tight\nClaim the world in streaks of light", + "[verse]\nCut-out shapes in swirling play\nTextures dance in bold array\nCats in denim, grinning wide\nStrut across the patterned tide\n\n[verse]\nPosters hum with neon glow\nSurreal scenes begin to grow\nColors crisp as folded art\nPatchwork beating like a heart\n\n[chorus]\nLet the collage come alive\nWatch the vibrant pieces thrive\nIn this joyful, crafted space\nEvery shape finds its own place", + "[verse]\nTiny world in crystal glass\nAncient tales behind the mass\nVillage lights in winter gleam\nFrozen in a mystic dream\n\n[verse]\nLantern beams in swirling air\nSoft enchantment everywhere\nShadows drift with gentle grace\nMagic sealed within the space\n\n[chorus]\nHold the sphere and you will see\nEchoes of a memory\nIn the glow of fragile light\nLives a realm of pure delight", + "[verse]\nArmor hums with power bright\nChopping sparks in jungle night\nMecha spirits shift and scream\nThrough the ferns like shattered beams\n\n[verse]\nAxes blaze in glowing arcs\nLighting up the shadowed marks\nNature roars in trembling air\nClash of steel and cosmic flare\n\n[chorus]\nRaise the fire, strike the ground\nLet your legend shake the sound\nIn the wild where echoes roam\nForge the fight and carve your home", + "[verse]\nCrowds ignite in vibrant flare\nBeats explode through smoky air\nDJ robes replaced with flame\nPope on decks in holy frame\n\n[verse]\nLeather gleams in blinding light\nTurntables spin with sacred might\nChoirs echo in the bass\nHeaven pulses through the place\n\n[chorus]\nLift the roof and shake the floor\nSacred rhythm evermore\nLet the music take control\nFeel the blessing in your soul", +] + +benchmark_dataset = [ + { + "input": { + "request_id": "", + "workflow_json": { + "14": { + "inputs": { + "tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic", + "lyrics": lyrics, + "lyrics_strength": 0.99, + "clip": ["40", 1] + }, + "class_type": "TextEncodeAceStepAudio", + "_meta": { + "title": "TextEncodeAceStepAudio" + } + }, + "17": { + "inputs": { + "seconds": 180, + "batch_size": 1 + }, + "class_type": "EmptyAceStepLatentAudio", + "_meta": { + "title": "EmptyAceStepLatentAudio" + } + }, + "18": { + "inputs": { + "samples": ["52", 0], + "vae": ["40", 2] + }, + "class_type": "VAEDecodeAudio", + "_meta": { + "title": "VAE Decode Audio" + } + }, + "40": { + "inputs": { + "ckpt_name": "ace_step_v1_3.5b.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "44": { + "inputs": { + "conditioning": ["14", 0] + }, + "class_type": "ConditioningZeroOut", + "_meta": { + "title": "ConditioningZeroOut" + } + }, + "49": { + "inputs": { + "model": ["51", 0], + "operation": ["50", 0] + }, + "class_type": "LatentApplyOperationCFG", + "_meta": { + "title": "LatentApplyOperationCFG" + } + }, + "50": { + "inputs": { + "multiplier": 1.15 + }, + "class_type": "LatentOperationTonemapReinhard", + "_meta": { + "title": "LatentOperationTonemapReinhard" + } + }, + "51": { + "inputs": { + "shift": 6, + "model": ["40", 0] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "52": { + "inputs": { + "seed": "__RANDOM_INT__", + "steps": 65, + "cfg": 4, + "sampler_name": "er_sde", + "scheduler": "linear_quadratic", + "denoise": 1, + "model": ["49", 0], + "positive": ["14", 0], + "negative": ["44", 0], + "latent_image": ["17", 0] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "59": { + "inputs": { + "filename_prefix": "audio/ComfyUI", + "quality": "V0", + "audioUI": "", + "audio": ["18", 0] + }, + "class_type": "SaveAudioMP3", + "_meta": { + "title": "Save Audio (MP3)" + } + } + } + } + } for lyrics in benchmark_lyrics +] + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate/sync", + allow_parallel_requests=False, + max_queue_time=10.0, + benchmark_config=BenchmarkConfig( + dataset=benchmark_dataset, + runs=1 + ), + workload_calculator= lambda _ : 1000.0 + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/examples/server/comfy_worker.py b/examples/server/comfy_worker.py new file mode 100644 index 00000000..ddb7da62 --- /dev/null +++ b/examples/server/comfy_worker.py @@ -0,0 +1,81 @@ +import random +import sys + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# ComyUI model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18288 +MODEL_LOG_FILE = '/var/log/portal/comfyui.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# ComyUI-specific log messages +MODEL_LOAD_LOG_MSG = [ + "To see the GUI go to: " +] + +MODEL_ERROR_LOG_MSGS = [ + "MetadataIncompleteBuffer", + "Value not in list: ", + "[ERROR] Provisioning Script failed" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Downloading' +] + +benchmark_prompts = [ + "Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.", + "Cozy farming-game scene with fine details.", + "2D vector child with soccer ball; airbrush chrome; swagger; antique copper.", + "Realistic futuristic downtown of low buildings at sunset.", + "Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.", + "Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.", + "Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.", + "Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.", + "Medieval village inside glass sphere; volumetric light; macro focus.", + "Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.", + "Pope Francis DJ in leather jacket, mixing on giant console; dramatic.", +] + + + +benchmark_dataset = [ + { + "input": { + "request_id": f"test-{random.randint(1000, 99999)}", + "modifier": "Text2Image", + "modifications": { + "prompt": prompt, + "width": 512, + "height": 512, + "steps": 20, + "seed": random.randint(0, sys.maxsize) + } + } + } for prompt in benchmark_prompts +] + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate/sync", + allow_parallel_requests=False, + max_queue_time=10.0, + benchmark_config=BenchmarkConfig( + dataset=benchmark_dataset, + ) + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/examples/server/tgi_worker.py b/examples/server/tgi_worker.py new file mode 100644 index 00000000..f8084ab2 --- /dev/null +++ b/examples/server/tgi_worker.py @@ -0,0 +1,76 @@ +import nltk +import random + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# TGI model configuration +MODEL_SERVER_URL = 'http://0.0.0.0' +MODEL_SERVER_PORT = 5001 +MODEL_LOG_FILE = "/workspace/infer.log" +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# TGI-specific log messages +MODEL_LOAD_LOG_MSG = [ + '"message":"Connected","target":"text_generation_router"', + '"message":"Connected","target":"text_generation_router::server"', +] + +MODEL_ERROR_LOG_MSGS = [ + "Error: WebserverFailed", + "Error: DownloadError", + "Error: ShardCannotStart", +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Download' +] + +nltk.download("words") +WORD_LIST = nltk.corpus.words.words() + + +def benchmark_generator() -> dict: + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + + benchmark_data = { + "inputs": prompt, + "parameters": { + "max_new_tokens": 128, + "temperature": 0.7, + "return_full_text": False + } + } + + return benchmark_data + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate", + allow_parallel_requests=True, + max_queue_time=60.0, + benchmark_config=BenchmarkConfig( + generator=benchmark_generator, + concurrency=50 + ), + workload_calculator= lambda x: x["parameters"]["max_new_tokens"] + ), + HandlerConfig( + route="/generate_stream", + allow_parallel_requests=True, + max_queue_time=60.0, + workload_calculator= lambda x: x["parameters"]["max_new_tokens"] + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/examples/server/tutorial_worker.py b/examples/server/tutorial_worker.py new file mode 100644 index 00000000..4cb86223 --- /dev/null +++ b/examples/server/tutorial_worker.py @@ -0,0 +1,92 @@ +from vastai import Worker, WorkerConfig, HandlerConfig, BenchmarkConfig, LogActionConfig + +# We define a WorkerConfig object to configure our PyWorker +# Here, we can implement handlers for different routes our +# endpoint may serve +worker_config = WorkerConfig( + # --- Model Config --- + # The local URL of your model + model_server_url="http://127.0.0.1", + # The port your model is running on + model_server_port=18000, + # The file your model writes logs to + model_log_file="/var/model/out.log", + # If your model responds to a healthcheck, you can specify it here + model_healthcheck_url="/health", + + # --- Handler Config --- + # Here, we define potentially multiple endpoint handlers for our endpoint, + # each of which services a different route on our endpoint + handlers=[ + HandlerConfig( + # The route on our endpoint that we are handling + route="/my/route", + # Enable this if the model backend supports handling multiple requests at once + # If 'False', the worker will enforce one request at a time on the + # model backend with strict FIFO ordering. + allow_parallel_requests=False, + # --- Benchmark config --- + # One endpoint handler must implement a BenchmarkConfig + # The BenchmarkConfig defines sample payloads we use for + # measuring the performance of any given machine. + # This is essential for correct optimal autoscaling behavior. + benchmark_config=BenchmarkConfig( + # A list of possible request payloads to benchmark on + dataset=[ + { "prompt" : "some" }, + { "prompt" : "sample" }, + { "prompt" : "data" } + ], + # You may also implement a `generator` function, which + # returns a benchmark payload dictionary + # generator= lambda: { "prompt" : "a" * random.randint(60) } + + # How many times you should run the benchmark + runs= 5, + + # If `allow_parallel_requests` == True, how many concurrent payloads per run + concurrency=10 + ), + # A function that calculates the workload per request + # Example: the length of the input data + workload_calculator= lambda request: len(request["prompt"]) + ) + ], + + # --- Log Config --- + # Here, we define various LogActions, which inform our worker + # of model start, model error, or useful model information. + # It's important that your model outputs logs to the file + # specified in `model_log_file`, so the worker knows the state + # of the model and can react accordingly. + log_action_config=LogActionConfig( + # A log line from our model that indicates + # the model has completed loading and is ready + # to recieve requests + on_load=[ + "Application startup complete.", + ], + # The log lines from our model that indicate + # the model has suffered an irrecoverable error + # and our worker must be restarted + on_error=[ + "INFO exited: vllm", + "RuntimeError: Engine", + "Traceback (most recent call last):" + ], + # A log line the model may emit + # containing relevant information + on_info=[ + '"message":"Download' + ] + ) +) + +# --- Running the Worker --- +# Run the worker synchronously +Worker(worker_config).run() + +# Or, if you wish to continue executing Python from this entrypoint, +# you can run your PyWorker in an asyncio background task +# pyworker_task = asyncio.run(Worker(worker_config).run_async()) +# ... more python here ... \ No newline at end of file diff --git a/examples/server/vllm_worker.py b/examples/server/vllm_worker.py new file mode 100644 index 00000000..6cf17f0e --- /dev/null +++ b/examples/server/vllm_worker.py @@ -0,0 +1,78 @@ +import nltk +import random +import os + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# vLLM model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18000 +MODEL_LOG_FILE = '/var/log/portal/vllm.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# vLLM-specific log messages +MODEL_LOAD_LOG_MSG = [ + "Application startup complete.", +] + +MODEL_ERROR_LOG_MSGS = [ + "INFO exited: vllm", + "RuntimeError: Engine", + "Traceback (most recent call last):" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Download' +] + +nltk.download("words") +WORD_LIST = nltk.corpus.words.words() + + +def completions_benchmark_generator() -> dict: + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + model = os.environ.get("MODEL_NAME") + if not model: + raise ValueError("MODEL_NAME environment variable not set") + + benchmark_data = { + "model": model, + "prompt": prompt, + "temperature": 0.7, + "max_tokens": 500, + } + + return benchmark_data + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/v1/completions", + workload_calculator= lambda data: data.get("max_tokens", 0), + allow_parallel_requests=True, + max_queue_time=60.0, + benchmark_config=BenchmarkConfig( + generator=completions_benchmark_generator, + concurrency=100, + runs=2 + ) + ), + HandlerConfig( + route="/v1/chat/completions", + workload_calculator= lambda data: data.get("max_tokens", 0), + allow_parallel_requests=True, + max_queue_time=60.0, + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/examples/server/wan_example.py b/examples/server/wan_example.py new file mode 100644 index 00000000..174b5f4f --- /dev/null +++ b/examples/server/wan_example.py @@ -0,0 +1,288 @@ +import random +import sys + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# ComyUI model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18288 +MODEL_LOG_FILE = '/var/log/portal/comfyui.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# ComyUI-specific log messages +MODEL_LOAD_LOG_MSG = [ + "To see the GUI go to: " +] + +MODEL_ERROR_LOG_MSGS = [ + "MetadataIncompleteBuffer", + "Value not in list: ", + "[ERROR] Provisioning Script failed" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Downloading' +] + +benchmark_prompts = [ + "Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.", + "Cozy farming-game scene with fine details.", + "2D vector child with soccer ball; airbrush chrome; swagger; antique copper.", + "Realistic futuristic downtown of low buildings at sunset.", + "Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.", + "Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.", + "Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.", + "Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.", + "Medieval village inside glass sphere; volumetric light; macro focus.", + "Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.", + "Pope Francis DJ in leather jacket, mixing on giant console; dramatic.", +] + +benchmark_dataset = [ + { + "input": { + "workflow_json": { + "90": { + "inputs": { + "clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors", + "type": "wan", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + }, + "91": { + "inputs": { + "text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW", + "clip": [ + "90", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "92": { + "inputs": { + "vae_name": "wan_2.1_vae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "93": { + "inputs": { + "shift": 8.000000000000002, + "model": [ + "101", + 0 + ] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "94": { + "inputs": { + "shift": 8, + "model": [ + "102", + 0 + ] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "95": { + "inputs": { + "add_noise": "disable", + "noise_seed": 0, + "steps": 20, + "cfg": 3.5, + "sampler_name": "euler", + "scheduler": "simple", + "start_at_step": 10, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "94", + 0 + ], + "positive": [ + "99", + 0 + ], + "negative": [ + "91", + 0 + ], + "latent_image": [ + "96", + 0 + ] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "96": { + "inputs": { + "add_noise": "enable", + "noise_seed": "__RANDOM_INT__", + "steps": 20, + "cfg": 3.5, + "sampler_name": "euler", + "scheduler": "simple", + "start_at_step": 0, + "end_at_step": 10, + "return_with_leftover_noise": "enable", + "model": [ + "93", + 0 + ], + "positive": [ + "99", + 0 + ], + "negative": [ + "91", + 0 + ], + "latent_image": [ + "104", + 0 + ] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "97": { + "inputs": { + "samples": [ + "95", + 0 + ], + "vae": [ + "92", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "98": { + "inputs": { + "filename_prefix": "video/ComfyUI", + "format": "auto", + "codec": "auto", + "video": [ + "100", + 0 + ] + }, + "class_type": "SaveVideo", + "_meta": { + "title": "Save Video" + } + }, + "99": { + "inputs": { + "text":prompt, + "clip": [ + "90", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Positive Prompt)" + } + }, + "100": { + "inputs": { + "fps": 16, + "images": [ + "97", + 0 + ] + }, + "class_type": "CreateVideo", + "_meta": { + "title": "Create Video" + } + }, + "101": { + "inputs": { + "unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "102": { + "inputs": { + "unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "104": { + "inputs": { + "width": 640, + "height": 640, + "length": 81, + "batch_size": 1 + }, + "class_type": "EmptyHunyuanLatentVideo", + "_meta": { + "title": "EmptyHunyuanLatentVideo" + } + } + } + } + } for prompt in benchmark_prompts +] + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate/sync", + allow_parallel_requests=False, + max_queue_time=10.0, + benchmark_config=BenchmarkConfig( + dataset=benchmark_dataset, + runs=1 + ), + workload_calculator= lambda _ : 10000.0 + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index aeb35545..67bef6ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,26 +21,84 @@ requires-python = ">3.9.1, <4.0" license = { text = "MIT" } dynamic = ["version"] dependencies = [ + # --- CLI core --- "xdg", "argcomplete==3.5.1", - "requests (>=2.32.4)", - "borb (==2.1.25)", - "python-dateutil==2.6.1", + "requests>=2.32.3", + "borb==2.1.25", + "python-dateutil>=2.6.1", "pytz", "urllib3>=2.6.3", - "poetry-dynamic-versioning (>=1.8.1,<2.0.0)", - "gitpython (>=3.1.44,<4.0.0)", - "toml (>=0.10.2,<0.11.0)", - "curlify (>=2.2.1,<3.0.0)", + "curlify>=2.2.1,<3.0.0", "setuptools", - "cryptography (>=44.0.2,<45.0.0)", + "cryptography>=44.0.2,<45.0.0", "rich", "fonttools>=4.60.2", - "qrcode" + "qrcode", + # --- SDK core (needed for vastai.sdk) --- + "pyparsing>=3.1,<4.0", +] + +[project.optional-dependencies] +# Serverless client + server dependencies +serverless = [ + "aiohttp>=3.10,<4.0", + "aiodns>=3.2.0", + "pycares>=4.4.0", + "anyio>=4.4,<5.0", + "psutil>=6.0,<7.0", + "pycryptodome>=3.20,<4.0", +] +# Everything: SDK + serverless +all = [ + "vastai[serverless]", +] +dev = [ + "pytest>=8.0", + "pytest-cov>=4.0", + "mypy>=1.14", + "types-requests>=2.32", + "types-python-dateutil>=2.9", + "ruff>=0.9", + "pre-commit>=4.0", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = [ + "--cov=vastai", + "--cov-report=term-missing", + "--cov-report=html:coverage_html", +] + +[tool.coverage.run] +branch = true +source = ["vastai"] +omit = [ + "*/tests/*", + "*/.venv/*", + "*/site-packages/*", +] + +[tool.coverage.report] +precision = 2 +show_missing = true +fail_under = 44 +exclude_lines = [ + "pragma: no cover", + "if __name__ == .__main__.:", + "raise NotImplementedError", + "if TYPE_CHECKING:", ] [tool.poetry] -packages = [{ include = "utils" }, { include = "vast.py" }] +packages = [ + { include = "vast.py" }, + { include = "vast_config.py" }, + { include = "utils" }, + { include = "vastai" }, + { include = "vastai_sdk" }, +] version = "0.0.0" [project.scripts] @@ -58,3 +116,91 @@ style = "semver" [tool.poetry.requires-plugins] poetry-dynamic-versioning = { version = ">=1.0.0,<2.0.0", extras = ["plugin"] } + +[tool.mypy] +python_version = "3.10" +strict = true +warn_return_any = true +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true + +[[tool.mypy.overrides]] +module = [ + "borb.*", + "PIL.*", + "argcomplete.*", + "curlify.*", + "pyparsing.*", + "xdg.*", +] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "vast" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "vast_pdf" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "vastai.vastai_base" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "vastai.vastai_sdk" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "vastai.serverless.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "tests.*" +disallow_untyped_defs = false +disallow_untyped_calls = false + +[tool.ruff] +line-length = 120 +target-version = "py310" +exclude = [ + ".git", + ".venv", + "__pycache__", + "build", + "dist", + ".eggs", + "*.egg-info", +] + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "B", + "C4", + "UP", +] +ignore = [ + "E501", +] + +[tool.ruff.lint.per-file-ignores] +"vast.py" = ["T20", "A001", "A002", "E501", "B006", "F401", "F523", "F541", "F811", "F841", "E701", "E703", "E711", "E713", "E721", "E741"] +"vast_pdf.py" = ["T20", "E501"] +"tests/**/*.py" = ["F401", "S101"] +"vastai/vastai_base.py" = ["B027"] +"vastai/vastai_sdk.py" = ["E402", "F401"] +"vastai/serverless/**/*.py" = ["F401", "F541", "F811", "F841", "E501"] + +[tool.ruff.lint.isort] +known-first-party = ["vastai", "vast"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..efd75be9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,76 @@ +import json +import os +import pytest +import argparse +import requests +from unittest.mock import MagicMock, patch + + +@pytest.fixture +def mock_args(): + """Minimal argparse.Namespace for testing CLI functions.""" + return argparse.Namespace( + api_key="test-key", + url="https://console.vast.ai", + retry=3, + raw=False, + explain=False, + quiet=False, + curl=False, + full=False, + no_color=True, + debugging=False, + ) + + +@pytest.fixture +def mock_response(): + """Mock HTTP response with configurable status and JSON body.""" + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"success": True} + response.text = '{"success": true}' + response.content = b'{"success": true}' + response.headers = {"Content-Type": "application/json"} + response.raise_for_status = MagicMock() + return response + + +@pytest.fixture +def mock_api_response(): + """Factory fixture for creating mock API responses with configurable status and data.""" + def _make_response(status_code=200, json_data=None, text=None, headers=None): + response = MagicMock() + response.status_code = status_code + response.json.return_value = json_data if json_data is not None else {} + response.text = text if text is not None else json.dumps(json_data or {}) + response.content = response.text.encode() + response.headers = headers if headers is not None else {"Content-Type": "application/json"} + if status_code >= 400: + response.raise_for_status.side_effect = requests.HTTPError(f"{status_code} Error") + else: + response.raise_for_status = MagicMock() + return response + return _make_response + + +@pytest.fixture +def mock_http_get(mock_api_response): + """Patch vast.http_get to return controlled responses.""" + with patch('vast.http_get') as mock: + mock.return_value = mock_api_response(200, {"success": True}) + yield mock + + +@pytest.fixture +def mock_http_post(mock_api_response): + """Patch vast.http_post to return controlled responses.""" + with patch('vast.http_post') as mock: + mock.return_value = mock_api_response(200, {"success": True}) + yield mock + + +@pytest.fixture +def vast_cli_path(): + """Return path to vast.py for subprocess tests.""" + return os.path.join(os.path.dirname(__file__), '..', 'vast.py') diff --git a/tests/regression/__init__.py b/tests/regression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/regression/test_bare_except.py b/tests/regression/test_bare_except.py new file mode 100644 index 00000000..a2ac0ebb --- /dev/null +++ b/tests/regression/test_bare_except.py @@ -0,0 +1,102 @@ +"""No bare except: clauses in vast.py. + +The bug: Bare except: catches SystemExit and KeyboardInterrupt, which can +mask critical errors and make the program unresponsive to Ctrl+C. It also +swallows programming errors (NameError, TypeError) that should crash loudly. + +The fix: Replace all bare except: with specific exception types appropriate +to each try block's expected failure modes. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import pytest + + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +class TestNoBareExcept: + """Lint-style tests verifying no bare except: remains in vast.py.""" + + def test_no_bare_except(self): + """No bare except: clauses should exist in vast.py. + + A bare except: catches everything including SystemExit and + KeyboardInterrupt, making Ctrl+C ineffective and masking bugs. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + # Match "except:" but not "except SomeName:" or "except (A, B):" + bare_excepts = re.findall(r'^\s*except\s*:', content, re.MULTILINE) + assert len(bare_excepts) == 0, ( + f"Found {len(bare_excepts)} bare except: clauses in vast.py. " + "Each except must catch specific exception types." + ) + + def test_except_clauses_have_types(self): + """Every except clause should specify at least one exception type.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + # Find all except lines + except_lines = re.findall(r'^\s*except\b.*:', content, re.MULTILINE) + for line in except_lines: + stripped = line.strip() + # Must be "except SomeType:" or "except (A, B) as e:" etc. + # NOT just "except:" + assert stripped != "except:", ( + f"Found bare except: -- should catch specific types: {line!r}" + ) + + +class TestKeyboardInterruptPropagation: + """Verify KeyboardInterrupt is not swallowed by import handlers.""" + + def test_argcomplete_import_does_not_catch_keyboard_interrupt(self): + """The argcomplete import try/except should not catch KeyboardInterrupt. + + Before the fix, bare except: would catch KeyboardInterrupt during + import, making Ctrl+C during startup silently ignored. + """ + import importlib + import unittest.mock as mock + + # Simulate argcomplete import raising KeyboardInterrupt + original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + + def mock_import(name, *args, **kwargs): + if name == 'argcomplete': + raise KeyboardInterrupt() + return original_import(name, *args, **kwargs) + + # The except ImportError: handler should NOT catch KeyboardInterrupt + # so it should propagate + with pytest.raises(KeyboardInterrupt): + with mock.patch('builtins.__import__', side_effect=mock_import): + # Re-execute the import block logic + try: + __import__('argcomplete') + except ImportError: + pass # This is what the fixed code does + + def test_argcomplete_import_catches_import_error(self): + """The argcomplete import try/except should catch ImportError.""" + import unittest.mock as mock + + original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + + def mock_import(name, *args, **kwargs): + if name == 'argcomplete': + raise ImportError("No module named 'argcomplete'") + return original_import(name, *args, **kwargs) + + # ImportError should be caught (not propagated) + caught = False + with mock.patch('builtins.__import__', side_effect=mock_import): + try: + __import__('argcomplete') + except ImportError: + caught = True + assert caught, "ImportError should be caught by except ImportError:" diff --git a/tests/regression/test_direct_requests.py b/tests/regression/test_direct_requests.py new file mode 100644 index 00000000..97be8800 --- /dev/null +++ b/tests/regression/test_direct_requests.py @@ -0,0 +1,146 @@ +"""No direct requests.get/post calls outside http_* helpers. + +The bug: Several functions used requests.get() or requests.post() directly, +bypassing the centralized http_* helpers that provide timeout, retry, and +error handling. + +The fix: Convert all direct requests.get/post calls in CLI command functions +to use http_get/http_post. Only allowed exceptions are: + - get_project_data(): module-level PyPI check, no args available + - fetch_url_content(): utility function, no args available (dead code) + - _get_gpu_names(): module-level GPU cache, no args available + - http_request(): the low-level implementation that uses requests.Session + - import statements and commented-out code +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import ast +import re + + +# Allowed locations for direct requests.get/post calls +ALLOWED_FUNCTIONS = { + 'get_project_data', # Module-level PyPI check, no args + 'fetch_url_content', # Utility function, no args (dead code) + '_get_gpu_names', # Module-level GPU cache, no args + 'http_request', # Low-level implementation +} + + +def _get_vast_source(): + """Read the vast.py source file.""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, encoding='utf-8') as f: + return f.read() + + +def test_no_direct_requests_get_in_command_functions(): + """No requests.get() calls exist in CLI command functions (outside allowed exceptions).""" + source = _get_vast_source() + tree = ast.parse(source) + + violations = [] + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + func_name = node.name + if func_name in ALLOWED_FUNCTIONS: + continue + + # Walk the function body looking for requests.get( or requests.post( + for child in ast.walk(node): + if isinstance(child, ast.Call): + func = child.func + # Match: requests.get(...) or requests.post(...) + if (isinstance(func, ast.Attribute) and + isinstance(func.value, ast.Name) and + func.value.id == 'requests' and + func.attr in ('get', 'post')): + violations.append( + f" {func_name}() at line {child.lineno}: " + f"requests.{func.attr}()" + ) + + assert not violations, ( + "Found direct requests.get/post calls in command functions " + "(should use http_get/http_post):\n" + "\n".join(violations) + ) + + +def test_no_direct_requests_get_via_grep(): + """Grep-style check: no unprotected requests.get/post patterns in vast.py.""" + source = _get_vast_source() + lines = source.split('\n') + + violations = [] + pattern = re.compile(r'(?Bad Gateway" + mock_response.text = "Bad Gateway" + mock_response.json.side_effect = JSONDecodeError("msg", "doc", 0) + mock_response.raise_for_status.return_value = None + mock_http_get.return_value = mock_response + + args = _make_args() + result = api_call(args, "GET", "/test/") + + assert result is not None + assert "_raw_text" in result + assert result["_raw_text"] == "Bad Gateway" + + +@patch('vast.http_get') +def test_api_call_returns_json_on_valid_response(mock_http_get): + """api_call() returns parsed JSON normally when response is valid.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b'{"success": true}' + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + mock_http_get.return_value = mock_response + + args = _make_args() + result = api_call(args, "GET", "/test/") + + assert result == {"success": True} + + +@patch('vast.http_get') +def test_api_call_returns_none_for_empty_response(mock_http_get): + """api_call() returns None when response has no content.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.status_code = 204 + mock_response.content = b"" + mock_response.raise_for_status.return_value = None + mock_http_get.return_value = mock_response + + args = _make_args() + result = api_call(args, "GET", "/test/") + + assert result is None + + +@patch('vast.http_get') +def test_api_call_no_exception_raised_on_html_response(mock_http_get): + """api_call() does NOT raise an exception when API returns HTML.""" + from vast import api_call, JSONDecodeError + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"502 Bad Gateway" + mock_response.text = "502 Bad Gateway" + mock_response.json.side_effect = JSONDecodeError("Expecting value", "", 0) + mock_response.raise_for_status.return_value = None + mock_http_get.return_value = mock_response + + args = _make_args() + + # This should NOT raise - the whole point of the fix + try: + result = api_call(args, "GET", "/instances/") + except JSONDecodeError: + pytest.fail("api_call() should not raise JSONDecodeError") + + assert "_raw_text" in result + + +@patch('vast.http_post') +def test_api_call_post_handles_json_decode_error(mock_http_post): + """api_call() handles JSONDecodeError for POST requests too.""" + from vast import api_call, JSONDecodeError + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"Internal Server Error" + mock_response.text = "Internal Server Error" + mock_response.json.side_effect = JSONDecodeError("msg", "doc", 0) + mock_response.raise_for_status.return_value = None + mock_http_post.return_value = mock_response + + args = _make_args() + result = api_call(args, "POST", "/instances/", json_body={"test": True}) + + assert result is not None + assert result["_raw_text"] == "Internal Server Error" diff --git a/tests/regression/test_main_error_handling.py b/tests/regression/test_main_error_handling.py new file mode 100644 index 00000000..b5d4187b --- /dev/null +++ b/tests/regression/test_main_error_handling.py @@ -0,0 +1,221 @@ +"""Regression tests for code quality fixes. + +- No Python builtins (id, sum) should be shadowed by local variables +- strip('-') should not be used for prefix removal (use startswith + lstrip) +- Error messages should reference correct field name +- Unused variables should be removed +""" +import re +import sys +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestVariableShadowing: + """No Python builtins should be shadowed by local variable assignments.""" + + def test_no_id_shadowing(self): + """Local variable 'id' should not shadow builtin. + + Note: keyword args in argparse.Namespace() like id=value are NOT shadowing. + """ + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find all assignments to bare 'id' (not id= which is keyword arg) + # Pattern: whitespace + id + optional space + = + not = (not ==) + # This should NOT match lines like "id=ask_contract_id," in Namespace() + matches = re.findall(r'^\s+id\s+=[^=]', content, re.MULTILINE) + + assert len(matches) == 0, ( + f"Found {len(matches)} instances of 'id' shadowing builtin. " + f"These should be renamed to domain-specific names like instance_id, " + f"workergroup_id, etc. First few matches: {matches[:5]}" + ) + + def test_no_sum_function_shadowing(self): + """Function 'sum' should not shadow builtin. + + The custom sum function should be renamed to sum_field or similar. + """ + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Check for def sum( function definition + matches = re.findall(r'^def sum\s*\(', content, re.MULTILINE) + + assert len(matches) == 0, ( + f"Found def sum() which shadows Python builtin. " + f"Should be renamed to sum_field() or similar." + ) + + def test_sum_field_function_exists(self): + """The renamed sum_field function should exist.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Check for def sum_field( function definition + assert 'def sum_field(' in content, ( + "Expected sum_field() function to exist after renaming from sum()" + ) + + def test_domain_specific_id_names_used(self): + """Verify domain-specific id names are used instead of bare 'id'.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + expected_patterns = [ + 'workergroup_id = args.id', + 'endpoint_id = args.id', + 'volume_id = args.id', + ] + + for pattern in expected_patterns: + assert pattern in content, ( + f"Expected domain-specific name pattern '{pattern}' not found. " + f"Bare 'id' may not have been renamed properly." + ) + + +class TestStringMethodFixes: + """strip('-') should not be used for prefix removal.""" + + def test_no_strip_dash_for_direction(self): + """strip('-') removes from both ends - should use startswith + lstrip.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # .strip("-") or .strip('-') should not appear for direction parsing + matches = re.findall(r'name\.strip\(["\'][-+]["\']', content) + + assert len(matches) == 0, ( + f"Found {len(matches)} instances of name.strip('-') or name.strip('+'). " + f"These should use startswith() + lstrip() instead: {matches}" + ) + + def test_direction_parsing_uses_startswith(self): + """Sort direction parsing should use startswith, not strip comparison.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Should find startswith("-") for direction checks + has_startswith_minus = 'name.startswith("-")' in content + has_startswith_plus = 'name.startswith("+")' in content + + assert has_startswith_minus, ( + "Expected name.startswith('-') for descending sort direction parsing" + ) + assert has_startswith_plus, ( + "Expected name.startswith('+') for ascending sort direction parsing" + ) + + def test_direction_parsing_uses_lstrip(self): + """After detecting prefix, should use lstrip to remove it.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Should find lstrip("-") and lstrip("+") + has_lstrip_minus = 'name.lstrip("-")' in content + has_lstrip_plus = 'name.lstrip("+")' in content + + assert has_lstrip_minus, ( + "Expected name.lstrip('-') for removing descending prefix" + ) + assert has_lstrip_plus, ( + "Expected name.lstrip('+') for removing ascending prefix" + ) + + def test_elif_structure_for_direction_parsing(self): + """Direction parsing should use elif, not two separate if statements. + + Using two separate if statements means a field like '-score' would first + match startswith('-') then incorrectly also check startswith('+'). + """ + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Look for the correct pattern: if startswith("-") ... elif startswith("+") + # Not: if startswith("-") ... if startswith("+") + pattern = r'if name\.startswith\("-"\):.*?elif name\.startswith\("\+"\):' + matches = re.findall(pattern, content, re.DOTALL) + + assert len(matches) >= 4, ( + f"Expected at least 4 occurrences of correct if/elif pattern for " + f"direction parsing (in search__offers, search__instances, " + f"search__volumes, search__network_volumes), found {len(matches)}" + ) + + +class TestErrorMessages: + """Error messages should reference correct field.""" + + def test_start_date_error_message_correct(self): + """Error for start_date should say 'start date', not 'end date'.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Should not have "start date" error saying "Ignoring end date" + bad_pattern = re.search( + r'Invalid start date.*Ignoring end date', + content, + re.IGNORECASE + ) + + assert bad_pattern is None, ( + "Found misleading error message - start_date error mentions 'Ignoring end date'. " + "Should say 'Ignoring start date' instead." + ) + + def test_start_date_error_says_start_date(self): + """Start date errors should reference start date.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Should find "Invalid start date" + "Ignoring start date" pattern + correct_pattern = re.search( + r'Invalid start date.*Ignoring start date', + content, + re.IGNORECASE + ) + + assert correct_pattern is not None, ( + "Expected start_date error message to say 'Ignoring start date'" + ) + + +class TestUnusedVariables: + """Unused variables should be removed.""" + + def test_no_unused_date_txt_in_show_earnings(self): + """In show__earnings, date_txt variables should not be assigned if unused. + + Note: date_txt variables ARE used in invoice functions, just not in show__earnings. + """ + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find the show__earnings function + show_earnings_match = re.search( + r'def show__earnings\(args\):.*?(?=\ndef |$)', + content, + re.DOTALL + ) + + assert show_earnings_match is not None, "Could not find show__earnings function" + + show_earnings_code = show_earnings_match.group(0) + + # Check that end_date_txt and start_date_txt are not assigned in this function + has_end_date_txt = 'end_date_txt' in show_earnings_code + has_start_date_txt = 'start_date_txt' in show_earnings_code + + assert not has_end_date_txt, ( + "Found unused end_date_txt assignment in show__earnings. " + "This variable is assigned but never used in this function." + ) + assert not has_start_date_txt, ( + "Found unused start_date_txt assignment in show__earnings. " + "This variable is assigned but never used in this function." + ) diff --git a/tests/regression/test_main_raw_output.py b/tests/regression/test_main_raw_output.py new file mode 100644 index 00000000..7e9df44f --- /dev/null +++ b/tests/regression/test_main_raw_output.py @@ -0,0 +1,199 @@ +"""Regression tests for API request and raw output fixes.""" +import re +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestApiKeyHandling: + """API key should be in headers, not JSON body.""" + + def test_no_api_key_in_json_blob(self): + """api_key should not appear in json_blob assignments.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Look for json_blob containing api_key (but not in comments) + lines = content.split('\n') + violations = [] + for i, line in enumerate(lines, 1): + # Skip comments + if line.strip().startswith('#'): + continue + # Check for api_key in json dict literal on same line as json_blob + if 'json_blob' in line and 'api_key' in line and '=' in line: + violations.append(f"Line {i}: {line.strip()}") + + assert len(violations) == 0, f"Found api_key in json_blob: {violations}" + + def test_get_endpt_logs_no_api_key_in_body(self): + """get__endpt_logs should not have api_key in JSON.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find function and check its body + pattern = r'def get__endpt_logs\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "get__endpt_logs function not found" + + func_body = match.group(0) + # Should not have api_key in any dict literal + json_lines = [l for l in func_body.split('\n') + if 'json_blob' in l and 'api_key' in l and '=' in l + and not l.strip().startswith('#')] + assert len(json_lines) == 0, f"get__endpt_logs has api_key in json_blob: {json_lines}" + + def test_get_wrkgrp_logs_no_api_key_in_body(self): + """get__wrkgrp_logs should not have api_key in JSON.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def get__wrkgrp_logs\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "get__wrkgrp_logs function not found" + + func_body = match.group(0) + json_lines = [l for l in func_body.split('\n') + if 'json_blob' in l and 'api_key' in l and '=' in l + and not l.strip().startswith('#')] + assert len(json_lines) == 0, f"get__wrkgrp_logs has api_key in json_blob: {json_lines}" + + def test_show_workergroups_no_api_key_in_body(self): + """show__workergroups should not have api_key in JSON.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def show__workergroups\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "show__workergroups function not found" + + func_body = match.group(0) + json_lines = [l for l in func_body.split('\n') + if 'json_blob' in l and 'api_key' in l and '=' in l + and not l.strip().startswith('#')] + assert len(json_lines) == 0, f"show__workergroups has api_key in json_blob: {json_lines}" + + def test_show_endpoints_no_api_key_in_body(self): + """show__endpoints should not have api_key in JSON.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def show__endpoints\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "show__endpoints function not found" + + func_body = match.group(0) + json_lines = [l for l in func_body.split('\n') + if 'json_blob' in l and 'api_key' in l and '=' in l + and not l.strip().startswith('#')] + assert len(json_lines) == 0, f"show__endpoints has api_key in json_blob: {json_lines}" + + +class TestSafeIteration: + """next() calls should have default to prevent StopIteration.""" + + def test_next_calls_have_default(self): + """All next() calls with generators should have a default parameter.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find next() calls without default + # Pattern: next(something for something) without a comma before closing + lines = content.split('\n') + risky_calls = [] + for i, line in enumerate(lines, 1): + if 'next(' in line and not line.strip().startswith('#'): + # Simple heuristic: if line has next( with 'for' inside but no comma + # This catches: next(x for x in y) but not next((x for x in y), None) + match = re.search(r'next\(\s*[^,)]+\s+for\s+[^,)]+\)', line) + if match: + risky_calls.append(f"Line {i}: {line.strip()}") + + assert len(risky_calls) == 0, f"Found next() without default: {risky_calls}" + + def test_show_clusters_next_has_default(self): + """show__clusters manager_node lookup should have default.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def show__clusters\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "show__clusters function not found" + + func_body = match.group(0) + + # Should have next(..., None) pattern - may be multiline with nested parens + assert 'next(' in func_body, "show__clusters should use next()" + # Check for manager_node = next(..., None) with multiline and nested parens + # The pattern has: next(\n(generator),\nNone\n) + has_none_default = re.search(r'manager_node\s*=\s*next\s*\(.*?,\s*None\s*\)', func_body, re.DOTALL) + assert has_none_default, "show__clusters next() should have None default" + + def test_show_clusters_handles_missing_manager(self): + """show__clusters should handle case when no manager node exists.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def show__clusters\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "show__clusters function not found" + + func_body = match.group(0) + + # Should check for None manager_node + assert 'manager_node is None' in func_body or 'if manager_node is None' in func_body, \ + "show__clusters should check for None manager_node" + + +class TestTransferCredit: + """--transfer_credit should be implemented or removed from docs.""" + + def test_transfer_credit_not_in_create_team_epilog(self): + """create team epilog should not mention --transfer_credit.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find create__team function and its decorator + pattern = r'@parser\.command\([^)]*argument\([^)]*team_name[^)]*\)[^)]*\)[^@]*def create__team' + match = re.search(pattern, content, re.DOTALL) + if match: + decorator_and_func = match.group(0) + # Should not mention --transfer_credit as a flag + assert '--transfer_credit' not in decorator_and_func, \ + "create team should not document --transfer_credit as a flag" + + def test_transfer_credit_consistency(self): + """If --transfer_credit is mentioned, it should be documented correctly.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find create__team decorator/epilog section + pattern = r'@parser\.command\(\s*argument\("--team_name".*?def create__team' + match = re.search(pattern, content, re.DOTALL) + + if match: + section = match.group(0) + # If transfer_credit is mentioned at all, it should NOT be as --transfer_credit flag + if 'transfer_credit' in section.lower(): + # Should be pointing to the separate command, not documenting a flag + assert 'vastai transfer credit' in section or '--transfer_credit' not in section, \ + "transfer_credit should point to 'vastai transfer credit' command, not a flag" + + def test_transfer_credit_command_exists(self): + """The transfer__credit command should exist as separate command.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # transfer__credit should exist as its own command + assert 'def transfer__credit' in content, \ + "transfer__credit should exist as a separate command" + + # It should have proper argument decorators - multiline decorator + # Look for recipient and amount in argument() calls before def transfer__credit + pattern = r'@parser\.command\(.*?def transfer__credit' + match = re.search(pattern, content, re.DOTALL) + assert match, "transfer__credit should have @parser.command decorator" + decorator_section = match.group(0) + assert 'recipient' in decorator_section, "transfer__credit should have recipient argument" + assert 'amount' in decorator_section, "transfer__credit should have amount argument" diff --git a/tests/regression/test_mutable_defaults.py b/tests/regression/test_mutable_defaults.py new file mode 100644 index 00000000..777c2b72 --- /dev/null +++ b/tests/regression/test_mutable_defaults.py @@ -0,0 +1,54 @@ +"""Mutable default arguments in http_put, http_post, http_del. + +The bug: Using json={} as a default argument means all calls share the same +dict object. If any caller mutates the dict, subsequent calls see the mutation. + +The fix: Use json=None and initialize to {} at runtime. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +from unittest.mock import MagicMock, patch +import argparse + + +def _make_args(): + return argparse.Namespace(retry=1, curl=False) + + +def _mock_http_request(verb, args, req_url, headers, json, **kwargs): + """Capture the json arg and return a mock response.""" + r = MagicMock() + r.status_code = 200 + # Store reference to the json dict that was passed + r._captured_json = json + return r + + +@patch('vast.http_request', side_effect=_mock_http_request) +def test_http_put_no_shared_default(mock_req): + from vast import http_put + args = _make_args() + r1 = http_put(args, "http://test1", headers=None) + r2 = http_put(args, "http://test2", headers=None) + # Each call must get its own dict, not share the mutable default + assert r1._captured_json is not r2._captured_json + + +@patch('vast.http_request', side_effect=_mock_http_request) +def test_http_post_no_shared_default(mock_req): + from vast import http_post + args = _make_args() + r1 = http_post(args, "http://test1", headers=None) + r2 = http_post(args, "http://test2", headers=None) + assert r1._captured_json is not r2._captured_json + + +@patch('vast.http_request', side_effect=_mock_http_request) +def test_http_del_no_shared_default(mock_req): + from vast import http_del + args = _make_args() + r1 = http_del(args, "http://test1", headers=None) + r2 = http_del(args, "http://test2", headers=None) + assert r1._captured_json is not r2._captured_json diff --git a/tests/regression/test_namespace_typo.py b/tests/regression/test_namespace_typo.py new file mode 100644 index 00000000..2aa8a79b --- /dev/null +++ b/tests/regression/test_namespace_typo.py @@ -0,0 +1,23 @@ +"""Typo 'debbuging' in Namespace construction. + +The bug: destroy_args = argparse.Namespace(..., debbuging=args.debugging, ...) +creates an attribute 'debbuging' instead of 'debugging'. Any code accessing +destroy_args.debugging gets AttributeError. + +The fix: Change 'debbuging' to 'debugging'. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + + +def test_no_debbuging_typo_in_source(): + """Verify the typo 'debbuging' does not appear in vast.py source.""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, 'r', encoding='utf-8', errors='replace') as f: + source = f.read() + + assert 'debbuging' not in source, ( + "Found 'debbuging' typo in vast.py. " + "Should be 'debugging' in the Namespace construction." + ) diff --git a/tests/regression/test_parse_query.py b/tests/regression/test_parse_query.py new file mode 100644 index 00000000..9d6bfaac --- /dev/null +++ b/tests/regression/test_parse_query.py @@ -0,0 +1,59 @@ +"""parse_query() field alias bug -- v references dict after pop. + +The bug: `v = res.setdefault(field, {})` gets a reference to a dict at the +original field name. Then `res.pop(field)` removes it from res. Writing to v +modifies an orphaned dict that's no longer in res. + +Fields affected: cuda_vers->cuda_max_good, dph->dph_total, dlperf_usd->dlperf_per_dphtotal + +The fix: Resolve alias before calling setdefault. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + + +def test_field_alias_cuda_vers(): + """parse_query correctly aliases cuda_vers to cuda_max_good.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("cuda_vers >= 12.0", {}, offers_fields, offers_alias) + # The result should have cuda_max_good, NOT cuda_vers + assert 'cuda_max_good' in result, ( + f"Expected 'cuda_max_good' in result, got keys: {list(result.keys())}. " + f"Field alias not applied correctly." + ) + assert 'cuda_vers' not in result, "Old field name 'cuda_vers' should not be in result" + assert 'gte' in result['cuda_max_good'], ( + f"Expected 'gte' operator in cuda_max_good, got: {result['cuda_max_good']}" + ) + + +def test_field_alias_dph(): + """parse_query correctly aliases dph to dph_total.""" + from vast import parse_query, offers_fields, offers_alias, offers_mult + + result = parse_query("dph <= 1.5", {}, offers_fields, offers_alias, offers_mult) + assert 'dph_total' in result, f"Expected 'dph_total', got keys: {list(result.keys())}" + assert 'dph' not in result + assert 'lte' in result['dph_total'] + + +def test_field_alias_value_preserved(): + """After alias resolution, the operator value is correctly stored.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("cuda_vers >= 12.0", {}, offers_fields, offers_alias) + assert result['cuda_max_good']['gte'] == '12.0', ( + f"Expected value '12.0', got: {result['cuda_max_good'].get('gte')}" + ) + + +def test_non_aliased_field_unaffected(): + """Fields without aliases work normally.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("num_gpus >= 2", {}, offers_fields, offers_alias) + assert 'num_gpus' in result + assert 'gte' in result['num_gpus'] + assert result['num_gpus']['gte'] == '2' diff --git a/tests/regression/test_raw_completeness.py b/tests/regression/test_raw_completeness.py new file mode 100644 index 00000000..a798c9dd --- /dev/null +++ b/tests/regression/test_raw_completeness.py @@ -0,0 +1,176 @@ +""" +Regression Tests: Raw Mode Completeness + +Verifies that command functions have consistent --raw handling. +""" +import re +import sys +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestRawModeCompleteness: + """Tests that command functions have raw mode handling.""" + + def setup_method(self): + """Load vast.py source code once per test.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + self.source = vast_path.read_text(encoding="utf-8") + + def test_minimum_raw_handlers(self): + """Verify at least 90 raw mode handlers exist.""" + raw_count = len(re.findall(r"if args\.raw:", self.source)) + assert raw_count >= 90, f"Expected at least 90 raw handlers, found {raw_count}" + + def test_command_functions_have_raw_handling(self): + """Check that common command functions have if args.raw: patterns.""" + # Key command functions that should have raw handling + expected_functions = [ + "attach__ssh", + "cancel__copy", + "cancel__sync", + "change__bid", + "create__api_key", + "create__env_var", + "create__ssh_key", + "create__workergroup", + "create__endpoint", + "create__subaccount", + "create__team", + "create__team_role", + "create__template", + "delete__api_key", + "delete__ssh_key", + "delete__scheduled_job", + "delete__workergroup", + "delete__endpoint", + "delete__env_var", + "delete__template", + "destroy__team", + "detach__ssh", + "invite__member", + "label__instance", + "prepay__instance", + "reboot__instance", + "recycle__instance", + "remove__member", + "remove__team_role", + "reports", + "reset__api_key", + "transfer__credit", + ] + + for func_name in expected_functions: + # Find the function definition + func_pattern = rf"^def {func_name}\(args" + match = re.search(func_pattern, self.source, re.MULTILINE) + assert match, f"Function {func_name} not found" + + # Get function body (rough approximation - from def to next def or EOF) + start = match.start() + next_def = re.search(r"^def \w+\(", self.source[start + 10:], re.MULTILINE) + end = start + 10 + next_def.start() if next_def else len(self.source) + func_body = self.source[start:end] + + # Check for raw handling + has_raw = "if args.raw:" in func_body + assert has_raw, f"Function {func_name} missing 'if args.raw:' handling" + + def test_no_orphan_print_without_raw_check(self): + """ + Verify that functions returning JSON data have raw checks before prints. + + This is a sampling test - check a few critical functions. + """ + # Functions that should have raw handling before their main output + critical_patterns = [ + # (function_name, expected_pattern_after_raw_check) + ("create__team", r"if args\.raw:.*?return.*?print\(result\)"), + ("delete__api_key", r"if args\.raw:.*?return.*?print\(result\)"), + ] + + for func_name, _ in critical_patterns: + func_pattern = rf"^def {func_name}\(args" + match = re.search(func_pattern, self.source, re.MULTILINE) + assert match, f"Function {func_name} not found" + + start = match.start() + next_def = re.search(r"^def \w+\(", self.source[start + 10:], re.MULTILINE) + end = start + 10 + next_def.start() if next_def else len(self.source) + func_body = self.source[start:end] + + # Verify raw check exists + assert "if args.raw:" in func_body, f"{func_name} should have raw check" + + +class TestRawModeReturnsData: + """Tests that raw mode handlers return data, not None.""" + + def setup_method(self): + """Load vast.py source code once per test.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + self.source = vast_path.read_text(encoding="utf-8") + + def test_raw_handlers_have_return_statements(self): + """Verify raw mode handlers include return statements.""" + # Find all if args.raw: blocks + raw_pattern = r"if args\.raw:\s*\n\s+return" + matches = re.findall(raw_pattern, self.source) + # Should have many return statements following raw checks + assert len(matches) >= 80, f"Expected 80+ 'if args.raw: return' patterns, found {len(matches)}" + + def test_no_empty_raw_blocks(self): + """Verify no raw blocks that just pass or do nothing.""" + # Pattern for raw checks that just pass + empty_raw_pattern = r"if args\.raw:\s*\n\s+pass\s*\n" + matches = re.findall(empty_raw_pattern, self.source) + assert len(matches) == 0, f"Found {len(matches)} empty 'if args.raw: pass' blocks" + + +class TestRawModeConsistency: + """Tests for consistent raw mode patterns across the codebase.""" + + def setup_method(self): + """Load vast.py source code once per test.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + self.source = vast_path.read_text(encoding="utf-8") + + def test_consistent_raw_return_pattern(self): + """ + Verify raw mode uses consistent return patterns. + + Expected patterns: + - if args.raw: return rj + - if args.raw: return result + - if args.raw: return rows + - if args.raw: return data + """ + # Find all raw return patterns + raw_return_pattern = r"if args\.raw:\s*\n\s+return\s+(\w+)" + matches = re.findall(raw_return_pattern, self.source) + + # Common return variable names + valid_names = {"rj", "result", "rows", "data", "processed", "user_blob", + "response_data", "instances", "volumes", "machines", "offers"} + + for var_name in matches: + # Allow any reasonable variable name (not just the common ones) + # This is a sanity check - variables should be short identifiers + assert len(var_name) < 30, f"Suspicious return variable: {var_name}" + + def test_output_result_handles_raw(self): + """Verify output_result function exists and handles raw mode.""" + # Check that output_result is defined + assert "def output_result(" in self.source, "output_result function not found" + + # Check that output_result checks args.raw + output_result_match = re.search( + r"def output_result\(.*?\n(.*?)(?=^def |\Z)", + self.source, + re.MULTILINE | re.DOTALL + ) + assert output_result_match, "Could not extract output_result function body" + func_body = output_result_match.group(1) + assert "args.raw" in func_body, "output_result should check args.raw" diff --git a/tests/regression/test_raw_errors.py b/tests/regression/test_raw_errors.py new file mode 100644 index 00000000..b30b54b7 --- /dev/null +++ b/tests/regression/test_raw_errors.py @@ -0,0 +1,188 @@ +"""Error messages in --raw mode must be valid JSON. + +The bug: When --raw mode is active and an HTTPError or ValueError occurs, +the error handler prints plain text (e.g., "failed with error 500: ..."). +Scripts and automation tools parsing JSON output get broken by mixed +text/JSON output. + +The fix: Check args.raw in the HTTPError and ValueError exception handlers +in main(), and output a JSON object with error/status_code/msg fields. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import json +import argparse +import pytest +from unittest import mock +from io import StringIO + +import requests + + +class TestHTTPErrorRawMode: + """Verify HTTPError in --raw mode produces valid JSON.""" + + def _make_mock_response(self, status_code, json_body=None): + """Create a mock response for HTTPError.""" + resp = mock.MagicMock() + resp.status_code = status_code + if json_body is not None: + resp.json.return_value = json_body + else: + resp.json.side_effect = json.JSONDecodeError("No JSON", "", 0) + return resp + + def _run_main_with_error(self, args): + """Run vast.main() with proper mocking, return captured stdout.""" + import vast + + captured = StringIO() + with mock.patch.object(vast, 'ARGS', args): + with mock.patch('vast.parser') as mock_parser: + mock_parser.parse_args.return_value = args + mock_parser.add_argument = mock.MagicMock() + mock_parser.parser = mock.MagicMock() + with mock.patch('vast.should_check_for_update', False): + with mock.patch('vast.TABCOMPLETE', False): + with mock.patch('vast.api_key_guard', 'GUARD'): + with mock.patch('sys.stdout', captured): + try: + vast.main() + except SystemExit: + pass + return captured.getvalue().strip() + + def test_http_error_raw_mode_produces_json(self): + """HTTPError with --raw should output valid JSON with error/status_code/msg.""" + mock_resp = self._make_mock_response(500, {"msg": "Internal server error"}) + http_error = requests.exceptions.HTTPError(response=mock_resp) + + args = argparse.Namespace( + raw=True, func=mock.MagicMock(side_effect=http_error), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + parsed = json.loads(output) + assert parsed["error"] is True + assert parsed["status_code"] == 500 + assert parsed["msg"] == "Internal server error" + + def test_http_error_raw_mode_401_produces_json(self): + """HTTPError 401 with --raw should output JSON with login message.""" + mock_resp = self._make_mock_response(401, json_body=None) + http_error = requests.exceptions.HTTPError(response=mock_resp) + + args = argparse.Namespace( + raw=True, func=mock.MagicMock(side_effect=http_error), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + parsed = json.loads(output) + assert parsed["error"] is True + assert parsed["status_code"] == 401 + assert "log in" in parsed["msg"].lower() or "sign up" in parsed["msg"].lower() + + def test_http_error_non_raw_mode_produces_text(self): + """HTTPError without --raw should output human-readable text.""" + mock_resp = self._make_mock_response(500, {"msg": "Server error"}) + http_error = requests.exceptions.HTTPError(response=mock_resp) + + args = argparse.Namespace( + raw=False, func=mock.MagicMock(side_effect=http_error), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + # Non-raw should NOT be valid JSON with error key + assert "failed with error" in output.lower() + + +class TestValueErrorRawMode: + """Verify ValueError in --raw mode produces valid JSON.""" + + def _run_main_with_error(self, args): + """Run vast.main() with proper mocking, return captured stdout.""" + import vast + + captured = StringIO() + with mock.patch.object(vast, 'ARGS', args): + with mock.patch('vast.parser') as mock_parser: + mock_parser.parse_args.return_value = args + mock_parser.add_argument = mock.MagicMock() + mock_parser.parser = mock.MagicMock() + with mock.patch('vast.should_check_for_update', False): + with mock.patch('vast.TABCOMPLETE', False): + with mock.patch('vast.api_key_guard', 'GUARD'): + with mock.patch('sys.stdout', captured): + try: + vast.main() + except SystemExit: + pass + return captured.getvalue().strip() + + def test_value_error_raw_mode_produces_json(self): + """ValueError with --raw should output valid JSON with error/msg.""" + args = argparse.Namespace( + raw=True, func=mock.MagicMock(side_effect=ValueError("bad value")), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + parsed = json.loads(output) + assert parsed["error"] is True + assert parsed["msg"] == "bad value" + + def test_value_error_non_raw_mode_produces_text(self): + """ValueError without --raw should print the error message as text.""" + args = argparse.Namespace( + raw=False, func=mock.MagicMock(side_effect=ValueError("bad value")), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + assert output == "bad value" + + +class TestRawErrorHandlerLintChecks: + """Lint-style tests to verify raw error handling exists in main().""" + + def test_httperror_handler_checks_args_raw(self): + """The HTTPError handler in main() must check args.raw.""" + VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + # Find the HTTPError handler section + assert 'HTTPError' in content, "HTTPError handler must exist" + # Find args.raw in the context of error handling + import re + # Look for args.raw near HTTPError handling + http_error_section = content[content.index('HTTPError'):] + # Limit to the next except or end of function + next_section = http_error_section[:http_error_section.index('except ValueError')] + assert 'args.raw' in next_section, ( + "HTTPError handler must check args.raw for JSON output" + ) + + def test_valueerror_handler_checks_args_raw(self): + """The ValueError handler in main() must check args.raw.""" + VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + # Find the ValueError handler in main() - the one near end of file + # Look for the last 'except ValueError' which is in main() + last_ve_idx = content.rindex('except ValueError') + ve_section = content[last_ve_idx:last_ve_idx + 300] + assert 'args.raw' in ve_section, ( + "ValueError handler in main() must check args.raw for JSON output" + ) diff --git a/tests/regression/test_retry.py b/tests/regression/test_retry.py new file mode 100644 index 00000000..ab1822a5 --- /dev/null +++ b/tests/regression/test_retry.py @@ -0,0 +1,233 @@ +"""Incomplete retry logic -- only retries on HTTP 429. + +The bug: http_request() only retries when the server returns 429 (rate limit). +Transient 5xx errors (502, 503, 504) and connection failures cause immediate +command failure instead of being retried. + +The fix: Expand retry to cover {429, 502, 503, 504} status codes and split +exception handling so ConnectionError/Timeout are retried while other +RequestException subclasses are raised immediately. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +from unittest.mock import MagicMock, patch +import pytest +import requests.exceptions + + +def _make_args(retry=3): + return argparse.Namespace(retry=retry, curl=False) + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_502(mock_session_cls, mock_sleep): + """http_request retries on 502 Bad Gateway and recovers.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_502 = MagicMock() + response_502.status_code = 502 + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [response_502, response_200] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_503(mock_session_cls, mock_sleep): + """http_request retries on 503 Service Unavailable.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_503 = MagicMock() + response_503.status_code = 503 + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [response_503, response_200] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_504(mock_session_cls, mock_sleep): + """http_request retries on 504 Gateway Timeout.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_504 = MagicMock() + response_504.status_code = 504 + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [response_504, response_200] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_429_still_works(mock_session_cls, mock_sleep): + """Original 429 retry behavior is preserved.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_429 = MagicMock() + response_429.status_code = 429 + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [response_429, response_200] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + # Should have slept once (after the 429) + assert mock_sleep.call_count == 1 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_connection_error(mock_session_cls, mock_sleep): + """http_request retries on ConnectionError and recovers.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [ + requests.exceptions.ConnectionError("connection refused"), + response_200, + ] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_timeout_exception(mock_session_cls, mock_sleep): + """http_request retries on Timeout exception and recovers.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [ + requests.exceptions.Timeout("read timed out"), + response_200, + ] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_no_retry_on_non_retryable_exception(mock_session_cls, mock_sleep): + """Non-retryable RequestException subclasses are raised immediately.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + # InvalidURL is a non-retryable error + mock_session.send.side_effect = requests.exceptions.InvalidURL("bad url") + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + with pytest.raises(requests.exceptions.InvalidURL): + http_request('GET', args, 'http://example.com/test') + + # Should NOT have retried -- only 1 call + assert mock_session.send.call_count == 1 + # Should NOT have slept + assert mock_sleep.call_count == 0 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_no_retry_on_non_retryable_status(mock_session_cls, mock_sleep): + """Non-retryable status codes (e.g., 400, 404, 500) break out immediately.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_400 = MagicMock() + response_400.status_code = 400 + + mock_session.send.return_value = response_400 + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + # Should return on first call without retrying + assert result.status_code == 400 + assert mock_session.send.call_count == 1 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retryable_status_codes_constant(mock_session_cls, mock_sleep): + """RETRYABLE_STATUS_CODES contains exactly {429, 502, 503, 504}.""" + from vast import RETRYABLE_STATUS_CODES + + assert RETRYABLE_STATUS_CODES == {429, 502, 503, 504} diff --git a/tests/regression/test_return_response.py b/tests/regression/test_return_response.py new file mode 100644 index 00000000..c97d9ca1 --- /dev/null +++ b/tests/regression/test_return_response.py @@ -0,0 +1,58 @@ +"""22 functions return Response object instead of parsed JSON. + +The bug: Functions using http_* directly return `r` (Response object) in raw +mode. Response objects are not JSON-serializable, causing json.dumps to fail. +The bare except: in main() masks this by calling res.json() as a fallback. + +The fix: Change `return r` to `return r.json()` in all 22 functions. +Also: Remove bare except in main() since all returns are now serializable. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re + + +def test_no_bare_return_r_in_functions(): + """No function (except http_request) returns bare `r` (Response object).""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, 'r', encoding='utf-8', errors='replace') as f: + lines = f.readlines() + + # Find all 'return r' lines outside http_request + in_http_request = False + violations = [] + for i, line in enumerate(lines, 1): + stripped = line.strip() + if stripped.startswith('def http_request('): + in_http_request = True + elif stripped.startswith('def ') and in_http_request: + in_http_request = False + + if in_http_request: + continue + + # Match 'return r' but not 'return rows', 'return r.json()', etc. + if re.match(r'\s+return r\s*$', line): + violations.append(f"Line {i}: {stripped}") + + assert not violations, ( + f"Found {len(violations)} function(s) still returning bare Response object:\n" + + "\n".join(violations) + ) + + +def test_no_bare_except_in_main(): + """main() should not have a bare except: clause for raw output.""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, 'r', encoding='utf-8', errors='replace') as f: + source = f.read() + + # Check that res.json() fallback is gone from main's raw handler + # The old pattern was: try: json.dumps(res) except: json.dumps(res.json()) + # After fix: just json.dumps(res) with no try/except + assert 'res.json()' not in source, ( + "Found res.json() fallback in main(). " + "All returns are JSON-serializable and the fallback is dead code." + ) diff --git a/tests/regression/test_safe_dict_access.py b/tests/regression/test_safe_dict_access.py new file mode 100644 index 00000000..fc285a5e --- /dev/null +++ b/tests/regression/test_safe_dict_access.py @@ -0,0 +1,237 @@ +"""Safe dict access on API response dicts. + +The bug: 60+ locations accessed API response dicts with rj["key"], +r.json()["key"], or result["key"] which raises KeyError if the API +response format changes or an endpoint returns unexpected data. + +The fix: Convert all API response dict accesses to .get() with +appropriate defaults: + - Boolean checks: rj.get("success") -- None is falsy + - Messages: rj.get("msg", "Unknown error") -- fallback text + - Iterable data: rj.get("offers", []) -- empty list for iteration + - Required fields: rj.get("result_url") with explicit error check +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import argparse +import pytest +from unittest.mock import MagicMock, patch + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +# --------------------------------------------------------------------------- # +# Lint-style tests # +# --------------------------------------------------------------------------- # + +class TestMinimalRawDictAccess: + """Ensure almost no raw rj['key'] access patterns remain on API data.""" + + def test_minimal_rj_bracket_access(self): + """Count rj["..."] patterns -- should be zero after the fix.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + raw_accesses = re.findall(r'rj\["[^"]+"\]', content) + assert len(raw_accesses) == 0, ( + f"Found {len(raw_accesses)} raw rj[\"key\"] accesses; " + f"expected 0. Convert to rj.get('key', default). " + f"Matches: {raw_accesses[:5]}" + ) + + def test_minimal_rj_single_quote_access(self): + """Count rj['...'] patterns -- should be zero after the fix.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + raw_accesses = re.findall(r"rj\['[^']+'\]", content) + assert len(raw_accesses) == 0, ( + f"Found {len(raw_accesses)} raw rj['key'] accesses; " + f"expected 0. Convert to rj.get('key', default). " + f"Matches: {raw_accesses[:5]}" + ) + + def test_minimal_r_json_bracket_access(self): + """Count r.json()["..."] patterns (excluding comments).""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + lines = f.readlines() + raw_accesses = [] + for i, line in enumerate(lines, 1): + stripped = line.lstrip() + if stripped.startswith('#'): + continue + matches = re.findall(r'\.json\(\)\["[^"]+"\]', line) + for m in matches: + raw_accesses.append(f"line {i}: {m}") + assert len(raw_accesses) == 0, ( + f"Found {len(raw_accesses)} raw .json()[\"key\"] accesses in " + f"non-comment lines; expected 0. Matches: {raw_accesses[:5]}" + ) + + def test_high_safe_access_count(self): + """Verify a high number of .get() patterns exist.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + rj_gets = len(re.findall(r'rj\.get\(', content)) + json_gets = len(re.findall(r'\.json\(\)\.get\(', content)) + result_gets = len(re.findall(r'result\.get\(', content)) + total = rj_gets + json_gets + result_gets + assert total >= 60, ( + f"Found only {total} safe .get() accesses on API response dicts " + f"(rj.get: {rj_gets}, .json().get: {json_gets}, result.get: {result_gets}); " + f"expected >= 60 after safe dict access conversion." + ) + + +# --------------------------------------------------------------------------- # +# Functional tests: missing "success" key # +# --------------------------------------------------------------------------- # + +class TestMissingSuccessKey: + """Verify functions handle missing 'success' key without KeyError.""" + + @patch('vast.http_put') + def test_prepay_instance_no_success_key(self, mock_put, capsys): + """prepay__instance should not crash if 'success' key is missing.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + # API returns dict WITHOUT 'success' key + mock_response.json.return_value = {"some_other_field": 123} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False, + id=12345, amount=10.0 + ) + + # Should not raise KeyError + vast.prepay__instance(args) + + captured = capsys.readouterr() + # Since success is missing (falsy), it should print the error branch + assert "Unknown error" in captured.out + + @patch('vast.api_call') + def test_label_instance_no_success_key(self, mock_api_call, capsys): + """label__instance should not crash if 'success' key is missing.""" + import vast + + # api_call returns a dict without 'success' key + mock_api_call.return_value = {"some_other_field": 123} + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False, + id=12345, label="test-label" + ) + + # Should not raise KeyError + vast.label__instance(args) + + captured = capsys.readouterr() + # Since success is missing (falsy), it should print the error branch + assert "Unknown error" in captured.out + + +# --------------------------------------------------------------------------- # +# Functional tests: missing "msg" key # +# --------------------------------------------------------------------------- # + +class TestMissingMsgKey: + """Verify functions handle missing 'msg' key by printing fallback.""" + + @patch('vast.http_put') + def test_prepay_failure_no_msg(self, mock_put, capsys): + """When API returns success=False without msg, print 'Unknown error'.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": False} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False, + id=999, amount=5.0 + ) + + vast.prepay__instance(args) + + captured = capsys.readouterr() + assert "Unknown error" in captured.out + + @patch('vast.http_post') + def test_create_overlay_no_msg(self, mock_post, capsys): + """create__overlay should print 'Unknown error' when msg key missing.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"status": "ok"} # no "msg" key + mock_post.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False, + name="test-overlay", cluster_id=123 + ) + + vast.create__overlay(args) + + captured = capsys.readouterr() + assert "Unknown error" in captured.out + + +# --------------------------------------------------------------------------- # +# Functional tests: missing data extraction keys # +# --------------------------------------------------------------------------- # + +class TestMissingDataKeys: + """Verify functions handle missing data keys gracefully.""" + + @patch('vast.http_get') + def test_show_instances_missing_instances_key(self, mock_get, capsys): + """show__instances should handle missing 'instances' key without crash.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": True} # no "instances" + mock_get.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=True, explain=False, + quiet=False + ) + + # Should return empty list, not KeyError + result = vast.show__instances(args) + assert result == [] + + @patch('vast.api_call') + def test_show_volumes_missing_volumes_key(self, mock_api_call, capsys): + """show__volumes should handle missing 'volumes' key without crash.""" + import vast + + # api_call returns a dict without 'volumes' key + mock_api_call.return_value = {"success": True} + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=True, explain=False, + quiet=False, type="all" + ) + + # Should return empty list, not KeyError + result = vast.show__volumes(args) + assert result == [] or result is None diff --git a/tests/regression/test_sdk_exception_handling.py b/tests/regression/test_sdk_exception_handling.py new file mode 100644 index 00000000..03af6887 --- /dev/null +++ b/tests/regression/test_sdk_exception_handling.py @@ -0,0 +1,59 @@ +"""Regression test: SDK wrapper must not use bare `except: pass`. + +The bug: The SDK wrapper at sdk.py has `except: pass` which catches +SystemExit (preventing CLI functions from exiting), KeyboardInterrupt +(preventing Ctrl+C), and all errors (returning empty string ''). + +The fix: Catch SystemExit separately (capture exit code), catch Exception +with logging, let KeyboardInterrupt propagate. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re + + +def test_no_bare_except_in_sdk(): + """sdk.py should not contain bare `except:` clauses.""" + sdk_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vastai', 'sdk.py') + with open(sdk_path, 'r') as f: + lines = f.readlines() + + bare_excepts = [] + for i, line in enumerate(lines, 1): + stripped = line.strip() + # Match `except:` but not `except SomeException:` or `except (A, B):` + if re.match(r'^except\s*:\s*$', stripped): + bare_excepts.append(f"Line {i}: {stripped}") + + assert not bare_excepts, ( + f"Found {len(bare_excepts)} bare except: clause(s) in sdk.py:\n" + + "\n".join(bare_excepts) + + "\nUse specific exception types instead." + ) + + +def test_sdk_catches_system_exit(): + """sdk.py catches SystemExit separately from other exceptions.""" + sdk_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vastai', 'sdk.py') + with open(sdk_path, 'r') as f: + source = f.read() + + assert 'except SystemExit' in source, ( + "sdk.py should catch SystemExit separately to handle CLI sys.exit() calls. " + "Expected: `except SystemExit as e:`" + ) + + +def test_sdk_does_not_catch_keyboard_interrupt(): + """sdk.py should NOT catch KeyboardInterrupt (Ctrl+C must work).""" + sdk_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vastai', 'sdk.py') + with open(sdk_path, 'r') as f: + source = f.read() + + # If it catches BaseException, that would include KeyboardInterrupt + assert 'except BaseException' not in source, ( + "sdk.py catches BaseException which includes KeyboardInterrupt. " + "Use `except Exception` instead to allow Ctrl+C to work." + ) diff --git a/tests/regression/test_sdk_integration.py b/tests/regression/test_sdk_integration.py new file mode 100644 index 00000000..7354d8a4 --- /dev/null +++ b/tests/regression/test_sdk_integration.py @@ -0,0 +1,456 @@ +""" +Regression tests for SDK integration. + +These tests verify the SDK wrapper integrates correctly with the live vast module +and supports all documented features, including method resolution, argument passing, +and output capture. +""" +import sys +import warnings +from io import StringIO +from pathlib import Path +from unittest.mock import MagicMock, patch + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestSDKLiveModuleImport: + """SDK must import from live vast module, not frozen copy.""" + + def test_sentinel_attribute_visible(self): + """Adding sentinel to vast module should be visible through SDK.""" + import vast + + # Add sentinel attribute + vast._SDK_TEST_SENTINEL = "phase6_integration_test" + + # Import SDK internals + from vastai import sdk + + # Verify SDK's imported vast module sees the sentinel + assert hasattr(sdk._vast, '_SDK_TEST_SENTINEL'), \ + "SDK's vast module should see runtime-added attributes" + assert sdk._vast._SDK_TEST_SENTINEL == "phase6_integration_test", \ + "Sentinel value should match" + + # Cleanup + del vast._SDK_TEST_SENTINEL + + def test_parser_from_vast_module(self): + """Parser should be imported from vast module.""" + import vast + from vastai import sdk + + # The parser used by SDK should be the same object as vast.parser + assert sdk.parser is vast.parser, \ + "SDK parser should be the same object as vast.parser" + + def test_apikey_file_from_vast_module(self): + """APIKEY_FILE should be imported from vast module.""" + import vast + from vastai import sdk + + # APIKEY_FILE should match + assert sdk.APIKEY_FILE == vast.APIKEY_FILE, \ + "SDK APIKEY_FILE should match vast.APIKEY_FILE" + + +class TestSDKFeatureCompleteness: + """VastAI class must support all documented features.""" + + def test_instantiation_with_api_key(self): + """VastAI can be instantiated with api_key parameter.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key_12345") + + assert sdk.api_key == "test_key_12345" + assert sdk._creds == "CODE" # API key provided in code + + def test_raw_mode_default(self): + """VastAI defaults to raw=True for SDK usage.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + assert sdk.raw is True, "SDK should default to raw=True" + + def test_raw_mode_toggle(self): + """VastAI raw mode can be toggled.""" + from vastai import VastAI + + sdk_raw = VastAI(api_key="test_key", raw=True) + sdk_human = VastAI(api_key="test_key", raw=False) + + assert sdk_raw.raw is True + assert sdk_human.raw is False + + def test_server_url_parameter(self): + """VastAI accepts server_url parameter.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key", server_url="https://custom.vast.ai") + + assert sdk.server_url == "https://custom.vast.ai" + + def test_retry_parameter(self): + """VastAI accepts retry parameter.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key", retry=5) + + assert sdk.retry == 5 + + def test_explain_parameter(self): + """VastAI accepts explain parameter.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key", explain=True) + + assert sdk.explain is True + + def test_quiet_parameter(self): + """VastAI accepts quiet parameter.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key", quiet=True) + + assert sdk.quiet is True + + def test_imported_methods_populated(self): + """VastAI should have imported_methods dict populated.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + assert hasattr(sdk, 'imported_methods') + assert isinstance(sdk.imported_methods, dict) + # Should have many methods (vast.py has 115+ commands) + assert len(sdk.imported_methods) > 50, \ + f"Expected 50+ methods, got {len(sdk.imported_methods)}" + + def test_dynamic_method_binding(self): + """VastAI should have methods dynamically bound from vast.py.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + # Check some well-known methods exist + assert hasattr(sdk, 'search_offers'), "search_offers should be bound" + assert hasattr(sdk, 'show_instances'), "show_instances should be bound" + assert hasattr(sdk, 'show_machines'), "show_machines should be bound" + assert callable(sdk.search_offers), "search_offers should be callable" + + def test_workergroup_and_autoscaler_aliases(self): + """Both workergroup and autoscaler/autogroup aliases should work.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + # Workergroup naming (primary method names from CLI) + assert hasattr(sdk, 'create_workergroup') or 'create_workergroup' in sdk.imported_methods, \ + "create_workergroup should exist" + assert hasattr(sdk, 'show_workergroups') or 'show_workergroups' in sdk.imported_methods, \ + "show_workergroups should exist" + assert hasattr(sdk, 'delete_workergroup') or 'delete_workergroup' in sdk.imported_methods, \ + "delete_workergroup should exist" + assert hasattr(sdk, 'update_workergroup') or 'update_workergroup' in sdk.imported_methods, \ + "update_workergroup should exist" + + # Autoscaler/autogroup backwards compatibility aliases + # Base class provides aliases: create_autogroup, delete_autoscaler, show_autoscalers, update_autoscaler + assert hasattr(sdk, 'create_autogroup'), \ + "create_autogroup alias should exist" + assert hasattr(sdk, 'show_autoscalers'), \ + "show_autoscalers alias should exist" + assert hasattr(sdk, 'delete_autoscaler'), \ + "delete_autoscaler alias should exist" + assert hasattr(sdk, 'update_autoscaler'), \ + "update_autoscaler alias should exist" + + +class TestSDKMethodCoverage: + """Verify SDK method coverage against CLI commands.""" + + def test_method_count_matches_cli_commands(self): + """SDK should have methods for most CLI commands.""" + from vastai import VastAI + import vast + + sdk = VastAI(api_key="test_key") + + # Count CLI commands (functions with double underscore) + cli_commands = [ + name for name in dir(vast) + if callable(getattr(vast, name)) + and '__' in name + and not name.startswith('_') + ] + + # SDK should have at least 80% coverage + # (some commands like 'help' are excluded) + min_expected = int(len(cli_commands) * 0.80) + actual_count = len(sdk.imported_methods) + + assert actual_count >= min_expected, \ + f"SDK has {actual_count} methods but CLI has {len(cli_commands)} commands. " \ + f"Expected at least {min_expected} methods." + + def test_all_critical_methods_exist(self): + """SDK must have all commonly-used methods.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + critical_methods = [ + 'search_offers', + 'create_instance', + 'destroy_instance', + 'show_instances', + 'start_instance', + 'stop_instance', + 'show_machines', + 'logs', + 'execute', + 'copy', + 'show_user', + ] + + for method in critical_methods: + assert hasattr(sdk, method) or method in sdk.imported_methods, \ + f"Critical method '{method}' missing from SDK" + + +class TestSDKMethodExecution: + """Verify SDK methods execute through the vast module correctly.""" + + def test_show_instances_method_exists_and_callable(self): + """show_instances method should exist and be callable.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + # Method should exist as callable + assert hasattr(sdk, 'show_instances'), "show_instances should be bound" + assert callable(sdk.show_instances), "show_instances should be callable" + + def test_search_offers_method_exists_and_callable(self): + """search_offers method should exist and be callable.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + assert hasattr(sdk, 'search_offers'), "search_offers should be bound" + assert callable(sdk.search_offers), "search_offers should be callable" + + def test_method_binding_uses_vast_functions(self): + """SDK methods should be bound from vast module functions.""" + from vastai import VastAI + import vast + + sdk = VastAI(api_key="test_key") + + # The method should be in imported_methods and callable + if 'show_instances' in sdk.imported_methods: + # The bound method comes from vast.show__instances + vast_func_name = 'show__instances' + assert hasattr(vast, vast_func_name), \ + f"vast.{vast_func_name} should exist for SDK to bind" + + def test_method_returns_callable_not_none(self): + """SDK methods should return callable functions, not None.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + # Check several methods to ensure they're properly bound + methods_to_check = ['search_offers', 'show_instances', 'show_machines', 'show_user'] + + for method_name in methods_to_check: + method = getattr(sdk, method_name, None) + assert method is not None, f"{method_name} should not be None" + assert callable(method), f"{method_name} should be callable" + + +class TestSdkMethodResolution: + """SDK wrapper method resolution tests.""" + + def test_search_offers_method_exists(self): + """SDK has search_offers method.""" + from vastai import VastAI + sdk = VastAI(api_key="test-key") + assert hasattr(sdk, 'search_offers') + assert callable(sdk.search_offers) + + def test_show_instances_method_exists(self): + """SDK has show_instances method.""" + from vastai import VastAI + sdk = VastAI(api_key="test-key") + assert hasattr(sdk, 'show_instances') + assert callable(sdk.show_instances) + + def test_create_instance_method_exists(self): + """SDK has create_instance method.""" + from vastai import VastAI + sdk = VastAI(api_key="test-key") + assert hasattr(sdk, 'create_instance') + assert callable(sdk.create_instance) + + def test_destroy_instance_method_exists(self): + """SDK has destroy_instance method.""" + from vastai import VastAI + sdk = VastAI(api_key="test-key") + assert hasattr(sdk, 'destroy_instance') + assert callable(sdk.destroy_instance) + + def test_method_resolution_is_consistent(self): + """Same method should resolve identically across multiple SDK instances.""" + from vastai import VastAI + sdk1 = VastAI(api_key="test-key-1") + sdk2 = VastAI(api_key="test-key-2") + + # Both instances should have same method names in imported_methods + assert sdk1.imported_methods.keys() == sdk2.imported_methods.keys() + + +class TestSdkArgumentPassing: + """SDK wrapper argument passing tests.""" + + def test_api_key_stored_in_instance(self): + """SDK stores api_key in instance for use in requests.""" + from vastai import VastAI + + sdk = VastAI(api_key="test-api-key-12345") + + # API key should be accessible + assert sdk.api_key == "test-api-key-12345" + + def test_retry_parameter_passed(self): + """SDK retry parameter is stored and accessible.""" + from vastai import VastAI + + sdk = VastAI(api_key="test-key", retry=10) + assert sdk.retry == 10 + + def test_server_url_parameter_passed(self): + """SDK server_url parameter is stored and accessible.""" + from vastai import VastAI + + sdk = VastAI(api_key="test-key", server_url="https://custom.vast.ai") + assert sdk.server_url == "https://custom.vast.ai" + + def test_explain_parameter_passed(self): + """SDK explain parameter is stored and accessible.""" + from vastai import VastAI + + sdk = VastAI(api_key="test-key", explain=True) + assert sdk.explain is True + + def test_quiet_parameter_passed(self): + """SDK quiet parameter is stored and accessible.""" + from vastai import VastAI + + sdk = VastAI(api_key="test-key", quiet=True) + assert sdk.quiet is True + + +class TestSdkOutputCapture: + """SDK wrapper output capture tests.""" + + def test_sdk_instance_has_output_capture_mechanism(self): + """SDK should have mechanism for capturing output.""" + from vastai import VastAI + + sdk = VastAI(api_key="test-key", raw=True) + + # SDK uses raw=True by default to return data instead of printing + # This verifies the mechanism exists + assert sdk.raw is True + + def test_sdk_raw_mode_returns_data_type(self): + """SDK in raw mode should be configured to return data.""" + from vastai import VastAI + + sdk = VastAI(api_key="test-key", raw=True) + + # Verify the SDK is configured correctly for raw output + assert hasattr(sdk, 'raw') + assert sdk.raw is True + # Raw mode means CLI functions return data instead of printing + + def test_sdk_non_raw_mode_available(self): + """SDK can be set to non-raw mode for human-readable output.""" + from vastai import VastAI + + sdk = VastAI(api_key="test-key", raw=False) + assert sdk.raw is False + + +class TestSdkBackwardsCompatibility: + """SDK backwards compatibility tests.""" + + def test_autogroup_alias_exists(self): + """SDK has autogroup alias methods for backwards compatibility.""" + from vastai import VastAI + sdk = VastAI(api_key="test-key") + + # Both workergroup and autogroup/autoscaler names should work + if hasattr(sdk, 'show_workergroups') or 'show_workergroups' in sdk.imported_methods: + # Autoscaler aliases from base class + assert hasattr(sdk, 'show_autoscalers'), "show_autoscalers alias should exist" + + def test_autoscaler_crud_aliases_exist(self): + """SDK has all autoscaler CRUD aliases for backwards compatibility.""" + from vastai import VastAI + sdk = VastAI(api_key="test-key") + + # These aliases are defined in vastai_base.py + assert hasattr(sdk, 'create_autogroup'), "create_autogroup alias should exist" + assert hasattr(sdk, 'show_autoscalers'), "show_autoscalers alias should exist" + assert hasattr(sdk, 'delete_autoscaler'), "delete_autoscaler alias should exist" + assert hasattr(sdk, 'update_autoscaler'), "update_autoscaler alias should exist" + + def test_old_import_path_warning(self): + """Importing from vastai_sdk works (with or without deprecation warning).""" + # The vastai_sdk module may or may not emit a deprecation warning + # depending on implementation. Either behavior is acceptable. + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + try: + from vastai_sdk import VastAI as OldVastAI + # Import succeeded - either with warning or without + # Both are acceptable for backwards compatibility + assert OldVastAI is not None + except ImportError: + # If vastai_sdk shim doesn't exist, that's also acceptable + # as long as the main vastai import works + from vastai import VastAI + assert VastAI is not None + + def test_primary_import_path_works(self): + """Primary import path 'from vastai import VastAI' works.""" + from vastai import VastAI + sdk = VastAI(api_key="test-key") + assert sdk is not None + assert hasattr(sdk, 'api_key') + + +class TestSdkMethodDocstrings: + """SDK methods have docstrings for IDE autocomplete.""" + + def test_sdk_methods_have_docstrings(self): + """SDK methods should have docstrings for IDE support.""" + from vastai import VastAI + sdk = VastAI(api_key="test-key") + + # Check some key methods have docstrings + methods_to_check = ['search_offers', 'show_instances', 'create_instance'] + + for method_name in methods_to_check: + if hasattr(sdk, method_name): + method = getattr(sdk, method_name) + # Method should be callable and have docstring + assert callable(method), f"{method_name} should be callable" + # Docstring may come from vast.py function or be added by SDK + # Either is acceptable as long as method works diff --git a/tests/regression/test_sdk_naming_and_async.py b/tests/regression/test_sdk_naming_and_async.py new file mode 100644 index 00000000..35fc6baa --- /dev/null +++ b/tests/regression/test_sdk_naming_and_async.py @@ -0,0 +1,232 @@ +"""Regression tests for SDK naming consistency, typo fixes, and async migration.""" +import re +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestNamingConsistency: + """Both workergroup and autogroup names should exist.""" + + def test_workergroup_methods_exist(self): + """SDK should have workergroup method names.""" + base_path = Path(__file__).parent.parent.parent / "vastai" / "vastai_base.py" + content = base_path.read_text() + + # Should have workergroup methods + workergroup_methods = re.findall(r'def\s+\w*workergroup\w*\s*\(', content, re.IGNORECASE) + assert len(workergroup_methods) > 0, "Should have workergroup methods" + + def test_autogroup_aliases_exist(self): + """SDK should have autogroup aliases for backwards compatibility.""" + base_path = Path(__file__).parent.parent.parent / "vastai" / "vastai_base.py" + content = base_path.read_text() + + # Should have autogroup names (either as primary or alias) + has_autogroup = "autogroup" in content.lower() + assert has_autogroup, "Should have autogroup names for backwards compatibility" + + def test_sdk_has_both_names(self): + """VastAI instance should have both naming conventions.""" + try: + from vastai import VastAI + v = VastAI(api_key="test") + + # Check for at least one workergroup and autogroup method + methods = dir(v) + has_workergroup = any("workergroup" in m.lower() for m in methods) + has_autogroup = any("autogroup" in m.lower() for m in methods) + + # At least one should exist (may not have full set in base) + assert has_workergroup or has_autogroup, "Should have group-related methods" + except ImportError: + # If import fails, just check file content + pass + + def test_all_workergroup_methods_have_aliases(self): + """Each workergroup method should have an autogroup alias.""" + base_path = Path(__file__).parent.parent.parent / "vastai" / "vastai_base.py" + content = base_path.read_text() + + # Find all workergroup method definitions + workergroup_defs = re.findall(r'def\s+(\w*workergroup\w*)\s*\(', content, re.IGNORECASE) + + # Each should have a corresponding alias comment or assignment + for method in workergroup_defs: + # Check for alias assignment pattern: xxx_autogroup = xxx_workergroup or xxx_autoscaler = xxx_workergroup + alias_patterns = [ + method.lower().replace("workergroup", "autogroup"), + method.lower().replace("workergroup", "autoscaler"), + ] + has_alias = any(p in content.lower() for p in alias_patterns) + assert has_alias, f"Method {method} should have a backwards compatibility alias" + + def test_naming_matches_cli(self): + """SDK method names should match CLI command naming.""" + # CLI uses: create__workergroup, delete__workergroup, show__workergroups, update__workergroup + # SDK should have: create_workergroup, delete_workergroup, show_workergroups, update_workergroup + base_path = Path(__file__).parent.parent.parent / "vastai" / "vastai_base.py" + content = base_path.read_text() + + expected_methods = [ + "create_workergroup", + "delete_workergroup", + "show_workergroups", + "update_workergroup", + ] + + for method in expected_methods: + pattern = rf'def\s+{method}\s*\(' + match = re.search(pattern, content) + assert match, f"Should have method {method} matching CLI command" + + +class TestTypoFixes: + """Response dict should have correct spelling.""" + + def test_no_reuqest_typo(self): + """Should not have 'reuqest' typo anywhere.""" + client_path = Path(__file__).parent.parent.parent / "vastai" / "serverless" / "client" / "client.py" + + if client_path.exists(): + content = client_path.read_text() + assert "reuqest" not in content, "Found 'reuqest' typo - should be 'request'" + + def test_request_idx_correct(self): + """Should have 'request_idx' with correct spelling.""" + client_path = Path(__file__).parent.parent.parent / "vastai" / "serverless" / "client" / "client.py" + + if client_path.exists(): + content = client_path.read_text() + # If request_idx is used, it should be spelled correctly + if "request_idx" in content or "reuqest_idx" in content: + assert "request_idx" in content, "Should have correctly spelled request_idx" + assert "reuqest_idx" not in content, "Should not have reuqest_idx typo" + + def test_response_dict_keys_spelled_correctly(self): + """Response dict keys should all be spelled correctly.""" + client_path = Path(__file__).parent.parent.parent / "vastai" / "serverless" / "client" / "client.py" + + if client_path.exists(): + content = client_path.read_text() + # Find the response dict assignment + if '"request_idx"' in content: + # Should have correct key + assert '"request_idx"' in content, "Response dict should have request_idx key" + + +class TestAsyncMigration: + """_fetch_pubkey should use proper HTTP, not subprocess curl.""" + + def test_no_subprocess_curl(self): + """Should not use subprocess to call curl.""" + backend_path = Path(__file__).parent.parent.parent / "vastai" / "serverless" / "server" / "lib" / "backend.py" + + if backend_path.exists(): + content = backend_path.read_text() + + # Should not have subprocess curl pattern + has_subprocess_curl = "subprocess.check_output" in content and "curl" in content + assert not has_subprocess_curl, "Should not use subprocess curl - use requests or aiohttp" + + def test_no_subprocess_import(self): + """Should not import subprocess (no longer needed).""" + backend_path = Path(__file__).parent.parent.parent / "vastai" / "serverless" / "server" / "lib" / "backend.py" + + if backend_path.exists(): + content = backend_path.read_text() + + # Check for subprocess import at module level + import_pattern = r'^import subprocess\s*$' + has_import = re.search(import_pattern, content, re.MULTILINE) + assert not has_import, "Should not import subprocess - no longer needed" + + def test_uses_requests_or_aiohttp(self): + """_fetch_pubkey should use requests or aiohttp.""" + backend_path = Path(__file__).parent.parent.parent / "vastai" / "serverless" / "server" / "lib" / "backend.py" + + if backend_path.exists(): + content = backend_path.read_text() + + # Should use requests (for sync) or aiohttp (for async) + uses_requests = "requests.get" in content + uses_aiohttp = "ClientSession" in content and "aiohttp" in content + + # _fetch_pubkey should use proper HTTP library + if "_fetch_pubkey" in content: + assert uses_requests or uses_aiohttp, "Should use requests or aiohttp for _fetch_pubkey" + + def test_has_async_variant(self): + """Should have async variant of _fetch_pubkey using aiohttp.""" + backend_path = Path(__file__).parent.parent.parent / "vastai" / "serverless" / "server" / "lib" / "backend.py" + + if backend_path.exists(): + content = backend_path.read_text() + + # Should have async version for use in async contexts + has_async = "async def _fetch_pubkey_async" in content + assert has_async, "Should have async variant _fetch_pubkey_async using aiohttp" + + def test_fetch_pubkey_has_timeout(self): + """_fetch_pubkey should have timeout to prevent hanging.""" + backend_path = Path(__file__).parent.parent.parent / "vastai" / "serverless" / "server" / "lib" / "backend.py" + + if backend_path.exists(): + content = backend_path.read_text() + + # Both sync and async versions should have timeout + assert "timeout" in content, "Should have timeout in HTTP requests" + + +class TestServerlessImports: + """Basic serverless module tests.""" + + def test_serverless_client_importable(self): + """Serverless client should be importable without aiohttp at vastai level.""" + # This tests that lazy imports work + try: + # Just import vastai - should not fail even without aiohttp + import vastai + assert vastai is not None + except ImportError as e: + if "aiohttp" in str(e): + # This is expected if aiohttp not installed + pass + else: + raise + + def test_vastai_base_importable(self): + """VastAIBase should be importable.""" + from vastai.vastai_base import VastAIBase + assert VastAIBase is not None + + def test_base_class_has_workergroup_methods(self): + """VastAIBase should define workergroup methods.""" + from vastai.vastai_base import VastAIBase + + # Check class attributes + assert hasattr(VastAIBase, "create_workergroup"), "Should have create_workergroup" + assert hasattr(VastAIBase, "delete_workergroup"), "Should have delete_workergroup" + assert hasattr(VastAIBase, "show_workergroups"), "Should have show_workergroups" + assert hasattr(VastAIBase, "update_workergroup"), "Should have update_workergroup" + + def test_base_class_has_autogroup_aliases(self): + """VastAIBase should have autogroup/autoscaler aliases for backwards compatibility.""" + from vastai.vastai_base import VastAIBase + + # Check aliases exist + assert hasattr(VastAIBase, "create_autogroup"), "Should have create_autogroup alias" + assert hasattr(VastAIBase, "delete_autoscaler"), "Should have delete_autoscaler alias" + assert hasattr(VastAIBase, "show_autoscalers"), "Should have show_autoscalers alias" + assert hasattr(VastAIBase, "update_autoscaler"), "Should have update_autoscaler alias" + + def test_aliases_point_to_same_method(self): + """Aliases should point to the same underlying method.""" + from vastai.vastai_base import VastAIBase + + # Aliases should be the same function + assert VastAIBase.create_autogroup is VastAIBase.create_workergroup + assert VastAIBase.delete_autoscaler is VastAIBase.delete_workergroup + assert VastAIBase.show_autoscalers is VastAIBase.show_workergroups + assert VastAIBase.update_autoscaler is VastAIBase.update_workergroup diff --git a/tests/regression/test_sdk_timezone_import.py b/tests/regression/test_sdk_timezone_import.py new file mode 100644 index 00000000..41ee4af7 --- /dev/null +++ b/tests/regression/test_sdk_timezone_import.py @@ -0,0 +1,56 @@ +"""Regression test: SDK copy must import timezone from datetime. + +The bug: vastai/vast.py imports date, datetime, timedelta but NOT +timezone. Any code using datetime.timezone.utc or timezone(...) fails with +NameError or ImportError. + +The fix: Add timezone to the import statement. +""" +import sys +import os + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +_SDK_VAST_PATH = os.path.join( + os.path.dirname(__file__), '..', '..', '..', 'vast-sdk', 'vastai', 'vast.py' +) + + +@pytest.mark.skipif( + not os.path.isfile(_SDK_VAST_PATH), + reason="vast-sdk sibling directory not present" +) +def test_frozen_sdk_has_timezone_import(): + """The frozen SDK copy imports timezone from datetime.""" + # Read the frozen SDK file directly + with open(_SDK_VAST_PATH, 'r', encoding='utf-8', errors='replace') as f: + source = f.read() + + assert 'timezone' in source, ( + "vast-sdk/vastai/vast.py does not import 'timezone' from datetime. " + "Add it: from datetime import date, datetime, timedelta, timezone" + ) + + # Also verify the import line specifically + import_found = False + for line in source.split('\n'): + if line.startswith('from datetime import') and 'timezone' in line: + import_found = True + break + + assert import_found, ( + "Could not find 'from datetime import ... timezone ...' in vast-sdk/vastai/vast.py" + ) + + +def test_live_vast_has_timezone_import(): + """The live vast.py already has the timezone import (sanity check).""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, 'r') as f: + for line in f: + if line.startswith('from datetime import') and 'timezone' in line: + return # Found it, test passes + + assert False, "Live vast.py missing timezone import" diff --git a/tests/regression/test_sdk_wrapper_safety.py b/tests/regression/test_sdk_wrapper_safety.py new file mode 100644 index 00000000..e1949fa6 --- /dev/null +++ b/tests/regression/test_sdk_wrapper_safety.py @@ -0,0 +1,189 @@ +"""Regression tests for SDK wrapper safety. + +- stdout capture must use finally block for guaranteed restoration +- sys.exit() from CLI functions should return exit code, not crash SDK +- Exception handling should use specific types, not bare except +""" +import io +import re +import sys +from pathlib import Path +from unittest.mock import patch, MagicMock + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestStdoutCapture: + """stdout capture must use finally block.""" + + def test_finally_block_exists(self): + """SDK wrapper should have finally block for stdout restoration.""" + sdk_path = Path(__file__).parent.parent.parent / "vastai" / "sdk.py" + content = sdk_path.read_text() + + # Should have finally: followed by stdout restoration + assert "finally:" in content, "SDK should have finally block" + + # The finally block should restore stdout + finally_pattern = re.search(r'finally:.*?sys\.stdout\s*=', content, re.DOTALL) + assert finally_pattern is not None, "finally block should restore sys.stdout" + + def test_stdout_restoration_guaranteed(self): + """Verify stdout restoration code is in finally block, not outside.""" + sdk_path = Path(__file__).parent.parent.parent / "vastai" / "sdk.py" + content = sdk_path.read_text() + + # Check that stdout restoration is inside finally + finally_match = re.search( + r'finally:\s*\n\s*#.*restore stdout.*\n\s*if out_o is not None.*:\s*\n\s*sys\.stdout = out_o', + content, + re.IGNORECASE + ) + assert finally_match is not None, "stdout restoration should be inside finally block" + + def test_stdout_restored_after_sdk_init(self): + """sys.stdout should be the original after SDK operations.""" + from vastai import VastAI + + original_stdout = sys.stdout + + # Create SDK instance + v = VastAI(api_key="test_key_12345") + + # stdout should still be the original + assert sys.stdout is original_stdout, "stdout was corrupted after SDK init" + + +class TestSysExitHandling: + """sys.exit() should be caught and converted.""" + + def test_systemexit_caught_separately(self): + """SystemExit should have its own except block.""" + sdk_path = Path(__file__).parent.parent.parent / "vastai" / "sdk.py" + content = sdk_path.read_text() + + # Should catch SystemExit separately + assert "except SystemExit" in content, "Should catch SystemExit separately" + + def test_exit_code_extracted(self): + """sys.exit() code should be extracted, not re-raised.""" + sdk_path = Path(__file__).parent.parent.parent / "vastai" / "sdk.py" + content = sdk_path.read_text() + + # Should access e.code + assert "e.code" in content, "Should extract exit code from SystemExit" + + def test_systemexit_before_general_exception(self): + """SystemExit should be caught before general Exception.""" + sdk_path = Path(__file__).parent.parent.parent / "vastai" / "sdk.py" + content = sdk_path.read_text() + + # Find positions of both except blocks + systemexit_pos = content.find("except SystemExit") + exception_pos = content.find("except Exception") + + assert systemexit_pos != -1, "SystemExit handler not found" + assert exception_pos != -1, "Exception handler not found" + assert systemexit_pos < exception_pos, "SystemExit should be caught before general Exception" + + +class TestExceptionSpecificity: + """Exception handling should be specific, not broad.""" + + def test_no_bare_except(self): + """Should not have bare 'except:' without exception type.""" + sdk_path = Path(__file__).parent.parent.parent / "vastai" / "sdk.py" + content = sdk_path.read_text() + + # Find bare except: (not except Something:) + lines = content.split('\n') + bare_excepts = [] + for i, line in enumerate(lines, 1): + stripped = line.strip() + # Match "except:" but not "except Something:" or "except (A, B):" + if stripped == "except:" or stripped.startswith("except: "): + bare_excepts.append(f"Line {i}: {line}") + + assert len(bare_excepts) == 0, f"Found bare except: {bare_excepts}" + + def test_exception_handlers_have_types(self): + """except blocks should specify exception types.""" + sdk_path = Path(__file__).parent.parent.parent / "vastai" / "sdk.py" + content = sdk_path.read_text() + + # Find all except statements + except_lines = re.findall(r'^\s*except\s+.*:', content, re.MULTILINE) + + # All should have a type (not bare except:) + for line in except_lines: + stripped = line.strip() + # Allow "except Type:" or "except (A, B):" or "except Type as e:" + assert stripped != "except:", f"Found bare except: {line}" + # Verify it has some type specification + type_match = re.match(r'except\s+[\w\(\),\s]+', stripped) + assert type_match, f"Exception handler should have type: {line}" + + def test_specific_exceptions_in_queryformatter(self): + """queryFormatter should use specific exceptions (KeyError, TypeError).""" + sdk_path = Path(__file__).parent.parent.parent / "vastai" / "sdk.py" + content = sdk_path.read_text() + + # The queryFormatter function should have specific exceptions + assert "except (KeyError, TypeError):" in content, \ + "queryFormatter should catch specific KeyError and TypeError" + + +class TestSDKImport: + """Basic SDK import and instantiation tests.""" + + def test_import_vastai(self): + """from vastai import VastAI should work.""" + from vastai import VastAI + assert VastAI is not None + + def test_instantiate_with_api_key(self): + """VastAI(api_key='test') should not raise.""" + from vastai import VastAI + v = VastAI(api_key="test_key_12345") + assert v is not None + assert hasattr(v, "api_key") + + def test_sdk_has_expected_attributes(self): + """SDK instance should have standard attributes.""" + from vastai import VastAI + v = VastAI(api_key="test_key_12345") + + # Core attributes + assert hasattr(v, "api_key") + assert hasattr(v, "server_url") + assert hasattr(v, "retry") + assert hasattr(v, "raw") + assert hasattr(v, "last_output") + + def test_sdk_last_output_initialized(self): + """SDK should initialize last_output to None.""" + from vastai import VastAI + v = VastAI(api_key="test_key_12345") + assert v.last_output is None + + +class TestStdoutCaptureIntegration: + """Integration tests for stdout capture behavior.""" + + def test_stdout_not_leaked_on_multiple_operations(self): + """Multiple SDK operations should not leak stdout state.""" + from vastai import VastAI + + original_stdout = sys.stdout + v = VastAI(api_key="test_key_12345") + + # Perform multiple operations (even if they fail due to no network) + for _ in range(3): + try: + # These may fail due to no API key/network, but stdout should be safe + v.show_user() + except Exception: + pass + + # Check stdout is still original + assert sys.stdout is original_stdout, "stdout leaked after SDK operation" diff --git a/tests/regression/test_serverless_lazy_import.py b/tests/regression/test_serverless_lazy_import.py new file mode 100644 index 00000000..dc17499a --- /dev/null +++ b/tests/regression/test_serverless_lazy_import.py @@ -0,0 +1,225 @@ +""" +Regression tests for serverless lazy imports. + +These tests verify that serverless classes don't load heavy dependencies +until they are actually accessed. +""" +import sys +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestLazyImportConfiguration: + """Verify lazy import mechanism is configured correctly.""" + + def test_lazy_imports_dict_exists(self): + """vastai/__init__.py should have _LAZY_IMPORTS dict.""" + import vastai + + assert hasattr(vastai, '_LAZY_IMPORTS'), \ + "vastai/__init__.py should have _LAZY_IMPORTS dict" + assert isinstance(vastai._LAZY_IMPORTS, dict), \ + "_LAZY_IMPORTS should be a dict" + + def test_lazy_imports_contains_serverless_classes(self): + """_LAZY_IMPORTS should map serverless class names to modules.""" + import vastai + + expected_classes = [ + 'Serverless', + 'ServerlessRequest', + 'Endpoint', + 'Worker', + 'WorkerConfig', + 'HandlerConfig', + ] + + for cls_name in expected_classes: + assert cls_name in vastai._LAZY_IMPORTS, \ + f"'{cls_name}' should be in _LAZY_IMPORTS" + + def test_getattr_function_exists(self): + """vastai/__init__.py should have __getattr__ for PEP 562.""" + import vastai + + # PEP 562: Module-level __getattr__ is exposed in the module namespace + assert hasattr(vastai, '__getattr__'), \ + "vastai/__init__.py should have __getattr__ function" + assert callable(vastai.__getattr__), \ + "__getattr__ should be callable" + + def test_getattr_raises_for_unknown_attribute(self): + """__getattr__ should raise AttributeError for unknown names.""" + import vastai + + try: + _ = vastai.NonExistentClass + assert False, "Should have raised AttributeError" + except AttributeError as e: + assert "NonExistentClass" in str(e) + + +class TestBasicVastAIImport: + """Basic vastai import should work without serverless deps.""" + + def test_vastai_import_succeeds(self): + """Basic import vastai should succeed.""" + import vastai + + assert vastai is not None + + def test_vastai_class_available(self): + """VastAI class should be directly importable.""" + from vastai import VastAI + + assert VastAI is not None + + def test_vastai_instantiation_works(self): + """VastAI can be instantiated without serverless deps.""" + from vastai import VastAI + + sdk = VastAI(api_key="test_key") + + assert sdk is not None + assert sdk.api_key == "test_key" + + def test_all_exports_defined(self): + """__all__ should list all public exports.""" + import vastai + + assert hasattr(vastai, '__all__'), \ + "vastai should have __all__ defined" + assert 'VastAI' in vastai.__all__, \ + "VastAI should be in __all__" + assert 'Serverless' in vastai.__all__, \ + "Serverless should be in __all__ (even if lazy)" + + +class TestServerlessClassAccess: + """Test that serverless classes can be accessed when deps are available.""" + + def test_serverless_in_all(self): + """Serverless classes should be listed in __all__.""" + import vastai + + serverless_classes = [ + 'Serverless', + 'ServerlessRequest', + 'Endpoint', + 'Worker', + 'WorkerConfig', + 'HandlerConfig', + 'LogActionConfig', + 'BenchmarkConfig', + ] + + for cls_name in serverless_classes: + assert cls_name in vastai.__all__, \ + f"'{cls_name}' should be in __all__" + + def test_serverless_import_path_correct(self): + """Lazy import paths should point to correct modules.""" + import vastai + + # Serverless and ServerlessRequest should be in client.client + assert '.serverless.client.client' in vastai._LAZY_IMPORTS['Serverless'], \ + "Serverless should be in serverless.client.client" + + # Worker classes should be in server.worker + assert '.serverless.server.worker' in vastai._LAZY_IMPORTS['Worker'], \ + "Worker should be in serverless.server.worker" + + +class TestAiohttpFreeEnvironment: + """Test behavior in environment without aiohttp.""" + + def test_aiohttp_not_imported_on_basic_import(self): + """Importing vastai should not import aiohttp.""" + # Clear any cached imports + import sys + + # Note: This test documents expected behavior + # The actual lazy import only delays the import until class access + # So if aiohttp is installed, this test just verifies the mechanism + + # Check that vast and vastai can be imported + if 'vastai' in sys.modules: + del sys.modules['vastai'] + if 'vastai.sdk' in sys.modules: + del sys.modules['vastai.sdk'] + + # Record aiohttp import state before + aiohttp_before = 'aiohttp' in sys.modules + + # Import vastai + import vastai + from vastai import VastAI + + # Create SDK instance + sdk = VastAI(api_key="test") + + # Note: We can't truly test aiohttp-free behavior if aiohttp is installed + # This test documents that the lazy import mechanism is in place + assert True, "Basic vastai import and VastAI instantiation succeeded" + + def test_lazy_import_delays_module_load(self): + """Serverless modules should not be in sys.modules after basic import.""" + import sys + + # Clear caches + modules_to_clear = [ + 'vastai.serverless.client.client', + 'vastai.serverless.server.worker', + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + # Import only VastAI + from vastai import VastAI + + # Check serverless modules are not yet loaded + # (This may vary depending on test order, so we just document behavior) + sdk = VastAI(api_key="test") + + # The SDK should work without serverless modules being loaded + assert sdk is not None + + +class TestServerlessIntegration: + """Test serverless framework integration.""" + + def test_serverless_class_loadable(self): + """Serverless class should be loadable (when deps available).""" + try: + from vastai import Serverless + + # If aiohttp is available, this should work + assert Serverless is not None + except ImportError as e: + # If aiohttp not available, should get ImportError from the module + assert 'aiohttp' in str(e).lower() or 'No module' in str(e), \ + f"Unexpected import error: {e}" + + def test_worker_class_loadable(self): + """Worker class should be loadable (when deps available).""" + try: + from vastai import Worker + + assert Worker is not None + except ImportError as e: + # Expected if dependencies not available + assert 'aiohttp' in str(e).lower() or 'No module' in str(e), \ + f"Unexpected import error: {e}" + + def test_endpoint_class_loadable(self): + """Endpoint class should be loadable.""" + try: + from vastai import Endpoint + + assert Endpoint is not None + except ImportError as e: + # May fail if deps not available + pass # Expected behavior diff --git a/tests/regression/test_shell_injection.py b/tests/regression/test_shell_injection.py new file mode 100644 index 00000000..40556aa0 --- /dev/null +++ b/tests/regression/test_shell_injection.py @@ -0,0 +1,116 @@ +"""No shell=True subprocess calls in vast.py. + +The bug: subprocess calls with shell=True are vulnerable to command injection +(CWE-78). User-controlled data (paths, instance IDs, addresses) could be +injected into shell commands. Additionally, subprocess.getoutput() implicitly +uses shell=True. + +The fix: Convert all shell=True calls to argument lists. Replace +subprocess.getoutput("echo $HOME") with os.path.expanduser("~"). Convert +get_update_command() to return a list instead of a string. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import pytest + + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +class TestNoShellTrue: + """Lint-style tests verifying no shell=True remains in vast.py.""" + + def test_no_shell_true(self): + """No shell=True subprocess calls should exist in vast.py. + + shell=True enables command injection when user-controlled data + is passed to subprocess calls. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + assert 'shell=True' not in content, ( + "Found shell=True in vast.py. All subprocess calls must use " + "argument lists instead of shell command strings." + ) + + def test_no_subprocess_getoutput(self): + """No subprocess.getoutput() calls should exist in vast.py. + + subprocess.getoutput() implicitly uses shell=True and is + vulnerable to command injection. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + assert 'subprocess.getoutput' not in content, ( + "Found subprocess.getoutput in vast.py. Use os.path.expanduser " + "or subprocess.run with argument lists instead." + ) + + +class TestGetUpdateCommand: + """Verify get_update_command() returns a list, not a string.""" + + def test_returns_list_for_pip(self): + """get_update_command() should return a list when is_pip_package.""" + import vast + from unittest.mock import patch + + with patch.object(vast, 'is_pip_package', return_value=True): + result = vast.get_update_command("1.2.3") + assert isinstance(result, list), ( + f"get_update_command() returned {type(result).__name__}, expected list" + ) + assert all(isinstance(item, str) for item in result), ( + "All elements of the command list must be strings" + ) + assert "vastai==1.2.3" in result, ( + "Command list should include version-pinned package name" + ) + + def test_returns_list_for_git(self): + """get_update_command() should return a list when not pip.""" + import vast + from unittest.mock import patch + + with patch.object(vast, 'is_pip_package', return_value=False): + result = vast.get_update_command("1.2.3") + assert isinstance(result, list), ( + f"get_update_command() returned {type(result).__name__}, expected list" + ) + assert all(isinstance(item, str) for item in result), ( + "All elements of the command list must be strings" + ) + assert "git" in result, "Git command list should contain 'git'" + + def test_pip_command_has_no_shell_operators(self): + """The pip command list should not contain shell operators.""" + import vast + from unittest.mock import patch + + # Shell operators that indicate command chaining/injection + shell_operators = ['&&', '||', '|', ';', '>', '<', '$(', '`'] + with patch.object(vast, 'is_pip_package', return_value=True): + result = vast.get_update_command("1.2.3") + combined = " ".join(result) + for op in shell_operators: + assert op not in combined, ( + f"Shell operator {op!r} found in pip command: {combined!r}" + ) + + def test_git_command_has_no_shell_operators(self): + """The git command list should not contain && or | operators.""" + import vast + from unittest.mock import patch + + with patch.object(vast, 'is_pip_package', return_value=False): + result = vast.get_update_command("1.2.3") + combined = " ".join(result) + assert "&&" not in combined, ( + "Git command should not contain '&&' -- use separate subprocess calls" + ) + assert "|" not in combined, ( + "Git command should not contain pipe operators" + ) diff --git a/tests/regression/test_show_instances.py b/tests/regression/test_show_instances.py new file mode 100644 index 00000000..fad8df97 --- /dev/null +++ b/tests/regression/test_show_instances.py @@ -0,0 +1,90 @@ +"""show__instances() loop rebinds local variable without updating list. + +The bug: `for row in rows: row = {...}` rebinds the local `row` variable to a +new dict, but the original dict in `rows` is unchanged. The stripped strings +and computed duration are lost. + +The fix: Build a new list and reassign rows. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +import time +from unittest.mock import MagicMock, patch + + +def test_rows_are_modified_after_loop(): + """After show__instances processes rows, the returned rows have modified data.""" + from vast import show__instances + + mock_instances = [ + { + "id": 12345, + "start_date": time.time() - 3600, # started 1 hour ago + "extra_env": [["KEY1", "val1"], ["KEY2", "val2"]], + "status": "running", + "name": " test ", # has leading/trailing spaces + } + ] + + mock_response = MagicMock() + mock_response.json.return_value = {"instances": mock_instances} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + args = argparse.Namespace( + api_key="test", url="https://console.vast.ai", + retry=3, raw=True, explain=False, quiet=False, + curl=False, full=False, + ) + + with patch('vast.apiurl', return_value="https://console.vast.ai/api/v0/instances"), \ + patch('vast.http_get', return_value=mock_response): + result = show__instances(args, extra={}) + + # In raw mode, show__instances returns rows + assert result is not None, "show__instances returned None in raw mode" + assert len(result) > 0, "show__instances returned empty list" + row = result[0] + # The row should have 'duration' field computed from the loop + assert 'duration' in row, "Row missing 'duration' -- loop rebinding bug still present" + assert row['duration'] > 0, f"Duration should be positive, got {row['duration']}" + # extra_env should be converted from list-of-pairs to dict + assert isinstance(row['extra_env'], dict), "extra_env not converted to dict" + assert row['extra_env'].get('KEY1') == 'val1' + + +def test_rows_stripped_strings_preserved(): + """Verify strip_strings is applied and preserved in the returned rows.""" + from vast import show__instances + + mock_instances = [ + { + "id": 99, + "start_date": time.time() - 100, + "extra_env": [], + "status": " running ", + } + ] + + mock_response = MagicMock() + mock_response.json.return_value = {"instances": mock_instances} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + args = argparse.Namespace( + api_key="test", url="https://console.vast.ai", + retry=3, raw=True, explain=False, quiet=False, + curl=False, full=False, + ) + + with patch('vast.apiurl', return_value="https://console.vast.ai/api/v0/instances"), \ + patch('vast.http_get', return_value=mock_response): + result = show__instances(args, extra={}) + + assert result is not None + row = result[0] + # strip_strings should have trimmed the status value + assert row['status'] == 'running', f"Expected 'running', got '{row['status']}' -- strip not applied" diff --git a/tests/regression/test_show_machine.py b/tests/regression/test_show_machine.py new file mode 100644 index 00000000..f75d0b56 --- /dev/null +++ b/tests/regression/test_show_machine.py @@ -0,0 +1,58 @@ +"""show__machine() doesn't handle single dict response. + +The bug: api_call returns a dict for single-machine queries. The code +iterates `for row in rows` which iterates over dict KEYS (strings like 'id', +'gpu_name') instead of dicts. display_table also expects a list of dicts. + +The fix: Wrap single dict responses in a list. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +from unittest.mock import patch, MagicMock + + +def test_single_dict_response_handled(): + """show__machine wraps a single dict response in a list.""" + from vast import show__machine + + single_machine = { + "id": 42, + "gpu_name": "RTX 4090", + "num_gpus": 1, + } + + args = argparse.Namespace( + api_key="test", url="https://console.vast.ai", + retry=3, raw=True, explain=False, quiet=False, + curl=False, id=42, + ) + + with patch('vast.api_call', return_value=single_machine): + result = show__machine(args) + + # In raw mode, should return a list (wrapped dict) + assert isinstance(result, list), f"Expected list, got {type(result)}" + assert len(result) == 1 + assert result[0]['id'] == 42 + + +def test_list_response_unchanged(): + """show__machine leaves list responses as-is.""" + from vast import show__machine + + machine_list = [{"id": 42, "gpu_name": "RTX 4090"}] + + args = argparse.Namespace( + api_key="test", url="https://console.vast.ai", + retry=3, raw=True, explain=False, quiet=False, + curl=False, id=42, + ) + + with patch('vast.api_call', return_value=machine_list): + result = show__machine(args) + + assert isinstance(result, list) + assert len(result) == 1 diff --git a/tests/regression/test_timeout.py b/tests/regression/test_timeout.py new file mode 100644 index 00000000..330d994c --- /dev/null +++ b/tests/regression/test_timeout.py @@ -0,0 +1,129 @@ +"""Missing timeout on HTTP requests. + +The bug: session.send() has no timeout parameter. Requests can hang +indefinitely if the server never responds or the connection stalls. + +The fix: Add timeout=DEFAULT_TIMEOUT (30s) to http_request() and forward +it to session.send(). All wrapper functions (http_get, http_post, http_put, +http_del) accept and forward the timeout parameter. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +from unittest.mock import MagicMock, patch, call +import pytest + + +def _make_args(retry=1): + return argparse.Namespace(retry=retry, curl=False) + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_default_timeout_forwarded(mock_session_cls, mock_sleep): + """http_request passes timeout=30 (DEFAULT_TIMEOUT) to session.send by default.""" + from vast import http_request, DEFAULT_TIMEOUT + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + success_response = MagicMock() + success_response.status_code = 200 + mock_session.send.return_value = success_response + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=1) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + # Verify timeout was passed to session.send + mock_session.send.assert_called_once() + _, kwargs = mock_session.send.call_args + assert kwargs.get('timeout') == DEFAULT_TIMEOUT + assert kwargs.get('timeout') == 30 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_custom_timeout_forwarded(mock_session_cls, mock_sleep): + """http_request forwards a custom timeout value to session.send.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + success_response = MagicMock() + success_response.status_code = 200 + mock_session.send.return_value = success_response + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=1) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test', timeout=120) + + assert result.status_code == 200 + _, kwargs = mock_session.send.call_args + assert kwargs.get('timeout') == 120 + + +@patch('vast.http_request') +def test_http_get_forwards_timeout(mock_http_request): + """http_get forwards timeout parameter to http_request.""" + from vast import http_get + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_get(args, 'http://example.com/test', timeout=60) + + mock_http_request.assert_called_once() + _, kwargs = mock_http_request.call_args + assert kwargs.get('timeout') == 60 + + +@patch('vast.http_request') +def test_http_post_forwards_timeout(mock_http_request): + """http_post forwards timeout parameter to http_request.""" + from vast import http_post + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_post(args, 'http://example.com/test', timeout=90) + + mock_http_request.assert_called_once() + _, kwargs = mock_http_request.call_args + assert kwargs.get('timeout') == 90 + + +@patch('vast.http_request') +def test_http_put_forwards_timeout(mock_http_request): + """http_put forwards timeout parameter to http_request.""" + from vast import http_put + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_put(args, 'http://example.com/test', timeout=45) + + mock_http_request.assert_called_once() + _, kwargs = mock_http_request.call_args + assert kwargs.get('timeout') == 45 + + +@patch('vast.http_request') +def test_http_del_forwards_timeout(mock_http_request): + """http_del forwards timeout parameter to http_request.""" + from vast import http_del + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_del(args, 'http://example.com/test', timeout=15) + + mock_http_request.assert_called_once() + _, kwargs = mock_http_request.call_args + assert kwargs.get('timeout') == 15 diff --git a/tests/regression/test_timezone_handling.py b/tests/regression/test_timezone_handling.py new file mode 100644 index 00000000..2bddcad0 --- /dev/null +++ b/tests/regression/test_timezone_handling.py @@ -0,0 +1,41 @@ +"""Timezone handling uses time.mktime() which interprets as local time. + +The bug: time.mktime() converts a time tuple to epoch using the LOCAL timezone. +For a user in PST (UTC-8), a date "01/15/2025" would produce an epoch value +that's 8 hours later than UTC midnight, giving wrong results. + +The fix: Use calendar.timegm() which always interprets the time tuple as UTC. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import calendar + + +def test_string_to_unix_epoch_utc(): + """string_to_unix_epoch returns UTC timestamps regardless of local timezone.""" + from vast import string_to_unix_epoch + + # 01/15/2025 00:00:00 UTC = 1736899200 + result = string_to_unix_epoch("01/15/2025") + expected = calendar.timegm((2025, 1, 15, 0, 0, 0, 0, 0, 0)) + assert expected == 1736899200, f"Sanity check: expected 1736899200, got {expected}" + assert result == expected, ( + f"string_to_unix_epoch('01/15/2025') returned {result}, expected {expected}. " + f"This likely means time.mktime() is still being used instead of calendar.timegm()." + ) + + +def test_string_to_unix_epoch_returns_float_passthrough(): + """string_to_unix_epoch returns float values as-is.""" + from vast import string_to_unix_epoch + + assert string_to_unix_epoch("1736899200") == 1736899200.0 + + +def test_string_to_unix_epoch_none(): + """string_to_unix_epoch returns None for None input.""" + from vast import string_to_unix_epoch + + assert string_to_unix_epoch(None) is None diff --git a/tests/regression/test_unreachable_code.py b/tests/regression/test_unreachable_code.py new file mode 100644 index 00000000..f70e3b93 --- /dev/null +++ b/tests/regression/test_unreachable_code.py @@ -0,0 +1,227 @@ +"""No unreachable code after raise_for_status(). + +The bug: 19 functions had `if (r.status_code == 200):` guards immediately +after `r.raise_for_status()`. Since raise_for_status() raises HTTPError for +non-2xx responses, the status_code check was always True and the `else` branch +(printing "failed with error {r.status_code}") was unreachable dead code. + +The fix: Remove the redundant status_code == 200 wrapper, de-indent the +success path, and remove the unreachable else branches. API-level success +checks (rj.get("success")) are preserved since those check the JSON body, +not the HTTP status. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import argparse +import pytest +from unittest import mock +from unittest.mock import MagicMock, patch +from io import StringIO + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +# --------------------------------------------------------------------------- # +# Lint-style tests # +# --------------------------------------------------------------------------- # + +class TestNoUnreachableStatusCheck: + """Ensure no status_code == 200 checks follow raise_for_status().""" + + def test_no_status_check_after_raise_for_status(self): + """Pattern: raise_for_status() followed within 3 lines by status_code == 200.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + # Match raise_for_status() followed (within a few lines) by status_code == 200 + pattern = r'raise_for_status\(\)\s*\n(?:\s*\n)*\s*if\s*\(?r\.status_code\s*==\s*200' + matches = re.findall(pattern, content) + assert len(matches) == 0, ( + f"Found {len(matches)} unreachable status_code == 200 checks after " + f"raise_for_status(). These are unreachable and should be removed." + ) + + def test_no_status_check_after_raise_for_status_response_var(self): + """Same pattern but with 'response' variable name.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + pattern = r'raise_for_status\(\)\s*\n(?:\s*\n)*\s*if\s*\(?response\.status_code\s*==\s*200' + matches = re.findall(pattern, content) + assert len(matches) == 0, ( + f"Found {len(matches)} unreachable status_code == 200 checks (response var) " + f"after raise_for_status()." + ) + + def test_no_unreachable_failed_with_error_after_unconditional_raise(self): + """The 'failed with error' message should not appear after an + *unconditional* raise_for_status() call (same indentation as function body). + Conditional raise_for_status() calls (inside if blocks) are excluded.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + lines = f.readlines() + + # Find unconditional raise_for_status lines (indented exactly 4 spaces, + # i.e., at function-body level, not nested inside an if/else) + rfs_indices = [] + for i, line in enumerate(lines): + stripped = line.rstrip() + if 'raise_for_status()' in stripped and not stripped.lstrip().startswith('#'): + # Check indentation: unconditional means base function indent (4 spaces) + indent = len(line) - len(line.lstrip()) + if indent == 4: + rfs_indices.append(i) + + for rfs_idx in rfs_indices: + # Check within 20 lines after raise_for_status for the dead pattern + for offset in range(1, 20): + check_idx = rfs_idx + offset + if check_idx >= len(lines): + break + line = lines[check_idx] + # If we hit a new function def or decorator, stop scanning + if re.match(r'^def\s+', line) or re.match(r'^@', line): + break + if 'failed with error {r.status_code}' in line: + assert False, ( + f"Line {check_idx + 1}: Found unreachable 'failed with error' " + f"message after unconditional raise_for_status() at line {rfs_idx + 1}" + ) + + +# --------------------------------------------------------------------------- # +# Functional tests # +# --------------------------------------------------------------------------- # + +class TestStartInstanceBehavior: + """Verify start_instance works correctly after unreachable code removal.""" + + @patch('vast.http_put') + def test_success_prints_message(self, mock_put, capsys): + """When API returns 200 with success=True, print starting message.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": True} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + result = vast.start_instance(12345, args) + + assert result is True + captured = capsys.readouterr() + assert "starting instance" in captured.out + + @patch('vast.http_put') + def test_api_failure_prints_msg(self, mock_put, capsys): + """When API returns 200 with success=False, print the error msg.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": False, "msg": "instance not found"} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + result = vast.start_instance(12345, args) + + assert result is True + captured = capsys.readouterr() + assert "instance not found" in captured.out + + @patch('vast.http_put') + def test_http_error_raises(self, mock_put): + """When API returns 500, raise_for_status raises HTTPError.""" + import vast + import requests + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + "500 Server Error" + ) + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + with pytest.raises(requests.exceptions.HTTPError): + vast.start_instance(12345, args) + + @patch('vast.http_put') + def test_missing_msg_uses_default(self, mock_put, capsys): + """When API returns success=False without msg, use default error.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": False} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + vast.start_instance(12345, args) + + captured = capsys.readouterr() + assert "Unknown error" in captured.out + + +class TestDestroyInstanceBehavior: + """Verify destroy_instance preserves raw mode path after refactor.""" + + @patch('vast.http_del') + def test_raw_mode_returns_json(self, mock_del): + """In raw mode, destroy_instance returns parsed JSON.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": True} + mock_del.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=True, explain=False + ) + + result = vast.destroy_instance(12345, args) + assert result == {"success": True} + + @patch('vast.http_del') + def test_non_raw_prints_destroying(self, mock_del, capsys): + """In non-raw mode, destroy_instance prints 'destroying instance'.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": True} + mock_del.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + vast.destroy_instance(12345, args) + + captured = capsys.readouterr() + assert "destroying instance" in captured.out diff --git a/tests/regression/test_utc_display.py b/tests/regression/test_utc_display.py new file mode 100644 index 00000000..0d31ca32 --- /dev/null +++ b/tests/regression/test_utc_display.py @@ -0,0 +1,97 @@ +"""UTC-labeled timestamps must actually display UTC. + +The bug: datetime.fromtimestamp(ts) without tz= returns LOCAL time, but +the columns are labeled "UTC" or the output is treated as UTC. Users in +non-UTC timezones see wrong times. + +The fix: All fromtimestamp() calls that produce UTC-labeled output now +pass tz=timezone.utc explicitly. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import pytest +from datetime import datetime, timezone + + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +class TestUnixToReadableUTC: + """Verify unix_to_readable() produces UTC output.""" + + def test_known_epoch_returns_utc_midnight(self): + """1704067200 is 2024-01-01 00:00:00 UTC. Output must show 00:00:00.""" + from vast import unix_to_readable + result = unix_to_readable(1704067200) + assert "00:00:00" in result, ( + f"unix_to_readable(1704067200) should show 00:00:00 UTC, got: {result}" + ) + + def test_known_epoch_returns_correct_date(self): + """1704067200 is 2024-01-01 UTC. Output must contain Jan-01-2024.""" + from vast import unix_to_readable + result = unix_to_readable(1704067200) + assert "Jan-01-2024" in result, ( + f"unix_to_readable(1704067200) should contain Jan-01-2024, got: {result}" + ) + + def test_midday_epoch_shows_utc_time(self): + """1704110400 is 2024-01-01 12:00:00 UTC. Output must show 12:00:00.""" + from vast import unix_to_readable + result = unix_to_readable(1704110400) + assert "12:00:00" in result, ( + f"unix_to_readable(1704110400) should show 12:00:00 UTC, got: {result}" + ) + + +class TestFromtimestampCallsUseUTC: + """Lint-style test: every fromtimestamp() call must use tz=timezone.utc.""" + + def test_all_fromtimestamp_calls_have_tz_utc(self): + """Every fromtimestamp( call in vast.py should include tz=timezone.utc. + + This prevents regressions where a new fromtimestamp() call is added + without the timezone parameter, silently producing local time. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + lines = f.readlines() + + violations = [] + for i, line in enumerate(lines, 1): + # Skip comments + stripped = line.strip() + if stripped.startswith('#'): + continue + # Find fromtimestamp( calls + if 'fromtimestamp(' in line and 'utcfromtimestamp' not in line: + if 'tz=timezone.utc' not in line: + violations.append(f"Line {i}: {stripped}") + + assert len(violations) == 0, ( + f"Found {len(violations)} fromtimestamp() calls without tz=timezone.utc:\n" + + "\n".join(violations) + ) + + +class TestCacheAgeUsesAwareDatetimes: + """Verify cache age calculation uses timezone-aware datetimes on both sides.""" + + def test_cache_age_uses_tz_aware_now(self): + """datetime.now() in cache age must use tz=timezone.utc.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + # Find the cache_age line + match = re.search(r'cache_age\s*=\s*(.+)', content) + assert match is not None, "cache_age assignment not found in vast.py" + + cache_age_line = match.group(1) + assert 'datetime.now(tz=timezone.utc)' in cache_age_line, ( + f"cache_age should use datetime.now(tz=timezone.utc), got: {cache_age_line}" + ) + assert 'fromtimestamp(' in cache_age_line and 'tz=timezone.utc' in cache_age_line, ( + f"cache_age should use fromtimestamp with tz=timezone.utc, got: {cache_age_line}" + ) diff --git a/tests/regression/test_utcfromtimestamp.py b/tests/regression/test_utcfromtimestamp.py new file mode 100644 index 00000000..4d5dd255 --- /dev/null +++ b/tests/regression/test_utcfromtimestamp.py @@ -0,0 +1,93 @@ +"""No deprecated utcfromtimestamp() calls in vast.py. + +The bug: datetime.utcfromtimestamp() is deprecated since Python 3.12 and +will be removed in a future version. It also returns a naive datetime that +is ambiguous about its timezone. + +The fix: Replace all utcfromtimestamp() calls with +datetime.fromtimestamp(ts, tz=timezone.utc), which returns an aware +datetime and is not deprecated. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import pytest +from datetime import datetime, timezone + + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +class TestNoUtcfromtimestamp: + """Lint-style tests verifying utcfromtimestamp is not used anywhere.""" + + def test_no_utcfromtimestamp_in_source(self): + """utcfromtimestamp should not appear anywhere in vast.py. + + This deprecated method returns naive datetimes and will be removed + in a future Python version. Use fromtimestamp(ts, tz=timezone.utc). + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + matches = re.findall(r'utcfromtimestamp', content) + assert len(matches) == 0, ( + f"Found {len(matches)} occurrences of utcfromtimestamp in vast.py. " + "Use datetime.fromtimestamp(ts, tz=timezone.utc) instead." + ) + + def test_no_utcnow_in_source(self): + """utcnow() is also deprecated for the same reason as utcfromtimestamp. + + Prevent regression by ensuring neither deprecated UTC method is used. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + # Match utcnow() but not in comments + lines = content.split('\n') + violations = [] + for i, line in enumerate(lines, 1): + stripped = line.strip() + if stripped.startswith('#'): + continue + if 'utcnow()' in line: + violations.append(f"Line {i}: {stripped}") + + assert len(violations) == 0, ( + f"Found {len(violations)} occurrences of utcnow() in vast.py. " + "Use datetime.now(tz=timezone.utc) instead.\n" + + "\n".join(violations) + ) + + +class TestReplacementProducesAwareDatetime: + """Verify the replacement fromtimestamp(ts, tz=timezone.utc) works correctly.""" + + def test_fromtimestamp_with_tz_returns_aware_datetime(self): + """datetime.fromtimestamp(ts, tz=timezone.utc) must return a tz-aware datetime.""" + ts = 1704067200 # 2024-01-01 00:00:00 UTC + dt = datetime.fromtimestamp(ts, tz=timezone.utc) + assert dt.tzinfo is not None, "Result should be timezone-aware" + assert dt.tzinfo == timezone.utc, "Result timezone should be UTC" + + def test_fromtimestamp_with_tz_correct_values(self): + """Verify the replacement produces correct year/month/day/hour.""" + ts = 1704067200 # 2024-01-01 00:00:00 UTC + dt = datetime.fromtimestamp(ts, tz=timezone.utc) + assert dt.year == 2024 + assert dt.month == 1 + assert dt.day == 1 + assert dt.hour == 0 + assert dt.minute == 0 + assert dt.second == 0 + + def test_fromtimestamp_with_tz_matches_expected_format(self): + """The replacement in schedule_maintenance should format correctly.""" + ts = 1704067200 # 2024-01-01 00:00:00 UTC + dt = datetime.fromtimestamp(ts, tz=timezone.utc) + # The schedule_maintenance function uses str(dt) implicitly in f-string + dt_str = str(dt) + assert "2024-01-01" in dt_str, f"Expected 2024-01-01 in {dt_str}" diff --git a/tests/smoke/__init__.py b/tests/smoke/__init__.py new file mode 100644 index 00000000..157ea08b --- /dev/null +++ b/tests/smoke/__init__.py @@ -0,0 +1 @@ +"""Smoke tests for CLI commands and standalone execution.""" diff --git a/tests/smoke/test_cli_commands.py b/tests/smoke/test_cli_commands.py new file mode 100644 index 00000000..dab965bf --- /dev/null +++ b/tests/smoke/test_cli_commands.py @@ -0,0 +1,429 @@ +"""Smoke tests for CLI commands. + +TEST-08: CLI smoke tests - major commands parse args and call correct endpoints. + +These tests verify that CLI commands: +1. Parse their arguments correctly +2. Call the expected API endpoints +3. Handle mocked responses appropriately + +All HTTP is mocked - no network calls are made. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +import json +import pytest +from unittest.mock import MagicMock, patch + + +def _make_base_args(**overrides): + """Create base args namespace with common fields.""" + args = argparse.Namespace( + api_key="test-key", + url="https://console.vast.ai", + retry=3, + raw=False, + explain=False, + quiet=False, + curl=False, + full=False, + no_color=True, + debugging=False, + ) + for key, value in overrides.items(): + setattr(args, key, value) + return args + + +@pytest.fixture(autouse=True) +def setup_vast_args(): + """Set up vast.ARGS to prevent NoneType errors in http_request.""" + import vast + old_args = vast.ARGS + vast.ARGS = _make_base_args() + yield + vast.ARGS = old_args + + +class TestSearchOffers: + """Smoke tests for 'search offers' command.""" + + @patch('vast.http_post') + def test_search_offers_calls_bundles_endpoint(self, mock_http_post): + """search offers calls /api/v0/bundles/ endpoint via POST.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"offers": []} + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.raise_for_status = MagicMock() + mock_http_post.return_value = mock_response + + args = _make_base_args( + query=["gpu_ram>=8"], + type="bid", + raw=True, + no_default=False, + new=False, + limit=None, + disable_bundling=False, + storage=5.0, + order="score-", + ) + vast.search__offers(args) + + mock_http_post.assert_called_once() + call_url = mock_http_post.call_args[0][1] + assert "/api/v0/bundles" in call_url + + @patch('vast.http_post') + def test_search_offers_with_gpu_name(self, mock_http_post): + """search offers parses gpu_name filter.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"offers": []} + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.raise_for_status = MagicMock() + mock_http_post.return_value = mock_response + + args = _make_base_args( + query=["gpu_name=RTX_4090"], + type="on-demand", + raw=True, + no_default=False, + new=False, + limit=None, + disable_bundling=False, + storage=5.0, + order="score-", + ) + vast.search__offers(args) + + mock_http_post.assert_called_once() + + +class TestShowInstances: + """Smoke tests for 'show instances' command.""" + + @patch('vast.http_get') + def test_show_instances_calls_instances_endpoint(self, mock_http_get): + """show instances calls /api/v0/instances/ endpoint.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"instances": []} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_get.return_value = mock_response + + args = _make_base_args(raw=True) + vast.show__instances(args) + + mock_http_get.assert_called_once() + call_url = mock_http_get.call_args[0][1] + assert "/api/v0/instances" in call_url + + @patch('vast.http_get') + def test_show_instances_returns_list(self, mock_http_get): + """show instances handles list response correctly.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = { + "instances": [ + {"id": 123, "status": "running", "start_date": 1700000000, "extra_env": []}, + {"id": 456, "status": "stopped", "start_date": 1700000000, "extra_env": []} + ] + } + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_get.return_value = mock_response + + args = _make_base_args(raw=True) + result = vast.show__instances(args) + + # In raw mode, should return the data + assert result is not None or mock_http_get.called + + +class TestShowMachines: + """Smoke tests for 'show machines' command.""" + + @patch('vast.api_call') + def test_show_machines_calls_machines_endpoint(self, mock_api_call): + """show machines calls /api/v0/machines/ endpoint.""" + import vast + + mock_api_call.return_value = {"machines": []} + + args = _make_base_args(raw=True, quiet=True) + vast.show__machines(args) + + mock_api_call.assert_called_once() + call_args = mock_api_call.call_args + assert call_args[0][1] == "GET" + assert "/machines" in call_args[0][2] + + +class TestShowUser: + """Smoke tests for 'show user' command.""" + + @patch('vast.api_call') + def test_show_user_calls_users_endpoint(self, mock_api_call): + """show user calls /api/v0/users/current endpoint.""" + import vast + + mock_api_call.return_value = {"id": 12345, "username": "testuser", "api_key": "secret"} + + args = _make_base_args(raw=True, quiet=True) + vast.show__user(args) + + mock_api_call.assert_called_once() + call_args = mock_api_call.call_args + assert call_args[0][1] == "GET" + assert "/users/current" in call_args[0][2] + + +class TestCreateInstance: + """Smoke tests for 'create instance' command.""" + + @patch('vast.http_put') + def test_create_instance_calls_asks_endpoint(self, mock_http_put): + """create instance calls /api/v0/asks/{id}/ endpoint.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"success": True, "new_contract": 789} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_put.return_value = mock_response + + args = _make_base_args( + id=12345, + bid_price=None, + disk=20.0, + image="pytorch/pytorch:latest", + raw=True, + onstart=None, + onstart_cmd=None, + entrypoint=None, + env=None, + args=None, + label=None, + extra=None, + jupyter=False, + jupyter_dir=None, + jupyter_lab=False, + lang_utf8=False, + python_utf8=False, + ssh=False, + direct=False, + cancel_unavail=False, + force=False, + login=None, + template_hash=None, + user=None, + create_volume=None, + link_volume=None, + ) + vast.create__instance(args) + + mock_http_put.assert_called_once() + call_url = mock_http_put.call_args[0][1] + assert "/api/v0/asks" in call_url + + +class TestDestroyInstance: + """Smoke tests for 'destroy instance' command.""" + + @patch('vast.http_del') + def test_destroy_instance_calls_instances_endpoint(self, mock_http_del): + """destroy instance calls DELETE /api/v0/instances/{id}/ endpoint.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"success": True} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_del.return_value = mock_response + + args = _make_base_args(id=12345, raw=True) + vast.destroy__instance(args) + + mock_http_del.assert_called_once() + call_url = mock_http_del.call_args[0][1] + assert "/api/v0/instances" in call_url + assert "12345" in call_url + + +class TestLogsCommand: + """Smoke tests for 'logs' command.""" + + @patch('vast.http_get') + @patch('vast.http_put') + def test_logs_calls_instances_endpoint(self, mock_http_put, mock_http_get): + """logs command calls appropriate endpoint.""" + import vast + + # Mock the logs request + mock_put_response = MagicMock() + mock_put_response.json.return_value = {"result_url": "https://example.com/logs"} + mock_put_response.status_code = 200 + mock_put_response.raise_for_status = MagicMock() + mock_http_put.return_value = mock_put_response + + # Mock the log fetch + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.text = "Log output here" + mock_http_get.return_value = mock_get_response + + args = _make_base_args( + INSTANCE_ID=123, + raw=True, + tail=None, + filter=None, + daemon_logs=False, + ) + + # logs may print and not return in non-raw mode + vast.logs(args) + + # Should have made HTTP calls + assert mock_http_put.called + + +class TestSetApiKey: + """Smoke tests for 'set api-key' command.""" + + @patch('builtins.open', create=True) + @patch('os.path.exists') + def test_set_api_key_writes_file(self, mock_exists, mock_open): + """set api-key writes key to config file.""" + import vast + + mock_exists.return_value = False + mock_file = MagicMock() + mock_open.return_value.__enter__ = MagicMock(return_value=mock_file) + mock_open.return_value.__exit__ = MagicMock(return_value=False) + + args = _make_base_args(new_api_key="sk-test-key-12345") + vast.set__api_key(args) + + mock_open.assert_called() + + +class TestShowApiKeys: + """Smoke tests for 'show api-keys' command.""" + + @patch('vast.http_get') + def test_show_api_keys_calls_auth_endpoint(self, mock_http_get): + """show api-keys calls /api/v0/auth/apikeys/ endpoint.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = [{"id": 1, "name": "test-key"}] + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_get.return_value = mock_response + + args = _make_base_args(raw=True) + result = vast.show__api_keys(args) + + mock_http_get.assert_called_once() + call_url = mock_http_get.call_args[0][1] + assert "/api/v0/auth/apikeys" in call_url + + +class TestParserStructure: + """Tests verifying the argparse parser structure.""" + + def test_parser_has_subcommands(self): + """Parser has expected subcommand structure.""" + import vast + + # Parser should exist + assert hasattr(vast, 'parser') + + # Should be able to parse help + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['--help']) + assert exc_info.value.code == 0 + + def test_search_offers_subcommand_exists(self): + """search offers subcommand is registered.""" + import vast + + # Should not raise - parse_args with --help raises SystemExit(0) + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['search', 'offers', '--help']) + assert exc_info.value.code == 0 + + def test_show_instances_subcommand_exists(self): + """show instances subcommand is registered.""" + import vast + + # Parse minimal args (should work) + args = vast.parser.parse_args(['show', 'instances']) + assert args is not None + + def test_create_instance_subcommand_exists(self): + """create instance subcommand is registered.""" + import vast + + # Should parse with required args + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['create', 'instance', '--help']) + assert exc_info.value.code == 0 + + def test_destroy_instance_subcommand_exists(self): + """destroy instance subcommand is registered.""" + import vast + + # Should parse with required args + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['destroy', 'instance', '--help']) + assert exc_info.value.code == 0 + + def test_show_user_subcommand_exists(self): + """show user subcommand is registered.""" + import vast + + args = vast.parser.parse_args(['show', 'user']) + assert args is not None + + def test_show_machines_subcommand_exists(self): + """show machines subcommand is registered.""" + import vast + + args = vast.parser.parse_args(['show', 'machines']) + assert args is not None + + def test_set_api_key_subcommand_exists(self): + """set api-key subcommand is registered.""" + import vast + + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['set', 'api-key', '--help']) + assert exc_info.value.code == 0 + + def test_logs_subcommand_exists(self): + """logs subcommand is registered.""" + import vast + + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['logs', '--help']) + assert exc_info.value.code == 0 + + def test_show_api_keys_subcommand_exists(self): + """show api-keys subcommand is registered.""" + import vast + + args = vast.parser.parse_args(['show', 'api-keys']) + assert args is not None diff --git a/tests/smoke/test_standalone.py b/tests/smoke/test_standalone.py new file mode 100644 index 00000000..918689cd --- /dev/null +++ b/tests/smoke/test_standalone.py @@ -0,0 +1,275 @@ +"""Smoke tests for standalone vast.py execution. + +TEST-09: Standalone vast.py smoke test - python vast.py --help works without pip dependencies. + +This test verifies that vast.py can be executed as a standalone script +with only the minimal dependencies (requests, python-dateutil) available. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import subprocess +import pytest + + +# Get the path to vast.py relative to this test file +VAST_CLI_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +VAST_PY_PATH = os.path.join(VAST_CLI_DIR, 'vast.py') + + +class TestStandaloneHelp: + """Tests for standalone vast.py --help execution.""" + + def test_vast_help_exits_zero(self): + """vast.py --help exits with code 0.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + + def test_vast_help_contains_usage(self): + """vast.py --help output contains usage information.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert 'usage' in result.stdout.lower() or 'vast' in result.stdout.lower(), \ + f"Help output missing expected content: {result.stdout[:500]}" + + def test_vast_help_contains_commands(self): + """vast.py --help output lists available commands.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + # Should mention some key commands + output = result.stdout.lower() + assert 'search' in output or 'show' in output or 'create' in output, \ + f"Help output missing command listings: {result.stdout[:500]}" + + +class TestSubcommandHelp: + """Tests for subcommand help output.""" + + def test_search_offers_help(self): + """vast.py search offers --help works.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'search', 'offers', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + assert 'search' in result.stdout.lower() or 'offers' in result.stdout.lower() + + def test_show_instances_help(self): + """vast.py show instances --help works.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'show', 'instances', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + + def test_create_instance_help(self): + """vast.py create instance --help works.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'create', 'instance', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + + def test_destroy_instance_help(self): + """vast.py destroy instance --help works.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'destroy', 'instance', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + + +class TestVersionFlag: + """Tests for version flag if implemented.""" + + def test_vast_version_or_help(self): + """vast.py responds to --version or --help without error.""" + # Try --version first + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--version'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + # If --version isn't implemented, that's okay - just verify it doesn't crash + # (might return non-zero for unrecognized flag, but shouldn't hang or throw) + assert result.returncode in [0, 1, 2], \ + f"Unexpected exit code {result.returncode}, stderr: {result.stderr}" + + +class TestInvalidCommand: + """Tests for invalid command handling.""" + + def test_invalid_command_exits_nonzero(self): + """vast.py with invalid command exits with non-zero code.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'not_a_real_command'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + # Should exit non-zero for invalid command + assert result.returncode != 0, "Invalid command should exit non-zero" + + def test_invalid_command_prints_error(self): + """vast.py with invalid command prints error message.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'not_a_real_command'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + # Should have some error output + combined_output = result.stdout + result.stderr + assert len(combined_output) > 0, "Should produce some output for invalid command" + + +class TestImportOnly: + """Tests that verify vast.py can be imported without side effects.""" + + def test_vast_importable(self): + """vast.py is importable as a module.""" + # This runs in the test process, verifying import works + import vast + + assert hasattr(vast, 'parser') + assert hasattr(vast, 'main') + + def test_vast_main_exists(self): + """vast.main() function exists.""" + import vast + + assert callable(vast.main) + + def test_vast_has_core_functions(self): + """vast module has core CLI functions.""" + import vast + + # Check for key command functions + assert hasattr(vast, 'search__offers') + assert hasattr(vast, 'show__instances') + assert hasattr(vast, 'create__instance') + assert hasattr(vast, 'destroy__instance') + + def test_vast_has_http_helpers(self): + """vast module has HTTP helper functions.""" + import vast + + assert hasattr(vast, 'http_get') + assert hasattr(vast, 'http_post') + assert hasattr(vast, 'http_put') + assert hasattr(vast, 'http_del') + + +class TestMinimalDependencies: + """Tests verifying vast.py works with minimal dependencies.""" + + def test_import_only_requires_requests(self): + """vast.py import only requires requests (and python-dateutil).""" + # Run a subprocess that only has requests available + # This is tested implicitly by the fact that we can import vast + # without extra dependencies installed + result = subprocess.run( + [sys.executable, '-c', 'import vast; print("OK")'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Import failed: {result.stderr}" + assert "OK" in result.stdout + + def test_help_runs_without_optional_deps(self): + """vast.py --help works even if optional dependencies are missing.""" + # argcomplete and curlify are optional + # This test verifies help still works + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0 + + +class TestExitCodes: + """Tests for proper exit codes.""" + + def test_help_exits_zero(self): + """--help exits with code 0.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + assert result.returncode == 0 + + def test_subcommand_help_exits_zero(self): + """Subcommand --help exits with code 0.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'show', 'instances', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + assert result.returncode == 0 + + def test_missing_required_arg_exits_nonzero(self): + """Missing required argument exits with non-zero code.""" + # create instance requires an ID + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'create', 'instance'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + assert result.returncode != 0 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..8ff8dac6 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for vast.py helper functions.""" diff --git a/tests/unit/test_http_helpers.py b/tests/unit/test_http_helpers.py new file mode 100644 index 00000000..3441f628 --- /dev/null +++ b/tests/unit/test_http_helpers.py @@ -0,0 +1,267 @@ +"""Unit tests for HTTP helper functions (api_call, http_*, retry logic). + +TEST-02: Unit tests for HTTP helper functions with mocked responses. + +These tests verify the helper functions in isolation, complementing the +regression tests that focus on specific bug fixes. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +import json +import pytest +from unittest.mock import MagicMock, patch, call +from requests.exceptions import ConnectionError, Timeout, HTTPError + + +def _make_args(retry=3, raw=False, explain=False, curl=False): + """Create minimal args namespace for testing.""" + return argparse.Namespace( + api_key="test-key", + url="https://console.vast.ai", + retry=retry, + raw=raw, + explain=explain, + quiet=False, + curl=curl, + ) + + +class TestApiCall: + """Tests for the api_call() helper function.""" + + @patch('vast.http_get') + def test_api_call_get_request(self, mock_http_get): + """api_call makes GET request and returns parsed JSON.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.json.return_value = {"offers": [{"id": 1}]} + mock_http_get.return_value = mock_response + + args = _make_args() + # api_call signature: api_call(args, method, path, *, json_body=None, query_args=None) + result = api_call(args, "GET", "/api/v0/bundles") + + assert result == {"offers": [{"id": 1}]} + mock_http_get.assert_called_once() + + @patch('vast.http_post') + def test_api_call_post_request(self, mock_http_post): + """api_call makes POST request with JSON body.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.json.return_value = {"success": True} + mock_http_post.return_value = mock_response + + args = _make_args() + # api_call signature: api_call(args, method, path, *, json_body=None, query_args=None) + result = api_call(args, "POST", "/api/v0/instances/123/", json_body={"action": "start"}) + + assert result == {"success": True} + mock_http_post.assert_called_once() + + @patch('vast.http_get') + def test_api_call_handles_http_error(self, mock_http_get): + """api_call propagates HTTPError from response.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = HTTPError("404 Not Found") + mock_http_get.return_value = mock_response + + args = _make_args() + with pytest.raises(HTTPError): + # api_call signature: api_call(args, method, path, *, json_body=None, query_args=None) + api_call(args, "GET", "/api/v0/nonexistent") + + +class TestHttpHelpers: + """Tests for http_get, http_post, http_put, http_del functions.""" + + @patch('vast.http_request') + def test_http_get_constructs_correct_request(self, mock_http_request): + """http_get passes correct method and URL to http_request.""" + from vast import http_get + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_get(args, "https://example.com/api/test") + + mock_http_request.assert_called_once() + call_args = mock_http_request.call_args + assert call_args[0][0] == "GET" # method + assert call_args[0][2] == "https://example.com/api/test" # url + + @patch('vast.http_request') + def test_http_post_sends_json_body(self, mock_http_request): + """http_post includes JSON body in request.""" + from vast import http_post + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + body = {"key": "value"} + + # http_post signature: http_post(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT) + http_post(args, "https://example.com/api/test", json=body) + + # http_request is called as: http_request('POST', args, req_url, headers, json, timeout=timeout) + # json is passed as positional arg at index 4 + call_args = mock_http_request.call_args[0] + assert call_args[4] == body # json is 5th positional arg (index 4) + + @patch('vast.http_request') + def test_http_put_sends_json_body(self, mock_http_request): + """http_put includes JSON body in request.""" + from vast import http_put + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + body = {"update": "data"} + + # http_put signature: http_put(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT) + http_put(args, "https://example.com/api/test", json=body) + + # http_request is called as: http_request('PUT', args, req_url, headers, json, timeout=timeout) + # json is passed as positional arg at index 4 + call_args = mock_http_request.call_args[0] + assert call_args[4] == body # json is 5th positional arg (index 4) + + @patch('vast.http_request') + def test_http_del_makes_delete_request(self, mock_http_request): + """http_del uses DELETE method.""" + from vast import http_del + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_del(args, "https://example.com/api/resource/123") + + call_args = mock_http_request.call_args + assert call_args[0][0] == "DELETE" + + +class TestHttpRequestRetry: + """Tests for retry logic in http_request.""" + + @patch('vast.time.sleep') + @patch('vast.requests.Session') + def test_retry_on_connection_error(self, mock_session_cls, mock_sleep): + """http_request retries on ConnectionError.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + # First call fails, second succeeds + success_response = MagicMock(status_code=200) + mock_session.send.side_effect = [ + ConnectionError("Connection refused"), + success_response + ] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + @patch('vast.time.sleep') + @patch('vast.requests.Session') + def test_retry_on_timeout(self, mock_session_cls, mock_sleep): + """http_request retries on Timeout.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + success_response = MagicMock(status_code=200) + mock_session.send.side_effect = [ + Timeout("Request timed out"), + success_response + ] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + @patch('vast.time.sleep') + @patch('vast.requests.Session') + def test_retry_on_503_status(self, mock_session_cls, mock_sleep): + """http_request retries on 503 Service Unavailable.""" + from vast import http_request, RETRYABLE_STATUS_CODES + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + error_response = MagicMock(status_code=503) + success_response = MagicMock(status_code=200) + mock_session.send.side_effect = [error_response, success_response] + mock_session.prepare_request.return_value = MagicMock() + + assert 503 in RETRYABLE_STATUS_CODES + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + @patch('vast.time.sleep') + @patch('vast.requests.Session') + def test_no_retry_on_400_status(self, mock_session_cls, mock_sleep): + """http_request does not retry on 400 Bad Request.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + error_response = MagicMock(status_code=400) + mock_session.send.return_value = error_response + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 400 + assert mock_session.send.call_count == 1 # No retry + + +class TestTimeoutConstants: + """Tests for timeout constant values.""" + + def test_default_timeout_defined(self): + """DEFAULT_TIMEOUT constant is defined and reasonable.""" + from vast import DEFAULT_TIMEOUT + + assert DEFAULT_TIMEOUT == 30 + assert isinstance(DEFAULT_TIMEOUT, (int, float)) + + def test_long_timeout_defined(self): + """LONG_TIMEOUT constant is defined for file operations.""" + from vast import LONG_TIMEOUT + + assert LONG_TIMEOUT == 120 + assert LONG_TIMEOUT > 30 # Should be longer than default + + def test_retryable_status_codes_defined(self): + """RETRYABLE_STATUS_CODES contains expected HTTP statuses.""" + from vast import RETRYABLE_STATUS_CODES + + assert 429 in RETRYABLE_STATUS_CODES # Too Many Requests + assert 502 in RETRYABLE_STATUS_CODES # Bad Gateway + assert 503 in RETRYABLE_STATUS_CODES # Service Unavailable + assert 504 in RETRYABLE_STATUS_CODES # Gateway Timeout + assert 500 not in RETRYABLE_STATUS_CODES # 500 not retried (may have side effects) diff --git a/tests/unit/test_query_parser.py b/tests/unit/test_query_parser.py new file mode 100644 index 00000000..2fe89948 --- /dev/null +++ b/tests/unit/test_query_parser.py @@ -0,0 +1,160 @@ +"""Unit tests for query parsing functions (parse_query, field aliases). + +TEST-04: Unit tests for query parsing. + +These tests verify the query parser handles field names, operators, +aliases, and various query formats correctly. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import pytest + + +class TestParseQuery: + """Tests for parse_query() function.""" + + def test_simple_equality(self): + """Simple equality operator parses correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("gpu_ram=8", {}, offers_fields, offers_alias) + + assert 'gpu_ram' in result + assert result['gpu_ram']['eq'] == '8' + + def test_greater_than_or_equal(self): + """Greater-than-or-equal operator parses correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("gpu_ram>=16", {}, offers_fields, offers_alias) + + assert 'gpu_ram' in result + assert 'gte' in result['gpu_ram'] + assert result['gpu_ram']['gte'] == '16' + + def test_less_than_or_equal(self): + """Less-than-or-equal operator parses correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("gpu_ram<=32", {}, offers_fields, offers_alias) + + assert 'gpu_ram' in result + assert 'lte' in result['gpu_ram'] + assert result['gpu_ram']['lte'] == '32' + + def test_multiple_conditions(self): + """Multiple space-separated conditions parse correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("gpu_ram>=8 num_gpus>=2", {}, offers_fields, offers_alias) + + assert 'gpu_ram' in result + assert 'num_gpus' in result + + +class TestFieldAliases: + """Tests for field alias handling in parse_query.""" + + def test_cuda_vers_alias(self): + """cuda_vers is aliased to cuda_max_good.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("cuda_vers>=12.0", {}, offers_fields, offers_alias) + + # After alias resolution, cuda_vers should become cuda_max_good + assert 'cuda_max_good' in result + assert 'cuda_vers' not in result + assert 'gte' in result['cuda_max_good'] + + def test_alias_resolution_preserves_value(self): + """Alias resolution preserves the comparison value.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("cuda_vers>=11.8", {}, offers_fields, offers_alias) + + assert result['cuda_max_good']['gte'] == '11.8' + + +class TestOfferFieldsDefinition: + """Tests verifying offers_fields contains expected fields.""" + + def test_gpu_ram_field_exists(self): + """gpu_ram is a valid offer field.""" + from vast import offers_fields + + assert 'gpu_ram' in offers_fields + + def test_num_gpus_field_exists(self): + """num_gpus is a valid offer field.""" + from vast import offers_fields + + assert 'num_gpus' in offers_fields + + def test_cuda_max_good_field_exists(self): + """cuda_max_good is a valid offer field.""" + from vast import offers_fields + + assert 'cuda_max_good' in offers_fields + + def test_dph_total_field_exists(self): + """dph_total (price) is a valid offer field.""" + from vast import offers_fields + + assert 'dph_total' in offers_fields + + +class TestQueryEdgeCases: + """Tests for edge cases in query parsing.""" + + def test_empty_query(self): + """Empty query returns empty result.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("", {}, offers_fields, offers_alias) + + assert result == {} or result is not None + + def test_whitespace_handling(self): + """Whitespace around operators is handled.""" + from vast import parse_query, offers_fields, offers_alias + + # Query with extra whitespace should still parse + result = parse_query("gpu_ram >= 8", {}, offers_fields, offers_alias) + + # Should have gpu_ram field with gte constraint + assert 'gpu_ram' in result + + def test_decimal_values(self): + """Decimal values in queries parse correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("dph_total<=0.5", {}, offers_fields, offers_alias) + + assert 'dph_total' in result + assert result['dph_total']['lte'] == '0.5' + + +class TestOfferAliasDefinition: + """Tests for offer field alias definitions.""" + + def test_offers_alias_is_dict(self): + """offers_alias is a dictionary.""" + from vast import offers_alias + + assert isinstance(offers_alias, dict) + + def test_cuda_vers_alias_defined(self): + """cuda_vers -> cuda_max_good alias is defined.""" + from vast import offers_alias + + assert 'cuda_vers' in offers_alias + assert offers_alias['cuda_vers'] == 'cuda_max_good' + + def test_dph_alias_defined(self): + """dph -> dph_total alias is defined.""" + from vast import offers_alias + + assert 'dph' in offers_alias + assert offers_alias['dph'] == 'dph_total' diff --git a/tests/unit/test_serverless_client.py b/tests/unit/test_serverless_client.py new file mode 100644 index 00000000..75c89051 --- /dev/null +++ b/tests/unit/test_serverless_client.py @@ -0,0 +1,257 @@ +"""Unit tests for vastai.serverless.client.client module. + +Tests for ServerlessRequest and Serverless class instantiation/configuration. +These tests focus on synchronous initialization code that does NOT require +actual network calls or async execution. +""" + +import pytest +import asyncio +from unittest.mock import MagicMock, patch, PropertyMock +import sys +import time +import logging +from pathlib import Path + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from vastai.serverless.client.client import Serverless, ServerlessRequest + + +@pytest.fixture +def event_loop_for_request(): + """Create an event loop for ServerlessRequest tests (Future requires event loop).""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() + asyncio.set_event_loop(None) + + +class TestServerlessRequest: + """Tests for ServerlessRequest class.""" + + def test_init_defaults(self, event_loop_for_request): + """Test ServerlessRequest initializes with expected defaults.""" + before_time = time.time() + req = ServerlessRequest() + after_time = time.time() + + assert req.status == "New" + assert req.start_time is None + assert req.complete_time is None + assert req.req_idx == 0 + # create_time should be between before and after test execution + assert before_time <= req.create_time <= after_time + + def test_init_status_is_new(self, event_loop_for_request): + """Test that initial status is always 'New'.""" + req = ServerlessRequest() + assert req.status == "New" + + def test_then_returns_self(self, event_loop_for_request): + """Test that then() method returns self for chaining.""" + req = ServerlessRequest() + callback = MagicMock() + result = req.then(callback) + assert result is req + + def test_then_callback_called_on_result(self, event_loop_for_request): + """Test that callback is called when result is set.""" + req = ServerlessRequest() + callback = MagicMock() + req.then(callback) + + test_data = {"test": "data", "value": 123} + req.set_result(test_data) + # Run pending callbacks + event_loop_for_request.run_until_complete(asyncio.sleep(0)) + + callback.assert_called_once_with(test_data) + + def test_then_callback_not_called_on_exception(self, event_loop_for_request): + """Test that callback is NOT called when exception is set.""" + req = ServerlessRequest() + callback = MagicMock() + req.then(callback) + + # Set exception instead of result + req.set_exception(ValueError("test error")) + # Run pending callbacks + event_loop_for_request.run_until_complete(asyncio.sleep(0)) + + # Callback should not be called for exceptions + callback.assert_not_called() + + def test_multiple_then_callbacks(self, event_loop_for_request): + """Test that multiple callbacks can be chained.""" + req = ServerlessRequest() + callback1 = MagicMock() + callback2 = MagicMock() + + req.then(callback1).then(callback2) + + test_data = {"result": "success"} + req.set_result(test_data) + # Run pending callbacks + event_loop_for_request.run_until_complete(asyncio.sleep(0)) + + callback1.assert_called_once_with(test_data) + callback2.assert_called_once_with(test_data) + + +class TestServerlessInit: + """Tests for Serverless class __init__ method.""" + + def test_raises_on_none_api_key(self): + """Test that None api_key raises AttributeError.""" + with pytest.raises(AttributeError, match="API key missing"): + Serverless(api_key=None) + + def test_raises_on_empty_api_key(self): + """Test that empty string api_key raises AttributeError.""" + with pytest.raises(AttributeError, match="API key missing"): + Serverless(api_key="") + + def test_valid_api_key_stored(self): + """Test that valid api_key is stored correctly.""" + client = Serverless(api_key="test_api_key_123") + assert client.api_key == "test_api_key_123" + + def test_prod_instance_url(self): + """Test instance='prod' sets autoscaler_url to production URL.""" + client = Serverless(api_key="test_key", instance="prod") + assert client.autoscaler_url == "https://run.vast.ai" + + def test_alpha_instance_url(self): + """Test instance='alpha' sets autoscaler_url to alpha URL.""" + client = Serverless(api_key="test_key", instance="alpha") + assert client.autoscaler_url == "https://run-alpha.vast.ai" + + def test_local_instance_url(self): + """Test instance='local' sets autoscaler_url to localhost.""" + client = Serverless(api_key="test_key", instance="local") + assert client.autoscaler_url == "http://localhost:8080" + + def test_unknown_instance_defaults_to_prod(self): + """Test that unknown instance value defaults to production URL.""" + client = Serverless(api_key="test_key", instance="unknown_value") + assert client.autoscaler_url == "https://run.vast.ai" + + def test_connection_limit_stored(self): + """Test that connection_limit is stored correctly.""" + client = Serverless(api_key="test_key", connection_limit=100) + assert client.connection_limit == 100 + + def test_default_connection_limit(self): + """Test default connection_limit value.""" + client = Serverless(api_key="test_key") + assert client.connection_limit == 500 + + def test_default_request_timeout_stored_as_float(self): + """Test that default_request_timeout is stored as float.""" + client = Serverless(api_key="test_key", default_request_timeout=120) + assert client.default_request_timeout == 120.0 + assert isinstance(client.default_request_timeout, float) + + def test_max_poll_interval_stored_as_float(self): + """Test that max_poll_interval is stored as float.""" + client = Serverless(api_key="test_key", max_poll_interval=30) + assert client.max_poll_interval == 30.0 + assert isinstance(client.max_poll_interval, float) + + def test_debug_true_sets_logger_level(self): + """Test debug=True sets logger to DEBUG level.""" + client = Serverless(api_key="test_key", debug=True) + assert client.logger.level == logging.DEBUG + + def test_debug_false_uses_null_handler(self): + """Test debug=False adds NullHandler (no output).""" + client = Serverless(api_key="test_key", debug=False) + # Logger should have handlers, and at least one should be NullHandler + handlers = client.logger.handlers + assert any(isinstance(h, logging.NullHandler) for h in handlers) + + def test_session_initially_none(self): + """Test that _session is None after initialization.""" + client = Serverless(api_key="test_key") + assert client._session is None + + def test_ssl_context_initially_none(self): + """Test that _ssl_context is None after initialization.""" + client = Serverless(api_key="test_key") + assert client._ssl_context is None + + +class TestServerlessIsOpen: + """Tests for Serverless.is_open() method.""" + + def test_returns_false_when_session_is_none(self): + """Test is_open() returns False when _session is None.""" + client = Serverless(api_key="test_key") + assert client._session is None + assert client.is_open() is False + + def test_returns_true_when_session_open(self): + """Test is_open() returns True when session exists and not closed.""" + client = Serverless(api_key="test_key") + + # Create a mock session that is not closed + mock_session = MagicMock() + mock_session.closed = False + client._session = mock_session + + assert client.is_open() is True + + def test_returns_false_when_session_closed(self): + """Test is_open() returns False when session is closed.""" + client = Serverless(api_key="test_key") + + # Create a mock session that IS closed + mock_session = MagicMock() + mock_session.closed = True + client._session = mock_session + + assert client.is_open() is False + + +class TestServerlessGetAvgRequestTime: + """Tests for Serverless.get_avg_request_time() method.""" + + def test_returns_60_hardcoded(self): + """Test get_avg_request_time() returns 60.0 (hardcoded value).""" + client = Serverless(api_key="test_key") + result = client.get_avg_request_time() + assert result == 60.0 + assert isinstance(result, float) + + +class TestServerlessConstants: + """Tests for Serverless class constants.""" + + def test_ssl_cert_url(self): + """Test SSL_CERT_URL constant value.""" + assert Serverless.SSL_CERT_URL == "https://console.vast.ai/static/jvastai_root.cer" + + def test_vast_web_url(self): + """Test VAST_WEB_URL constant value.""" + assert Serverless.VAST_WEB_URL == "https://console.vast.ai" + + def test_vast_serverless_url(self): + """Test VAST_SERVERLESS_URL constant value.""" + assert Serverless.VAST_SERVERLESS_URL == "https://run.vast.ai" + + +class TestServerlessLatencies: + """Tests for latencies deque initialization.""" + + def test_latencies_initialized_empty(self): + """Test that latencies deque is empty after init.""" + client = Serverless(api_key="test_key") + assert len(client.latencies) == 0 + + def test_latencies_maxlen_is_50(self): + """Test that latencies deque has maxlen of 50.""" + client = Serverless(api_key="test_key") + assert client.latencies.maxlen == 50 diff --git a/tests/unit/test_serverless_connection.py b/tests/unit/test_serverless_connection.py new file mode 100644 index 00000000..9f3f34ab --- /dev/null +++ b/tests/unit/test_serverless_connection.py @@ -0,0 +1,284 @@ +"""Unit tests for vastai.serverless.client.connection module helper functions. + +Tests cover: +- _retryable: status code retry determination +- _backoff_delay: exponential backoff with jitter calculation +- _build_kwargs: HTTP request argument construction +""" +import pytest +from unittest.mock import MagicMock, patch +import sys +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from vastai.serverless.client.connection import ( + _retryable, + _backoff_delay, + _build_kwargs, + _JITTER_CAP_SECONDS, +) + + +class TestRetryable: + """Tests for _retryable() function.""" + + def test_retryable_408_request_timeout(self): + """408 Request Timeout is retryable.""" + assert _retryable(408) is True + + def test_retryable_429_too_many_requests(self): + """429 Too Many Requests is retryable.""" + assert _retryable(429) is True + + def test_retryable_500_internal_server_error(self): + """500 Internal Server Error is retryable.""" + assert _retryable(500) is True + + def test_retryable_501_not_implemented(self): + """501 Not Implemented is retryable (5xx range).""" + assert _retryable(501) is True + + def test_retryable_502_bad_gateway(self): + """502 Bad Gateway is retryable.""" + assert _retryable(502) is True + + def test_retryable_503_service_unavailable(self): + """503 Service Unavailable is retryable.""" + assert _retryable(503) is True + + def test_retryable_504_gateway_timeout(self): + """504 Gateway Timeout is retryable.""" + assert _retryable(504) is True + + def test_retryable_599_max_5xx(self): + """599 (max 5xx) is retryable.""" + assert _retryable(599) is True + + def test_not_retryable_200_ok(self): + """200 OK is not retryable.""" + assert _retryable(200) is False + + def test_not_retryable_201_created(self): + """201 Created is not retryable.""" + assert _retryable(201) is False + + def test_not_retryable_400_bad_request(self): + """400 Bad Request is not retryable.""" + assert _retryable(400) is False + + def test_not_retryable_401_unauthorized(self): + """401 Unauthorized is not retryable.""" + assert _retryable(401) is False + + def test_not_retryable_403_forbidden(self): + """403 Forbidden is not retryable.""" + assert _retryable(403) is False + + def test_not_retryable_404_not_found(self): + """404 Not Found is not retryable.""" + assert _retryable(404) is False + + def test_not_retryable_600_boundary(self): + """600 is not retryable (beyond 5xx range).""" + assert _retryable(600) is False + + +class TestBackoffDelay: + """Tests for _backoff_delay() function.""" + + def test_backoff_attempt_1(self): + """Attempt 1 returns delay between 2.0 and 3.0.""" + delay = _backoff_delay(1) + # 2^1 + random.uniform(0, 1) = 2.0 to 3.0 + assert 2.0 <= delay <= 3.0 + + def test_backoff_attempt_2(self): + """Attempt 2 returns delay between 4.0 and 5.0.""" + delay = _backoff_delay(2) + # 2^2 + random.uniform(0, 1) = 4.0 to 5.0 (capped at 5.0) + assert 4.0 <= delay <= 5.0 + + def test_backoff_capped_at_jitter_cap(self): + """Large attempts are capped at _JITTER_CAP_SECONDS (5.0).""" + delay = _backoff_delay(10) # 2^10 = 1024 >> 5 + assert delay <= _JITTER_CAP_SECONDS + assert delay <= 5.0 + + def test_backoff_attempt_0(self): + """Attempt 0 returns delay between 1.0 and 2.0.""" + delay = _backoff_delay(0) + # 2^0 + random.uniform(0, 1) = 1.0 to 2.0 + assert 1.0 <= delay <= 2.0 + + def test_backoff_always_positive(self): + """All backoff delays are positive.""" + for attempt in range(10): + delay = _backoff_delay(attempt) + assert delay > 0 + + @patch('vastai.serverless.client.connection.random.uniform') + def test_backoff_uses_random_uniform(self, mock_uniform): + """_backoff_delay uses random.uniform for jitter.""" + mock_uniform.return_value = 0.5 + delay = _backoff_delay(1) + # 2^1 + 0.5 = 2.5 (not capped) + assert delay == 2.5 + mock_uniform.assert_called_once_with(0, 1) + + +class TestBuildKwargs: + """Tests for _build_kwargs() function.""" + + def test_includes_body_for_post(self): + """_build_kwargs includes json body for POST requests.""" + result = _build_kwargs( + headers={"Authorization": "Bearer test"}, + params={"key": "val"}, + ssl_context=None, + timeout=30, + body={"data": "test"}, + method="POST", + stream=False, + ) + assert "json" in result + assert result["json"] == {"data": "test"} + + def test_includes_body_for_put(self): + """_build_kwargs includes json body for PUT requests.""" + result = _build_kwargs( + headers={}, + params={}, + ssl_context=None, + timeout=30, + body={"data": "update"}, + method="PUT", + stream=False, + ) + assert "json" in result + assert result["json"] == {"data": "update"} + + def test_includes_body_for_delete(self): + """_build_kwargs includes json body for DELETE requests.""" + result = _build_kwargs( + headers={}, + params={}, + ssl_context=None, + timeout=30, + body={"id": 123}, + method="DELETE", + stream=False, + ) + assert "json" in result + assert result["json"] == {"id": 123} + + def test_excludes_body_for_get(self): + """_build_kwargs excludes body for GET requests.""" + result = _build_kwargs( + headers={}, + params={}, + ssl_context=None, + timeout=30, + body={"data": "test"}, + method="GET", + stream=False, + ) + assert "json" not in result + + def test_excludes_body_when_none(self): + """_build_kwargs excludes json when body is None.""" + result = _build_kwargs( + headers={}, + params={}, + ssl_context=None, + timeout=30, + body=None, + method="POST", + stream=False, + ) + assert "json" not in result + + def test_includes_headers(self): + """_build_kwargs includes headers in result.""" + headers = {"Authorization": "Bearer token", "Content-Type": "application/json"} + result = _build_kwargs( + headers=headers, + params={}, + ssl_context=None, + timeout=30, + body=None, + method="GET", + stream=False, + ) + assert result["headers"] == headers + + def test_includes_params(self): + """_build_kwargs includes params in result.""" + params = {"api_key": "abc123", "limit": "10"} + result = _build_kwargs( + headers={}, + params=params, + ssl_context=None, + timeout=30, + body=None, + method="GET", + stream=False, + ) + assert result["params"] == params + + def test_includes_ssl_context(self): + """_build_kwargs includes ssl context in result.""" + mock_ssl = MagicMock() + result = _build_kwargs( + headers={}, + params={}, + ssl_context=mock_ssl, + timeout=30, + body=None, + method="GET", + stream=False, + ) + assert result["ssl"] == mock_ssl + + def test_timeout_non_stream(self): + """_build_kwargs sets timeout for non-streaming requests.""" + result = _build_kwargs( + headers={}, + params={}, + ssl_context=None, + timeout=30, + body=None, + method="GET", + stream=False, + ) + assert result["timeout"].total == 30 + + def test_timeout_stream(self): + """_build_kwargs sets timeout to None for streaming requests.""" + result = _build_kwargs( + headers={}, + params={}, + ssl_context=None, + timeout=30, + body=None, + method="GET", + stream=True, + ) + assert result["timeout"].total is None + + def test_all_keys_present(self): + """_build_kwargs always includes headers, params, ssl, timeout.""" + result = _build_kwargs( + headers={}, + params={}, + ssl_context=None, + timeout=60, + body=None, + method="GET", + stream=False, + ) + assert "headers" in result + assert "params" in result + assert "ssl" in result + assert "timeout" in result diff --git a/tests/unit/test_serverless_data_types.py b/tests/unit/test_serverless_data_types.py new file mode 100644 index 00000000..201f1957 --- /dev/null +++ b/tests/unit/test_serverless_data_types.py @@ -0,0 +1,450 @@ +"""Unit tests for vastai.serverless.server.lib.data_types module. + +Tests cover: +- JsonDataException: message storage +- AuthData: from_json_msg with valid/missing/extra fields +- SystemMetrics: empty() factory, update_disk_usage(), reset() +- RequestMetrics: dataclass instantiation +- BenchmarkResult: is_successful property +- ModelMetrics: empty() factory, properties, set_errored(), reset() +- WorkerStatusData: dataclass instantiation +- LogAction enum: value validation +""" +import pytest +from unittest.mock import MagicMock, patch +import sys +import time +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from vastai.serverless.server.lib.data_types import ( + JsonDataException, + AuthData, + SystemMetrics, + RequestMetrics, + BenchmarkResult, + ModelMetrics, + WorkerStatusData, + LogAction, +) + + +class TestJsonDataException: + """Tests for JsonDataException class.""" + + def test_stores_message_dict(self): + """JsonDataException stores the error message dict.""" + error_msg = {"field": "missing parameter"} + exc = JsonDataException(error_msg) + assert exc.message == {"field": "missing parameter"} + + def test_multiple_fields_in_message(self): + """JsonDataException can store multiple field errors.""" + error_msg = {"cost": "missing parameter", "endpoint": "invalid format"} + exc = JsonDataException(error_msg) + assert exc.message == error_msg + assert "cost" in exc.message + assert "endpoint" in exc.message + + def test_is_exception(self): + """JsonDataException is a proper Exception subclass.""" + exc = JsonDataException({"error": "test"}) + assert isinstance(exc, Exception) + + +class TestAuthData: + """Tests for AuthData dataclass.""" + + def test_from_json_msg_valid(self): + """AuthData.from_json_msg creates instance with valid data.""" + json_msg = { + "cost": "0.10", + "endpoint": "/api/test", + "reqnum": 1, + "request_idx": 0, + "signature": "abc123", + "url": "https://example.com", + } + auth = AuthData.from_json_msg(json_msg) + assert auth.cost == "0.10" + assert auth.endpoint == "/api/test" + assert auth.reqnum == 1 + assert auth.request_idx == 0 + assert auth.signature == "abc123" + assert auth.url == "https://example.com" + + def test_from_json_msg_missing_field_raises(self): + """AuthData.from_json_msg raises JsonDataException for missing fields.""" + json_msg = { + "cost": "0.10", + "endpoint": "/api/test", + # missing reqnum, request_idx, signature, url + } + with pytest.raises(JsonDataException) as exc_info: + AuthData.from_json_msg(json_msg) + assert "reqnum" in exc_info.value.message + assert "request_idx" in exc_info.value.message + assert "signature" in exc_info.value.message + assert "url" in exc_info.value.message + + def test_from_json_msg_single_missing_field(self): + """AuthData.from_json_msg reports single missing field.""" + json_msg = { + "cost": "0.10", + "endpoint": "/api/test", + "reqnum": 1, + "request_idx": 0, + "signature": "abc123", + # missing url + } + with pytest.raises(JsonDataException) as exc_info: + AuthData.from_json_msg(json_msg) + assert "url" in exc_info.value.message + assert exc_info.value.message["url"] == "missing parameter" + + def test_from_json_msg_extra_fields_ignored(self): + """AuthData.from_json_msg ignores extra fields.""" + json_msg = { + "cost": "0.10", + "endpoint": "/api/test", + "reqnum": 1, + "request_idx": 0, + "signature": "abc123", + "url": "https://example.com", + "extra_field": "should be ignored", + "another_extra": 12345, + } + auth = AuthData.from_json_msg(json_msg) + assert auth.cost == "0.10" + assert not hasattr(auth, "extra_field") + assert not hasattr(auth, "another_extra") + + +class TestSystemMetrics: + """Tests for SystemMetrics dataclass.""" + + @patch("vastai.serverless.server.lib.data_types.psutil.disk_usage") + @patch("vastai.serverless.server.lib.data_types.time.time") + def test_empty_factory(self, mock_time, mock_disk_usage): + """SystemMetrics.empty() creates instance with default values.""" + mock_time.return_value = 1000.0 + mock_disk_usage.return_value = MagicMock(used=10 * (2**30)) # 10 GB + + metrics = SystemMetrics.empty() + + assert metrics.model_loading_start == 1000.0 + assert metrics.model_loading_time is None + assert metrics.last_disk_usage == 10.0 # GB + assert metrics.additional_disk_usage == 0.0 + assert metrics.model_is_loaded is False + + @patch("vastai.serverless.server.lib.data_types.psutil.disk_usage") + def test_update_disk_usage(self, mock_disk_usage): + """SystemMetrics.update_disk_usage() updates disk usage stats.""" + # Initial disk usage: 10 GB + metrics = SystemMetrics( + model_loading_start=0.0, + model_loading_time=None, + last_disk_usage=10.0, + additional_disk_usage=0.0, + model_is_loaded=False, + ) + + # After update: 15 GB used + mock_disk_usage.return_value = MagicMock(used=15 * (2**30)) + metrics.update_disk_usage() + + assert metrics.last_disk_usage == 15.0 + assert metrics.additional_disk_usage == 5.0 # 15 - 10 + + def test_reset_clears_loading_time_when_expected(self): + """SystemMetrics.reset() clears model_loading_time when matching expected.""" + metrics = SystemMetrics( + model_loading_start=0.0, + model_loading_time=5.0, + last_disk_usage=10.0, + additional_disk_usage=0.0, + model_is_loaded=True, + ) + + metrics.reset(expected=5.0) + assert metrics.model_loading_time is None + + def test_reset_preserves_loading_time_when_different(self): + """SystemMetrics.reset() preserves model_loading_time when not matching expected.""" + metrics = SystemMetrics( + model_loading_start=0.0, + model_loading_time=5.0, + last_disk_usage=10.0, + additional_disk_usage=0.0, + model_is_loaded=True, + ) + + metrics.reset(expected=10.0) # Different from current value + assert metrics.model_loading_time == 5.0 # Preserved + + def test_reset_with_none_expected(self): + """SystemMetrics.reset() with None expected clears None loading time.""" + metrics = SystemMetrics( + model_loading_start=0.0, + model_loading_time=None, + last_disk_usage=10.0, + additional_disk_usage=0.0, + model_is_loaded=False, + ) + + metrics.reset(expected=None) + assert metrics.model_loading_time is None + + +class TestRequestMetrics: + """Tests for RequestMetrics dataclass.""" + + def test_instantiation_required_fields(self): + """RequestMetrics can be instantiated with required fields.""" + metrics = RequestMetrics( + request_idx=0, + reqnum=1, + workload=1.5, + status="processing", + ) + assert metrics.request_idx == 0 + assert metrics.reqnum == 1 + assert metrics.workload == 1.5 + assert metrics.status == "processing" + assert metrics.success is False # default + + def test_instantiation_with_success(self): + """RequestMetrics respects success parameter.""" + metrics = RequestMetrics( + request_idx=1, + reqnum=2, + workload=2.0, + status="completed", + success=True, + ) + assert metrics.success is True + + +class TestBenchmarkResult: + """Tests for BenchmarkResult dataclass.""" + + def test_is_successful_true(self): + """BenchmarkResult.is_successful returns True for 200 response.""" + mock_response = MagicMock() + mock_response.status = 200 + + mock_task = MagicMock() + result = BenchmarkResult( + request_idx=0, + workload=1.0, + task=mock_task, + response=mock_response, + ) + assert result.is_successful is True + + def test_is_successful_false_non_200(self): + """BenchmarkResult.is_successful returns False for non-200 response.""" + mock_response = MagicMock() + mock_response.status = 500 + + mock_task = MagicMock() + result = BenchmarkResult( + request_idx=0, + workload=1.0, + task=mock_task, + response=mock_response, + ) + assert result.is_successful is False + + def test_is_successful_false_no_response(self): + """BenchmarkResult.is_successful returns False when response is None.""" + mock_task = MagicMock() + result = BenchmarkResult( + request_idx=0, + workload=1.0, + task=mock_task, + response=None, # No response yet + ) + assert result.is_successful is False + + +class TestModelMetrics: + """Tests for ModelMetrics dataclass.""" + + def test_empty_factory(self): + """ModelMetrics.empty() creates instance with zero values.""" + metrics = ModelMetrics.empty() + assert metrics.workload_pending == 0.0 + assert metrics.workload_served == 0.0 + assert metrics.workload_cancelled == 0.0 + assert metrics.workload_errored == 0.0 + assert metrics.workload_rejected == 0.0 + assert metrics.workload_received == 0.0 + assert metrics.error_msg is None + assert metrics.max_throughput == 0.0 + + def test_workload_processing_property(self): + """ModelMetrics.workload_processing computes correctly.""" + metrics = ModelMetrics.empty() + metrics.workload_received = 10.0 + metrics.workload_cancelled = 3.0 + assert metrics.workload_processing == 7.0 # 10 - 3 + + def test_workload_processing_never_negative(self): + """ModelMetrics.workload_processing is clamped to 0.""" + metrics = ModelMetrics.empty() + metrics.workload_received = 5.0 + metrics.workload_cancelled = 10.0 # More cancelled than received + assert metrics.workload_processing == 0.0 + + def test_wait_time_empty_requests(self): + """ModelMetrics.wait_time returns 0 when no requests working.""" + metrics = ModelMetrics.empty() + assert metrics.wait_time == 0.0 + + def test_wait_time_calculation(self): + """ModelMetrics.wait_time computes workload/throughput.""" + metrics = ModelMetrics.empty() + metrics.max_throughput = 10.0 + metrics.requests_working = { + 0: RequestMetrics(request_idx=0, reqnum=1, workload=5.0, status="working"), + 1: RequestMetrics(request_idx=1, reqnum=2, workload=5.0, status="working"), + } + # wait_time = sum(workloads) / max_throughput = 10.0 / 10.0 = 1.0 + assert metrics.wait_time == 1.0 + + def test_wait_time_near_zero_throughput(self): + """ModelMetrics.wait_time handles near-zero throughput gracefully.""" + metrics = ModelMetrics.empty() + metrics.max_throughput = 0.0 # Will use 0.00001 denominator + metrics.requests_working = { + 0: RequestMetrics(request_idx=0, reqnum=1, workload=1.0, status="working"), + } + # Should not raise ZeroDivisionError + result = metrics.wait_time + assert result == 1.0 / 0.00001 # Very large but finite + + def test_cur_load_property(self): + """ModelMetrics.cur_load sums workload from working requests.""" + metrics = ModelMetrics.empty() + metrics.requests_working = { + 0: RequestMetrics(request_idx=0, reqnum=1, workload=2.0, status="working"), + 1: RequestMetrics(request_idx=1, reqnum=2, workload=3.0, status="working"), + } + assert metrics.cur_load == 5.0 + + def test_cur_load_empty(self): + """ModelMetrics.cur_load returns 0 for no working requests.""" + metrics = ModelMetrics.empty() + assert metrics.cur_load == 0.0 + + def test_working_request_idxs_property(self): + """ModelMetrics.working_request_idxs returns list of request indices.""" + metrics = ModelMetrics.empty() + metrics.requests_working = { + 0: RequestMetrics(request_idx=0, reqnum=1, workload=1.0, status="working"), + 5: RequestMetrics(request_idx=5, reqnum=6, workload=1.0, status="working"), + } + idxs = metrics.working_request_idxs + assert 0 in idxs + assert 5 in idxs + + def test_set_errored(self): + """ModelMetrics.set_errored resets metrics and sets error message.""" + metrics = ModelMetrics.empty() + metrics.workload_served = 10.0 + metrics.workload_received = 20.0 + + metrics.set_errored("Test error message") + + assert metrics.error_msg == "Test error message" + assert metrics.workload_served == 0 + assert metrics.workload_received == 0 + + def test_reset(self): + """ModelMetrics.reset clears workload counters.""" + metrics = ModelMetrics.empty() + metrics.workload_served = 10.0 + metrics.workload_received = 20.0 + metrics.workload_cancelled = 5.0 + metrics.workload_errored = 2.0 + metrics.workload_rejected = 1.0 + + metrics.reset() + + assert metrics.workload_served == 0 + assert metrics.workload_received == 0 + assert metrics.workload_cancelled == 0 + assert metrics.workload_errored == 0 + assert metrics.workload_rejected == 0 + + +class TestWorkerStatusData: + """Tests for WorkerStatusData dataclass.""" + + def test_instantiation_all_fields(self): + """WorkerStatusData can be instantiated with all fields.""" + status = WorkerStatusData( + id=123, + mtoken="token-abc", + version="1.0.0", + loadtime=5.0, + cur_load=2.0, + rej_load=0.5, + new_load=1.5, + error_msg="", + max_perf=10.0, + cur_perf=8.0, + cur_capacity=100.0, + max_capacity=200.0, + num_requests_working=3, + num_requests_recieved=10, + additional_disk_usage=1.5, + working_request_idxs=[0, 1, 2], + url="https://worker.example.com", + ) + + assert status.id == 123 + assert status.mtoken == "token-abc" + assert status.version == "1.0.0" + assert status.loadtime == 5.0 + assert status.cur_load == 2.0 + assert status.rej_load == 0.5 + assert status.new_load == 1.5 + assert status.error_msg == "" + assert status.max_perf == 10.0 + assert status.cur_perf == 8.0 + assert status.cur_capacity == 100.0 + assert status.max_capacity == 200.0 + assert status.num_requests_working == 3 + assert status.num_requests_recieved == 10 + assert status.additional_disk_usage == 1.5 + assert status.working_request_idxs == [0, 1, 2] + assert status.url == "https://worker.example.com" + + +class TestLogAction: + """Tests for LogAction enum.""" + + def test_model_loaded_value(self): + """LogAction.ModelLoaded has value 1.""" + assert LogAction.ModelLoaded.value == 1 + + def test_model_error_value(self): + """LogAction.ModelError has value 2.""" + assert LogAction.ModelError.value == 2 + + def test_info_value(self): + """LogAction.Info has value 3.""" + assert LogAction.Info.value == 3 + + def test_all_members_exist(self): + """All expected LogAction members exist.""" + members = list(LogAction) + assert len(members) == 3 + assert LogAction.ModelLoaded in members + assert LogAction.ModelError in members + assert LogAction.Info in members diff --git a/tests/unit/test_serverless_endpoint.py b/tests/unit/test_serverless_endpoint.py new file mode 100644 index 00000000..596251b5 --- /dev/null +++ b/tests/unit/test_serverless_endpoint.py @@ -0,0 +1,247 @@ +"""Unit tests for vastai.serverless.client.endpoint module. + +Tests for Endpoint and RouteResponse classes. +These tests focus on synchronous initialization and validation code. +""" + +import pytest +from unittest.mock import MagicMock +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from vastai.serverless.client.endpoint import Endpoint, RouteResponse + + +class TestEndpointInit: + """Tests for Endpoint class __init__ method.""" + + def test_raises_on_none_client(self): + """Test that None client raises ValueError.""" + with pytest.raises(ValueError, match="cannot be created without client"): + Endpoint(client=None, name="test", id=1, api_key="key") + + def test_raises_on_empty_name(self): + """Test that empty string name raises ValueError.""" + mock_client = MagicMock() + with pytest.raises(ValueError, match="name cannot be empty"): + Endpoint(client=mock_client, name="", id=1, api_key="key") + + def test_raises_on_falsy_name(self): + """Test that falsy name (None) raises ValueError.""" + mock_client = MagicMock() + with pytest.raises(ValueError, match="name cannot be empty"): + Endpoint(client=mock_client, name=None, id=1, api_key="key") + + def test_raises_on_none_id(self): + """Test that None id raises ValueError.""" + mock_client = MagicMock() + with pytest.raises(ValueError, match="id cannot be empty"): + Endpoint(client=mock_client, name="test", id=None, api_key="key") + + def test_valid_instantiation_stores_attributes(self): + """Test that valid instantiation stores all attributes correctly.""" + mock_client = MagicMock() + endpoint = Endpoint( + client=mock_client, + name="my-endpoint", + id=123, + api_key="secret_api_key" + ) + assert endpoint.client is mock_client + assert endpoint.name == "my-endpoint" + assert endpoint.id == 123 + assert endpoint.api_key == "secret_api_key" + + def test_zero_id_is_valid(self): + """Test that id=0 is valid (not None).""" + mock_client = MagicMock() + endpoint = Endpoint(client=mock_client, name="test", id=0, api_key="key") + assert endpoint.id == 0 + + def test_negative_id_is_valid(self): + """Test that negative id is stored (no validation against negative).""" + mock_client = MagicMock() + endpoint = Endpoint(client=mock_client, name="test", id=-5, api_key="key") + assert endpoint.id == -5 + + +class TestEndpointRepr: + """Tests for Endpoint __repr__ method.""" + + def test_repr_format(self): + """Test __repr__ returns expected string format.""" + mock_client = MagicMock() + endpoint = Endpoint(client=mock_client, name="test-ep", id=42, api_key="key") + assert repr(endpoint) == "" + + def test_repr_with_long_name(self): + """Test __repr__ handles long endpoint names.""" + mock_client = MagicMock() + endpoint = Endpoint( + client=mock_client, + name="very-long-endpoint-name-for-testing", + id=999, + api_key="key" + ) + assert repr(endpoint) == "" + + def test_repr_with_special_chars_in_name(self): + """Test __repr__ handles special characters in name.""" + mock_client = MagicMock() + endpoint = Endpoint(client=mock_client, name="my_test.endpoint", id=1, api_key="key") + assert repr(endpoint) == "" + + +class TestRouteResponseStatusReady: + """Tests for RouteResponse when URL is present (READY status).""" + + def test_status_ready_when_url_present(self): + """Test that presence of 'url' in body sets status to READY.""" + body = {"url": "https://worker.example.com", "request_idx": 5} + response = RouteResponse(body) + assert response.status == "READY" + + def test_request_idx_extracted_when_present(self): + """Test request_idx is extracted from body when present.""" + body = {"url": "https://worker.example.com", "request_idx": 42} + response = RouteResponse(body) + assert response.request_idx == 42 + + def test_body_stored_when_ready(self): + """Test that body dict is stored.""" + body = {"url": "https://worker.example.com", "extra": "data"} + response = RouteResponse(body) + assert response.body == body + + +class TestRouteResponseStatusWaiting: + """Tests for RouteResponse when URL is NOT present (WAITING status).""" + + def test_status_waiting_when_no_url(self): + """Test that absence of 'url' in body sets status to WAITING.""" + body = {"request_idx": 3} + response = RouteResponse(body) + assert response.status == "WAITING" + + def test_request_idx_extracted_when_waiting(self): + """Test request_idx is extracted from body when in waiting status.""" + body = {"request_idx": 7} + response = RouteResponse(body) + assert response.request_idx == 7 + + def test_body_stored_when_waiting(self): + """Test that body dict is stored when waiting.""" + body = {"request_idx": 3, "other": "info"} + response = RouteResponse(body) + assert response.body == body + + +class TestRouteResponseRequestIdx: + """Tests for RouteResponse request_idx handling.""" + + def test_request_idx_defaults_to_zero(self): + """Test request_idx defaults to 0 when not in body.""" + body = {} + response = RouteResponse(body) + assert response.request_idx == 0 + + def test_request_idx_zero_when_explicit(self): + """Test request_idx=0 when explicitly set.""" + body = {"request_idx": 0} + response = RouteResponse(body) + assert response.request_idx == 0 + + def test_request_idx_large_value(self): + """Test request_idx with large value.""" + body = {"request_idx": 999999} + response = RouteResponse(body) + assert response.request_idx == 999999 + + +class TestRouteResponseRepr: + """Tests for RouteResponse __repr__ method.""" + + def test_repr_format_ready(self): + """Test __repr__ returns expected format for READY status.""" + body = {"url": "https://test.com"} + response = RouteResponse(body) + assert repr(response) == "" + + def test_repr_format_waiting(self): + """Test __repr__ returns expected format for WAITING status.""" + body = {} + response = RouteResponse(body) + assert repr(response) == "" + + +class TestRouteResponseGetUrl: + """Tests for RouteResponse.get_url() method.""" + + def test_get_url_returns_url_from_body(self): + """Test get_url() returns url from body when present.""" + body = {"url": "https://example.com/worker"} + response = RouteResponse(body) + assert response.get_url() == "https://example.com/worker" + + def test_get_url_none_when_missing(self): + """Test get_url() returns None when url not in body.""" + body = {} + response = RouteResponse(body) + assert response.get_url() is None + + def test_get_url_with_complex_url(self): + """Test get_url() with complex URL including path and query.""" + body = {"url": "https://worker-123.vast.ai:8080/inference?model=gpt"} + response = RouteResponse(body) + assert response.get_url() == "https://worker-123.vast.ai:8080/inference?model=gpt" + + def test_get_url_empty_string(self): + """Test get_url() when url is empty string.""" + body = {"url": ""} + response = RouteResponse(body) + assert response.get_url() == "" + + +class TestEndpointRequest: + """Tests for Endpoint.request() method delegation.""" + + def test_request_delegates_to_client(self): + """Test that request() delegates to client.queue_endpoint_request().""" + mock_client = MagicMock() + endpoint = Endpoint(client=mock_client, name="test", id=1, api_key="key") + + endpoint.request("/route", {"data": "value"}) + + mock_client.queue_endpoint_request.assert_called_once() + call_kwargs = mock_client.queue_endpoint_request.call_args[1] + assert call_kwargs["endpoint"] is endpoint + assert call_kwargs["worker_route"] == "/route" + assert call_kwargs["worker_payload"] == {"data": "value"} + + def test_request_passes_optional_params(self): + """Test that request() passes optional parameters.""" + mock_client = MagicMock() + endpoint = Endpoint(client=mock_client, name="test", id=1, api_key="key") + + endpoint.request("/route", {"data": "value"}, cost=200, retry=False, stream=True) + + call_kwargs = mock_client.queue_endpoint_request.call_args[1] + assert call_kwargs["cost"] == 200 + assert call_kwargs["retry"] is False + assert call_kwargs["stream"] is True + + +class TestEndpointGetWorkers: + """Tests for Endpoint.get_workers() method delegation.""" + + def test_get_workers_delegates_to_client(self): + """Test that get_workers() delegates to client.get_endpoint_workers().""" + mock_client = MagicMock() + endpoint = Endpoint(client=mock_client, name="test", id=1, api_key="key") + + endpoint.get_workers() + + mock_client.get_endpoint_workers.assert_called_once_with(endpoint) diff --git a/tests/unit/test_serverless_metrics.py b/tests/unit/test_serverless_metrics.py new file mode 100644 index 00000000..ad7661d2 --- /dev/null +++ b/tests/unit/test_serverless_metrics.py @@ -0,0 +1,338 @@ +""" +Unit tests for vastai/serverless/server/lib/metrics.py + +Tests the Metrics class request lifecycle methods without network calls. +""" +import pytest +from unittest.mock import MagicMock, patch +import sys +import os +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +@pytest.fixture +def mock_env(): + """Mock environment variables required by Metrics class.""" + env = { + "CONTAINER_ID": "12345", + "REPORT_ADDR": "https://run.vast.ai", + "PUBLIC_IPADDR": "192.168.1.100", + "WORKER_PORT": "8080", + "VAST_TCP_PORT_8080": "8080", + } + with patch.dict(os.environ, env, clear=False): + # Clear the cached get_url function so it picks up new env vars + from vastai.serverless.server.lib.metrics import get_url + get_url.cache_clear() + yield env + + +@pytest.fixture +def metrics(mock_env): + """Create Metrics instance with mocked environment.""" + from vastai.serverless.server.lib.metrics import Metrics + return Metrics() + + +@pytest.fixture +def request_metrics(): + """Create a RequestMetrics object for testing.""" + from vastai.serverless.server.lib.data_types import RequestMetrics + return RequestMetrics( + request_idx=1, + reqnum=100, + workload=50.0, + status="Created" + ) + + +class TestMetricsInstantiation: + """Tests for Metrics class instantiation.""" + + def test_metrics_created_with_env_defaults(self, metrics, mock_env): + """Metrics instance should use environment variables for defaults.""" + assert metrics.id == int(mock_env["CONTAINER_ID"]) + assert mock_env["REPORT_ADDR"] in metrics.report_addr + assert metrics.version == "0" + assert metrics.mtoken == "" + assert metrics.update_pending is False + + def test_metrics_url_constructed_from_env(self, metrics, mock_env): + """URL should be constructed from PUBLIC_IPADDR and port.""" + expected_url = f"http://{mock_env['PUBLIC_IPADDR']}:{mock_env['VAST_TCP_PORT_8080']}" + assert metrics.url == expected_url + + def test_metrics_system_metrics_initialized(self, metrics): + """System metrics should be initialized with empty factory.""" + assert metrics.system_metrics is not None + assert metrics.system_metrics.model_is_loaded is False + + def test_metrics_model_metrics_initialized(self, metrics): + """Model metrics should be initialized with empty factory.""" + assert metrics.model_metrics is not None + assert metrics.model_metrics.workload_pending == 0.0 + assert metrics.model_metrics.workload_served == 0.0 + + +class TestMetricsRequestStart: + """Tests for _request_start method.""" + + def test_updates_status_to_started(self, metrics, request_metrics): + """_request_start should set request status to Started.""" + metrics._request_start(request_metrics) + assert request_metrics.status == "Started" + + def test_increments_workload_pending(self, metrics, request_metrics): + """_request_start should increase workload_pending by request workload.""" + metrics._request_start(request_metrics) + assert metrics.model_metrics.workload_pending == request_metrics.workload + + def test_increments_workload_received(self, metrics, request_metrics): + """_request_start should increase workload_received by request workload.""" + metrics._request_start(request_metrics) + assert metrics.model_metrics.workload_received == request_metrics.workload + + def test_adds_to_requests_recieved(self, metrics, request_metrics): + """_request_start should add reqnum to requests_recieved set.""" + metrics._request_start(request_metrics) + assert request_metrics.reqnum in metrics.model_metrics.requests_recieved + + def test_adds_to_requests_working(self, metrics, request_metrics): + """_request_start should add request to requests_working dict.""" + metrics._request_start(request_metrics) + assert request_metrics.reqnum in metrics.model_metrics.requests_working + assert metrics.model_metrics.requests_working[request_metrics.reqnum] == request_metrics + + def test_sets_update_pending_true(self, metrics, request_metrics): + """_request_start should set update_pending to True.""" + metrics._request_start(request_metrics) + assert metrics.update_pending is True + + +class TestMetricsRequestEnd: + """Tests for _request_end method.""" + + def test_decrements_workload_pending(self, metrics, request_metrics): + """_request_end should decrease workload_pending by request workload.""" + # Start request first + metrics._request_start(request_metrics) + initial_pending = metrics.model_metrics.workload_pending + + metrics._request_end(request_metrics) + assert metrics.model_metrics.workload_pending == initial_pending - request_metrics.workload + + def test_removes_from_requests_working(self, metrics, request_metrics): + """_request_end should remove request from requests_working dict.""" + metrics._request_start(request_metrics) + metrics._request_end(request_metrics) + assert request_metrics.reqnum not in metrics.model_metrics.requests_working + + def test_adds_to_requests_deleting(self, metrics, request_metrics): + """_request_end should add request to requests_deleting list.""" + metrics._request_start(request_metrics) + metrics._request_end(request_metrics) + assert request_metrics in metrics.model_metrics.requests_deleting + + def test_updates_last_request_served(self, metrics, request_metrics): + """_request_end should update last_request_served timestamp.""" + metrics._request_start(request_metrics) + metrics._request_end(request_metrics) + assert metrics.last_request_served > 0 + + +class TestMetricsRequestSuccess: + """Tests for _request_success method.""" + + def test_increments_workload_served(self, metrics, request_metrics): + """_request_success should increase workload_served by request workload.""" + metrics._request_success(request_metrics) + assert metrics.model_metrics.workload_served == request_metrics.workload + + def test_sets_status_to_success(self, metrics, request_metrics): + """_request_success should set request status to Success.""" + metrics._request_success(request_metrics) + assert request_metrics.status == "Success" + + def test_sets_success_flag_true(self, metrics, request_metrics): + """_request_success should set request success flag to True.""" + metrics._request_success(request_metrics) + assert request_metrics.success is True + + def test_sets_update_pending_true(self, metrics, request_metrics): + """_request_success should set update_pending to True.""" + metrics._request_success(request_metrics) + assert metrics.update_pending is True + + +class TestMetricsRequestErrored: + """Tests for _request_errored method.""" + + def test_increments_workload_errored(self, metrics, request_metrics): + """_request_errored should increase workload_errored by request workload.""" + metrics._request_errored(request_metrics, "Test error") + assert metrics.model_metrics.workload_errored == request_metrics.workload + + def test_sets_status_to_error(self, metrics, request_metrics): + """_request_errored should set request status to Error.""" + metrics._request_errored(request_metrics, "Test error") + assert request_metrics.status == "Error" + + def test_sets_success_flag_false(self, metrics, request_metrics): + """_request_errored should set request success flag to False.""" + metrics._request_errored(request_metrics, "Test error") + assert request_metrics.success is False + + def test_sets_update_pending_true(self, metrics, request_metrics): + """_request_errored should set update_pending to True.""" + metrics._request_errored(request_metrics, "Test error") + assert metrics.update_pending is True + + +class TestMetricsRequestCanceled: + """Tests for _request_canceled method.""" + + def test_increments_workload_cancelled(self, metrics, request_metrics): + """_request_canceled should increase workload_cancelled by request workload.""" + metrics._request_canceled(request_metrics) + assert metrics.model_metrics.workload_cancelled == request_metrics.workload + + def test_sets_status_to_cancelled(self, metrics, request_metrics): + """_request_canceled should set request status to Cancelled.""" + metrics._request_canceled(request_metrics) + assert request_metrics.status == "Cancelled" + + def test_sets_success_flag_true(self, metrics, request_metrics): + """_request_canceled should set request success flag to True (cancelled = successful cleanup).""" + metrics._request_canceled(request_metrics) + assert request_metrics.success is True + + +class TestMetricsRequestReject: + """Tests for _request_reject method.""" + + def test_increments_workload_rejected(self, metrics, request_metrics): + """_request_reject should increase workload_rejected by request workload.""" + metrics._request_reject(request_metrics) + assert metrics.model_metrics.workload_rejected == request_metrics.workload + + def test_sets_status_to_rejected(self, metrics, request_metrics): + """_request_reject should set request status to Rejected.""" + metrics._request_reject(request_metrics) + assert request_metrics.status == "Rejected" + + def test_sets_success_flag_false(self, metrics, request_metrics): + """_request_reject should set request success flag to False.""" + metrics._request_reject(request_metrics) + assert request_metrics.success is False + + def test_adds_to_requests_recieved(self, metrics, request_metrics): + """_request_reject should add reqnum to requests_recieved set.""" + metrics._request_reject(request_metrics) + assert request_metrics.reqnum in metrics.model_metrics.requests_recieved + + def test_adds_to_requests_deleting(self, metrics, request_metrics): + """_request_reject should add request to requests_deleting list.""" + metrics._request_reject(request_metrics) + assert request_metrics in metrics.model_metrics.requests_deleting + + def test_sets_update_pending_true(self, metrics, request_metrics): + """_request_reject should set update_pending to True.""" + metrics._request_reject(request_metrics) + assert metrics.update_pending is True + + +class TestMetricsModelLoaded: + """Tests for _model_loaded method.""" + + def test_sets_model_is_loaded_true(self, metrics): + """_model_loaded should set system_metrics.model_is_loaded to True.""" + metrics._model_loaded(max_throughput=100.0) + assert metrics.system_metrics.model_is_loaded is True + + def test_sets_max_throughput(self, metrics): + """_model_loaded should set model_metrics.max_throughput.""" + metrics._model_loaded(max_throughput=150.5) + assert metrics.model_metrics.max_throughput == 150.5 + + def test_calculates_loading_time(self, metrics): + """_model_loaded should calculate model_loading_time from start time.""" + metrics._model_loaded(max_throughput=100.0) + # Loading time should be >= 0 (time since model_loading_start) + assert metrics.system_metrics.model_loading_time >= 0 + + +class TestMetricsModelErrored: + """Tests for _model_errored method.""" + + def test_sets_error_msg(self, metrics): + """_model_errored should set model_metrics.error_msg.""" + metrics._model_errored("Model failed to load") + assert metrics.model_metrics.error_msg == "Model failed to load" + + def test_sets_model_is_loaded_true(self, metrics): + """_model_errored should set system_metrics.model_is_loaded to True.""" + metrics._model_errored("Model failed to load") + assert metrics.system_metrics.model_is_loaded is True + + +class TestMetricsSetters: + """Tests for simple setter methods.""" + + def test_set_version(self, metrics): + """_set_version should update version field.""" + metrics._set_version("2.0.1") + assert metrics.version == "2.0.1" + + def test_set_mtoken(self, metrics): + """_set_mtoken should update mtoken field.""" + metrics._set_mtoken("secret-token-123") + assert metrics.mtoken == "secret-token-123" + + +class TestMetricsUrlWithSSL: + """Tests for URL construction with SSL enabled.""" + + def test_url_with_ssl_enabled(self): + """URL should use https when USE_SSL=true.""" + env = { + "CONTAINER_ID": "12345", + "REPORT_ADDR": "https://run.vast.ai", + "PUBLIC_IPADDR": "192.168.1.100", + "WORKER_PORT": "8080", + "VAST_TCP_PORT_8080": "8080", + "USE_SSL": "true", + } + with patch.dict(os.environ, env, clear=False): + from vastai.serverless.server.lib.metrics import get_url, Metrics + get_url.cache_clear() + metrics = Metrics() + assert metrics.url.startswith("https://") + + +class TestMetricsMultipleRequests: + """Tests for handling multiple concurrent requests.""" + + def test_multiple_requests_tracked_independently(self, metrics): + """Multiple requests should be tracked independently.""" + from vastai.serverless.server.lib.data_types import RequestMetrics + + req1 = RequestMetrics(request_idx=1, reqnum=100, workload=50.0, status="Created") + req2 = RequestMetrics(request_idx=2, reqnum=101, workload=75.0, status="Created") + + metrics._request_start(req1) + metrics._request_start(req2) + + # Both should be tracked + assert len(metrics.model_metrics.requests_working) == 2 + assert metrics.model_metrics.workload_pending == 125.0 # 50 + 75 + + # End one request + metrics._request_success(req1) + metrics._request_end(req1) + + # Only one should remain + assert len(metrics.model_metrics.requests_working) == 1 + assert req2.reqnum in metrics.model_metrics.requests_working diff --git a/tests/unit/test_serverless_worker_config.py b/tests/unit/test_serverless_worker_config.py new file mode 100644 index 00000000..065adf8b --- /dev/null +++ b/tests/unit/test_serverless_worker_config.py @@ -0,0 +1,455 @@ +""" +Unit tests for vastai/serverless/server/worker.py + +Tests the configuration dataclasses and EndpointHandlerFactory without network calls. +""" +import pytest +from unittest.mock import MagicMock +import sys +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from vastai.serverless.server.worker import ( + LogActionConfig, + BenchmarkConfig, + HandlerConfig, + WorkerConfig, + EndpointHandlerFactory, +) +from vastai.serverless.server.lib.data_types import LogAction + + +class TestLogActionConfig: + """Tests for LogActionConfig dataclass.""" + + def test_default_empty_lists(self): + """Default factory should create empty lists for all action types.""" + config = LogActionConfig() + assert config.on_load == [] + assert config.on_error == [] + assert config.on_info == [] + + def test_log_actions_on_load_converts_to_tuples(self): + """on_load strings should convert to (LogAction.ModelLoaded, str) tuples.""" + config = LogActionConfig(on_load=["Server started"]) + actions = config.log_actions + assert (LogAction.ModelLoaded, "Server started") in actions + + def test_log_actions_on_error_converts_to_tuples(self): + """on_error strings should convert to (LogAction.ModelError, str) tuples.""" + config = LogActionConfig(on_error=["Critical failure"]) + actions = config.log_actions + assert (LogAction.ModelError, "Critical failure") in actions + + def test_log_actions_on_info_converts_to_tuples(self): + """on_info strings should convert to (LogAction.Info, str) tuples.""" + config = LogActionConfig(on_info=["Download progress"]) + actions = config.log_actions + assert (LogAction.Info, "Download progress") in actions + + def test_log_actions_combined_all_types(self): + """log_actions should combine all three action types.""" + config = LogActionConfig( + on_load=["Model loaded", "Ready to serve"], + on_error=["Load failed"], + on_info=["Starting download"] + ) + actions = config.log_actions + assert len(actions) == 4 + assert (LogAction.ModelLoaded, "Model loaded") in actions + assert (LogAction.ModelLoaded, "Ready to serve") in actions + assert (LogAction.ModelError, "Load failed") in actions + assert (LogAction.Info, "Starting download") in actions + + def test_log_actions_empty_returns_empty_list(self): + """Empty config should return empty log_actions list.""" + config = LogActionConfig() + assert config.log_actions == [] + + +class TestBenchmarkConfig: + """Tests for BenchmarkConfig dataclass.""" + + def test_default_values(self): + """Default values should be set correctly.""" + config = BenchmarkConfig() + assert config.dataset is None + assert config.generator is None + assert config.runs == 8 + assert config.concurrency == 10 + + def test_custom_dataset_stored(self): + """Custom dataset should be stored.""" + test_data = [{"input": "test1"}, {"input": "test2"}] + config = BenchmarkConfig(dataset=test_data) + assert config.dataset == test_data + + def test_custom_generator_stored(self): + """Custom generator function should be stored.""" + def my_generator(): + return {"input": "generated"} + config = BenchmarkConfig(generator=my_generator) + assert config.generator == my_generator + + def test_custom_runs_stored(self): + """Custom runs value should be stored.""" + config = BenchmarkConfig(runs=16) + assert config.runs == 16 + + def test_custom_concurrency_stored(self): + """Custom concurrency value should be stored.""" + config = BenchmarkConfig(concurrency=5) + assert config.concurrency == 5 + + +class TestHandlerConfig: + """Tests for HandlerConfig dataclass.""" + + def test_route_required_and_stored(self): + """Route should be stored correctly.""" + config = HandlerConfig(route="/inference") + assert config.route == "/inference" + + def test_default_allow_parallel_requests(self): + """Default allow_parallel_requests should be False.""" + config = HandlerConfig(route="/") + assert config.allow_parallel_requests is False + + def test_default_max_queue_time(self): + """Default max_queue_time should be 30.0.""" + config = HandlerConfig(route="/") + assert config.max_queue_time == 30.0 + + def test_default_benchmark_config_none(self): + """Default benchmark_config should be None.""" + config = HandlerConfig(route="/") + assert config.benchmark_config is None + + def test_default_handler_class_none(self): + """Default handler_class should be None.""" + config = HandlerConfig(route="/") + assert config.handler_class is None + + def test_default_payload_class_none(self): + """Default payload_class should be None.""" + config = HandlerConfig(route="/") + assert config.payload_class is None + + def test_default_request_parser_none(self): + """Default request_parser should be None.""" + config = HandlerConfig(route="/") + assert config.request_parser is None + + def test_default_response_generator_none(self): + """Default response_generator should be None.""" + config = HandlerConfig(route="/") + assert config.response_generator is None + + def test_default_workload_calculator_none(self): + """Default workload_calculator should be None.""" + config = HandlerConfig(route="/") + assert config.workload_calculator is None + + def test_custom_values_stored(self): + """Custom values should be stored correctly.""" + benchmark = BenchmarkConfig(runs=5) + config = HandlerConfig( + route="/custom", + allow_parallel_requests=True, + max_queue_time=60.0, + benchmark_config=benchmark + ) + assert config.route == "/custom" + assert config.allow_parallel_requests is True + assert config.max_queue_time == 60.0 + assert config.benchmark_config == benchmark + + +class TestWorkerConfig: + """Tests for WorkerConfig dataclass.""" + + def test_default_handlers_empty_list(self): + """Default handlers should be empty list.""" + config = WorkerConfig() + assert config.handlers == [] + + def test_default_log_action_config(self): + """Default log_action_config should be LogActionConfig instance.""" + config = WorkerConfig() + assert isinstance(config.log_action_config, LogActionConfig) + + def test_default_model_server_url_none(self): + """Default model_server_url should be None.""" + config = WorkerConfig() + assert config.model_server_url is None + + def test_default_model_server_port_none(self): + """Default model_server_port should be None.""" + config = WorkerConfig() + assert config.model_server_port is None + + def test_default_model_log_file_none(self): + """Default model_log_file should be None.""" + config = WorkerConfig() + assert config.model_log_file is None + + def test_default_model_healthcheck_url_none(self): + """Default model_healthcheck_url should be None.""" + config = WorkerConfig() + assert config.model_healthcheck_url is None + + def test_custom_values_stored(self): + """Custom values should be stored correctly.""" + handler = HandlerConfig(route="/inference") + log_config = LogActionConfig(on_load=["Ready"]) + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + model_log_file="/var/log/model.log", + model_healthcheck_url="http://localhost:8000/health", + handlers=[handler], + log_action_config=log_config + ) + assert config.model_server_url == "http://localhost" + assert config.model_server_port == 8000 + assert config.model_log_file == "/var/log/model.log" + assert config.model_healthcheck_url == "http://localhost:8000/health" + assert config.handlers == [handler] + assert config.log_action_config == log_config + + +class TestEndpointHandlerFactory: + """Tests for EndpointHandlerFactory class.""" + + def test_default_handler_created_when_no_handlers(self): + """Factory should create a default handler at '/' when no handlers configured.""" + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + ) + factory = EndpointHandlerFactory(config) + assert factory.has_handlers() is True + assert "/" in factory.get_all_handlers() + + def test_get_handler_returns_handler_for_route(self): + """get_handler should return handler for registered route.""" + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + ) + factory = EndpointHandlerFactory(config) + handler = factory.get_handler("/") + assert handler is not None + + def test_get_handler_returns_none_for_unknown_route(self): + """get_handler should return None for unregistered route.""" + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + ) + factory = EndpointHandlerFactory(config) + assert factory.get_handler("/unknown") is None + + def test_get_all_handlers_returns_copy(self): + """get_all_handlers should return a copy of handlers dict.""" + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + ) + factory = EndpointHandlerFactory(config) + handlers = factory.get_all_handlers() + # Modifying returned dict shouldn't affect internal state + handlers["/test"] = "fake_handler" + assert "/test" not in factory.get_all_handlers() + + def test_has_handlers_returns_true_when_handlers_exist(self): + """has_handlers should return True when handlers are registered.""" + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + ) + factory = EndpointHandlerFactory(config) + assert factory.has_handlers() is True + + def test_model_server_base_url_property(self): + """model_server_base_url should return formatted URL with port.""" + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + ) + factory = EndpointHandlerFactory(config) + assert factory.model_server_base_url == "http://localhost:8000" + + def test_builds_handlers_from_config_list(self): + """Factory should build handlers from config handlers list.""" + handler_configs = [ + HandlerConfig(route="/inference"), + HandlerConfig(route="/health"), + ] + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=handler_configs, + ) + factory = EndpointHandlerFactory(config) + handlers = factory.get_all_handlers() + assert "/inference" in handlers + assert "/health" in handlers + + def test_handler_uses_config_allow_parallel_requests(self): + """Created handler should use allow_parallel_requests from config.""" + handler_config = HandlerConfig( + route="/test", + allow_parallel_requests=True, + ) + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + handler = factory.get_handler("/test") + assert handler.allow_parallel_requests is True + + def test_handler_uses_config_max_queue_time(self): + """Created handler should use max_queue_time from config.""" + handler_config = HandlerConfig( + route="/test", + max_queue_time=60.0, + ) + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + handler = factory.get_handler("/test") + assert handler.max_queue_time == 60.0 + + def test_handler_endpoint_property_returns_route(self): + """Handler's endpoint property should return the route.""" + handler_config = HandlerConfig(route="/inference") + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + handler = factory.get_handler("/inference") + assert handler.endpoint == "/inference" + + def test_get_benchmark_handler_returns_handler_with_benchmark(self): + """get_benchmark_handler should return handler with BenchmarkConfig.""" + benchmark = BenchmarkConfig( + dataset=[{"input": "test"}], + runs=4, + concurrency=2, + ) + handler_config = HandlerConfig( + route="/inference", + benchmark_config=benchmark, + ) + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + benchmark_handler = factory.get_benchmark_handler() + assert benchmark_handler is not None + assert benchmark_handler.has_benchmark is True + + def test_get_benchmark_handler_raises_when_missing(self): + """get_benchmark_handler should raise when no BenchmarkConfig exists.""" + handler_config = HandlerConfig(route="/test") + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + with pytest.raises(Exception, match="Missing EndpointHandler with BenchmarkConfig"): + factory.get_benchmark_handler() + + def test_get_benchmark_handler_raises_when_multiple(self): + """get_benchmark_handler should raise when multiple BenchmarkConfigs exist.""" + benchmark1 = BenchmarkConfig(dataset=[{"input": "test1"}]) + benchmark2 = BenchmarkConfig(dataset=[{"input": "test2"}]) + handler_configs = [ + HandlerConfig(route="/inference1", benchmark_config=benchmark1), + HandlerConfig(route="/inference2", benchmark_config=benchmark2), + ] + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=handler_configs, + ) + factory = EndpointHandlerFactory(config) + with pytest.raises(Exception, match="Cannot define BenchmarkConfig for more than one"): + factory.get_benchmark_handler() + + +class TestHandlerBenchmarkProperties: + """Tests for handler benchmark-related properties.""" + + def test_handler_has_benchmark_false_without_config(self): + """Handler should have has_benchmark=False without BenchmarkConfig.""" + handler_config = HandlerConfig(route="/test") + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + handler = factory.get_handler("/test") + assert handler.has_benchmark is False + + def test_handler_has_benchmark_true_with_config(self): + """Handler should have has_benchmark=True with BenchmarkConfig.""" + benchmark = BenchmarkConfig(dataset=[{"input": "test"}]) + handler_config = HandlerConfig( + route="/test", + benchmark_config=benchmark, + ) + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + handler = factory.get_handler("/test") + assert handler.has_benchmark is True + + def test_handler_benchmark_runs_from_config(self): + """Handler should use runs value from BenchmarkConfig.""" + benchmark = BenchmarkConfig(dataset=[{"input": "test"}], runs=16) + handler_config = HandlerConfig( + route="/test", + benchmark_config=benchmark, + ) + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + handler = factory.get_handler("/test") + assert handler.benchmark_runs == 16 + + def test_handler_concurrency_from_config(self): + """Handler should use concurrency value from BenchmarkConfig.""" + benchmark = BenchmarkConfig(dataset=[{"input": "test"}], concurrency=5) + handler_config = HandlerConfig( + route="/test", + benchmark_config=benchmark, + ) + config = WorkerConfig( + model_server_url="http://localhost", + model_server_port=8000, + handlers=[handler_config], + ) + factory = EndpointHandlerFactory(config) + handler = factory.get_handler("/test") + assert handler.concurrency == 5 diff --git a/tests/unit/test_timezone.py b/tests/unit/test_timezone.py new file mode 100644 index 00000000..9ea33596 --- /dev/null +++ b/tests/unit/test_timezone.py @@ -0,0 +1,125 @@ +"""Unit tests for timezone handling functions. + +TEST-03: Unit tests for timezone handling functions. + +These tests verify that all timezone conversions produce correct UTC results +regardless of the local system timezone. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import calendar +from datetime import datetime, timezone +import pytest + + +class TestStringToUnixEpoch: + """Tests for string_to_unix_epoch() function.""" + + def test_date_string_to_utc_epoch(self): + """Date string converts to correct UTC epoch timestamp.""" + from vast import string_to_unix_epoch + + # 01/15/2025 00:00:00 UTC = 1736899200 + result = string_to_unix_epoch("01/15/2025") + expected = calendar.timegm((2025, 1, 15, 0, 0, 0, 0, 0, 0)) + + assert expected == 1736899200, f"Sanity check failed: expected 1736899200, got {expected}" + assert result == expected + + def test_numeric_string_passthrough(self): + """Numeric string is converted to float and returned.""" + from vast import string_to_unix_epoch + + assert string_to_unix_epoch("1736899200") == 1736899200.0 + assert string_to_unix_epoch("1736899200.5") == 1736899200.5 + + def test_none_returns_none(self): + """None input returns None.""" + from vast import string_to_unix_epoch + + assert string_to_unix_epoch(None) is None + + def test_empty_string_raises_value_error(self): + """Empty string raises ValueError (not numeric, can't parse as date).""" + from vast import string_to_unix_epoch + + # Empty string is not a valid float and not a valid date format + with pytest.raises(ValueError): + string_to_unix_epoch("") + + def test_various_date_formats(self): + """Various date formats parse correctly.""" + from vast import string_to_unix_epoch + + # Test MM/DD/YYYY format + result1 = string_to_unix_epoch("12/31/2024") + expected1 = calendar.timegm((2024, 12, 31, 0, 0, 0, 0, 0, 0)) + assert result1 == expected1 + + # Test boundary dates + result2 = string_to_unix_epoch("01/01/2025") + expected2 = calendar.timegm((2025, 1, 1, 0, 0, 0, 0, 0, 0)) + assert result2 == expected2 + + +class TestFromTimestampUtc: + """Tests for UTC-aware datetime.fromtimestamp usage.""" + + def test_fromtimestamp_with_utc(self): + """datetime.fromtimestamp with UTC timezone gives correct result.""" + # This tests the pattern used after the fix fix + epoch = 1736899200 # 01/15/2025 00:00:00 UTC + + # The correct pattern (after fix) + dt_utc = datetime.fromtimestamp(epoch, tz=timezone.utc) + + assert dt_utc.year == 2025 + assert dt_utc.month == 1 + assert dt_utc.day == 15 + assert dt_utc.hour == 0 + assert dt_utc.minute == 0 + assert dt_utc.second == 0 + assert dt_utc.tzinfo == timezone.utc + + def test_calendar_timegm_inverse(self): + """calendar.timegm is the inverse of datetime.fromtimestamp(tz=utc).""" + original_epoch = 1736899200 + + # Convert to datetime + dt = datetime.fromtimestamp(original_epoch, tz=timezone.utc) + + # Convert back to epoch + timetuple = dt.timetuple() + recovered_epoch = calendar.timegm(timetuple) + + assert recovered_epoch == original_epoch + + +class TestTimezoneConsistency: + """Tests verifying timezone handling is consistent across functions.""" + + def test_known_epoch_value(self): + """Test against a known epoch value that's verifiable.""" + # Unix epoch 0 is January 1, 1970, 00:00:00 UTC + epoch_zero = 0 + + dt = datetime.fromtimestamp(epoch_zero, tz=timezone.utc) + + assert dt.year == 1970 + assert dt.month == 1 + assert dt.day == 1 + assert dt.hour == 0 + + def test_y2k_epoch(self): + """Test Y2K timestamp (known reference point).""" + # January 1, 2000, 00:00:00 UTC = 946684800 + y2k_epoch = 946684800 + + dt = datetime.fromtimestamp(y2k_epoch, tz=timezone.utc) + + assert dt.year == 2000 + assert dt.month == 1 + assert dt.day == 1 + assert dt.hour == 0 diff --git a/tests/unit/test_vastai_base.py b/tests/unit/test_vastai_base.py new file mode 100644 index 00000000..1b5792b4 --- /dev/null +++ b/tests/unit/test_vastai_base.py @@ -0,0 +1,445 @@ +"""Unit tests for VastAIBase class method signatures and structure. + +This module verifies that VastAIBase methods exist with correct signatures, +ensuring backwards compatibility and API contract stability. +""" + +import pytest +import inspect +from abc import ABC +import sys +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from vastai.vastai_base import VastAIBase + + +class TestVastAIBaseClassStructure: + """Verify VastAIBase class structure and attributes.""" + + def test_is_abstract_base_class(self): + """VastAIBase should inherit from ABC.""" + assert issubclass(VastAIBase, ABC) + + def test_has_docstring(self): + """VastAIBase should have a class docstring.""" + assert VastAIBase.__doc__ is not None + assert len(VastAIBase.__doc__) > 0 + + def test_docstring_mentions_sdk(self): + """Docstring should describe SDK purpose.""" + assert "SDK" in VastAIBase.__doc__ or "sdk" in VastAIBase.__doc__.lower() + + +class TestVastAIBaseInstanceMethodSignatures: + """Verify instance method signatures exist with correct parameters.""" + + def test_attach_ssh_signature(self): + """attach_ssh should have instance_id and ssh_key parameters.""" + sig = inspect.signature(VastAIBase.attach_ssh) + params = list(sig.parameters.keys()) + assert "self" in params + assert "instance_id" in params + assert "ssh_key" in params + + def test_cancel_copy_signature(self): + """cancel_copy should have dst parameter.""" + sig = inspect.signature(VastAIBase.cancel_copy) + params = list(sig.parameters.keys()) + assert "self" in params + assert "dst" in params + + def test_cancel_sync_signature(self): + """cancel_sync should have dst parameter.""" + sig = inspect.signature(VastAIBase.cancel_sync) + params = list(sig.parameters.keys()) + assert "self" in params + assert "dst" in params + + def test_change_bid_signature(self): + """change_bid should have id and optional price parameters.""" + sig = inspect.signature(VastAIBase.change_bid) + params = sig.parameters + assert "self" in params + assert "id" in params + assert "price" in params + assert params["price"].default is None + + def test_copy_signature(self): + """copy should have src, dst, and optional identity parameters.""" + sig = inspect.signature(VastAIBase.copy) + params = sig.parameters + assert "self" in params + assert "src" in params + assert "dst" in params + assert "identity" in params + + def test_create_instance_signature(self): + """create_instance should have id as int and disk with default 10.""" + sig = inspect.signature(VastAIBase.create_instance) + params = sig.parameters + assert "self" in params + assert "id" in params + assert "disk" in params + assert params["disk"].default == 10 + assert "image" in params + assert "ssh" in params + assert "jupyter" in params + + def test_create_workergroup_signature(self): + """create_workergroup should have expected parameters.""" + sig = inspect.signature(VastAIBase.create_workergroup) + params = sig.parameters + assert "self" in params + assert "test_workers" in params + assert params["test_workers"].default == 3 + assert "gpu_ram" in params + assert "template_hash" in params + assert "endpoint_name" in params + + def test_create_endpoint_signature(self): + """create_endpoint should have min_load, target_util, cold_mult defaults.""" + sig = inspect.signature(VastAIBase.create_endpoint) + params = sig.parameters + assert "self" in params + assert "min_load" in params + assert params["min_load"].default == 0.0 + assert "target_util" in params + assert params["target_util"].default == 0.9 + assert "cold_mult" in params + assert params["cold_mult"].default == 2.5 + + def test_destroy_instance_signature(self): + """destroy_instance should have id parameter.""" + sig = inspect.signature(VastAIBase.destroy_instance) + params = list(sig.parameters.keys()) + assert "self" in params + assert "id" in params + + def test_destroy_instances_signature(self): + """destroy_instances should have ids list parameter.""" + sig = inspect.signature(VastAIBase.destroy_instances) + params = list(sig.parameters.keys()) + assert "self" in params + assert "ids" in params + + def test_execute_signature(self): + """execute should have id and COMMAND parameters.""" + sig = inspect.signature(VastAIBase.execute) + params = list(sig.parameters.keys()) + assert "self" in params + assert "id" in params + assert "COMMAND" in params + + def test_search_offers_signature(self): + """search_offers should have type, query, limit and other parameters.""" + sig = inspect.signature(VastAIBase.search_offers) + params = sig.parameters + assert "self" in params + assert "type" in params + assert "query" in params + assert "limit" in params + assert "no_default" in params + assert params["no_default"].default is False + + def test_show_instances_signature(self): + """show_instances should have quiet parameter with False default.""" + sig = inspect.signature(VastAIBase.show_instances) + params = sig.parameters + assert "self" in params + assert "quiet" in params + assert params["quiet"].default is False + + def test_logs_signature(self): + """logs should have INSTANCE_ID and optional tail parameters.""" + sig = inspect.signature(VastAIBase.logs) + params = sig.parameters + assert "self" in params + assert "INSTANCE_ID" in params + assert "tail" in params + + +class TestVastAIBaseVolumeMethodSignatures: + """Verify volume method signatures.""" + + def test_clone_volume_signature(self): + """clone_volume should have source, dest, size, disable_compression.""" + sig = inspect.signature(VastAIBase.clone_volume) + params = sig.parameters + assert "self" in params + assert "source" in params + assert "dest" in params + assert "size" in params + assert "disable_compression" in params + assert params["disable_compression"].default is False + + def test_create_volume_signature(self): + """create_volume should have id, size with default 15, and optional name.""" + sig = inspect.signature(VastAIBase.create_volume) + params = sig.parameters + assert "self" in params + assert "id" in params + assert "size" in params + assert params["size"].default == 15 + assert "name" in params + + def test_delete_volume_signature(self): + """delete_volume should have id parameter.""" + sig = inspect.signature(VastAIBase.delete_volume) + params = list(sig.parameters.keys()) + assert "self" in params + assert "id" in params + + def test_search_volumes_signature(self): + """search_volumes should have query, no_default, limit, storage, order.""" + sig = inspect.signature(VastAIBase.search_volumes) + params = sig.parameters + assert "self" in params + assert "query" in params + assert "no_default" in params + assert "limit" in params + assert "storage" in params + assert params["storage"].default == 1.0 + assert "order" in params + assert params["order"].default == "score-" + + def test_show_volumes_signature(self): + """show_volumes should have type parameter with 'all' default.""" + sig = inspect.signature(VastAIBase.show_volumes) + params = sig.parameters + assert "self" in params + assert "type" in params + assert params["type"].default == "all" + + +class TestVastAIBaseClusterMethodSignatures: + """Verify cluster method signatures.""" + + def test_create_cluster_signature(self): + """create_cluster should have subnet and manager_id parameters.""" + sig = inspect.signature(VastAIBase.create_cluster) + params = list(sig.parameters.keys()) + assert "self" in params + assert "subnet" in params + assert "manager_id" in params + + def test_delete_cluster_signature(self): + """delete_cluster should have cluster_id parameter.""" + sig = inspect.signature(VastAIBase.delete_cluster) + params = list(sig.parameters.keys()) + assert "self" in params + assert "cluster_id" in params + + def test_join_cluster_signature(self): + """join_cluster should have cluster_id and machine_ids parameters.""" + sig = inspect.signature(VastAIBase.join_cluster) + params = list(sig.parameters.keys()) + assert "self" in params + assert "cluster_id" in params + assert "machine_ids" in params + + def test_show_clusters_signature(self): + """show_clusters should exist with self parameter only.""" + sig = inspect.signature(VastAIBase.show_clusters) + params = list(sig.parameters.keys()) + assert "self" in params + + +class TestVastAIBaseOverlayMethodSignatures: + """Verify overlay method signatures.""" + + def test_create_overlay_signature(self): + """create_overlay should have cluster_id and name parameters.""" + sig = inspect.signature(VastAIBase.create_overlay) + params = list(sig.parameters.keys()) + assert "self" in params + assert "cluster_id" in params + assert "name" in params + + def test_delete_overlay_signature(self): + """delete_overlay should have overlay_identifier parameter.""" + sig = inspect.signature(VastAIBase.delete_overlay) + params = list(sig.parameters.keys()) + assert "self" in params + assert "overlay_identifier" in params + + def test_join_overlay_signature(self): + """join_overlay should have name and instance_id parameters.""" + sig = inspect.signature(VastAIBase.join_overlay) + params = list(sig.parameters.keys()) + assert "self" in params + assert "name" in params + assert "instance_id" in params + + def test_show_overlays_signature(self): + """show_overlays should exist with self parameter only.""" + sig = inspect.signature(VastAIBase.show_overlays) + params = list(sig.parameters.keys()) + assert "self" in params + + +class TestVastAIBaseEnvVarMethodSignatures: + """Verify environment variable method signatures.""" + + def test_create_env_var_signature(self): + """create_env_var should have name and value parameters.""" + sig = inspect.signature(VastAIBase.create_env_var) + params = list(sig.parameters.keys()) + assert "self" in params + assert "name" in params + assert "value" in params + + def test_delete_env_var_signature(self): + """delete_env_var should have name parameter.""" + sig = inspect.signature(VastAIBase.delete_env_var) + params = list(sig.parameters.keys()) + assert "self" in params + assert "name" in params + + def test_update_env_var_signature(self): + """update_env_var should have name and value parameters.""" + sig = inspect.signature(VastAIBase.update_env_var) + params = list(sig.parameters.keys()) + assert "self" in params + assert "name" in params + assert "value" in params + + def test_show_env_vars_signature(self): + """show_env_vars should have show_values parameter with False default.""" + sig = inspect.signature(VastAIBase.show_env_vars) + params = sig.parameters + assert "self" in params + assert "show_values" in params + assert params["show_values"].default is False + + +class TestVastAIBaseBackwardsCompatibility: + """Verify backwards compatibility aliases exist and point to correct methods.""" + + def test_create_autogroup_is_create_workergroup(self): + """create_autogroup should be an alias for create_workergroup.""" + assert VastAIBase.create_autogroup is VastAIBase.create_workergroup + + def test_delete_autoscaler_is_delete_workergroup(self): + """delete_autoscaler should be an alias for delete_workergroup.""" + assert VastAIBase.delete_autoscaler is VastAIBase.delete_workergroup + + def test_update_autoscaler_is_update_workergroup(self): + """update_autoscaler should be an alias for update_workergroup.""" + assert VastAIBase.update_autoscaler is VastAIBase.update_workergroup + + def test_show_autoscalers_is_show_workergroups(self): + """show_autoscalers should be an alias for show_workergroups.""" + assert VastAIBase.show_autoscalers is VastAIBase.show_workergroups + + +class TestVastAIBaseMethodCount: + """Verify expected number of public methods exist.""" + + def test_has_many_methods(self): + """VastAIBase should have 100+ public methods.""" + methods = [ + m for m in dir(VastAIBase) + if not m.startswith('_') and callable(getattr(VastAIBase, m)) + ] + assert len(methods) >= 100, f"Expected 100+ methods, got {len(methods)}" + + def test_no_missing_core_methods(self): + """Essential methods should exist.""" + core_methods = [ + 'attach_ssh', 'cancel_copy', 'change_bid', 'copy', + 'create_instance', 'destroy_instance', 'execute', + 'search_offers', 'show_instances', 'logs', + 'create_volume', 'delete_volume', 'show_volumes', + 'create_cluster', 'show_clusters', + 'create_overlay', 'show_overlays', + 'create_env_var', 'show_env_vars', + ] + for method in core_methods: + assert hasattr(VastAIBase, method), f"Missing method: {method}" + assert callable(getattr(VastAIBase, method)), f"Not callable: {method}" + + +class TestVastAIBaseMethodDocstrings: + """Verify methods have docstrings for IDE autocomplete support.""" + + def test_attach_ssh_has_docstring(self): + """attach_ssh should have a docstring.""" + assert VastAIBase.attach_ssh.__doc__ is not None + + def test_create_instance_has_docstring(self): + """create_instance should have a docstring.""" + assert VastAIBase.create_instance.__doc__ is not None + + def test_search_offers_has_docstring(self): + """search_offers should have a docstring.""" + assert VastAIBase.search_offers.__doc__ is not None + + def test_create_volume_has_docstring(self): + """create_volume should have a docstring.""" + assert VastAIBase.create_volume.__doc__ is not None + + def test_create_cluster_has_docstring(self): + """create_cluster should have a docstring.""" + assert VastAIBase.create_cluster.__doc__ is not None + + def test_show_invoices_v1_has_docstring(self): + """show_invoices_v1 should have a docstring.""" + assert VastAIBase.show_invoices_v1.__doc__ is not None + + +class TestVastAIBaseAdditionalMethodSignatures: + """Additional method signature tests for coverage.""" + + def test_take_snapshot_signature(self): + """take_snapshot should have instance_id, repo, and other parameters.""" + sig = inspect.signature(VastAIBase.take_snapshot) + params = sig.parameters + assert "self" in params + assert "instance_id" in params + assert "repo" in params + assert "container_registry" in params + assert params["container_registry"].default == "docker.io" + + def test_update_instance_signature(self): + """update_instance should have id and optional template parameters.""" + sig = inspect.signature(VastAIBase.update_instance) + params = sig.parameters + assert "self" in params + assert "id" in params + assert "template_id" in params + assert "template_hash_id" in params + + def test_vm_copy_signature(self): + """vm_copy should have src and dst parameters.""" + sig = inspect.signature(VastAIBase.vm_copy) + params = list(sig.parameters.keys()) + assert "self" in params + assert "src" in params + assert "dst" in params + + def test_get_endpt_logs_signature(self): + """get_endpt_logs should have id, level, and tail parameters.""" + sig = inspect.signature(VastAIBase.get_endpt_logs) + params = sig.parameters + assert "self" in params + assert "id" in params + assert "level" in params + assert params["level"].default == 1 + assert "tail" in params + + def test_show_invoices_v1_signature(self): + """show_invoices_v1 should have extensive filtering parameters.""" + sig = inspect.signature(VastAIBase.show_invoices_v1) + params = sig.parameters + assert "self" in params + assert "invoices" in params + assert "charges" in params + assert "invoice_type" in params + assert "charge_type" in params + assert "limit" in params + assert params["limit"].default == 20 + assert "format" in params + assert params["format"].default == "table" diff --git a/vast.py b/vast.py index 4d737770..ded88cc7 100755 --- a/vast.py +++ b/vast.py @@ -9,7 +9,8 @@ import argparse import os import time -from typing import Dict, List, Tuple, Optional +import calendar +from typing import Any, Dict, List, Tuple, Optional from datetime import date, datetime, timedelta, timezone import hashlib import math @@ -21,6 +22,7 @@ from time import sleep from subprocess import PIPE import urllib3 +import ssl import atexit from contextlib import redirect_stdout, redirect_stderr from io import StringIO @@ -43,7 +45,7 @@ try: import argcomplete TABCOMPLETE = True -except: +except ImportError: # No tab-completion for you pass @@ -57,15 +59,7 @@ except ImportError: from urllib.parse import quote_plus # Python 3+ -try: - JSONDecodeError = json.JSONDecodeError -except AttributeError: - JSONDecodeError = ValueError - -try: - input = raw_input -except NameError: - pass +JSONDecodeError = json.JSONDecodeError #server_url_default = "https://vast.ai" @@ -80,6 +74,10 @@ format="%(levelname)s - %(message)s" ) +DEFAULT_TIMEOUT = 30 # seconds -- normal API calls +LONG_TIMEOUT = 120 # seconds -- file operations, large queries +RETRYABLE_STATUS_CODES = {429, 502, 503, 504} + def parse_version(version: str) -> tuple[int, ...]: parts = version.split(".") @@ -117,14 +115,15 @@ def is_pip_package(): except Exception: return False -def get_update_command(stable_version: str) -> str: +def get_update_command(stable_version: str) -> list: if is_pip_package(): + cmd = [sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir"] if "test.pypi.org" in PYPI_BASE_PATH: - return f"{sys.executable} -m pip install --force-reinstall --no-cache-dir -i {PYPI_BASE_PATH} vastai=={stable_version}" - else: - return f"{sys.executable} -m pip install --force-reinstall --no-cache-dir vastai=={stable_version}" + cmd.extend(["-i", PYPI_BASE_PATH]) + cmd.append(f"vastai=={stable_version}") + return cmd else: - return f"git fetch --all --tags --prune && git checkout tags/v{stable_version}" + return ["git", "fetch", "--all", "--tags", "--prune"] def get_local_version(): @@ -135,7 +134,7 @@ def get_local_version(): def get_project_data(project_name: str) -> dict[str, dict[str, str]]: url = PYPI_BASE_PATH + f"/pypi/{project_name}/json" - response = requests.get(url, headers={"Accept": "application/json"}) + response = requests.get(url, headers={"Accept": "application/json"}, timeout=10) # this will raise for HTTP status 4xx and 5xx response.raise_for_status() @@ -184,12 +183,20 @@ def check_for_update(): print("Updating...") _ = subprocess.run( update_command, - shell=True, check=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) + if not is_pip_package(): + # git case: need a second command to checkout the tag + _ = subprocess.run( + ["git", "checkout", f"tags/v{pypi_version}"], + check=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) print("Update completed successfully!\nAttempt to run your command again!") sys.exit(0) @@ -197,12 +204,6 @@ def check_for_update(): APP_NAME = "vastai" VERSION = get_local_version() -# define emoji support and fallbacks -_HAS_EMOJI = sys.stdout.encoding and 'utf' in sys.stdout.encoding.lower() -SUCCESS = "✅" if _HAS_EMOJI else "[OK]" -WARN = "⚠️" if _HAS_EMOJI else "[!]" -FAIL = "❌" if _HAS_EMOJI else "[X]" -INFO = "ℹ️" if _HAS_EMOJI else "[i]" try: # Although xdg-base-dirs is the newer name, there's @@ -218,7 +219,7 @@ def check_for_update(): 'temp': xdg.xdg_cache_home() } -except: +except (ImportError, KeyError, OSError): # Reasonable defaults. DIRS = { 'config': os.path.join(os.getenv('HOME'), '.config'), @@ -237,6 +238,13 @@ def check_for_update(): APIKEY_FILE_HOME = os.path.expanduser("~/.vast_api_key") # Legacy TFAKEY_FILE = os.path.join(DIRS['config'], "vast_tfa_key") +# Emoji support with fallbacks for terminals that don't support Unicode +_HAS_EMOJI = sys.stdout.encoding and 'utf' in sys.stdout.encoding.lower() +SUCCESS = "\u2705" if _HAS_EMOJI else "[OK]" +WARN = "\u26a0\ufe0f" if _HAS_EMOJI else "[!]" +FAIL = "\u274c" if _HAS_EMOJI else "[X]" +INFO = "\u2139\ufe0f" if _HAS_EMOJI else "[i]" + if not os.path.exists(APIKEY_FILE) and os.path.exists(APIKEY_FILE_HOME): #print(f'copying key from {APIKEY_FILE_HOME} -> {APIKEY_FILE}') shutil.copyfile(APIKEY_FILE_HOME, APIKEY_FILE) @@ -287,11 +295,11 @@ def string_to_unix_epoch(date_string): except ValueError: # If not, parse it as a date string date_object = datetime.strptime(date_string, "%m/%d/%Y") - return time.mktime(date_object.timetuple()) + return calendar.timegm(date_object.timetuple()) def unix_to_readable(ts): # ts: integer or float, Unix timestamp - return datetime.fromtimestamp(ts).strftime('%H:%M:%S|%h-%d-%Y') + return datetime.fromtimestamp(ts, tz=timezone.utc).strftime('%H:%M:%S|%h-%d-%Y') def fix_date_fields(query: Dict[str, Dict], date_fields: List[str]): """Takes in a query and date fields to correct and returns query with appropriate epoch dates""" @@ -330,18 +338,13 @@ def __nonzero__(self): def append(self, x): self.l.append(x) -def http_request(verb, args, req_url, headers: dict[str, str] | None = None, json_data = None): +def http_request(verb, args, req_url, headers: dict[str, str] | None = None, json = None, timeout=DEFAULT_TIMEOUT): t = 0.15 + r = None for i in range(0, args.retry): - req = requests.Request(method=verb, url=req_url, headers=headers, json=json_data) + req = requests.Request(method=verb, url=req_url, headers=headers, json=json) session = requests.Session() prep = session.prepare_request(req) - if args.explain: - print(f"\n{INFO} Prepared Request:") - print(f"{prep.method} {prep.url}") - print(f"Headers: {json.dumps(headers, indent=1)}") - print(f"Body: {json.dumps(json_data, indent=1)}" + "\n" + "_"*100 + "\n") - if ARGS.curl: as_curl = curlify.to_curl(prep) simple = re.sub(r" -H '[^']*'", '', as_curl) @@ -352,26 +355,42 @@ def http_request(verb, args, req_url, headers: dict[str, str] | None = None, jso print("\n" + ' \\\n '.join(parts).strip() + "\n") sys.exit(0) else: - r = session.send(prep) + try: + r = session.send(prep, timeout=timeout) + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: + if i < args.retry - 1: + time.sleep(t) + t *= 1.5 + continue + raise + except requests.exceptions.RequestException as e: + # Non-retryable request errors (e.g., InvalidURL) + raise - if (r.status_code == 429): + if r.status_code in RETRYABLE_STATUS_CODES: time.sleep(t) t *= 1.5 else: break return r -def http_get(args, req_url, headers = None, json = None): - return http_request('GET', args, req_url, headers, json) +def http_get(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT): + return http_request('GET', args, req_url, headers, json, timeout=timeout) -def http_put(args, req_url, headers = None, json = {}): - return http_request('PUT', args, req_url, headers, json) +def http_put(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT): + if json is None: + json = {} + return http_request('PUT', args, req_url, headers, json, timeout=timeout) -def http_post(args, req_url, headers = None, json={}): - return http_request('POST', args, req_url, headers, json) +def http_post(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT): + if json is None: + json = {} + return http_request('POST', args, req_url, headers, json, timeout=timeout) -def http_del(args, req_url, headers = None, json={}): - return http_request('DELETE', args, req_url, headers, json) +def http_del(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT): + if json is None: + json = {} + return http_request('DELETE', args, req_url, headers, json, timeout=timeout) def load_permissions_from_file(file_path): @@ -630,7 +649,7 @@ def apiheaders(args: argparse.Namespace) -> Dict: return result -def deindent(message: str, add_separator: bool = True) -> str: +def deindent(message: str) -> str: """ Deindent a quoted string. Scans message and finds the smallest number of whitespace characters in any line and removes that many from the start of every line. @@ -642,13 +661,141 @@ def deindent(message: str, add_separator: bool = True) -> str: indents = [len(x) for x in re.findall("^ *(?=[^ ])", message, re.MULTILINE) if len(x)] a = min(indents) message = re.sub(r"^ {," + str(a) + "}", "", message, flags=re.MULTILINE) - if add_separator: - # For help epilogs - cleanly separating extra help from options - line_width = min(150, shutil.get_terminal_size((80, 20)).columns) - message = "_"*line_width + "\n"*2 + message.strip() + "\n" + "_"*line_width return message.strip() +def api_call( + args: argparse.Namespace, + method: str, + path: str, + *, + json_body: dict[str, Any] | None = None, + query_args: dict[str, Any] | None = None, +) -> dict[str, Any] | list[dict[str, Any]] | None: + """Centralized API call: URL construction + HTTP dispatch + status check. + + Args: + args: argparse.Namespace with url, api_key, explain, raw, retry, curl. + method: HTTP method string ("GET", "POST", "PUT", "DELETE"). + path: API path (e.g., "/instances/", "/auth/apikeys/{id}/"). + json_body: Optional dict for request body (POST/PUT/DELETE). + query_args: Optional dict for URL query parameters. + + Returns: + Parsed JSON response (dict or list), or None for empty responses. + + Raises: + requests.exceptions.HTTPError: On non-2xx status codes. + """ + url = apiurl(args, path, query_args) + dispatch = { + "GET": http_get, + "POST": http_post, + "PUT": http_put, + "DELETE": http_del, + } + http_fn = dispatch[method] + + if method == "GET": + r = http_fn(args, url, headers=headers, json=json_body) + else: + r = http_fn(args, url, headers=headers, json=json_body if json_body is not None else {}) + + r.raise_for_status() + + if r.content: + try: + return r.json() + except JSONDecodeError: + return {"_raw_text": r.text} + return None + + +def output_result( + args: argparse.Namespace, + data: list[dict[str, Any]] | dict[str, Any], + fields: list[tuple[str, str, str]] | None = None, +) -> list[dict[str, Any]] | dict[str, Any] | None: + """Unified output handler for command results. + + In raw mode: returns data for main() to serialize as JSON. + In table mode: calls display_table() if fields are provided. + In JSON mode: prints formatted JSON (when no fields defined). + + Args: + args: argparse.Namespace with raw flag. + data: The response data (dict, list, or None). + fields: Optional tuple of field definitions for display_table(). + + Returns: + data if in raw mode, None otherwise. + """ + if args.raw: + return data + if data is None: + return None + if fields: + rows = data if isinstance(data, list) else [data] + display_table(rows, fields) + else: + print(json.dumps(data, indent=1, sort_keys=True)) + return None + + +def error_output( + args: argparse.Namespace, + status_code: int, + message: str, + *, + detail: str | None = None, +) -> None: + """Output an error in the appropriate format for the current mode. + + In raw mode: prints JSON error object to stderr. + In non-raw mode: prints human-readable error to stderr. + + Args: + args: argparse.Namespace with raw flag. + status_code: HTTP status code or error code. + message: Error message string. + detail: Optional additional detail string. + """ + if getattr(args, 'raw', False): + error = {"error": True, "status_code": status_code, "msg": message} + if detail: + error["detail"] = detail + print(json.dumps(error), file=sys.stderr) + else: + print(f"failed with error {status_code}: {message}", file=sys.stderr) + + +def require_id(args: argparse.Namespace, field: str = "id") -> int | str: + """Extract and validate an ID argument. + + Args: + args: argparse.Namespace containing the ID field. + field: Name of the attribute on args (default "id"). + + Returns: + The value of the requested field. + + Raises: + SystemExit: If the field is None or missing. + """ + val = getattr(args, field, None) + if val is None: + print(f"Error: {field} is required", file=sys.stderr) + raise SystemExit(1) + return val + + +# Field definition tuples: (key, display_name, format_string, converter_or_None, left_justify) +# key: API response dict key +# display_name: Column header in table output +# format_string: Python format spec (e.g., ">8", "<16", ">10.4f") +# converter_or_None: Lambda to transform value, or None for raw value +# left_justify: Boolean, True for left-aligned columns + # These are the fields that are displayed when a search is run displayable_fields = ( # ("bw_nvlink", "Bandwidth NVLink", "{}", None, True), @@ -873,8 +1020,8 @@ def deindent(message: str, add_separator: bool = True) -> str: # These fields are displayed when you do 'show maints' maintenance_fields = ( ("machine_id", "Machine ID", "{}", None, True), - ("start_time", "Start (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), - ("end_time", "End (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), + ("start_time", "Start (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x, tz=timezone.utc).strftime('%Y-%m-%d/%H:%M'), True), + ("end_time", "End (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x, tz=timezone.utc).strftime('%Y-%m-%d/%H:%M'), True), ("duration_hours", "Duration (Hrs)", "{}", None, True), ("maintenance_category", "Category", "{}", None, True), ) @@ -899,8 +1046,8 @@ def deindent(message: str, add_separator: bool = True) -> str: ("id", "Scheduled Job ID", "{}", None, True), ("instance_id", "Instance ID", "{}", None, True), ("api_endpoint", "API Endpoint", "{}", None, True), - ("start_time", "Start (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), - ("end_time", "End (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), + ("start_time", "Start (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x, tz=timezone.utc).strftime('%Y-%m-%d/%H:%M'), True), + ("end_time", "End (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x, tz=timezone.utc).strftime('%Y-%m-%d/%H:%M'), True), ("day_of_the_week", "Day of the Week", "{}", None, True), ("hour_of_the_day", "Hour of the Day in UTC", "{}", None, True), ("min_of_the_hour", "Minute of the Hour", "{}", None, True), @@ -1105,13 +1252,16 @@ def parse_query(query_str: str, res: Dict = None, fields = {}, field_alias = {}, for field, op, _, value, _ in opts: value = value.strip(",[]") - v = res.setdefault(field, {}) op = op.strip() op_name = op_names.get(op) if field in field_alias: - res.pop(field) + old_field = field field = field_alias[field] + if old_field in res: + res[field] = res.pop(old_field) + + v = res.setdefault(field, {}) if (field == "driver_version") and ('.' in value): value = numeric_version(value) @@ -1163,23 +1313,35 @@ def parse_query(query_str: str, res: Dict = None, fields = {}, field_alias = {}, #print(res) return res -# ANSI escape codes for background/foreground colors -BG_DARK_GRAY = '\033[40m' # Dark gray background -BG_LIGHT_GRAY = '\033[48;5;240m' # Light gray background -FG_WHITE = '\033[97m' # Bright white text -BG_RESET = '\033[0m' # Reset all formatting -def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_width: bool = True) -> None: - """Basically takes a set of field names and rows containing the corresponding data and prints a nice tidy table - of it. +# ANSI color codes for table formatting +BG_DARK_GRAY = '\033[40m' # Dark gray background +BG_LIGHT_GRAY = '\033[48;5;240m' # Light gray background +FG_WHITE = '\033[97m' # Bright white text +BG_RESET = '\033[0m' # Reset all formatting - :param list rows: Each row is a dict with keys corresponding to the field names (first element) in the fields tuple. - :param Tuple fields: 5-tuple describing a field. First element is field name, second is human readable version, third is format string, fourth is a lambda function run on the data in that field, fifth is a bool determining text justification. True = left justify, False = right justify. Here is an example showing the tuples in action. +def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_width: bool = True) -> None: + """Display data as a formatted table with automatic column width management. - :rtype None: + Takes a set of field definitions and rows of data and prints a formatted table. + When auto_width is enabled, columns are grouped to fit within terminal width, + with alternating row colors for readability. - Example of 5-tuple: ("cpu_ram", "RAM", "{:0.1f}", lambda x: x / 1000, False) + Args: + rows: List of dicts with keys corresponding to field names in the fields tuple. + fields: Tuple of 5-tuples defining each column: + - field_name: API response dict key + - display_name: Column header text + - format_string: Python format spec (e.g., "{:0.1f}") + - converter: Lambda to transform value, or None for raw value + - left_justify: Boolean, True for left-aligned columns + replace_spaces: If True, replace spaces with underscores in cell values. + auto_width: If True, automatically group columns to fit terminal width + with colored alternating rows. If False, print simple table. + + Example field tuple: + ("cpu_ram", "RAM", "{:0.1f}", lambda x: x / 1000, False) """ header = [name for _, name, _, _, _ in fields] out_rows = [header] @@ -1200,7 +1362,7 @@ def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_w idx = len(row) lengths[idx] = max(len(s), lengths[idx]) row.append(s) - + if auto_width: width = shutil.get_terminal_size((80, 20)).columns start_col_idxs = [0] @@ -1210,7 +1372,7 @@ def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_w if total_len > width: start_col_idxs.append(i) # index for the start of the next group total_len = l + 6 # l + 2 + the 4 from the initial length - + groups = {} for row in out_rows: grp_num = 0 @@ -1219,7 +1381,7 @@ def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_w end = start_col_idxs[i+1]-1 if i+1 < len(start_col_idxs) else len(lengths) groups.setdefault(grp_num, []).append(row[start:end]) grp_num += 1 - + for i, group in groups.items(): idx = start_col_idxs[i] group_lengths = lengths[idx:idx+len(group[0])] @@ -1289,7 +1451,7 @@ def parse_vast_url(url_str): try: instance_id = int(path) path = "/" - except: + except (ValueError, TypeError): pass valid_unix_path_regex = re.compile('^(/)?([^/\0]+(/)?)+$') @@ -1317,7 +1479,7 @@ def get_ssh_key(argstr): has around 200 or so "base64" characters and ends with some-user@some-where. "Generate public ssh key" would be a good search term if you don't know how to do this. - """, add_separator=False)) + """)) if not ssh_key.lower().startswith('ssh'): raise ValueError(deindent(""" @@ -1330,7 +1492,7 @@ def get_ssh_key(argstr): {} And welp, that just don't look right. - """.format(ssh_key), add_separator=False)) + """.format(ssh_key))) return ssh_key @@ -1338,14 +1500,15 @@ def get_ssh_key(argstr): @parser.command( argument("instance_id", help="id of instance to attach to", type=int), argument("ssh_key", help="ssh key to attach to instance", type=str), + description="Attach an SSH key to an instance for remote access", usage="vastai attach ssh instance_id ssh_key", - help="Attach an ssh key to an instance. This will allow you to connect to the instance with the ssh key", + help="Attach an SSH key to an instance for remote access", epilog=deindent(""" Attach an ssh key to an instance. This will allow you to connect to the instance with the ssh key. Examples: - vastai attach "ssh 12371 ssh-rsa AAAAB3NzaC1yc2EAAA..." - vastai attach "ssh 12371 ssh-rsa $(cat ~/.ssh/id_rsa)" + vastai attach ssh 12371 ssh-rsa AAAAB3NzaC1yc2EAAA... + vastai attach ssh 12371 ssh-rsa $(cat ~/.ssh/id_rsa) """), ) def attach__ssh(args): @@ -1354,12 +1517,20 @@ def attach__ssh(args): req_json = {"ssh_key": ssh_key} r = http_post(args, url, headers=headers, json=req_json) r.raise_for_status() - print(r.json()) + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return rj + print(rj) @parser.command( argument("dst", help="instance_id:/path to target of copy operation", type=str), + description="Cancel an in-progress file copy operation", usage="vastai cancel copy DST", - help="Cancel a remote copy in progress, specified by DST id", + help="Cancel an in-progress file copy operation", epilog=deindent(""" Use this command to cancel any/all current remote copy operations copying to a specific named instance, given by DST. @@ -1388,21 +1559,24 @@ def cancel__copy(args: argparse.Namespace): req_json = { "client_id": "me", "dst_id": dst_id, } r = http_del(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): + try: rj = r.json(); - if (rj["success"]): - print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") - else: - print(rj["msg"]); + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return rj + if rj.get("success"): + print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(rj.get("msg", "Unknown error")); @parser.command( argument("dst", help="instance_id:/path to target of sync operation", type=str), + description="Cancel an in-progress file sync operation", usage="vastai cancel sync DST", - help="Cancel a remote copy in progress, specified by DST id", + help="Cancel an in-progress file sync operation", epilog=deindent(""" Use this command to cancel any/all current remote cloud sync operations copying to a specific named instance, given by DST. @@ -1431,15 +1605,17 @@ def cancel__sync(args: argparse.Namespace): req_json = { "client_id": "me", "dst_id": dst_id, } r = http_del(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): + try: rj = r.json(); - if (rj["success"]): - print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") - else: - print(rj["msg"]); + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return rj + if rj.get("success"): + print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(rj.get("msg", "Unknown error")); def default_start_date(): return datetime.now(timezone.utc).strftime("%Y-%m-%d") @@ -1491,6 +1667,7 @@ def parse_hour_cron_style(value): argument("--end_date", type=str, default=default_end_date(), help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is 7 days from now. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), + description="Change the bid price for a spot/interruptible instance", usage="vastai change bid id [--price PRICE]", help="Change the bid price for a spot/interruptible instance", epilog=deindent(""" @@ -1504,8 +1681,6 @@ def change__bid(args: argparse.Namespace): :param argparse.Namespace args: should supply all the command-line options :rtype int: """ - url = apiurl(args, "/instances/bid_price/{id}/".format(id=args.id)) - json_blob = {"client_id": "me", "price": args.price,} if (args.explain): print("request json: ") @@ -1516,12 +1691,13 @@ def change__bid(args: argparse.Namespace): cli_command = "change bid" api_endpoint = "/api/v0/instances/bid_price/{id}/".format(id=args.id) json_blob["instance_id"] = args.id - add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT", instance_id=args.id) + add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT", instance_id=args.id) return - - r = http_put(args, url, headers=headers, json=json_blob) - r.raise_for_status() - print("Per gpu bid price changed".format(r.json())) + + result = api_call(args, "PUT", "/instances/bid_price/{id}/".format(id=args.id), json_body=json_blob) + if args.raw: + return result + print("Per gpu bid price changed".format(result)) @@ -1530,8 +1706,9 @@ def change__bid(args: argparse.Namespace): argument("dest", help="id of volume offer volume is being copied to", type=int), argument("-s", "--size", help="Size of new volume contract, in GB. Must be greater than or equal to the source volume, and less than or equal to the destination offer.", type=float), argument("-d", "--disable_compression", action="store_true", help="Do not compress volume data before copying."), + description="Create a copy of an existing volume", usage="vastai copy volume [options]", - help="Clone an existing volume", + help="Create a copy of an existing volume", epilog=deindent(""" Create a new volume with the given offer, by copying the existing volume. Size defaults to the size of the existing volume, but can be increased if there is available space. @@ -1555,18 +1732,24 @@ def clone__volume(args: argparse.Namespace): print(json_blob) r = http_post(args, url, headers=headers,json=json_blob) r.raise_for_status() + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return if args.raw: - return r + return rj else: - print("Created. {}".format(r.json())) + print("Created. {}".format(rj)) @parser.command( argument("src", help="Source location for copy operation (supports multiple formats)", type=str), argument("dst", help="Target location for copy operation (supports multiple formats)", type=str), argument("-i", "--identity", help="Location of ssh private key", type=str), + description="Copy files/directories between instances or between local and instance", usage="vastai copy SRC DST", - help="Copy directories between instances and/or local", + help="Copy files/directories between instances or between local and instance", epilog=deindent(""" Copies a directory from a source location to a target location. Each of source and destination directories can be either local or remote, subject to appropriate read and write @@ -1632,43 +1815,43 @@ def copy(args: argparse.Namespace): url = apiurl(args, f"/commands/copy_direct/") r = http_put(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): + try: rj = r.json() - #print(json.dumps(rj, indent=1, sort_keys=True)) - if (rj["success"]) and ((src_id is None or src_id == "local") or (dst_id is None or dst_id == "local")): - homedir = subprocess.getoutput("echo $HOME") - #print(f"homedir: {homedir}") - remote_port = None - identity = f"-i {args.identity}" if (args.identity is not None) else "" - if (src_id is None or src_id == "local"): - #result = subprocess.run(f"mkdir -p {src_path}", shell=True) - remote_port = rj["dst_port"] - remote_addr = rj["dst_addr"] - cmd = f"rsync -arz -v --progress --rsh=ssh -e 'ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no' {src_path} vastai_kaalia@{remote_addr}::{dst_id}/{dst_path}" - print(cmd) - result = subprocess.run(cmd, shell=True) - #result = subprocess.run(["sudo", "rsync" "-arz", "-v", "--progress", "-rsh=ssh", "-e 'sudo ssh -i {homedir}/.ssh/id_rsa -p {remote_port} -o StrictHostKeyChecking=no'", src_path, "vastai_kaalia@{remote_addr}::{dst_id}"], shell=True) - elif (dst_id is None or dst_id == "local"): - result = subprocess.run(f"mkdir -p {dst_path}", shell=True) - remote_port = rj["src_port"] - remote_addr = rj["src_addr"] - cmd = f"rsync -arz -v --progress --rsh=ssh -e 'ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no' vastai_kaalia@{remote_addr}::{src_id}/{src_path} {dst_path}" - print(cmd) - result = subprocess.run(cmd, shell=True) - #result = subprocess.run(["sudo", "rsync" "-arz", "-v", "--progress", "-rsh=ssh", "-e 'ssh -i {homedir}/.ssh/id_rsa -p {remote_port} -o StrictHostKeyChecking=no'", "vastai_kaalia@{remote_addr}::{src_id}", dst_path], shell=True) + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + #print(json.dumps(rj, indent=1, sort_keys=True)) + if rj.get("success") and ((src_id is None or src_id == "local") or (dst_id is None or dst_id == "local")): + homedir = os.path.expanduser("~") + #print(f"homedir: {homedir}") + remote_port = None + identity = f"-i {args.identity}" if (args.identity is not None) else "" + if (src_id is None or src_id == "local"): + remote_port = rj.get("dst_port") + remote_addr = rj.get("dst_addr") + ssh_cmd = f"ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no".strip() + rsync_args = ["rsync", "-arz", "-v", "--progress", "-e", ssh_cmd, src_path, f"vastai_kaalia@{remote_addr}::{dst_id}/{dst_path}"] + print(" ".join(rsync_args)) + result = subprocess.run(rsync_args) + elif (dst_id is None or dst_id == "local"): + os.makedirs(dst_path, exist_ok=True) + remote_port = rj.get("src_port") + remote_addr = rj.get("src_addr") + ssh_cmd = f"ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no".strip() + rsync_args = ["rsync", "-arz", "-v", "--progress", "-e", ssh_cmd, f"vastai_kaalia@{remote_addr}::{src_id}/{src_path}", dst_path] + print(" ".join(rsync_args)) + result = subprocess.run(rsync_args) + else: + if rj.get("success"): + print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") else: - if (rj["success"]): - print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") + msg = rj.get("msg", "Unknown error") + if msg == "src_path not supported VMs.": + print("copy between VM instances does not currently support subpaths (only full disk copy)") + elif msg == "dst_path not supported for VMs.": + print("copy between VM instances does not currently support subpaths (only full disk copy)") else: - if rj["msg"] == "src_path not supported VMs.": - print("copy between VM instances does not currently support subpaths (only full disk copy)") - elif rj["msg"] == "dst_path not supported for VMs.": - print("copy between VM instances does not currently support subpaths (only full disk copy)") - else: - print(rj["msg"]) - else: - print(r.text) - print("failed with error {r.status_code}".format(**locals())); + print(msg) ''' @@ -1709,20 +1892,21 @@ def vm__copy(args: argparse.Namespace): r = http_put(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): + try: rj = r.json(); - if (rj["success"]): - print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") - else: - if rj["msg"] == "Invalid src_path.": - print("src instance is not a VM") - elif rj["msg"] == "Invalid dst_path.": - print("dst instance is not a VM") - else: - print(rj["msg"]); + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if rj.get("success"): + print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + msg = rj.get("msg", "Unknown error") + if msg == "Invalid src_path.": + print("src instance is not a VM") + elif msg == "Invalid dst_path.": + print("dst instance is not a VM") + else: + print(msg); ''' @parser.command( @@ -1741,8 +1925,9 @@ def vm__copy(args: argparse.Namespace): argument("--end_date", type=str, help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is contract's end. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), + description="Copy files between instances and cloud storage (S3, GCS, Azure)", usage="vastai cloud copy --src SRC --dst DST --instance INSTANCE_ID -connection CONNECTION_ID --transfer TRANSFER_TYPE", - help="Copy files/folders to and from cloud providers", + help="Copy files between instances and cloud storage (S3, GCS, Azure)", epilog=deindent(""" Copies a directory from a source location to a target location. Each of source and destination directories can be either local or remote, subject to appropriate read and write @@ -1810,8 +1995,13 @@ def cloud__copy(args: argparse.Namespace): req_url = apiurl(args, "/instances/{id}/".format(id=args.instance) , {"owner": "me"} ) r = http_get(args, req_url) r.raise_for_status() - row = r.json()["instances"] - + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + row = rj.get("instances") + if args.transfer.lower() == "instance to cloud": if row: # Get the cost per TB of internet upload @@ -1838,12 +2028,8 @@ def cloud__copy(args: argparse.Namespace): r = http_post(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): - print("Cloud Copy Started - check instance status bar for progress updates (~30 seconds delayed).") - print("When the operation is finished you should see 'Cloud Copy Operation Finished' in the instance status bar.") - else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print("Cloud Copy Started - check instance status bar for progress updates (~30 seconds delayed).") + print("When the operation is finished you should see 'Cloud Copy Operation Finished' in the instance status bar.") @parser.command( @@ -1853,10 +2039,11 @@ def cloud__copy(args: argparse.Namespace): argument("--docker_login_user",help="Username for container registry with repo", type=str), argument("--docker_login_pass",help="Password or token for container registry with repo", type=str), argument("--pause", help="Pause container's processes being executed by the CPU to take snapshot (true/false). Default will be true", type=str, default="true"), + description="Create a snapshot of a running container and push to registry", usage="vastai take snapshot INSTANCE_ID " "--repo REPO --docker_login_user USER --docker_login_pass PASS" "[--container_registry REGISTRY] [--pause true|false]", - help="Schedule a snapshot of a running container and push it to your repo in a container registry", + help="Create a snapshot of a running container and push to registry", epilog=deindent(""" Takes a snapshot of a running container instance and pushes snapshot to the specified repository in container registry. @@ -1901,16 +2088,15 @@ def take__snapshot(args: argparse.Namespace): # POST to the snapshot endpoint r = http_post(args, url, headers=headers, json=req_json) r.raise_for_status() - - if r.status_code == 200: + try: data = r.json() - if data.get("success"): - print(f"Snapshot request sent successfully. Please check your repo {repo} in container registry {container_registry} in 5-10 mins. It can take longer than 5-10 mins to push your snapshot image to your repo depending on the size of your image.") - else: - print(data.get("msg", "Unknown error with snapshot request")) + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if data.get("success"): + print(f"Snapshot request sent successfully. Please check your repo {repo} in container registry {container_registry} in 5-10 mins. It can take longer than 5-10 mins to push your snapshot image to your repo depending on the size of your image.") else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(data.get("msg", "Unknown error with snapshot request")) def validate_frequency_values(day_of_the_week, hour_of_the_day, frequency): @@ -1961,7 +2147,7 @@ def add_scheduled_job(args, req_json, cli_command, api_endpoint, request_method, "instance_id": instance_id } # Send a POST request - response = requests.post(schedule_job_url, headers=headers, json=request_body) + response = http_post(args, schedule_job_url, headers=headers, json=request_body) if args.explain: print("request json: ") @@ -1975,29 +2161,33 @@ def add_scheduled_job(args, req_json, cli_command, api_endpoint, request_method, elif response.status_code == 422: user_input = input("Existing scheduled job found. Do you want to update it (y|n)? ") if user_input.strip().lower() == "y": - scheduled_job_id = response.json()["scheduled_job_id"] + try: + resp_data = response.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + scheduled_job_id = resp_data.get("scheduled_job_id") + if not scheduled_job_id: + print("Error: API response missing required 'scheduled_job_id' field", file=sys.stderr) + return schedule_job_url = apiurl(args, f"/commands/schedule_job/{scheduled_job_id}/") - response = update_scheduled_job(cli_command, schedule_job_url, frequency, args.start_date, args.end_date, request_body) + response = update_scheduled_job(args, cli_command, schedule_job_url, frequency, args.start_date, args.end_date, request_body) else: print("Job update aborted by the user.") else: # print(r.text) print(f"add_scheduled_job insert: failed error: {response.status_code}. Response body: {response.text}") -def update_scheduled_job(cli_command, schedule_job_url, frequency, start_date, end_date, request_body): - response = requests.put(schedule_job_url, headers=headers, json=request_body) +def update_scheduled_job(args, cli_command, schedule_job_url, frequency, start_date, end_date, request_body): + response = http_put(args, schedule_job_url, headers=headers, json=request_body) - # Raise an exception for HTTP errors + # Raise an exception for HTTP errors response.raise_for_status() - if response.status_code == 200: - print(f"add_scheduled_job update: success - Scheduling {frequency} job to {cli_command} from {start_date} UTC to {end_date} UTC") - print(response.json()) - elif response.status_code == 401: - print(f"add_scheduled_job update: failed status_code: {response.status_code}. It could be because you aren't using a valid api_key.") - else: - # print(r.text) - print(f"add_scheduled_job update: failed status_code: {response.status_code}.") + print(f"add_scheduled_job update: success - Scheduling {frequency} job to {cli_command} from {start_date} UTC to {end_date} UTC") + try: print(response.json()) + except JSONDecodeError: + print(response.text) return response @@ -2006,8 +2196,9 @@ def update_scheduled_job(cli_command, schedule_job_url, frequency, start_date, e argument("--name", help="name of the api-key", type=str), argument("--permission_file", help="file path for json encoded permissions, see https://vast.ai/docs/cli/roles-and-permissions for more information", type=str), argument("--key_params", help="optional wildcard key params for advanced keys", type=str), + description="Create a new API key with custom permissions", usage="vastai create api-key --name NAME --permission_file PERMISSIONS", - help="Create a new api-key with restricted permissions. Can be sent to other users and teammates", + help="Create a new API key with custom permissions", epilog=deindent(""" In order to create api keys you must understand how permissions must be sent via json format. You can find more information about permissions here: https://vast.ai/docs/cli/roles-and-permissions @@ -2019,7 +2210,14 @@ def create__api_key(args): permissions = load_permissions_from_file(args.permission_file) r = http_post(args, url, headers=headers, json={"name": args.name, "permissions": permissions, "key_params": args.key_params}) r.raise_for_status() - print("api-key created {}".format(r.json())) + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return rj + print("api-key created {}".format(rj)) except FileNotFoundError: print("Error: Permission file '{}' not found.".format(args.permission_file)) except requests.exceptions.RequestException as e: @@ -2031,8 +2229,9 @@ def create__api_key(args): @parser.command( argument("subnet", help="local subnet for cluster, ex: '0.0.0.0/24'", type=str), argument("manager_id", help="Machine ID of manager node in cluster. Must exist already.", type=int), + description="[Beta] Create a new machine cluster", usage="vastai create cluster SUBNET MANAGER_ID", - help="Create Vast cluster", + help="[Beta] Create a new machine cluster", epilog=deindent(""" Create Vast Cluster by defining a local subnet and manager id.""") ) @@ -2052,25 +2251,31 @@ def create__cluster(args: argparse.Namespace): r = http_post(args, req_url, json=json_blob) r.raise_for_status() + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: - return r + return rj - print(r.json()["msg"]) + print(rj.get("msg", "Unknown error")) @parser.command( argument("name", help="Environment variable name", type=str), argument("value", help="Environment variable value", type=str), + description="Create a new account-level environment variable", usage="vastai create env-var ", - help="Create a new user environment variable", + help="Create a new account-level environment variable", ) def create__env_var(args): """Create a new environment variable for the current user.""" - url = apiurl(args, "/secrets/") data = {"key": args.name, "value": args.value} - r = http_post(args, url, headers=headers, json=data) - r.raise_for_status() + result = api_call(args, "POST", "/secrets/", json_body=data) - result = r.json() + if args.raw: + return result if result.get("success"): print(result.get("msg", "Environment variable created successfully.")) else: @@ -2079,8 +2284,9 @@ def create__env_var(args): @parser.command( argument("ssh_key", help="add your existing ssh public key to your account (from the .pub file). If no public key is provided, a new key pair will be generated.", type=str, nargs='?'), argument("-y", "--yes", help="automatically answer yes to prompts", action="store_true"), + description="Add an SSH public key to your account", usage="vastai create ssh-key [ssh_public_key] [-y]", - help="Create a new ssh-key", + help="Add an SSH public key to your account", epilog=deindent(""" You may use this command to add an existing public key, or create a new ssh key pair and add that public key, to your Vast account. @@ -2101,20 +2307,20 @@ def create__env_var(args): def create__ssh_key(args): ssh_key_content = args.ssh_key - + # If no SSH key provided, generate one if not ssh_key_content: ssh_key_content = generate_ssh_key(args.yes) else: print("Adding provided SSH public key to account...") - + # Send the SSH key to the API - url = apiurl(args, "/ssh/") - r = http_post(args, url, headers=headers, json={"ssh_key": ssh_key_content}) - r.raise_for_status() - + result = api_call(args, "POST", "/ssh/", json_body={"ssh_key": ssh_key_content}) + + if args.raw: + return result # Print json response - print("ssh-key created {}\nNote: You may need to add the new public key to any pre-existing instances".format(r.json())) + print("ssh-key created {}\nNote: You may need to add the new public key to any pre-existing instances".format(result)) def generate_ssh_key(auto_yes=False): @@ -2245,8 +2451,9 @@ def generate_ssh_key(auto_yes=False): argument("--cold_mult", help="[NOTE: this field isn't currently used at the workergroup level]cold/stopped instance capacity target as multiple of hot capacity target (default 2.0)", type=float), argument("--cold_workers", help="min number of workers to keep 'cold' for this workergroup", type=int), argument("--auto_instance", help=argparse.SUPPRESS, type=str, default="prod"), + description="Create an autoscaling worker group for serverless inference", usage="vastai workergroup create [OPTIONS]", - help="Create a new autoscale group", + help="Create an autoscaling worker group for serverless inference", epilog=deindent(""" Create a new autoscaling group to manage a pool of worker instances. @@ -2275,7 +2482,10 @@ def create__workergroup(args): r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print("workergroup create {}".format(r.json())) + rj = r.json() + if args.raw: + return rj + print("workergroup create {}".format(rj)) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -2295,8 +2505,9 @@ def create__workergroup(args): argument("--endpoint_name", help="deployment endpoint name (allows multiple autoscale groups to share same deployment endpoint)", type=str), argument("--auto_instance", help=argparse.SUPPRESS, type=str, default="prod"), + description="Create a serverless inference endpoint", usage="vastai create endpoint [OPTIONS]", - help="Create a new endpoint group", + help="Create a serverless inference endpoint", epilog=deindent(""" Create a new endpoint group to manage many autoscaling groups @@ -2311,11 +2522,14 @@ def create__endpoint(args): if (args.explain): print("request json: ") print(json_blob) - r = requests.post(url, headers=headers,json=json_blob) + r = http_post(args, url, headers=headers, json=json_blob) r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print("create endpoint {}".format(r.json())) + rj = r.json() + if args.raw: + return rj + print("create endpoint {}".format(rj)) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -2414,8 +2628,9 @@ def validate_portal_config(json_blob): argument("--mount-path", help="The path to the volume from within the new instance container. e.g. /root/volume", type=str), argument("--volume-label", help="(optional) A name to give the new volume. Only usable with --create-volume", type=str), + description="Create a new GPU instance from an offer", usage="vastai create instance ID [OPTIONS] [--args ...]", - help="Create a new instance", + help="Create a new GPU instance from an offer", epilog=deindent(""" Performs the same action as pressing the "RENT" button on the website at https://console.vast.ai/create/ Creates an instance from an offer ID (which is returned from "search offers"). Each offer ID can only be used to create one instance. @@ -2507,7 +2722,7 @@ def create__instance(args: argparse.Namespace): r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: - return r + return r.json() else: print("Started. {}".format(r.json())) @@ -2516,8 +2731,9 @@ def create__instance(args: argparse.Namespace): argument("--username", help="username to use for login", type=str), argument("--password", help="password to use for login", type=str), argument("--type", help="host/client", type=str), + description="Create a subaccount for delegated access", usage="vastai create subaccount --email EMAIL --username USERNAME --password PASSWORD --type TYPE", - help="Create a subaccount", + help="Create a subaccount for delegated access", epilog=deindent(""" Creates a new account that is considered a child of your current account as defined via the API key. @@ -2531,7 +2747,7 @@ def create__subaccount(args): """ # Default value for host_only, can adjust based on expected default behavior host_only = False - + # Only process the --account_type argument if it's provided if args.type: host_only = args.type.lower() == "host" @@ -2554,20 +2770,18 @@ def create__subaccount(args): url = apiurl(args, "/users/") r = http_post(args, url, headers=headers, json=json_blob) r.raise_for_status() - - if r.status_code == 200: - rj = r.json() - print(rj) - else: - print(r.text) - print(f"Failed with error {r.status_code}") + rj = r.json() + if args.raw: + return rj + print(rj) @parser.command( argument("--team_name", help="name of the team", type=str), + description="Create a new team", usage="vastai create-team --team_name TEAM_NAME", help="Create a new team", epilog=deindent(""" - Creates a new team under your account. + Creates a new team under your account. Unlike legacy teams, this command does NOT convert your personal account into a team. Each team is created as a separate account, and you can be a member of multiple teams. @@ -2578,13 +2792,9 @@ def create__subaccount(args): - Default roles (owner, manager, member) are automatically created. - You can invite others, assign roles, and manage resources within the team. - Optional: - You can transfer a portion of your existing personal credits to the team by using - the `--transfer_credit` flag. Example: - vastai create-team --team_name myteam --transfer_credit 25 - Notes: - You cannot create a team from within another team account. + - To transfer credits to a team, use `vastai transfer credit ` after team creation. For more details, see: https://vast.ai/docs/teams-quickstart @@ -2592,27 +2802,28 @@ def create__subaccount(args): ) def create__team(args): - url = apiurl(args, "/team/") - r = http_post(args, url, headers=headers, json={"team_name": args.team_name}) - r.raise_for_status() - print(r.json()) + result = api_call(args, "POST", "/team/", json_body={"team_name": args.team_name}) + if args.raw: + return result + print(result) @parser.command( argument("--name", help="name of the role", type=str), argument("--permissions", help="file path for json encoded permissions, look in the docs for more information", type=str), + description="Create a custom role with specific permissions", usage="vastai create team-role --name NAME --permissions PERMISSIONS", - help="Add a new role to your team", + help="Create a custom role with specific permissions", epilog=deindent(""" Creating a new team role involves understanding how permissions must be sent via json format. You can find more information about permissions here: https://vast.ai/docs/cli/roles-and-permissions """) ) def create__team_role(args): - url = apiurl(args, "/team/roles/") permissions = load_permissions_from_file(args.permissions) - r = http_post(args, url, headers=headers, json={"name": args.name, "permissions": permissions}) - r.raise_for_status() - print(r.json()) + result = api_call(args, "POST", "/team/roles/", json_body={"name": args.name, "permissions": permissions}) + if args.raw: + return result + print(result) def get_template_arguments(): return [ @@ -2640,8 +2851,9 @@ def get_template_arguments(): @parser.command( *get_template_arguments(), + description="Create a reusable instance configuration template", usage="vastai create template", - help="Create a new template", + help="Create a reusable instance configuration template", epilog=deindent(""" Create a template that can be used to create instances with @@ -2668,7 +2880,7 @@ def create__template(args): default_search_query = {} if not args.no_default: default_search_query = {"verified": {"eq": True}, "external": {"eq": False}, "rentable": {"eq": True}, "rented": {"eq": False}} - + extra_filters = parse_query(args.search_params, default_search_query, offers_fields, offers_alias, offers_mult) template = { "name" : args.name, @@ -2701,10 +2913,12 @@ def create__template(args): r.raise_for_status() try: rj = r.json() - if rj["success"]: - print(f"New Template: {rj['template']}") + if args.raw: + return rj + if rj.get("success"): + print(f"New Template: {rj.get('template', '')}") else: - print(rj['msg']) + print(rj.get('msg', 'Unknown error')) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") @@ -2714,15 +2928,16 @@ def create__template(args): argument("-s", "--size", help="size in GB of volume. Default %(default)s GB.", default=15, type=float), argument("-n", "--name", help="Optional name of volume.", type=str), + description="Create a new persistent storage volume", usage="vastai create volume ID [options]", - help="Create a new volume", + help="Create a new persistent storage volume", epilog=deindent(""" Creates a volume from an offer ID (which is returned from "search volumes"). Each offer ID can be used to create multiple volumes, provided the size of all volumes does not exceed the size of the offer. """) ) def create__volume(args: argparse.Namespace): - + json_blob ={ "size": int(args.size), "id": int(args.id) @@ -2738,7 +2953,7 @@ def create__volume(args: argparse.Namespace): r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: - return r + return r.json() else: print("Created. {}".format(r.json())) @@ -2748,8 +2963,9 @@ def create__volume(args: argparse.Namespace): argument("-s", "--size", help="size in GB of network volume. Default %(default)s GB.", default=15, type=float), argument("-n", "--name", help="Optional name of network volume.", type=str), + description="[Host] [Beta] Create a new network-attached storage volume", usage="vastai create network volume ID [options]", - help="Create a new network volume", + help="[Host] [Beta] Create a new network-attached storage volume", epilog=deindent(""" Creates a network volume from an offer ID (which is returned from "search network volumes"). Each offer ID can be used to create multiple volumes, provided the size of all volumes does not exceed the size of the offer. @@ -2772,15 +2988,16 @@ def create__network_volume(args: argparse.Namespace): r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: - return r + return r.json() else: print("Created. {}".format(r.json())) @parser.command( argument("cluster_id", help="ID of cluster to create overlay on top of", type=int), argument("name", help="overlay network name"), + description="[Beta] Create a virtual overlay network on a cluster", usage="vastai create overlay CLUSTER_ID OVERLAY_NAME", - help="Creates overlay network on top of a physical cluster", + help="[Beta] Create a virtual overlay network on a cluster", epilog=deindent(""" Creates an overlay network to allow local networking between instances on a physical cluster""") ) @@ -2798,49 +3015,53 @@ def create__overlay(args: argparse.Namespace): r.raise_for_status() if args.raw: - return r + return r.json() - print(r.json()["msg"]) + print(r.json().get("msg", "Unknown error")) @parser.command( argument("id", help="id of apikey to remove", type=int), + description="Delete an API key", usage="vastai delete api-key ID", - help="Remove an api-key", + help="Delete an API key", ) def delete__api_key(args): - url = apiurl(args, "/auth/apikeys/{id}/".format(id=args.id)) - r = http_del(args, url, headers=headers) - r.raise_for_status() - print(r.json()) + result = api_call(args, "DELETE", f"/auth/apikeys/{args.id}/") + if args.raw: + return result + print(result) @parser.command( argument("id", help="id ssh key to delete", type=int), + description="Remove an SSH key from your account", usage="vastai delete ssh-key ID", - help="Remove an ssh-key", + help="Remove an SSH key from your account", ) def delete__ssh_key(args): - url = apiurl(args, "/ssh/{id}/".format(id=args.id)) - r = http_del(args, url, headers=headers) - r.raise_for_status() - print(r.json()) + result = api_call(args, "DELETE", f"/ssh/{args.id}/") + if args.raw: + return result + print(result) @parser.command( argument("id", help="id of scheduled job to remove", type=int), + description="Delete a scheduled job", usage="vastai delete scheduled-job ID", help="Delete a scheduled job", ) def delete__scheduled_job(args): - url = apiurl(args, "/commands/schedule_job/{id}/".format(id=args.id)) - r = http_del(args, url, headers=headers) - r.raise_for_status() - print(r.json()) + result = api_call(args, "DELETE", f"/commands/schedule_job/{args.id}/") + if args.raw: + return result + print(result) @parser.command( argument("cluster_id", help="ID of cluster to delete", type=int), + description="[Beta] Delete a machine cluster", usage="vastai delete cluster CLUSTER_ID", - help="Delete Cluster", + help="[Beta] Delete a machine cluster", epilog=deindent(""" Delete Vast Cluster""") ) @@ -2852,28 +3073,27 @@ def delete__cluster(args: argparse.Namespace): if args.explain: print("request json:", json_blob) - req_url = apiurl(args, "/cluster/") - r = http_del(args, req_url, json=json_blob) - r.raise_for_status() + result = api_call(args, "DELETE", "/cluster/", json_body=json_blob) if args.raw: - return r + return result - print(r.json()["msg"]) + print(result.get("msg", "Unknown error")) @parser.command( argument("id", help="id of group to delete", type=int), + description="Delete an autoscaling worker group", usage="vastai delete workergroup ID ", - help="Delete a workergroup group", + help="Delete an autoscaling worker group", epilog=deindent(""" Note that deleting a workergroup doesn't automatically destroy all the instances that are associated with your workergroup. Example: vastai delete workergroup 4242 """), ) def delete__workergroup(args): - id = args.id - url = apiurl(args, f"/autojobs/{id}/" ) + workergroup_id = args.id + url = apiurl(args, f"/autojobs/{workergroup_id}/") json_blob = {"client_id": "me", "autojob_id": args.id} if (args.explain): print("request json: ") @@ -2882,7 +3102,10 @@ def delete__workergroup(args): r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print("workergroup delete {}".format(r.json())) + rj = r.json() + if args.raw: + return rj + print("workergroup delete {}".format(rj)) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -2893,15 +3116,16 @@ def delete__workergroup(args): @parser.command( argument("id", help="id of endpoint group to delete", type=int), + description="Delete a serverless inference endpoint", usage="vastai delete endpoint ID ", - help="Delete an endpoint group", + help="Delete a serverless inference endpoint", epilog=deindent(""" Example: vastai delete endpoint 4242 """), ) def delete__endpoint(args): - id = args.id - url = apiurl(args, f"/endptjobs/{id}/" ) + endpoint_id = args.id + url = apiurl(args, f"/endptjobs/{endpoint_id}/") json_blob = {"client_id": "me", "endptjob_id": args.id} if (args.explain): print("request json: ") @@ -2910,7 +3134,10 @@ def delete__endpoint(args): r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print("delete endpoint {}".format(r.json())) + rj = r.json() + if args.raw: + return rj + print("delete endpoint {}".format(rj)) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -2921,6 +3148,7 @@ def delete__endpoint(args): @parser.command( argument("name", help="Environment variable name to delete", type=str), + description="Delete a user environment variable", usage="vastai delete env-var ", help="Delete a user environment variable", ) @@ -2931,7 +3159,13 @@ def delete__env_var(args): r = http_del(args, url, headers=headers, json=data) r.raise_for_status() - result = r.json() + try: + result = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return result if result.get("success"): print(result.get("msg", "Environment variable deleted successfully.")) else: @@ -2939,8 +3173,9 @@ def delete__env_var(args): @parser.command( argument("overlay_identifier", help="ID (int) or name (str) of overlay to delete", nargs="?"), + description="[Beta] Delete an overlay network and its instances", usage="vastai delete overlay OVERLAY_IDENTIFIER", - help="Deletes overlay and removes all of its associated instances" + help="[Beta] Delete an overlay network and its instances" ) def delete__overlay(args: argparse.Namespace): identifier = args.overlay_identifier @@ -2957,20 +3192,19 @@ def delete__overlay(args: argparse.Namespace): if args.explain: print("request json:", json_blob) - req_url = apiurl(args, "/overlay/") - r = http_del(args, req_url, json=json_blob) - r.raise_for_status() + result = api_call(args, "DELETE", "/overlay/", json_body=json_blob) if args.raw: - return r + return result - print(r.json()["msg"]) + print(result.get("msg", "Unknown error")) @parser.command( argument("--template-id", help="Template ID of Template to Delete", type=int), argument("--hash-id", help="Hash ID of Template to Delete", type=str), + description="Delete a template", usage="vastai delete template [--template-id | --hash-id ]", - help="Delete a Template", + help="Delete a template", epilog=deindent(""" Note: Deleting a template only removes the user's replationship to a template. It does not get destroyed Example: vastai delete template --template-id 12345 @@ -2979,7 +3213,7 @@ def delete__overlay(args: argparse.Namespace): ) def delete__template(args): url = apiurl(args, f"/template/" ) - + if args.hash_id: json_blob = { "hash_id": args.hash_id } elif args.template_id: @@ -2987,18 +3221,20 @@ def delete__template(args): else: print('ERROR: Must Specify either Template ID or Hash ID to delete a template') return - + if (args.explain): print("request json: ") print(json_blob) print(args) print(url) r = http_del(args, url, headers=headers,json=json_blob) - print(r) # r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print(r.json()['msg']) + rj = r.json() + if args.raw: + return rj + print(rj.get('msg', 'Unknown error')) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -3010,20 +3246,19 @@ def delete__template(args): @parser.command( argument("id", help="id of volume contract", type=int), + description="Delete a persistent storage volume", usage="vastai delete volume ID", - help="Delete a volume", + help="Delete a persistent storage volume", epilog=deindent(""" Deletes volume with the given ID. All instances using the volume must be destroyed before the volume can be deleted. """) ) def delete__volume(args: argparse.Namespace): - url = apiurl(args, "/volumes/", query_args={"id": args.id}) - r = http_del(args, url, headers=headers) - r.raise_for_status() + result = api_call(args, "DELETE", "/volumes/", query_args={"id": args.id}) if args.raw: - return r + return result else: - print("Deleted. {}".format(r.json())) + print("Deleted. {}".format(result)) def destroy_instance(id,args): @@ -3031,25 +3266,32 @@ def destroy_instance(id,args): r = http_del(args, url, headers=headers,json={}) r.raise_for_status() if args.raw: - return r - elif (r.status_code == 200): - rj = r.json(); - if (rj["success"]): - print("destroying instance {id}.".format(**(locals()))); - else: - print(rj["msg"]); + return r.json() + rj = r.json(); + if rj.get("success"): + print("destroying instance {id}.".format(**(locals()))); else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(rj.get("msg", "Unknown error")); @parser.command( argument("id", help="id of instance to delete", type=int), + description="Destroy an instance (irreversible, deletes data)", usage="vastai destroy instance id [-h] [--api-key API_KEY] [--raw]", help="Destroy an instance (irreversible, deletes data)", epilog=deindent(""" - Perfoms the same action as pressing the "DESTROY" button on the website at https://console.vast.ai/instances/ - Example: vastai destroy instance 4242 + Performs the same action as pressing the "DESTROY" button on the website at https://console.vast.ai/instances/ + + WARNING: This action is IMMEDIATE and IRREVERSIBLE. All data on the instance will be permanently + deleted unless you have saved it to a persistent volume or external storage. + + Examples: + vastai destroy instance 12345 # Destroy instance with ID 12345 + + Before destroying: + - Save any important data using 'vastai copy' or by mounting a persistent volume + - Check instance ID carefully with 'vastai show instances' + - Consider using 'vastai stop instance' if you want to pause without data loss """), ) def destroy__instance(args): @@ -3060,8 +3302,9 @@ def destroy__instance(args): destroy_instance(args.id,args) @parser.command( - argument("ids", help="ids of instance to destroy", type=int, nargs='+'), - usage="vastai destroy instances [--raw] ", + argument("ids", help="ids of instances to destroy", type=int, nargs='+'), + description="Destroy a list of instances (irreversible, deletes data)", + usage="vastai destroy instances IDS [OPTIONS]", help="Destroy a list of instances (irreversible, deletes data)", ) def destroy__instances(args): @@ -3071,20 +3314,22 @@ def destroy__instances(args): destroy_instance(id, args) @parser.command( + description="Delete your team and remove all members", usage="vastai destroy team", - help="Destroy your team", + help="Delete your team and remove all members", ) def destroy__team(args): - url = apiurl(args, "/team/") - r = http_del(args, url, headers=headers) - r.raise_for_status() - print(r.json()) + result = api_call(args, "DELETE", "/team/") + if args.raw: + return result + print(result) @parser.command( argument("instance_id", help="id of the instance", type=int), argument("ssh_key_id", help="id of the key to detach to the instance", type=str), + description="Remove an SSH key from an instance", usage="vastai detach instance_id ssh_key_id", - help="Detach an ssh key from an instance", + help="Remove an SSH key from an instance", epilog=deindent(""" Example: vastai detach 99999 12345 """) @@ -3093,7 +3338,13 @@ def detach__ssh(args): url = apiurl(args, "/instances/{id}/ssh/{ssh_key_id}/".format(id=args.instance_id, ssh_key_id=args.ssh_key_id)) r = http_del(args, url, headers=headers) r.raise_for_status() - print(r.json()) + try: + rj = r.json() + except JSONDecodeError: + rj = {"response": r.text} + if args.raw: + return rj + print(rj) @parser.command( argument("id", help="id of instance to execute on", type=int), @@ -3103,8 +3354,9 @@ def detach__ssh(args): argument("--end_date", type=str, default=default_end_date(), help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is 7 days from now. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), + description="Execute a command on a running instance", usage="vastai execute id COMMAND", - help="Execute a (constrained) remote command on a machine", + help="Execute a command on a running instance", epilog=deindent(""" Examples: vastai execute 99999 'ls -l -o -r' @@ -3141,22 +3393,21 @@ def execute(args): add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT", instance_id=args.id) return - if (r.status_code == 200): - rj = r.json() - if (rj["success"]): - for i in range(0,30): - time.sleep(0.3) - url = rj["result_url"] - r = requests.get(url) - if (r.status_code == 200): - filtered_text = r.text.replace(rj["writeable_path"], ''); - print(filtered_text) - break - else: - print(rj); + rj = r.json() + if rj.get("success"): + url = rj.get("result_url") + if not url: + print("Error: API response missing required 'result_url' field", file=sys.stderr) + return + for i in range(0,30): + time.sleep(0.3) + r = http_get(args, url) + if (r.status_code == 200): + filtered_text = r.text.replace(rj.get("writeable_path", ''), ''); + print(filtered_text) + break else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(rj); @@ -3164,8 +3415,9 @@ def execute(args): argument("id", help="id of endpoint group to fetch logs from", type=int), argument("--level", help="log detail level (0 to 3)", type=int, default=1), argument("--tail", help="", type=int, default=None), + description="Get logs for a serverless endpoint", usage="vastai get endpt-logs ID [--api-key API_KEY]", - help="Fetch logs for a specific serverless endpoint group", + help="Get logs for a serverless endpoint", epilog=deindent(""" Example: vastai get endpt-logs 382 """), @@ -3175,7 +3427,7 @@ def get__endpt_logs(args): if args.url == server_url_default: args.url = None url = (args.url or "https://run.vast.ai") + "/get_endpoint_logs/" - json_blob = {"id": args.id, "api_key": args.api_key} + json_blob = {"id": args.id} if args.tail: json_blob["tail"] = args.tail if (args.explain): print(f"{url} with request json: ") @@ -3185,29 +3437,27 @@ def get__endpt_logs(args): r.raise_for_status() levels = {0 : "info0", 1: "info1", 2: "trace", 3: "debug"} - if (r.status_code == 200): - rj = None - try: - rj = r.json() - except Exception as e: - print(str(e)) - print(r.text) - if args.raw: - # sort_keys - return rj or r.text - else: - dbg_lvl = levels[args.level] - if rj and dbg_lvl: print(rj[dbg_lvl]) - #print(json.dumps(rj, indent=1, sort_keys=True)) - else: + rj = None + try: + rj = r.json() + except Exception as e: + print(str(e)) print(r.text) + if args.raw: + # sort_keys + return rj or r.text + else: + dbg_lvl = levels[args.level] + if rj and dbg_lvl: print(rj[dbg_lvl]) + #print(json.dumps(rj, indent=1, sort_keys=True)) @parser.command( argument("id", help="id of endpoint group to fetch logs from", type=int), argument("--level", help="log detail level (0 to 3)", type=int, default=1), argument("--tail", help="", type=int, default=None), + description="Get logs for an autoscaling worker group", usage="vastai get wrkgrp-logs ID [--api-key API_KEY]", - help="Fetch logs for a specific serverless worker group group", + help="Get logs for an autoscaling worker group", epilog=deindent(""" Example: vastai get endpt-logs 382 """), @@ -3217,7 +3467,7 @@ def get__wrkgrp_logs(args): if args.url == server_url_default: args.url = None url = (args.url or "https://run.vast.ai") + "/get_autogroup_logs/" - json_blob = {"id": args.id, "api_key": args.api_key} + json_blob = {"id": args.id} if args.tail: json_blob["tail"] = args.tail if (args.explain): print(f"{url} with request json: ") @@ -3227,45 +3477,46 @@ def get__wrkgrp_logs(args): r.raise_for_status() levels = {0 : "info0", 1: "info1", 2: "trace", 3: "debug"} - if (r.status_code == 200): - rj = None - try: - rj = r.json() - except Exception as e: - print(str(e)) - print(r.text) - if args.raw: - # sort_keys - return rj or r.text - else: - dbg_lvl = levels[args.level] - if rj and dbg_lvl: print(rj[dbg_lvl]) - #print(json.dumps(rj, indent=1, sort_keys=True)) - else: + rj = None + try: + rj = r.json() + except Exception as e: + print(str(e)) print(r.text) + if args.raw: + # sort_keys + return rj or r.text + else: + dbg_lvl = levels[args.level] + if rj and dbg_lvl: print(rj[dbg_lvl]) + #print(json.dumps(rj, indent=1, sort_keys=True)) @parser.command( argument("--email", help="email of user to be invited", type=str), argument("--role", help="role of user to be invited", type=str), + description="Invite a user to join your team", usage="vastai invite member --email EMAIL --role ROLE", - help="Invite a team member", + help="Invite a user to join your team", ) def invite__member(args): url = apiurl(args, "/team/invite/", query_args={"email": args.email, "role": args.role}) r = http_post(args, url, headers=headers) r.raise_for_status() - if (r.status_code == 200): - print(f"successfully invited {args.email} to your current team") - else: - print(r.text); - print(f"failed with error {r.status_code}") + try: + rj = r.json() + except JSONDecodeError: + rj = {"success": True, "email": args.email} + if args.raw: + return rj + print(f"successfully invited {args.email} to your current team") @parser.command( argument("cluster_id", help="ID of cluster to add machine to", type=int), argument("machine_ids", help="machine id(s) to join cluster", type=int, nargs="+"), + description="[Beta] Add a machine to an existing cluster", usage="vastai join cluster CLUSTER_ID MACHINE_IDS", - help="Join Machine to Cluster", + help="[Beta] Add a machine to an existing cluster", epilog=deindent(""" Join's Machine to Vast Cluster """) @@ -3284,16 +3535,17 @@ def join__cluster(args: argparse.Namespace): r.raise_for_status() if args.raw: - return r + return r.json() - print(r.json()["msg"]) + print(r.json().get("msg", "Unknown error")) @parser.command( argument("name", help="Overlay network name to join instance to.", type=str), argument("instance_id", help="Instance ID to add to overlay.", type=int), + description="[Beta] Connect an instance to an overlay network", usage="vastai join overlay OVERLAY_NAME INSTANCE_ID", - help="Adds instance to an overlay network", + help="[Beta] Connect an instance to an overlay network", epilog=deindent(""" Adds an instance to a compatible overlay network.""") ) @@ -3311,15 +3563,16 @@ def join__overlay(args: argparse.Namespace): r.raise_for_status() if args.raw: - return r + return r.json() - print(r.json()["msg"]) + print(r.json().get("msg", "Unknown error")) @parser.command( argument("id", help="id of instance to label", type=int), argument("label", help="label to set", type=str), + description="Assign a string label to an instance", usage="vastai label instance