diff --git a/.env.test b/.env.test new file mode 100644 index 0000000..3bc1baa --- /dev/null +++ b/.env.test @@ -0,0 +1,6 @@ +# .env.test +DB_HOST=localhost +DB_PORT=5432 +DB_NAME=inventory_test +DB_USER=postgres +DB_PASSWORD=postgres \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5751128 --- /dev/null +++ b/.gitignore @@ -0,0 +1,31 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +# Virtual environments +.venv/ +venv/ +env/ + +# IDE / Editor +.vscode/ +.idea/ + +# Testing / coverage +.pytest_cache/ +.mypy_cache/ +.coverage +coverage.xml + +# Packaging / build +build/ +dist/ +*.egg-info/ + +# Environment / local config +# .env +# .env.* + +# OS files +.DS_Store + diff --git a/API-SPEC.md b/API-SPEC.md new file mode 100644 index 0000000..e5a0c55 --- /dev/null +++ b/API-SPEC.md @@ -0,0 +1,23 @@ +# HTTP API Spec + +Base URL: `http://localhost:8000` + +## Server object +- Fields: `id`, `hostname`, `ip_address`, `datacenter`, `state`, `created_at`, `updated_at` +- Allowed states: `active`, `offline`, `retired` +- Validation: hostname must be unique; IP must be valid IPv4/IPv6; state must be allowed. + +## Endpoints +- `POST /servers` — create + - Body: `{ "hostname": str, "ip_address": str, "datacenter": str, "state": str }` + - Responses: `201` with server; `409` duplicate hostname; `422` validation. +- `GET /servers` — list all + - Responses: `200` with `[]` when empty. +- `GET /servers/{id}` — fetch one + - Responses: `200` with server; `404` if missing. +- `PUT /servers/{id}` — update (partial allowed) + - Body: any subset of `hostname`, `ip_address`, `datacenter`, `state` + - Responses: `200` updated server; `404` if missing; `409` on hostname conflict; `422` on validation. +- `DELETE /servers/{id}` — delete + - Responses: `200` `{ "detail": "Server deleted" }`; `404` if missing. + diff --git a/API.md b/API.md new file mode 100644 index 0000000..d9504ee --- /dev/null +++ b/API.md @@ -0,0 +1,65 @@ +# Inventory Guide + +How to run the stack locally and where to find detailed specs. Also how to test the stack using CLI + +## Run the stack +- **Prereqs:** Docker + Docker Compose, Python 3.12, pip. +- **Start everything (API + Postgres):** `docker-compose up --build`. API +listens on `http://localhost:8000`. +- create python virtual environment for CLI + - `python -m venv .venv && source .venv/bin/activate` + - `pip install -r requirements.txt` +- **Default DB env vars:** `DB_HOST=localhost`, `DB_PORT=5432`, `DB_NAME=inventory`, `DB_USER=postgres`, `DB_PASSWORD=postgres`. +- **API docs:** `http://localhost:8000/docs` (Swagger) or `http://localhost:8000/redoc`. +- **Tests:** `pytest` (uses in-memory fakes; no real DB needed). + +## Specs +- HTTP API spec: see `API-SPEC.md`. +- CLI spec: see `CLI-SPEC.md`. + +## Example Testing Scenario Using CLI + +### Prepare the environment +```bash +# open terminal in the root of the repo +docker compose up +# open a new terminal or a new tab in the root of the repo +python -m venv .venv && source .venv/bin/activate +# alternatively use this command python3 -m venv .venv && source .venv/bin/activate +pip install -r requirements.txt +pytest tests/ +``` + +```bash +# Test basic commands +./cli.sh create srv-1 10.0.0.1 us-east active +./cli.sh list +./cli.sh get 1 +./cli.sh update 1 --state offline +./cli.sh update 1 --state retired +./cli.sh delete 1 --confirm + +# Test Validation +# Hostname unique +./cli.sh create srv-dup 10.0.0.2 us-east active +./cli.sh create srv-dup 10.0.0.3 us-east active +# IP must look like IP +./cli.sh create bad-ip not-an-ip us-east active +./cli.sh create bad-ip 1111.1.1.1 us-east active +# State must be active/offline/retired: +./cli.sh create bad-state 10.0.0.4 us-east retiring +# Update validation +./cli.sh update 2 --ip-address not-an-ip +./cli.sh update 2 --ip-address 1111.1.1.1 +./cli.sh update 2 --state retiring +# Update hostname uniqueness: +./cli.sh create srv-a 10.0.0.5 us-east active +./cli.sh create srv-b 10.0.0.6 us-east active +# Then try to duplicate: +./cli.sh update 4 --hostname srv-a +``` + + +## Running Unit Tests + +To run all unit tests run the command `pytest tests/`. \ No newline at end of file diff --git a/CLI-SPEC.md b/CLI-SPEC.md new file mode 100644 index 0000000..00b53a2 --- /dev/null +++ b/CLI-SPEC.md @@ -0,0 +1,22 @@ +# CLI Spec + +For ease of use there is a cli.sh script that runs the python cli. It replaces `python -m cli.main` with `./cli.sh ...` e.g. `python -m cli.main create srv-1 10.0.0.1 us-east active` is replaced by `./cli.sh create srv-1 10.0.0.1 us-east active` + +Entrypoint: `python -m cli.main ...` +Uses `API_BASE_URL` to target the API (default `http://localhost:8000`). Add `--help` to any command for usage. + +- `create ` — create a server; prints JSON. +- `list` — list all servers in a table. +- `get ` — show one server as JSON. +- `update [--hostname ...] [--ip-address ...] [--datacenter ...] [--state ...]` — partial update; requires at least one option. +- `delete [--confirm]` — delete; prompts unless `--confirm` set. + +## Common workflows +- Happy-path CRUD (with API running): + - `python -m cli.main create srv-1 10.0.0.1 us-east active` + - `python -m cli.main list` + - `python -m cli.main get 1` + - `python -m cli.main update 1 --state offline` + - `python -m cli.main delete 1 --confirm` +- Override API base URL: `API_BASE_URL=https://api.example.com python -m cli.main list` + diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/db/__init__.py b/app/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/db/connection.py b/app/db/connection.py new file mode 100644 index 0000000..7242d70 --- /dev/null +++ b/app/db/connection.py @@ -0,0 +1,76 @@ +import os +from psycopg2.pool import SimpleConnectionPool +from psycopg2.extensions import connection as Connection + + +def get_database_url() -> str: + """ + Construct database URL from environment variables. + Returns PostgreSQL connection string. + """ + db_host = os.getenv("DB_HOST", "localhost") + db_port = os.getenv("DB_PORT", "5432") + db_name = os.getenv("DB_NAME", "inventory") + db_user = os.getenv("DB_USER", "postgres") + db_password = os.getenv("DB_PASSWORD", "postgres") + + return f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" + + +def create_connection_pool() -> SimpleConnectionPool: + """ + Create and configure a database connection pool. + Returns a psycopg2 connection pool instance. + """ + database_url = get_database_url() + + # Create a connection pool with min 1 and max 10 connections + pool = SimpleConnectionPool( + minconn=1, + maxconn=10, + dsn=database_url + ) + + return pool + + +def init_db() -> None: + """ + Initialize the database schema. + Creates the servers table if it doesn't exist. + Should be called on application startup. + """ + pool = create_connection_pool() + conn = pool.getconn() + + try: + with conn.cursor() as cursor: + # Create servers table with all required fields and constraints + cursor.execute(""" + CREATE TABLE IF NOT EXISTS servers ( + id SERIAL PRIMARY KEY, + hostname VARCHAR(255) UNIQUE NOT NULL, + ip_address VARCHAR(45) NOT NULL, + datacenter VARCHAR(255) NOT NULL, + state VARCHAR(50) NOT NULL, + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() + ) + """) + + # Create an index on hostname for faster lookups + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_servers_hostname + ON servers(hostname) + """) + + # Create an index on state for filtering + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_servers_state + ON servers(state) + """) + + conn.commit() + finally: + pool.putconn(conn) + pool.closeall() \ No newline at end of file diff --git a/app/db/queries.py b/app/db/queries.py new file mode 100644 index 0000000..978e1f6 --- /dev/null +++ b/app/db/queries.py @@ -0,0 +1,158 @@ +from typing import List, Optional, Any +from psycopg2.extensions import connection as Connection +from datetime import datetime + + +def row_to_dict(cursor, row) -> dict: + """ + Convert a database row to a dictionary using cursor description. + """ + if row is None: + return None + + columns = [desc[0] for desc in cursor.description] + return dict(zip(columns, row)) + + +def insert_server(conn: Connection, hostname: str, ip_address: str, + datacenter: str, state: str) -> int: + """ + Execute SQL INSERT to create a new server record. + Returns the ID of the newly created server. + """ + with conn.cursor() as cursor: + cursor.execute(""" + INSERT INTO servers (hostname, ip_address, datacenter, state) + VALUES (%s, %s, %s, %s) + RETURNING id + """, (hostname, ip_address, datacenter, state)) + + server_id = cursor.fetchone()[0] + conn.commit() + + return server_id + + +def select_all_servers(conn: Connection) -> List[dict]: + """ + Execute SQL SELECT to retrieve all server records. + Returns list of dictionaries representing server rows. + """ + # TODO: Add pagination support (LIMIT and OFFSET parameters) + with conn.cursor() as cursor: + cursor.execute(""" + SELECT id, hostname, ip_address, datacenter, state, + created_at, updated_at + FROM servers + ORDER BY id ASC + """) + + rows = cursor.fetchall() + return [row_to_dict(cursor, row) for row in rows] + + +def select_server_by_id(conn: Connection, server_id: int) -> Optional[dict]: + """ + Execute SQL SELECT to retrieve a single server by ID. + Returns dictionary representing the server row, or None if not found. + """ + with conn.cursor() as cursor: + cursor.execute(""" + SELECT id, hostname, ip_address, datacenter, state, + created_at, updated_at + FROM servers + WHERE id = %s + """, (server_id,)) + + row = cursor.fetchone() + + if row is None: + return None + + return row_to_dict(cursor, row) + + +def select_server_by_hostname(conn: Connection, hostname: str) -> Optional[dict]: + """ + Execute SQL SELECT to retrieve a server by hostname. + Used for hostname uniqueness validation. + Returns dictionary representing the server row, or None if not found. + """ + with conn.cursor() as cursor: + cursor.execute(""" + SELECT id, hostname, ip_address, datacenter, state, + created_at, updated_at + FROM servers + WHERE hostname = %s + """, (hostname,)) + + row = cursor.fetchone() + + if row is None: + return None + + return row_to_dict(cursor, row) + + +def update_server_by_id(conn: Connection, server_id: int, **fields) -> Optional[dict]: + """ + Execute SQL UPDATE to modify server fields. + Dynamically builds UPDATE statement based on provided fields. + Returns the updated server record, or None if server not found. + """ + if not fields: + # If no fields provided, just return the current record + return select_server_by_id(conn, server_id) + + # Valid fields that can be updated + valid_fields = {'hostname', 'ip_address', 'datacenter', 'state'} + + # Filter out invalid fields + update_fields = {k: v for k, v in fields.items() if k in valid_fields} + + if not update_fields: + # If no valid fields to update, just return the current record + return select_server_by_id(conn, server_id) + + # Always update the updated_at timestamp + update_fields['updated_at'] = datetime.now() + + # Build the SET clause dynamically + set_clause = ', '.join([f"{field} = %s" for field in update_fields.keys()]) + values = list(update_fields.values()) + values.append(server_id) # Add server_id for WHERE clause + + with conn.cursor() as cursor: + cursor.execute(f""" + UPDATE servers + SET {set_clause} + WHERE id = %s + RETURNING id, hostname, ip_address, datacenter, state, + created_at, updated_at + """, values) + + row = cursor.fetchone() + conn.commit() + + if row is None: + return None + + return row_to_dict(cursor, row) + + +def delete_server_by_id(conn: Connection, server_id: int) -> bool: + """ + Execute SQL DELETE to remove a server record. + Returns True if server was deleted, False if not found. + """ + with conn.cursor() as cursor: + cursor.execute(""" + DELETE FROM servers + WHERE id = %s + RETURNING id + """, (server_id,)) + + deleted_row = cursor.fetchone() + conn.commit() + + return deleted_row is not None \ No newline at end of file diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..6dc540e --- /dev/null +++ b/app/main.py @@ -0,0 +1,232 @@ +from typing import Generator, List + +from fastapi import FastAPI, APIRouter, Depends, HTTPException, status +from psycopg2 import IntegrityError +from psycopg2.extensions import connection as Connection + +from app.db.connection import create_connection_pool, init_db +from app.db.queries import ( + insert_server, + select_all_servers, + select_server_by_id, + update_server_by_id, + delete_server_by_id, +) +from app.models.schemas import ServerCreate, ServerUpdate, ServerResponse +from app.validators import validate_hostname_unique + + +connection_pool = None # Will be initialized inside create_app + + +def create_app() -> FastAPI: + """ + Initialize and configure the FastAPI application. + Sets up middleware, exception handlers, and includes routers. + Returns the configured FastAPI instance. + """ + global connection_pool + + app = FastAPI( + title="Inventory Management API", + version="1.0.0", + description="CRUD API for managing server inventory across datacenters.", + ) + + # Initialize database pool and schema on startup + connection_pool = create_connection_pool() + init_db() + + router = APIRouter(prefix="/servers", tags=["servers"]) + + @router.post( + "", + response_model=ServerResponse, + status_code=status.HTTP_201_CREATED, + ) + async def create_server_route( + server: ServerCreate, db: Connection = Depends(get_db_connection) + ): + return await create_server(server, db) + + @router.get("", response_model=List[ServerResponse]) + async def list_servers_route(db: Connection = Depends(get_db_connection)): + return await list_servers(db) + + @router.get("/{server_id}", response_model=ServerResponse) + async def get_server_route( + server_id: int, db: Connection = Depends(get_db_connection) + ): + return await get_server(server_id, db) + + @router.put("/{server_id}", response_model=ServerResponse) + async def update_server_route( + server_id: int, server: ServerUpdate, db: Connection = Depends(get_db_connection) + ): + return await update_server(server_id, server, db) + + @router.delete("/{server_id}") + async def delete_server_route( + server_id: int, db: Connection = Depends(get_db_connection) + ): + return await delete_server(server_id, db) + + app.include_router(router) + + @app.on_event("shutdown") + def shutdown_pool() -> None: + if connection_pool: + connection_pool.closeall() + + return app + + +# app/api/routes/servers.py + +async def create_server(server: ServerCreate, db: Connection) -> ServerResponse: + """ + Create a new server in the database. + + Validates: + - hostname uniqueness + - IP address format + - state is one of: active, offline, retired + + Returns the created server with generated ID. + Raises HTTPException 400 for validation errors, 409 for duplicate hostname. + """ + # Validation handled by Pydantic model; duplicate hostname handled here + if not validate_hostname_unique(db, server.hostname): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Hostname already exists", + ) + + try: + server_id = insert_server( + db, + server.hostname, + server.ip_address, + server.datacenter, + server.state, + ) + except IntegrityError: + # Defensive: catch race conditions on hostname uniqueness + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Hostname already exists", + ) + + created = select_server_by_id(db, server_id) + return ServerResponse(**created) + + +async def list_servers(db: Connection) -> List[ServerResponse]: + """ + Retrieve all servers from the database. + Returns a list of all server records. + """ + servers = select_all_servers(db) + return [ServerResponse(**server) for server in servers] + + +async def get_server(server_id: int, db: Connection) -> ServerResponse: + """ + Retrieve a single server by ID. + Raises HTTPException 404 if server not found. + """ + server = select_server_by_id(db, server_id) + if not server: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Server not found", + ) + return ServerResponse(**server) + + +async def update_server(server_id: int, server: ServerUpdate, db: Connection) -> ServerResponse: + """ + Update an existing server's information. + + Validates: + - hostname uniqueness (if hostname is being changed) + - IP address format (if IP is being changed) + - state is valid (if state is being changed) + + Returns the updated server. + Raises HTTPException 404 if server not found, 400 for validation errors, 409 for duplicate hostname. + """ + current = select_server_by_id(db, server_id) + if not current: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Server not found", + ) + + update_data = server.model_dump(exclude_unset=True) + + if "hostname" in update_data and not validate_hostname_unique( + db, update_data["hostname"], exclude_id=server_id + ): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Hostname already exists", + ) + + try: + updated = update_server_by_id(db, server_id, **update_data) + except IntegrityError: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Hostname already exists", + ) + + if not updated: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Server not found", + ) + + return ServerResponse(**updated) + + +async def delete_server(server_id: int, db: Connection) -> dict: + """ + Delete a server from the database. + Returns a success message. + Raises HTTPException 404 if server not found. + """ + deleted = delete_server_by_id(db, server_id) + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Server not found", + ) + return {"detail": "Server deleted"} + + +# app/api/dependencies.py + +async def get_db_connection() -> Generator[Connection, None, None]: + """ + Dependency that provides a database connection. + Handles connection creation, transaction management, and cleanup. + Yields a psycopg2 connection object. + """ + if connection_pool is None: + # Lazily initialize if create_app hasn't been called (e.g. during tests) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Database connection pool not initialized", + ) + + conn = connection_pool.getconn() + try: + yield conn + finally: + conn.rollback() + connection_pool.putconn(conn) + + +# Instantiate the FastAPI application for ASGI servers +app = create_app() \ No newline at end of file diff --git a/app/models/schemas.py b/app/models/schemas.py new file mode 100644 index 0000000..83c0f5c --- /dev/null +++ b/app/models/schemas.py @@ -0,0 +1,78 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, Field, ConfigDict, field_validator + +from app.validators import ( + ALLOWED_STATES, + validate_ip_address as is_valid_ip, + validate_state as is_valid_state, +) + + +class ServerBase(BaseModel): + """ + Base Pydantic model with common server fields. + """ + + hostname: str = Field(..., min_length=1, max_length=255) + ip_address: str = Field(..., min_length=7, max_length=45) + datacenter: str = Field(..., min_length=1, max_length=255) + state: str = Field(..., description="Server lifecycle state") + + @field_validator("state") + def validate_state(cls, value: str) -> str: + if not is_valid_state(value): + raise ValueError(f"Invalid state '{value}'. Allowed: {sorted(ALLOWED_STATES)}") + return value + + @field_validator("ip_address") + def validate_ip(cls, value: str) -> str: + if not is_valid_ip(value): + raise ValueError("Invalid IP address format") + return value + + +class ServerCreate(ServerBase): + """ + Pydantic model for server creation requests. + """ + + +class ServerUpdate(BaseModel): + """ + Pydantic model for server update requests. + All fields are optional to allow partial updates. + """ + + hostname: Optional[str] = Field(None, min_length=1, max_length=255) + ip_address: Optional[str] = Field(None, min_length=7, max_length=45) + datacenter: Optional[str] = Field(None, min_length=1, max_length=255) + state: Optional[str] = Field(None, description="Server lifecycle state") + + @field_validator("state") + def validate_state(cls, value: Optional[str]) -> Optional[str]: + if value is not None and not is_valid_state(value): + raise ValueError(f"Invalid state '{value}'. Allowed: {sorted(ALLOWED_STATES)}") + return value + + @field_validator("ip_address") + def validate_ip(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return value + if not is_valid_ip(value): + raise ValueError("Invalid IP address format") + return value + + +class ServerResponse(ServerBase): + """ + Pydantic model for server responses. + Includes all fields plus id and timestamps. + """ + + id: int + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/app/validators.py b/app/validators.py new file mode 100644 index 0000000..3f5ca65 --- /dev/null +++ b/app/validators.py @@ -0,0 +1,45 @@ +from typing import Optional, Set +import ipaddress + +from psycopg2.extensions import connection as Connection + +from app.db.queries import select_server_by_hostname + +# Allowed lifecycle states for servers +ALLOWED_STATES: Set[str] = {"active", "offline", "retired"} + + +def validate_ip_address(ip_address: str) -> bool: + """ + Validate that a string is a valid IPv4 or IPv6 address. + Returns True if valid, False otherwise. + """ + try: + ipaddress.ip_address(ip_address) + return True + except ValueError: + return False + + +def validate_state(state: str) -> bool: + """ + Validate that state is one of the allowed values. + Returns True if valid, False otherwise. + """ + return state in ALLOWED_STATES + + +def validate_hostname_unique( + conn: Connection, hostname: str, exclude_id: Optional[int] = None +) -> bool: + """ + Check if hostname is unique in the database. + exclude_id parameter is used during updates to exclude the current server. + Returns True if unique, False if duplicate exists. + """ + existing = select_server_by_hostname(conn, hostname) + if not existing: + return True + if exclude_id is not None and existing["id"] == exclude_id: + return True + return False \ No newline at end of file diff --git a/cli.sh b/cli.sh new file mode 100755 index 0000000..37b3cc4 --- /dev/null +++ b/cli.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Simple wrapper so you can call the CLI without remembering `python -m`. +# It ensures the project root is on PYTHONPATH, then delegates to the Typer app. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$SCRIPT_DIR" + +# Prepend repo root to PYTHONPATH if it's not already there +case ":${PYTHONPATH-}:" in + *":$REPO_ROOT:"*) ;; + *) export PYTHONPATH="$REPO_ROOT${PYTHONPATH:+:$PYTHONPATH}" ;; +esac + +python -m cli.main "$@" + diff --git a/cli/main.py b/cli/main.py new file mode 100644 index 0000000..115a171 --- /dev/null +++ b/cli/main.py @@ -0,0 +1,220 @@ +import json +import os +from typing import List, Optional + +import httpx +import typer + + +def create_cli_app() -> typer.Typer: + """ + Initialize and configure the Typer CLI application. + Registers all command groups and sets up configuration. + Returns the configured Typer instance. + """ + app = typer.Typer(help="CLI for the Inventory Management API") + + app.command("create")(create_server_command) + app.command("list")(list_servers_command) + app.command("get")(get_server_command) + app.command("update")(update_server_command) + app.command("delete")(delete_server_command) + + return app + + +def get_api_base_url() -> str: + """ + Get the API base URL from environment variable or use default. + Returns the base URL for making API requests. + """ + base = os.getenv("API_BASE_URL", "http://localhost:8000") + return base.rstrip("/") + + +# cli/commands.py + +def create_server_command( + hostname: str = typer.Argument(..., help="Unique hostname"), + ip_address: str = typer.Argument(..., help="IPv4/IPv6 address"), + datacenter: str = typer.Argument(..., help="Datacenter identifier"), + state: str = typer.Argument(..., help="Server state (active/offline/retired)"), +) -> None: + """ + CLI command to create a new server. + Makes POST request to /servers endpoint. + Displays the created server or error message. + """ + payload = { + "hostname": hostname, + "ip_address": ip_address, + "datacenter": datacenter, + "state": state, + } + + with httpx.Client(base_url=get_api_base_url(), timeout=5.0) as client: + try: + resp = client.post("/servers", json=payload) + except httpx.RequestError as exc: + typer.echo(f"Request failed: {exc}") + raise typer.Exit(code=1) + + if resp.is_success: + typer.echo(json.dumps(resp.json(), indent=2)) + else: + handle_api_error(resp) + + +def list_servers_command() -> None: + """ + CLI command to list all servers. + Makes GET request to /servers endpoint. + Displays servers in a formatted table. + """ + with httpx.Client(base_url=get_api_base_url(), timeout=5.0) as client: + try: + resp = client.get("/servers") + except httpx.RequestError as exc: + typer.echo(f"Request failed: {exc}") + raise typer.Exit(code=1) + + if resp.is_success: + format_server_table(resp.json()) + else: + handle_api_error(resp) + + +def get_server_command(server_id: int = typer.Argument(..., help="Server ID")) -> None: + """ + CLI command to retrieve a single server by ID. + Makes GET request to /servers/{id} endpoint. + Displays the server details or error message. + """ + with httpx.Client(base_url=get_api_base_url(), timeout=5.0) as client: + try: + resp = client.get(f"/servers/{server_id}") + except httpx.RequestError as exc: + typer.echo(f"Request failed: {exc}") + raise typer.Exit(code=1) + + if resp.is_success: + typer.echo(json.dumps(resp.json(), indent=2)) + else: + handle_api_error(resp) + + +def update_server_command( + server_id: int = typer.Argument(..., help="Server ID"), + hostname: Optional[str] = typer.Option(None, "--hostname", help="New hostname"), + ip_address: Optional[str] = typer.Option(None, "--ip-address", help="New IP address"), + datacenter: Optional[str] = typer.Option(None, "--datacenter", help="New datacenter"), + state: Optional[str] = typer.Option(None, "--state", help="New state (active/offline/retired)"), +) -> None: + """ + CLI command to update a server. + Builds update payload from provided arguments. + Makes PUT request to /servers/{id} endpoint. + Displays the updated server or error message. + """ + payload = { + "hostname": hostname, + "ip_address": ip_address, + "datacenter": datacenter, + "state": state, + } + payload = {k: v for k, v in payload.items() if v is not None} + + if not payload: + typer.echo("No fields provided to update.") + raise typer.Exit(code=1) + + with httpx.Client(base_url=get_api_base_url(), timeout=5.0) as client: + try: + resp = client.put(f"/servers/{server_id}", json=payload) + except httpx.RequestError as exc: + typer.echo(f"Request failed: {exc}") + raise typer.Exit(code=1) + + if resp.is_success: + typer.echo(json.dumps(resp.json(), indent=2)) + else: + handle_api_error(resp) + + +def delete_server_command( + server_id: int = typer.Argument(..., help="Server ID"), + confirm: bool = typer.Option(False, "--confirm", help="Skip confirmation prompt", is_flag=True), +) -> None: + """ + CLI command to delete a server. + Optionally prompts for confirmation if --confirm flag not provided. + Makes DELETE request to /servers/{id} endpoint. + Displays success or error message. + """ + if not confirm: + confirmed = typer.confirm(f"Delete server {server_id}?") + if not confirmed: + typer.echo("Cancelled.") + raise typer.Exit() + + with httpx.Client(base_url=get_api_base_url(), timeout=5.0) as client: + try: + resp = client.delete(f"/servers/{server_id}") + except httpx.RequestError as exc: + typer.echo(f"Request failed: {exc}") + raise typer.Exit(code=1) + + if resp.is_success: + typer.echo(resp.json().get("detail", "Deleted")) + else: + handle_api_error(resp) + + +# cli/utils.py + +def format_server_table(servers: List[dict]) -> None: + """ + Format and display a list of servers as a table. + Uses simple fixed-width columns to avoid extra dependencies. + """ + if not servers: + typer.echo("No servers found.") + return + + headers = ["id", "hostname", "ip_address", "datacenter", "state"] + col_widths = {h: max(len(h), max(len(str(item.get(h, ""))) for item in servers)) for h in headers} + + def _row(values): + return " ".join(str(values[h]).ljust(col_widths[h]) for h in headers) + + typer.echo(_row({h: h for h in headers})) + typer.echo("-" * (sum(col_widths.values()) + 8)) + for item in servers: + typer.echo(_row(item)) + + +def handle_api_error(response: httpx.Response) -> None: + """ + Parse and display API error responses in a user-friendly format. + Extracts error details from response and formats for CLI output. + """ + try: + data = response.json() + except ValueError: + typer.echo(f"Error {response.status_code}: {response.text}") + raise typer.Exit(code=1) + + detail = data.get("detail") if isinstance(data, dict) else None + if detail: + typer.echo(f"Error {response.status_code}: {detail}") + else: + typer.echo(f"Error {response.status_code}: {json.dumps(data, indent=2)}") + + raise typer.Exit(code=1) + + +app = create_cli_app() + + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..36cb53c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,46 @@ +version: "3.9" + +services: + db: + image: postgres:16 + container_name: ben-postgres + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: inventory + ports: + - "5432:5432" + volumes: + - db_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 5s + timeout: 5s + retries: 5 + start_period: 10s + + api: + image: python:3.12-slim + container_name: ben-inventory-api + working_dir: /app + depends_on: + db: + condition: service_healthy + environment: + DB_HOST: db + DB_PORT: 5432 + DB_NAME: inventory + DB_USER: postgres + DB_PASSWORD: postgres + PYTHONUNBUFFERED: "1" + volumes: + - ./:/app + command: > + sh -c "pip install --no-cache-dir -r requirements.txt && + uvicorn app.main:app --host 0.0.0.0 --port 8000" + ports: + - "8000:8000" + +volumes: + db_data: + diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..55300ca --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +# pytest.ini +[pytest] +testpaths = tests +pythonpath = . +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --tb=short \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e9dc4ec --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +# requirements.txt +psycopg2-binary==2.9.10 +pytest==8.3.4 +pytest-cov==6.0.0 +pytest-sugar==1.0.0 +pydantic==2.12.5 +fastapi==0.124.4 +httpx==0.28.1 +typer==0.12.5 +click==8.1.7 +uvicorn==0.32.0 \ No newline at end of file diff --git a/tests/db/__init__.py b/tests/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/db/conftest.py b/tests/db/conftest.py new file mode 100644 index 0000000..ce56ab0 --- /dev/null +++ b/tests/db/conftest.py @@ -0,0 +1,157 @@ +# tests/db/conftest.py +import pytest +import os +import psycopg2 +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +from psycopg2.pool import SimpleConnectionPool + + +def create_test_database_if_not_exists(): + """Create test database if it doesn't exist.""" + db_host = os.getenv("TEST_DB_HOST", "localhost") + db_port = os.getenv("TEST_DB_PORT", "5432") + db_name = os.getenv("TEST_DB_NAME", "inventory_test") + db_user = os.getenv("TEST_DB_USER", "postgres") + db_password = os.getenv("TEST_DB_PASSWORD", "postgres") + + try: + # Connect to default postgres database + conn = psycopg2.connect( + host=db_host, + port=db_port, + user=db_user, + password=db_password, + database="postgres" + ) + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + cursor = conn.cursor() + + # Check if database exists + cursor.execute( + "SELECT 1 FROM pg_database WHERE datname = %s", + (db_name,) + ) + exists = cursor.fetchone() + + if not exists: + cursor.execute(f'CREATE DATABASE {db_name}') + print(f"\nCreated test database: {db_name}") + + cursor.close() + conn.close() + + except psycopg2.Error as e: + print(f"\nWarning: Could not create test database: {e}") + print("Please create it manually with: createdb inventory_test") + + +@pytest.fixture(scope="session", autouse=True) +def setup_test_database(): + """ + Automatically create test database before any tests run. + This fixture runs once per test session. + """ + create_test_database_if_not_exists() + yield + + +@pytest.fixture(scope="session") +def test_db_pool(setup_test_database): + """ + Create a connection pool for the test database. + This fixture is session-scoped, so it's created once per test session. + """ + # Set test database environment variables + os.environ["DB_HOST"] = os.getenv("TEST_DB_HOST", "localhost") + os.environ["DB_PORT"] = os.getenv("TEST_DB_PORT", "5432") + os.environ["DB_NAME"] = os.getenv("TEST_DB_NAME", "inventory_test") + os.environ["DB_USER"] = os.getenv("TEST_DB_USER", "postgres") + os.environ["DB_PASSWORD"] = os.getenv("TEST_DB_PASSWORD", "postgres") + + from app.db.connection import create_connection_pool + + pool = create_connection_pool() + + yield pool + + # Cleanup: close all connections + pool.closeall() + + +@pytest.fixture(scope="session") +def init_test_db(test_db_pool): + """ + Initialize the test database schema. + This runs once before all tests. + """ + from app.db.connection import init_db + + init_db() + + yield + + +@pytest.fixture +def db_conn(test_db_pool, init_test_db): + """ + Provide a database connection for each test. + This fixture is function-scoped, so each test gets a fresh connection. + """ + conn = test_db_pool.getconn() + + yield conn + + # Cleanup disabled temporarily to inspect DB state after tests. + # Remember to re-enable rollback + truncation when finished. + conn.rollback() + with conn.cursor() as cursor: + cursor.execute("TRUNCATE TABLE servers RESTART IDENTITY CASCADE") + conn.commit() + + test_db_pool.putconn(conn) + + +@pytest.fixture +def sample_server_data(): + """ + Provide sample server data for testing. + """ + return { + "hostname": "test-server-01", + "ip_address": "192.168.1.100", + "datacenter": "US-EAST", + "state": "active" + } + + +@pytest.fixture +def multiple_servers_data(): + """ + Provide multiple server records for testing. + """ + return [ + { + "hostname": "web-server-01", + "ip_address": "10.0.1.10", + "datacenter": "US-EAST", + "state": "active" + }, + { + "hostname": "web-server-02", + "ip_address": "10.0.1.11", + "datacenter": "US-WEST", + "state": "active" + }, + { + "hostname": "db-server-01", + "ip_address": "10.0.2.10", + "datacenter": "EU-CENTRAL", + "state": "offline" + }, + { + "hostname": "cache-server-01", + "ip_address": "10.0.3.10", + "datacenter": "ASIA-PACIFIC", + "state": "retired" + } + ] \ No newline at end of file diff --git a/tests/db/test_connection.py b/tests/db/test_connection.py new file mode 100644 index 0000000..65016c5 --- /dev/null +++ b/tests/db/test_connection.py @@ -0,0 +1,146 @@ +import pytest +import os +from pathlib import Path +from typing import Dict +from psycopg2.pool import SimpleConnectionPool +from app.db.connection import get_database_url, create_connection_pool, init_db + + +def load_test_env(env_filename: str = ".env.test") -> Dict[str, str]: + """ + Load key/value pairs from the provided env file into os.environ. + Returns a dict of loaded values for easy access in assertions. + """ + required_keys = {"DB_HOST", "DB_PORT", "DB_NAME", "DB_USER", "DB_PASSWORD"} + env_path = Path(__file__).resolve().parents[2] / env_filename + if not env_path.exists(): + raise FileNotFoundError(f"Env file not found: {env_path}") + + env_vars: Dict[str, str] = {} + with env_path.open() as env_file: + for raw_line in env_file: + line = raw_line.strip() + if not line or line.startswith("#"): + continue + + if line.startswith("export "): + line = line[len("export ") :] + + key, sep, value = line.partition("=") + if not sep: + continue + + cleaned_value = value.strip().strip('"').strip("'") + env_vars[key.strip()] = cleaned_value + os.environ[key.strip()] = cleaned_value + + missing = required_keys - set(env_vars) + if missing: + raise KeyError(f"Missing required env keys in {env_path}: {sorted(missing)}") + + return env_vars + + +TEST_ENV_VARS = load_test_env() + + +def reset_test_env() -> Dict[str, str]: + """Reload test env vars into os.environ and update the cached dict.""" + global TEST_ENV_VARS + TEST_ENV_VARS = load_test_env() + return TEST_ENV_VARS + + +def test_get_database_url(): + """Test that database URL is constructed correctly from environment variables.""" + expected_url = ( + f"postgresql://{TEST_ENV_VARS['DB_USER']}:{TEST_ENV_VARS['DB_PASSWORD']}" + f"@{TEST_ENV_VARS['DB_HOST']}:{TEST_ENV_VARS['DB_PORT']}/{TEST_ENV_VARS['DB_NAME']}" + ) + actual_url = get_database_url() + + assert actual_url == expected_url + + +def test_get_database_url_with_defaults(): + """Test that database URL uses default values when env vars are not set.""" + # Clear environment variables + for key in ["DB_HOST", "DB_PORT", "DB_NAME", "DB_USER", "DB_PASSWORD"]: + if key in os.environ: + del os.environ[key] + + expected_url = "postgresql://postgres:postgres@localhost:5432/inventory" + actual_url = get_database_url() + + assert actual_url == expected_url + + # Restore test environment for subsequent tests + reset_test_env() + + +def test_create_connection_pool(): + """Test that connection pool is created successfully.""" + pool = create_connection_pool() + + assert isinstance(pool, SimpleConnectionPool) + assert pool.minconn == 1 + assert pool.maxconn == 10 + + # Test that we can get a connection from the pool + conn = pool.getconn() + assert conn is not None + + # Return connection and close pool + pool.putconn(conn) + pool.closeall() + + +def test_init_db(test_db_pool): + """Test that database initialization creates the servers table.""" + init_db() + + conn = test_db_pool.getconn() + + try: + with conn.cursor() as cursor: + # Check that servers table exists + cursor.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'servers' + ) + """) + + table_exists = cursor.fetchone()[0] + assert table_exists is True + + # Check that the table has the correct columns + cursor.execute(""" + SELECT column_name, data_type, is_nullable + FROM information_schema.columns + WHERE table_name = 'servers' + ORDER BY ordinal_position + """) + + columns = cursor.fetchall() + column_names = [col[0] for col in columns] + + expected_columns = [ + 'id', 'hostname', 'ip_address', 'datacenter', + 'state', 'created_at', 'updated_at' + ] + + assert column_names == expected_columns + + # Check that hostname has unique constraint + cursor.execute(""" + SELECT constraint_name, constraint_type + FROM information_schema.table_constraints + WHERE table_name = 'servers' AND constraint_type = 'UNIQUE' + """) + + unique_constraints = cursor.fetchall() + assert len(unique_constraints) > 0 + + finally: + test_db_pool.putconn(conn) \ No newline at end of file diff --git a/tests/db/test_queries.py b/tests/db/test_queries.py new file mode 100644 index 0000000..03e1c7c --- /dev/null +++ b/tests/db/test_queries.py @@ -0,0 +1,449 @@ +import pytest +from datetime import datetime +from psycopg2 import IntegrityError +from app.db.queries import ( + insert_server, + select_all_servers, + select_server_by_id, + select_server_by_hostname, + update_server_by_id, + delete_server_by_id, + row_to_dict +) + + +class TestInsertServer: + """Tests for insert_server function.""" + + def test_insert_server_success(self, db_conn, sample_server_data): + """Test successful server insertion.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + assert isinstance(server_id, int) + assert server_id > 0 + + def test_insert_server_duplicate_hostname(self, db_conn, sample_server_data): + """Test that inserting duplicate hostname raises IntegrityError.""" + # Insert first server + insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + # Try to insert server with same hostname + with pytest.raises(IntegrityError): + insert_server( + db_conn, + sample_server_data["hostname"], # Same hostname + "192.168.1.101", # Different IP + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + def test_insert_server_creates_timestamps(self, db_conn, sample_server_data): + """Test that created_at and updated_at are set automatically.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + server = select_server_by_id(db_conn, server_id) + + assert server["created_at"] is not None + assert server["updated_at"] is not None + assert isinstance(server["created_at"], datetime) + assert isinstance(server["updated_at"], datetime) + + +class TestSelectAllServers: + """Tests for select_all_servers function.""" + + def test_select_all_servers_empty(self, db_conn): + """Test selecting all servers when table is empty.""" + servers = select_all_servers(db_conn) + + assert isinstance(servers, list) + assert len(servers) == 0 + + def test_select_all_servers_single(self, db_conn, sample_server_data): + """Test selecting all servers with one server.""" + insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + servers = select_all_servers(db_conn) + + assert len(servers) == 1 + assert servers[0]["hostname"] == sample_server_data["hostname"] + assert servers[0]["ip_address"] == sample_server_data["ip_address"] + + def test_select_all_servers_multiple(self, db_conn, multiple_servers_data): + """Test selecting all servers with multiple servers.""" + for server_data in multiple_servers_data: + insert_server( + db_conn, + server_data["hostname"], + server_data["ip_address"], + server_data["datacenter"], + server_data["state"] + ) + + servers = select_all_servers(db_conn) + + assert len(servers) == len(multiple_servers_data) + + # Verify all servers are present + hostnames = [s["hostname"] for s in servers] + expected_hostnames = [s["hostname"] for s in multiple_servers_data] + + assert set(hostnames) == set(expected_hostnames) + + def test_select_all_servers_ordered_by_id(self, db_conn, multiple_servers_data): + """Test that servers are returned ordered by id.""" + inserted_ids = [] + for server_data in multiple_servers_data: + server_id = insert_server( + db_conn, + server_data["hostname"], + server_data["ip_address"], + server_data["datacenter"], + server_data["state"] + ) + inserted_ids.append(server_id) + + servers = select_all_servers(db_conn) + returned_ids = [s["id"] for s in servers] + + assert returned_ids == sorted(inserted_ids) + + +class TestSelectServerById: + """Tests for select_server_by_id function.""" + + def test_select_server_by_id_success(self, db_conn, sample_server_data): + """Test successful server selection by ID.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + server = select_server_by_id(db_conn, server_id) + + assert server is not None + assert server["id"] == server_id + assert server["hostname"] == sample_server_data["hostname"] + assert server["ip_address"] == sample_server_data["ip_address"] + assert server["datacenter"] == sample_server_data["datacenter"] + assert server["state"] == sample_server_data["state"] + + def test_select_server_by_id_not_found(self, db_conn): + """Test selecting server with non-existent ID.""" + server = select_server_by_id(db_conn, 99999) + + assert server is None + + def test_select_server_by_id_returns_all_fields(self, db_conn, sample_server_data): + """Test that all fields are returned.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + server = select_server_by_id(db_conn, server_id) + + expected_fields = [ + "id", "hostname", "ip_address", "datacenter", + "state", "created_at", "updated_at" + ] + + for field in expected_fields: + assert field in server + + +class TestSelectServerByHostname: + """Tests for select_server_by_hostname function.""" + + def test_select_server_by_hostname_success(self, db_conn, sample_server_data): + """Test successful server selection by hostname.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + server = select_server_by_hostname(db_conn, sample_server_data["hostname"]) + + assert server is not None + assert server["id"] == server_id + assert server["hostname"] == sample_server_data["hostname"] + + def test_select_server_by_hostname_not_found(self, db_conn): + """Test selecting server with non-existent hostname.""" + server = select_server_by_hostname(db_conn, "non-existent-hostname") + + assert server is None + + def test_select_server_by_hostname_case_sensitive(self, db_conn, sample_server_data): + """Test that hostname search is case-sensitive.""" + insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + # Try with different case + server = select_server_by_hostname( + db_conn, + sample_server_data["hostname"].upper() + ) + + # This should not find the server (case-sensitive) + assert server is None + + +class TestUpdateServerById: + """Tests for update_server_by_id function.""" + + def test_update_server_single_field(self, db_conn, sample_server_data): + """Test updating a single field.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + updated_server = update_server_by_id(db_conn, server_id, state="offline") + + assert updated_server is not None + assert updated_server["state"] == "offline" + assert updated_server["hostname"] == sample_server_data["hostname"] + assert updated_server["ip_address"] == sample_server_data["ip_address"] + + def test_update_server_multiple_fields(self, db_conn, sample_server_data): + """Test updating multiple fields.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + updated_server = update_server_by_id( + db_conn, + server_id, + state="retired", + datacenter="EU-WEST", + ip_address="10.0.0.50" + ) + + assert updated_server is not None + assert updated_server["state"] == "retired" + assert updated_server["datacenter"] == "EU-WEST" + assert updated_server["ip_address"] == "10.0.0.50" + assert updated_server["hostname"] == sample_server_data["hostname"] + + def test_update_server_updates_timestamp(self, db_conn, sample_server_data): + """Test that updated_at timestamp is modified.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + original_server = select_server_by_id(db_conn, server_id) + original_updated_at = original_server["updated_at"] + + # Small delay to ensure timestamp difference + import time + time.sleep(0.1) + + updated_server = update_server_by_id(db_conn, server_id, state="offline") + + assert updated_server["updated_at"] > original_updated_at + + def test_update_server_not_found(self, db_conn): + """Test updating non-existent server.""" + updated_server = update_server_by_id(db_conn, 99999, state="offline") + + assert updated_server is None + + def test_update_server_no_fields(self, db_conn, sample_server_data): + """Test update with no fields returns current record.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + server = update_server_by_id(db_conn, server_id) + + assert server is not None + assert server["id"] == server_id + + def test_update_server_invalid_fields_ignored(self, db_conn, sample_server_data): + """Test that invalid fields are ignored.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + # Try to update with invalid field + updated_server = update_server_by_id( + db_conn, + server_id, + state="offline", + invalid_field="should_be_ignored" + ) + + assert updated_server is not None + assert updated_server["state"] == "offline" + assert "invalid_field" not in updated_server + + def test_update_server_duplicate_hostname(self, db_conn, multiple_servers_data): + """Test that updating to duplicate hostname raises IntegrityError.""" + # Insert two servers + server_id_1 = insert_server( + db_conn, + multiple_servers_data[0]["hostname"], + multiple_servers_data[0]["ip_address"], + multiple_servers_data[0]["datacenter"], + multiple_servers_data[0]["state"] + ) + + server_id_2 = insert_server( + db_conn, + multiple_servers_data[1]["hostname"], + multiple_servers_data[1]["ip_address"], + multiple_servers_data[1]["datacenter"], + multiple_servers_data[1]["state"] + ) + + # Try to update server_id_2 with hostname of server_id_1 + with pytest.raises(IntegrityError): + update_server_by_id( + db_conn, + server_id_2, + hostname=multiple_servers_data[0]["hostname"] + ) + + +class TestDeleteServerById: + """Tests for delete_server_by_id function.""" + + def test_delete_server_success(self, db_conn, sample_server_data): + """Test successful server deletion.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + result = delete_server_by_id(db_conn, server_id) + + assert result is True + + # Verify server is deleted + server = select_server_by_id(db_conn, server_id) + assert server is None + + def test_delete_server_not_found(self, db_conn): + """Test deleting non-existent server.""" + result = delete_server_by_id(db_conn, 99999) + + assert result is False + + def test_delete_server_removes_from_list(self, db_conn, multiple_servers_data): + """Test that deleted server is removed from select_all_servers.""" + # Insert multiple servers + server_ids = [] + for server_data in multiple_servers_data: + server_id = insert_server( + db_conn, + server_data["hostname"], + server_data["ip_address"], + server_data["datacenter"], + server_data["state"] + ) + server_ids.append(server_id) + + # Delete one server + delete_server_by_id(db_conn, server_ids[0]) + + # Check remaining servers + servers = select_all_servers(db_conn) + + assert len(servers) == len(multiple_servers_data) - 1 + + returned_ids = [s["id"] for s in servers] + assert server_ids[0] not in returned_ids + + +class TestRowToDict: + """Tests for row_to_dict helper function.""" + + def test_row_to_dict_converts_correctly(self, db_conn, sample_server_data): + """Test that row_to_dict converts rows to dictionaries correctly.""" + server_id = insert_server( + db_conn, + sample_server_data["hostname"], + sample_server_data["ip_address"], + sample_server_data["datacenter"], + sample_server_data["state"] + ) + + with db_conn.cursor() as cursor: + cursor.execute("SELECT * FROM servers WHERE id = %s", (server_id,)) + row = cursor.fetchone() + result = row_to_dict(cursor, row) + + assert isinstance(result, dict) + assert "id" in result + assert "hostname" in result + assert result["id"] == server_id + + def test_row_to_dict_none_input(self, db_conn): + """Test that row_to_dict handles None input.""" + with db_conn.cursor() as cursor: + cursor.execute("SELECT * FROM servers WHERE id = %s", (99999,)) + row = cursor.fetchone() + result = row_to_dict(cursor, row) + + assert result is None \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..2686db2 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,321 @@ +import importlib +from datetime import datetime, timedelta +from typing import Any, Dict + +import pytest +from fastapi.testclient import TestClient +from psycopg2 import IntegrityError + + +def _now_iso() -> str: + return datetime.utcnow().isoformat() + + +class FakeConnection: + def rollback(self) -> None: + return None + + +class FakePool: + def __init__(self) -> None: + self._closed = False + self._connections = [] + + def getconn(self) -> Any: + conn = FakeConnection() + self._connections.append(conn) + return conn + + def putconn(self, conn: Any) -> None: + if conn in self._connections: + self._connections.remove(conn) + + def closeall(self) -> None: + self._closed = True + + +@pytest.fixture +def api_client(monkeypatch): + """ + Provide a TestClient with all DB interactions stubbed to avoid real IO. + Each test gets isolated in-memory state. + """ + + # Track create_connection_pool calls + call_counts = {"create_pool": 0} + fake_pool = FakePool() + + def fake_create_pool(): + call_counts["create_pool"] += 1 + return fake_pool + + # Patch startup dependencies before importing the app + monkeypatch.setattr("app.db.connection.create_connection_pool", fake_create_pool) + monkeypatch.setattr("app.db.connection.init_db", lambda: None) + + main_module = importlib.import_module("app.main") + + # Fresh in-memory store + store: Dict[int, Dict[str, Any]] = {} + next_id = {"value": 1} + + def make_record(hostname: str, ip: str, dc: str, state: str, id_override=None): + record_id = id_override or next_id["value"] + now = _now_iso() + record = { + "id": record_id, + "hostname": hostname, + "ip_address": ip, + "datacenter": dc, + "state": state, + "created_at": now, + "updated_at": now, + } + store[record_id] = record + next_id["value"] = record_id + 1 + return record_id + + def insert_server(conn, hostname, ip, dc, state): + return make_record(hostname, ip, dc, state) + + def select_all_servers(conn): + return list(store.values()) + + def select_server_by_id(conn, server_id: int): + return store.get(server_id) + + def update_server_by_id(conn, server_id: int, **updates): + record = store.get(server_id) + if not record: + return None + record = {**record, **updates, "updated_at": _now_iso()} + store[server_id] = record + return record + + def delete_server_by_id(conn, server_id: int): + return store.pop(server_id, None) is not None + + def validate_hostname_unique(conn, hostname: str, exclude_id=None): + for rid, record in store.items(): + if record["hostname"] == hostname and rid != exclude_id: + return False + return True + + # Bind stubs into the main module + monkeypatch.setattr(main_module, "connection_pool", fake_pool) + monkeypatch.setattr(main_module, "insert_server", insert_server) + monkeypatch.setattr(main_module, "select_all_servers", select_all_servers) + monkeypatch.setattr(main_module, "select_server_by_id", select_server_by_id) + monkeypatch.setattr(main_module, "update_server_by_id", update_server_by_id) + monkeypatch.setattr(main_module, "delete_server_by_id", delete_server_by_id) + monkeypatch.setattr(main_module, "validate_hostname_unique", validate_hostname_unique) + + client = TestClient(main_module.app) + + yield client, store, call_counts, fake_pool, main_module + + client.close() + + +def test_startup_initializes_pool_once(api_client): + _, _, call_counts, _, _ = api_client + assert call_counts["create_pool"] == 1 + + +def test_shutdown_closes_pool(api_client): + client, _, _, pool, _ = api_client + with client: + client.get("/servers") + assert pool._closed is True + + +def test_create_server_success(api_client): + client, _, _, _, _ = api_client + payload = { + "hostname": "srv-1", + "ip_address": "10.0.0.1", + "datacenter": "us-east", + "state": "active", + } + resp = client.post("/servers", json=payload) + assert resp.status_code == 201 + body = resp.json() + assert body["hostname"] == payload["hostname"] + assert body["state"] == "active" + + +def test_create_server_conflict_on_hostname(api_client, monkeypatch): + client, store, _, _, main_module = api_client + + # Seed an existing record + store[1] = { + "id": 1, + "hostname": "srv-1", + "ip_address": "10.0.0.1", + "datacenter": "us-east", + "state": "active", + "created_at": _now_iso(), + "updated_at": _now_iso(), + } + + monkeypatch.setattr(main_module, "validate_hostname_unique", lambda conn, hostname, exclude_id=None: False) + + payload = { + "hostname": "srv-1", + "ip_address": "10.0.0.2", + "datacenter": "us-east", + "state": "active", + } + resp = client.post("/servers", json=payload) + assert resp.status_code == 409 + + +def test_create_server_integrity_error(api_client, monkeypatch): + client, _, _, _, main_module = api_client + + def boom(*args, **kwargs): + raise IntegrityError("duplicate", None, None) + + monkeypatch.setattr(main_module, "insert_server", boom) + + payload = { + "hostname": "srv-2", + "ip_address": "10.0.0.3", + "datacenter": "us-east", + "state": "active", + } + resp = client.post("/servers", json=payload) + assert resp.status_code == 409 + + +def test_list_servers_empty(api_client): + client, _, _, _, _ = api_client + resp = client.get("/servers") + assert resp.status_code == 200 + assert resp.json() == [] + + +def test_list_servers_populated(api_client, monkeypatch): + client, _, _, _, main_module = api_client + + def fake_select_all(conn): + now = _now_iso() + return [ + {"id": 1, "hostname": "srv-1", "ip_address": "10.0.0.1", "datacenter": "dc", "state": "active", "created_at": now, "updated_at": now}, + {"id": 2, "hostname": "srv-2", "ip_address": "10.0.0.2", "datacenter": "dc", "state": "offline", "created_at": now, "updated_at": now}, + ] + + monkeypatch.setattr(main_module, "select_all_servers", fake_select_all) + + resp = client.get("/servers") + assert resp.status_code == 200 + body = resp.json() + assert len(body) == 2 + assert body[0]["hostname"] == "srv-1" + + +def test_get_server_found(api_client, monkeypatch): + client, _, _, _, main_module = api_client + now = _now_iso() + + monkeypatch.setattr( + main_module, + "select_server_by_id", + lambda conn, sid: {"id": sid, "hostname": "srv-1", "ip_address": "10.0.0.1", "datacenter": "dc", "state": "active", "created_at": now, "updated_at": now}, + ) + + resp = client.get("/servers/1") + assert resp.status_code == 200 + assert resp.json()["id"] == 1 + + +def test_get_server_not_found(api_client, monkeypatch): + client, _, _, _, main_module = api_client + monkeypatch.setattr(main_module, "select_server_by_id", lambda conn, sid: None) + resp = client.get("/servers/999") + assert resp.status_code == 404 + + +def test_update_server_success(api_client, monkeypatch): + client, _, _, _, main_module = api_client + now = _now_iso() + base = {"id": 1, "hostname": "srv-1", "ip_address": "10.0.0.1", "datacenter": "dc", "state": "active", "created_at": now, "updated_at": now} + + monkeypatch.setattr(main_module, "select_server_by_id", lambda conn, sid: base if sid == 1 else None) + monkeypatch.setattr(main_module, "update_server_by_id", lambda conn, sid, **u: {**base, **u, "updated_at": _now_iso()}) + + resp = client.put("/servers/1", json={"state": "offline"}) + assert resp.status_code == 200 + assert resp.json()["state"] == "offline" + + +def test_update_server_not_found(api_client, monkeypatch): + client, _, _, _, main_module = api_client + monkeypatch.setattr(main_module, "select_server_by_id", lambda conn, sid: None) + resp = client.put("/servers/123", json={"state": "offline"}) + assert resp.status_code == 404 + + +def test_update_server_conflict(api_client, monkeypatch): + client, _, _, _, main_module = api_client + now = _now_iso() + base = {"id": 1, "hostname": "srv-1", "ip_address": "10.0.0.1", "datacenter": "dc", "state": "active", "created_at": now, "updated_at": now} + + monkeypatch.setattr(main_module, "select_server_by_id", lambda conn, sid: base) + monkeypatch.setattr(main_module, "validate_hostname_unique", lambda conn, hostname, exclude_id=None: False) + + resp = client.put("/servers/1", json={"hostname": "dup"}) + assert resp.status_code == 409 + + +def test_update_server_integrity_error(api_client, monkeypatch): + client, _, _, _, main_module = api_client + now = _now_iso() + base = {"id": 1, "hostname": "srv-1", "ip_address": "10.0.0.1", "datacenter": "dc", "state": "active", "created_at": now, "updated_at": now} + + monkeypatch.setattr(main_module, "select_server_by_id", lambda conn, sid: base) + + def boom(*args, **kwargs): + raise IntegrityError("duplicate", None, None) + + monkeypatch.setattr(main_module, "update_server_by_id", boom) + + resp = client.put("/servers/1", json={"hostname": "dup"}) + assert resp.status_code == 409 + + +def test_delete_server_success(api_client, monkeypatch): + client, _, _, _, main_module = api_client + monkeypatch.setattr(main_module, "delete_server_by_id", lambda conn, sid: True) + resp = client.delete("/servers/1") + assert resp.status_code == 200 + assert resp.json() == {"detail": "Server deleted"} + + +def test_delete_server_not_found(api_client, monkeypatch): + client, _, _, _, main_module = api_client + monkeypatch.setattr(main_module, "delete_server_by_id", lambda conn, sid: False) + resp = client.delete("/servers/999") + assert resp.status_code == 404 + + +def test_get_db_connection_fails_when_pool_none(monkeypatch): + import app.main as main_module + importlib.reload(main_module) + main_module.connection_pool = None + client = TestClient(main_module.app) + resp = client.get("/servers") + assert resp.status_code == 500 + + +def test_request_validation_rejects_bad_ip(api_client): + client, _, _, _, _ = api_client + payload = { + "hostname": "srv-1", + "ip_address": "bad-ip", + "datacenter": "us-east", + "state": "active", + } + resp = client.post("/servers", json=payload) + assert resp.status_code == 422 + diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..a0c1903 --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,83 @@ +import pytest +from pydantic import ValidationError + +from app.models.schemas import ( + ServerBase, + ServerCreate, + ServerResponse, + ServerUpdate, +) + + +def test_server_base_accepts_valid_data(): + data = { + "hostname": "web-1", + "ip_address": "10.0.0.1", + "datacenter": "us-east", + "state": "active", + } + model = ServerBase(**data) + assert model.hostname == data["hostname"] + assert model.ip_address == data["ip_address"] + assert model.state == data["state"] + + +@pytest.mark.parametrize("state", ["bad", "", "retiring"]) +def test_server_base_rejects_invalid_state(state): + with pytest.raises(ValidationError): + ServerBase( + hostname="web-1", + ip_address="10.0.0.1", + datacenter="us-east", + state=state, + ) + + +@pytest.mark.parametrize("ip", ["999.0.0.1", "not-an-ip", "1234", ""]) +def test_server_base_rejects_invalid_ip(ip): + with pytest.raises(ValidationError): + ServerBase( + hostname="web-1", + ip_address=ip, + datacenter="us-east", + state="active", + ) + + +def test_server_create_inherits_validation(): + with pytest.raises(ValidationError): + ServerCreate( + hostname="web-1", + ip_address="invalid", + datacenter="us-east", + state="active", + ) + + +def test_server_update_allows_partial_and_validates_present_fields(): + # Partial update: no validation triggered for missing fields + partial = ServerUpdate(state=None) + assert partial.state is None + + # Provided fields still validate + with pytest.raises(ValidationError): + ServerUpdate(ip_address="not-an-ip") + + with pytest.raises(ValidationError): + ServerUpdate(state="invalid-state") + + +def test_server_response_accepts_full_payload(): + data = { + "hostname": "web-1", + "ip_address": "10.0.0.1", + "datacenter": "us-east", + "state": "offline", + "id": 123, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + } + model = ServerResponse(**data) + assert model.id == 123 + assert model.state == "offline" + diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..a62acec --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,64 @@ +import types +from typing import Any, Dict + +import pytest + +from app.validators import ( + ALLOWED_STATES, + validate_hostname_unique, + validate_ip_address, + validate_state, +) + + +class DummyConnection: + """Lightweight stand-in for a psycopg2 connection object.""" + + pass + + +def test_validate_ip_address_accepts_ipv4_and_ipv6(): + assert validate_ip_address("192.168.0.1") is True + assert validate_ip_address("2001:0db8:85a3:0000:0000:8a2e:0370:7334") is True + + +@pytest.mark.parametrize("ip", ["999.0.0.1", "not-an-ip", "", "1234"]) +def test_validate_ip_address_rejects_invalid(ip: str): + assert validate_ip_address(ip) is False + + +def test_validate_state_allows_expected_values(): + for state in ALLOWED_STATES: + assert validate_state(state) is True + + +def test_validate_state_rejects_unknown_value(): + assert validate_state("decommissioned") is False + + +def test_validate_hostname_unique_returns_true_when_none(monkeypatch): + def fake_select(conn: Any, hostname: str) -> None: + return None + + monkeypatch.setattr("app.validators.select_server_by_hostname", fake_select) + + assert validate_hostname_unique(DummyConnection(), "new-host") is True + + +def test_validate_hostname_unique_returns_false_for_duplicate(monkeypatch): + def fake_select(conn: Any, hostname: str) -> Dict[str, Any]: + return {"id": 1, "hostname": hostname} + + monkeypatch.setattr("app.validators.select_server_by_hostname", fake_select) + + assert validate_hostname_unique(DummyConnection(), "existing-host") is False + + +def test_validate_hostname_unique_allows_same_id_on_update(monkeypatch): + def fake_select(conn: Any, hostname: str) -> Dict[str, Any]: + return {"id": 5, "hostname": hostname} + + monkeypatch.setattr("app.validators.select_server_by_hostname", fake_select) + + assert validate_hostname_unique(DummyConnection(), "existing-host", exclude_id=5) is True +