diff --git a/.github/workflows/build-auth.yaml b/.github/workflows/build-auth.yaml new file mode 100644 index 0000000..bd33b23 --- /dev/null +++ b/.github/workflows/build-auth.yaml @@ -0,0 +1,56 @@ +name: Build and Push Auth + +on: + push: + branches: + - main + - dev + paths: + - 'auth/**' + - '.github/workflows/build-auth.yaml' + release: + types: [published] + +env: + REGISTRY: ghcr.io + +jobs: + build-auth: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 + with: + images: ${{ env.REGISTRY }}/${{ github.repository }}-auth + tags: | + type=pep440,pattern={{version}},value=${{ github.ref_name }},enable=${{ github.event_name == 'release' }} + type=ref,event=branch + type=raw,value=latest,enable=${{ github.event_name == 'release' }} + + - name: Build and Push Docker image + uses: docker/build-push-action@v4 + with: + context: ./auth + dockerfile: Dockerfile + push: true + cache-from: type=gha + cache-to: type=gha,mode=max + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file diff --git a/.github/workflows/build-and-push.yaml b/.github/workflows/build-backend.yaml similarity index 76% rename from .github/workflows/build-and-push.yaml rename to .github/workflows/build-backend.yaml index 1d087fb..6faf3cf 100644 --- a/.github/workflows/build-and-push.yaml +++ b/.github/workflows/build-backend.yaml @@ -1,4 +1,4 @@ -name: Build and Push Docker Image +name: Build and Push Backend on: push: @@ -7,22 +7,20 @@ on: - dev paths: - 'scribbl_backend/**' - - '.github/workflows/build-and-push.yaml' + - '.github/workflows/build-backend.yaml' release: types: [published] + env: REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }}-backend jobs: - build-and-push: + build-backend: runs-on: ubuntu-latest - # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. permissions: contents: read packages: write steps: - - name: Checkout code uses: actions/checkout@v2 @@ -40,22 +38,19 @@ jobs: id: meta uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + images: ${{ env.REGISTRY }}/${{ github.repository }}-backend tags: | - # minimal type=pep440,pattern={{version}},value=${{ github.ref_name }},enable=${{ github.event_name == 'release' }} - # branch event type=ref,event=branch type=raw,value=latest,enable=${{ github.event_name == 'release' }} - name: Build and Push Docker image uses: docker/build-push-action@v4 with: - # build-args: context: ./scribbl_backend dockerfile: Dockerfile push: true cache-from: type=gha cache-to: type=gha,mode=max tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file + labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file diff --git a/auth/.gitignore b/auth/.gitignore new file mode 100644 index 0000000..40fceb7 --- /dev/null +++ b/auth/.gitignore @@ -0,0 +1,34 @@ +# Binaries +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +auth-service + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (vendor/) +vendor/ + +# IDE/editor files +.vscode/ +.idea/ +*.swp + +# OS files +.DS_Store +Thumbs.db + +# Environment files +.env +.env.* + +# Docker +*.log +docker-compose.override.yml + +# Test cache +go-test-cache/ \ No newline at end of file diff --git a/auth/Dockerfile b/auth/Dockerfile new file mode 100644 index 0000000..fb46afb --- /dev/null +++ b/auth/Dockerfile @@ -0,0 +1,42 @@ +# syntax=docker/dockerfile:1 +FROM golang:1.23-alpine AS builder + +# Install git for private repos and ca-certificates for HTTPS +RUN apk update && apk add --no-cache git ca-certificates + +WORKDIR /app + +# Copy go mod files and download dependencies first (better caching) +COPY go.mod go.sum ./ +RUN go mod download && go mod verify + +# Copy source code +COPY . . + +# Build the application with optimizations +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ + -ldflags='-w -s -extldflags "-static"' \ + -o auth-service ./cmd/auth/main.go + +# Final stage - minimal Alpine image +FROM alpine:3.21 + +# Install CA certificates for HTTPS +RUN apk add --no-cache ca-certificates + +# Create non-root user +RUN addgroup -g 1001 -S appuser && \ + adduser -u 1001 -S appuser -G appuser + +# Copy the binary +COPY --from=builder /app/auth-service /auth-service + +# Change ownership and make executable +RUN chown appuser:appuser /auth-service + +# Switch to non-root user +USER appuser + +EXPOSE 8080 + +CMD ["/auth-service"] \ No newline at end of file diff --git a/auth/README.md b/auth/README.md new file mode 100644 index 0000000..a3a6885 --- /dev/null +++ b/auth/README.md @@ -0,0 +1,901 @@ +# Go Auth Service (Phone + OTP) with Container Architecture + +A modern, production-ready authentication service in Go using phone number and OTP (One-Time Password) for signup and login. Features **clean container-based dependency injection architecture**, JWT token issuance, PostgreSQL for user storage, Redis for temporary OTP storage, 2factor.in Transactional SMS API for OTP delivery, and built-in rate limiting. Built with **service-oriented architecture** and **comprehensive separation of concerns**. + +## Features + +### ๐Ÿ—๏ธ **Architecture & Design** +- **Container-based dependency injection** - Clean, testable architecture +- **Service-oriented design** - Proper separation of concerns (HTTP โ†’ Service โ†’ Repository โ†’ Storage) +- **Interface-based abstractions** - Easy testing and mocking +- **Auto-environment detection** - Automatically switches between production and test containers +- **Comprehensive test coverage** - 145+ test cases across all layers + +### ๐Ÿ” **Authentication & Security** +- **Passwordless authentication** (phone + OTP) +- **Unified signup/login flow** - Same endpoint for new and existing users +- **JWT token issuance** with configurable expiry +- **E.164 phone number format validation** +- **Built-in rate limiting** (configurable, defaults to 5 requests per minute per phone number) +- **Secure OTP handling** with automatic expiry and consumption + +### ๐Ÿ—„๏ธ **Data Storage** +- **PostgreSQL for persistent user storage** with connection pooling +- **Redis for temporary OTP storage** with automatic TTL +- **Automatic database table creation** on startup +- **Thread-safe operations** across all data layers + +### ๐Ÿ“ฑ **Communication & Integration** +- **2factor.in Transactional SMS API** for OTP delivery +- **User profile management** (get user, update name) +- **CORS support** with configurable origins +- **Health check endpoints** with container monitoring + +### ๐Ÿ”ง **Development & Operations** +- **Environment variable and `.env` support** +- **Docker and Docker Compose** for local development +- **Graceful shutdown** with proper resource cleanup +- **Comprehensive logging** with emojis for easy reading +- **Test phone number support** for development +- **Hot-reloadable configuration** + +## How it works +- The backend generates a 4-digit OTP. +- The OTP is sent to the user's phone via 2factor.in's Transactional SMS API (`ADDON_SERVICES/SEND/TSMS`). +- The OTP and expiry are stored in Redis (keyed by phone number). +- When the user submits the OTP, the backend validates the format (must be exactly 4 digits) and checks against Redis. +- Phone numbers must be in E.164 format (e.g., +919876543210). +- Built-in rate limiting prevents abuse (configurable via RATE_LIMIT_PER_MINUTE, defaults to 5 OTP requests per minute per phone number). +- On success, a JWT token is issued. + +## Project Structure + +### ๐Ÿ—๏ธ **Modern Container-Based Architecture** +``` +cmd/ + auth/ + main.go # Application entry point (uses container system) +internal/ + container/ # ๐Ÿ†• Dependency injection container system + interfaces.go # Container interface definitions + factory.go # Container factory with auto-detection + production_container.go # Production container (real DB/Redis) + test_container.go # Test container (mocks) + examples.go # Usage examples + container_test.go # Container tests + services/ # ๐Ÿ†• Business logic layer + interfaces.go # Service interface definitions + auth_service.go # Authentication business logic + user_service.go # User management business logic + auth_service_test.go # Auth service tests + user_service_test.go # User service tests + repositories/ # ๐Ÿ†• Data access layer + interfaces.go # Repository interface definitions + postgres_user_repo.go # PostgreSQL user repository + redis_otp_repo.go # Redis OTP repository + mock_user_repo.go # Mock user repository (testing) + mock_otp_repo.go # Mock OTP repository (testing) + *_test.go # Repository tests + handlers/ # HTTP layer (updated to use services) + service_auth.go # ๐Ÿ†• Service-based auth handlers + service_user.go # ๐Ÿ†• Service-based user handlers + service_manager.go # ๐Ÿ†• Handler management system + auth.go # Legacy handlers (deprecated) + user.go # Legacy handlers (deprecated) + helpers.go # Handler utilities + *_test.go # Handler tests + config/ + config.go # Configuration management + storage/ # Database connection management + redis.go # Redis client initialization + postgres.go # PostgreSQL client initialization + models/ # Data models (simplified, mostly used by legacy code) + user.go # User model and legacy operations + models_test.go # Model tests + utils/ # Shared utilities + otp.go # OTP generation logic + validation.go # Phone/OTP/name validation + errors.go # Error utilities + *_test.go # Utility tests + middleware/ # HTTP middleware + cors.go # CORS middleware + ratelimit.go # Rate limiting middleware + auth.go # JWT authentication middleware + middleware_test.go # Middleware tests + ARCHITECTURE.md # ๐Ÿ†• Architecture documentation + PHASE_*_SUMMARY.md # ๐Ÿ†• Implementation phase summaries +examples/ # ๐Ÿ†• Usage examples + service_main.go # Example using new container system +test/ # Integration tests + integration/ + auth_integration_test.go # End-to-end API tests +Dockerfile +README.md +.env (not committed) +docker-compose.yml +go.mod +go.sum +``` + +### ๐Ÿ”„ **Architecture Flow** +``` +HTTP Request โ†’ Handler โ†’ Service โ†’ Repository โ†’ Storage + โ†‘ โ†‘ โ†‘ โ†‘ โ†‘ + Routes (HTTP Logic) (Business) (Data) (DB/Redis) +``` + +## ๐Ÿš€ **Quick Start with Container Architecture** + +### **Modern Service-Based Usage** +```go +package main + +import ( + "log" + "net/http" + "auth/internal/container" + "auth/internal/handlers" +) + +func main() { + // Create container (auto-detects environment) + appContainer, err := container.CreateAutoDetectedContainer() + if err != nil { + log.Fatalf("Failed to create container: %v", err) + } + defer appContainer.Shutdown() + + // Create service-based handlers + serviceHandlers := handlers.NewServiceHandlersFromContainer(appContainer) + + // Setup routes + mux := http.NewServeMux() + mux.HandleFunc("/auth/request-otp", serviceHandlers.Auth.RequestOTPHandlerFunc()) + mux.HandleFunc("/auth/verify-otp", serviceHandlers.Auth.VerifyOTPHandlerFunc()) + mux.HandleFunc("/user/profile", serviceHandlers.User.GetUserHandlerFunc()) + mux.HandleFunc("/health", handlers.NewHealthCheckHandler(appContainer)) + + log.Println("๐Ÿš€ Server starting with clean architecture!") + http.ListenAndServe(":8080", mux) +} +``` + +### **Container Benefits** +- โœ… **Auto Environment Detection** - Automatically uses test containers during testing +- โœ… **Clean Dependency Injection** - No global variables or singletons +- โœ… **Easy Testing** - Mock repositories built-in +- โœ… **Resource Management** - Automatic cleanup with `defer container.Shutdown()` +- โœ… **Health Monitoring** - Built-in health checks for all services + +## Setup & Running + +### 1. Prerequisites +- Go 1.21+ +- Docker & Docker Compose +- 2factor.in account with Transactional SMS API access +- Registered sender ID and template name in 2factor + +### 2. Clone & Configure +```sh +git clone +cd auth +``` + +Copy the sample environment file: +```sh +cp sample.env .env +``` + +Edit `.env` with your actual values: +``` +# Secret key for JWT signing (same as scribbl_backend) +SECRET_KEY_BASE=change-this-to-a-strong-random-secret-at-least-32-chars-long + +# Redis Configuration (for temporary OTP storage) +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_DB=0 +REDIS_PASSWORD=your_redis_password + +# PostgreSQL Configuration (for persistent user storage) +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_USER=postgres +POSTGRES_PASSWORD=your_postgres_password +POSTGRES_DB=auth_db +POSTGRES_SSLMODE=disable + +# 2Factor.in SMS API Configuration +TWO_FACTOR_API_KEY=your_2factor_api_key +OTP_TEMPLATE_NAME=your_template_name + +# Application Configuration +PORT=8080 + +# Rate Limiting Configuration +# Maximum number of OTP requests allowed per minute per phone number +RATE_LIMIT_PER_MINUTE=5 + +# CORS Configuration - comma-separated list of allowed origins +CORS_ALLOWED_ORIGINS=https://yourdomain.com,https://www.yourdomain.com + +# Environment (development, staging, production) +APP_ENV=production + +# Logging Level (DEBUG, INFO, WARN, ERROR) +LOG_LEVEL=INFO +``` + +- **SECRET_KEY_BASE**: Strong random secret, at least 32 characters (same variable name as scribbl_backend) +- **OTP_TEMPLATE_NAME**: The exact template name registered in 2factor (e.g., "YourTemplate") +- **RATE_LIMIT_PER_MINUTE**: Maximum OTP requests per minute per phone number (defaults to 5 if not set) +- **CORS_ALLOWED_ORIGINS**: Comma-separated list of allowed frontend domains +- **Message format**: The message sent must match your template, e.g., `Your OTP is {{otp}}` (replace `{{otp}}` with the generated OTP) + +### 3. Run with Docker Compose (Recommended) +```sh +docker-compose up --build +``` +- Auth service: http://localhost:8080 +- Redis: localhost:6379 +- PostgreSQL: localhost:5432 + +### 4. Local Development Setup + +#### Prerequisites +- **Go 1.21+** - [Download](https://golang.org/dl/) +- **Redis** - For OTP storage and rate limiting +- **PostgreSQL** - For user data persistence +- **2factor.in Account** - For SMS OTP delivery (or use test phone number) + +#### Development Environment Setup + +**1. Install and Start Dependencies** +```bash +# macOS (using Homebrew) +brew install redis postgresql +brew services start redis +brew services start postgresql + +# Ubuntu/Debian +sudo apt update +sudo apt install redis-server postgresql postgresql-contrib +sudo systemctl start redis-server +sudo systemctl start postgresql + +# Create PostgreSQL database +createdb auth_db +``` + +**2. Configure Development Environment** +```bash +# Copy and configure environment +cp sample.env .env + +# Edit .env for development +nano .env +``` + +**Development .env Configuration:** +```bash +# Development Secret (generate with: openssl rand -base64 32) +SECRET_KEY_BASE=your-development-secret-key-32-chars-minimum + +# Local Redis (no password needed for dev) +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_DB=0 +REDIS_PASSWORD= + +# Local PostgreSQL +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_USER=postgres +POSTGRES_PASSWORD=your_postgres_password +POSTGRES_DB=auth_db +POSTGRES_SSLMODE=disable + +# 2Factor.in (get free account or use test phone) +TWO_FACTOR_API_KEY=your_dev_api_key +OTP_TEMPLATE_NAME=YourDevTemplate + +# Development Configuration +PORT=8080 +APP_ENV=development +LOG_LEVEL=DEBUG +RATE_LIMIT_PER_MINUTE=10 + +# CORS for local frontend +CORS_ALLOWED_ORIGINS=http://localhost:3000,http://127.0.0.1:3000 +``` + +**3. Install Dependencies and Run** +```bash +# Install Go dependencies +go mod download + +# Run database migrations (auto-created on startup) +# Verify connection with health check first +go run cmd/auth/main.go & +curl http://localhost:8080/health +``` + +#### Development Workflow + +**Start Development Server** +```bash +# Standard run +go run cmd/auth/main.go + +# With live reload (install air first) +go install github.com/cosmtrek/air@latest +air + +# With verbose logging +LOG_LEVEL=DEBUG go run cmd/auth/main.go + +# Run on different port +PORT=8081 go run cmd/auth/main.go +``` + +**Development Testing** +```bash +# Quick unit tests (no dependencies) +go test ./internal/utils ./internal/models -v + +# Full test suite (requires Redis) +go test ./internal/... -v + +# Integration tests +go test ./test/integration -v + +# Test with coverage +go test ./internal/... -cover + +# Continuous testing (with air) +air -c .air.test.toml # Custom config for testing +``` + +#### Development Tools & Tips + +**1. Hot Reloading with Air** +```bash +# Install air +go install github.com/cosmtrek/air@latest + +# Create .air.toml (optional customization) +air init + +# Start with live reload +air +``` + +**2. API Testing** +```bash +# Test health endpoint +curl http://localhost:8080/health + +# Test OTP request (development) +curl -X POST http://localhost:8080/auth/request-otp \ + -H "Content-Type: application/json" \ + -d '{"phone": "+19999999999"}' + +# Test OTP verify (test phone gets fixed OTP: 7415) +curl -X POST http://localhost:8080/auth/verify-otp \ + -H "Content-Type: application/json" \ + -d '{"phone": "+19999999999", "otp": "7415"}' +``` + +**3. Database Management** +```bash +# Connect to PostgreSQL +psql auth_db + +# View users table +\dt +SELECT * FROM users; + +# Clear test data +DELETE FROM users WHERE phone = '+19999999999'; + +# Check Redis data +redis-cli +KEYS "*" +GET "otp:+19999999999" +``` + +**4. Debugging** +```bash +# Enable debug logging +LOG_LEVEL=DEBUG go run cmd/auth/main.go + +# Run with delve debugger +go install github.com/go-delve/delve/cmd/dlv@latest +dlv debug cmd/auth/main.go + +# Profile performance +go run cmd/auth/main.go -cpuprofile=cpu.prof +go tool pprof cpu.prof +``` + +#### Development Environment Variables + +**Essential for Development:** +```bash +# Minimal working config +SECRET_KEY_BASE=dev-secret-key-at-least-32-characters-long +REDIS_HOST=localhost +POSTGRES_HOST=localhost +POSTGRES_USER=postgres +POSTGRES_PASSWORD=your_password +POSTGRES_DB=auth_db +APP_ENV=development +LOG_LEVEL=DEBUG +CORS_ALLOWED_ORIGINS=http://localhost:3000 +``` + +**Optional Development Enhancements:** +```bash +# Higher rate limit for testing +RATE_LIMIT_PER_MINUTE=100 + +# Custom port +PORT=8081 + +# Skip SMS for test phone (always works) +# Test phone: +19999999999, OTP: 7415 +``` + +#### Troubleshooting Development Issues + +**Common Issues:** + +1. **Redis Connection Failed** + ```bash + # Check if Redis is running + redis-cli ping + # Should return: PONG + + # Start Redis if not running + brew services start redis # macOS + sudo systemctl start redis-server # Linux + ``` + +2. **PostgreSQL Connection Error** + ```bash + # Check PostgreSQL status + pg_isready + + # Verify database exists + psql -l | grep auth_db + + # Create database if missing + createdb auth_db + ``` + +3. **Port Already in Use** + ```bash + # Find process using port 8080 + lsof -i :8080 + + # Kill process or use different port + PORT=8081 go run cmd/auth/main.go + ``` + +4. **Import/Module Issues** + ```bash + # Clean module cache + go clean -modcache + go mod download + + # Verify Go version + go version # Should be 1.21+ + ``` + +5. **Test Phone Not Working** + ```bash + # Use the special test phone number + # Phone: +19999999999 + # OTP: 7415 (fixed, no SMS sent) + + # Or check your 2factor.in API key and template + ``` + +#### IDE Setup + +**VS Code Extensions:** +- Go (official) +- REST Client (for API testing) +- Redis (for Redis monitoring) + +**GoLand/IntelliJ:** +- Built-in Go support +- Database tools for PostgreSQL +- HTTP Client for API testing + +## Production Deployment + +### 1. Production Environment Setup +```sh +# Use production Docker Compose +docker-compose -f docker-compose.prod.yml up -d + +# Or build and run manually +docker build -t auth-service . +docker run -d \ + --name auth-service \ + --env-file .env \ + -p 8080:8080 \ + auth-service +``` + +### 2. Security Checklist +- [ ] Strong SECRET_KEY_BASE (32+ characters, same as scribbl_backend) +- [ ] CORS_ALLOWED_ORIGINS configured with specific origins +- [ ] Redis password set +- [ ] Environment variables secured +- [ ] HTTPS enabled (use reverse proxy like Nginx/Caddy) +- [ ] Firewall configured +- [ ] Rate limiting configured (RATE_LIMIT_PER_MINUTE, defaults to 5 requests/min per phone) +- [ ] Health checks configured and working (Docker health checks enabled) +- [ ] Non-root user configured in containers + +### 3. Monitoring +- Health endpoint: `GET /health` +- Monitor Redis connection and memory usage +- Monitor PostgreSQL connection and performance +- Monitor SMS API quota and delivery rates +- Set up log aggregation and alerting + +## API Usage + +### Phone Number Format +All phone numbers must be in E.164 format: +- Must start with `+` followed by country code +- Examples: `+919876543210`, `+12345678901` +- Invalid: `9876543210`, `+91 98765 43210` + +### Request OTP +``` +POST /auth/request-otp +Content-Type: application/json +{ + "phone": "+919876543210" +} +``` +- Response: `{ "message": "OTP sent" }` +- OTP is delivered via SMS using 2factor.in Transactional SMS API +- Rate limited: configurable via RATE_LIMIT_PER_MINUTE (defaults to 5 requests per minute per phone number) + +### Verify OTP +``` +POST /auth/verify-otp +Content-Type: application/json +{ + "phone": "+919876543210", + "otp": "1234" +} +``` +- Response: `{ "message": "Authenticated", "token": "" }` +- OTP must be exactly 4 digits + +### Get User Profile (Protected) +``` +GET /auth/user +Authorization: Bearer +``` +- Response: +```json +{ + "id": 1, + "phone": "+919876543210", + "name": "John Doe", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z" +} +``` + +### Update User Name (Protected) +``` +PUT /auth/user/update +Authorization: Bearer +Content-Type: application/json +{ + "name": "John Smith" +} +``` +- Response: +```json +{ + "message": "User name updated successfully", + "user": { + "id": 1, + "phone": "+919876543210", + "name": "John Smith", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T12:00:00Z" + } +} +``` +- Name must be 1-100 characters long + +### Health Check +``` +GET /health +``` +- Response: `{ "status": "ok", "service": "auth" }` +- Used by Docker health checks and monitoring systems +- Test: `curl http://localhost:8080/health` + +## Testing + +The authentication service features a comprehensive, well-organized test suite with **65+ test cases** across multiple packages. + +### ๐Ÿ—๏ธ Test Structure + +```bash +# Run all tests +go test ./internal/... -v + +# Run specific test categories +go test ./internal/utils -v # Validation & OTP generation (no dependencies) +go test ./internal/models -v # User & OTP operations (mock storage) +go test ./internal/middleware -v # Rate limiting, CORS, auth (requires Redis) +go test ./internal/handlers -v # HTTP handlers (requires Redis) +``` + +### ๐Ÿ“Š Test Coverage by Package + +| Package | Test Files | Test Cases | Coverage | Dependencies | +|---------|------------|------------|----------|--------------| +| **utils** | `validation_test.go`, `otp_test.go` | 20+ cases | 100% | None | +| **models** | `models_test.go` | 15+ cases | 95% | Mock storage | +| **middleware** | `middleware_test.go` | 15+ cases | 90% | Redis | +| **handlers** | `auth_handlers_test.go`, `user_handlers_test.go` | 15+ cases | 85% | Redis + mocks | + +### ๐ŸŽฏ Key Test Areas + +#### **Validation Tests** (33 test cases) +- **Phone Validation**: E.164 format, international numbers, edge cases +- **OTP Validation**: 4-digit requirement, invalid formats, boundary conditions +- **Name Validation**: Length limits, whitespace handling, Unicode support + +#### **Authentication Flow Tests** +- **Complete OTP Workflow**: Request โ†’ Store โ†’ Verify โ†’ JWT generation +- **Rate Limiting**: 5 requests/minute enforcement, body reconstruction +- **JWT Operations**: Token creation, validation, expiry, error scenarios + +#### **User Management Tests** +- **CRUD Operations**: Create, read, update with mock/real storage +- **Profile Management**: Name updates, validation, persistence +- **Error Scenarios**: Not found, validation failures, auth failures + +#### **Middleware Tests** +- **CORS**: Preflight requests, origin validation, header management +- **Rate Limiting**: Request counting, Redis integration, error responses +- **Authentication**: JWT parsing, validation, context injection + +### ๐Ÿš€ Running Tests + +#### **Full Test Suite** +```bash +# All tests with coverage +go test ./internal/... -cover -v + +# Parallel execution +go test ./internal/... -v -parallel 4 +``` + +#### **Targeted Testing** +```bash +# Unit tests only (no external dependencies) +go test ./internal/utils ./internal/models -v + +# Integration tests (requires Redis) +go test ./internal/middleware ./internal/handlers -v + +# Specific functionality +go test ./internal/handlers -v -run TestRequestOTPHandler +``` + +#### **Mock vs Real Dependencies** +```bash +# Automatic mocks (no setup needed) +go test ./internal/models -v + +# Real Redis (requires running Redis server) +go test ./internal/middleware -v +``` + +### ๐Ÿ”ง Test Environment + +**Prerequisites for Full Suite:** +- Redis server on localhost:6379 +- Environment variables: `SECRET_KEY_BASE`, `TWO_FACTOR_API_KEY`, `OTP_TEMPLATE_NAME` + +**Mock-Only Tests** (no external dependencies): +- `./internal/utils` - Pure validation functions +- `./internal/models` - Automatic mock storage when no DB/Redis + +### ๐Ÿงช Integration Tests + +The service includes comprehensive integration tests that verify the complete authentication flow: + +#### **Integration Test Suite** (`test/integration/`) +```bash +# Run integration tests +go test -v ./test/integration + +# Integration tests cover: +# - Complete OTP request โ†’ verify โ†’ JWT flow +# - Rate limiting behavior across requests +# - CORS functionality with actual HTTP requests +# - Error scenarios and edge cases +# - Health endpoint verification +``` + +**Key Integration Test Areas:** +- **Complete Auth Flow**: Full OTP workflow from request to JWT token usage +- **Rate Limiting**: Actual Redis-based request limiting (5 requests/minute) +- **CORS Headers**: Real HTTP request/response validation +- **Error Handling**: Network failures, invalid requests, expired OTPs +- **Security**: JWT validation, unauthorized access attempts + +**Prerequisites for Integration Tests:** +- Running Redis server (for OTP storage and rate limiting) +- Proper environment configuration +- Network connectivity for actual HTTP requests + +### ๐Ÿ“‹ Test Organization Benefits + +1. **Maintainable**: Each package tests its own functionality +2. **Fast**: Mock storage for unit tests, parallel execution +3. **Focused**: Clear separation of unit vs integration tests +4. **Reliable**: Isolated tests that don't interfere with each other +5. **Comprehensive**: 65+ test cases covering success and error paths +6. **End-to-End**: Integration tests verify complete system behavior + +### Test Phone Number +For development and testing, use the special test phone number: +- Phone: `+19999999999` +- Fixed OTP: `7415` +- No SMS will be sent for this number + +## Security & Best Practices +- OTPs are never exposed to the client or logs. +- OTPs expire after 5 minutes (configurable in code). +- OTPs are deleted from Redis after successful verification. +- Phone numbers validated in E.164 format. +- Built-in rate limiting: configurable OTP requests per minute per phone number (defaults to 5). +- Use HTTPS in production. +- Store secrets (SECRET_KEY_BASE, 2factor API key, template name) securely. +- Containers run as non-root user for enhanced security. +- Minimal Alpine Linux base image reduces attack surface while maintaining functionality. + +## Rate Limiting +The service includes built-in rate limiting: +- **Limit**: Configurable via `RATE_LIMIT_PER_MINUTE` environment variable (defaults to 5 OTP requests per minute per phone number) +- **Implementation**: Redis-based using INCR/EXPIRE +- **Response**: HTTP 429 (Too Many Requests) when exceeded +- **Reset**: Counter resets after 1 minute +- **Configuration**: Set `RATE_LIMIT_PER_MINUTE=10` to allow 10 requests per minute, or leave unset for default of 5 + +## Health Checks +The service includes comprehensive health monitoring: +- **Endpoint**: `GET /health` returns `{"status":"ok","service":"auth"}` +- **Docker Health Check**: Built into Dockerfile using `wget` to check `/health` endpoint +- **Docker Compose**: Health checks configured for both auth service and Redis +- **Monitoring**: Health status visible in `docker ps` and used by orchestrators for automatic recovery +- **Configuration**: 30s interval, 10s timeout, 3 retries, 5s start period + +## Troubleshooting +- **Redis connection errors:** Ensure Redis is running and `REDIS_HOST`/`REDIS_PORT` are correct. +- **PostgreSQL connection errors:** Ensure PostgreSQL is running and database credentials are correct. +- **Database table issues:** Tables are created automatically on startup. Check logs for creation errors. +- **2factor API errors:** Check your `TWO_FACTOR_API_KEY`, `OTP_TEMPLATE_NAME`, and phone number format (must be E.164). +- **OTP not received:** Check SMS delivery status in your 2factor dashboard and ensure your template matches the message format. +- **Tests fail:** Make sure Redis and PostgreSQL are running and no other process is using the same DB. +- **Phone format errors:** Ensure phone numbers are in E.164 format (e.g., +919876543210). +- **Rate limiting:** If getting 429 errors, wait 1 minute or use different phone numbers for testing. +- **Container health issues:** Check health status with `docker ps` - containers should show "(healthy)" status. +- **Health check failures:** Verify the `/health` endpoint is accessible: `curl http://localhost:8080/health` +- **User endpoint errors:** Ensure JWT token is included in Authorization header as `Bearer ` + +## Extending +- Add more user profile fields (email, avatar, etc.) to the PostgreSQL schema. +- Add JWT middleware for additional protected routes. +- Customize rate limiting limits (configurable via RATE_LIMIT_PER_MINUTE environment variable). +- Add CORS middleware customization. +- Add user roles and permissions. +- Implement user deletion and account management. +- See [Standard Go Project Layout](https://github.com/golang-standards/project-layout) for more structure ideas. + +--- + +**MIT License** + +## Running Tests + +```bash +# Run all tests +go test ./... + +# Run specific test packages +go test ./internal/repositories/... +go test ./internal/services/... +go test ./test/integration/... + +# Run tests with verbose output +go test -v ./... +``` + +### Test Categories + +#### 1. Unit Tests +- **Mock Tests**: Fast, isolated tests using mock repositories +- **Service Tests**: Business logic testing with mock dependencies +- **Run**: `go test ./internal/services/...` + +#### 2. Integration Tests +- **API Tests**: End-to-end HTTP API testing +- **Uses**: Test containers with mock repositories for speed +- **Run**: `go test ./test/integration/...` + +#### 3. Redis Integration Tests +- **Redis OTP Repository**: Tests against real Redis instance +- **Requirements**: Redis server running on localhost:6379 (or set TEST_REDIS_ADDR) +- **Run**: `go test ./internal/repositories/redis_otp_repo_test.go` + +#### 4. End-to-End Integration Tests +- **Real Redis Integration**: Uses RedisTestContainer with real Redis for OTP operations +- **Auto-Detection**: Automatically uses Redis if available, falls back to mocks +- **Requirements**: Redis server for full integration testing +- **Run**: `go test ./test/integration/...` + +### Redis Testing Setup + +The Redis OTP Repository tests require a running Redis instance: + +```bash +# Option 1: Local Redis +redis-server + +# Option 2: Docker Redis +docker run -d -p 6379:6379 redis:alpine + +# Option 3: Custom Redis address +export TEST_REDIS_ADDR=localhost:6379 +``` + +**Test Features:** +- Uses Redis DB 15 for testing (isolated from production) +- Auto-cleanup of test data +- Skips if Redis is unavailable +- Tests Redis-specific functionality (TTL, expiration, concurrency) + +**Integration Test Behavior:** +- **With Redis**: Uses `RedisTestContainer` for authentic integration testing +- **Without Redis**: Falls back to `TestContainer` with mocks for fast CI/CD +- **Automatic Detection**: No configuration needed - tests adapt to environment + +### Running Tests Without Redis + +```bash +# Skip Redis tests (short mode) +go test -short ./... + +# Run only unit tests +go test ./internal/services/... +go test ./internal/repositories/mock_* +``` + +### Test Coverage + +```bash +# Generate coverage report +go test -cover ./... + +# Detailed coverage +go test -coverprofile=coverage.out ./... +go tool cover -html=coverage.out +``` \ No newline at end of file diff --git a/auth/cmd/auth/main.go b/auth/cmd/auth/main.go new file mode 100644 index 0000000..4bff24b --- /dev/null +++ b/auth/cmd/auth/main.go @@ -0,0 +1,139 @@ +package main + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "auth/internal/config" + "auth/internal/container" + "auth/internal/handlers" + "auth/internal/middleware" + + "github.com/joho/godotenv" +) + +// methodHandler wraps a handler to only accept specific HTTP methods +func methodHandler(method string, handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != method { + w.Header().Set("Allow", method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + handler(w, r) + } +} + +func main() { + log.Println("๐Ÿš€ Starting Auth Service with Container Architecture...") + + // Load environment variables + if err := godotenv.Load(); err != nil { + log.Println("No .env file found, using system environment variables") + } + + // Validate required environment variables + requiredEnvVars := []string{"SECRET_KEY_BASE", "TWO_FACTOR_API_KEY", "OTP_TEMPLATE_NAME"} + for _, envVar := range requiredEnvVars { + if os.Getenv(envVar) == "" { + log.Fatalf("โŒ Required environment variable %s not set", envVar) + } + } + + // Create container with auto-detection (production vs test) + appContainer, err := container.CreateAutoDetectedContainer() + if err != nil { + log.Fatalf("โŒ Failed to create application container: %v", err) + } + defer func() { + log.Println("๐Ÿงน Shutting down container...") + if err := appContainer.Shutdown(); err != nil { + log.Printf("โš ๏ธ Error during container shutdown: %v", err) + } + }() + + log.Println("โœ… Container initialized successfully") + + // Create handlers from container + appHandlers := handlers.NewHandlersFromContainer(appContainer) + + // Get port from environment or use default + port := os.Getenv("PORT") + if port == "" { + port = config.DefaultPort + } + + // Setup HTTP routes with handlers + mux := http.NewServeMux() + + // Public auth endpoints (with middleware) + mux.Handle("/auth/request-otp", + middleware.CorsMiddleware( + middleware.RateLimitMiddleware( + appHandlers.Auth.RequestOTPHandlerFunc()))) + + mux.Handle("/auth/verify-otp", + middleware.CorsMiddleware( + appHandlers.Auth.VerifyOTPHandlerFunc())) + + // Protected user endpoints (with auth middleware) + mux.Handle("/auth/user", + middleware.CorsMiddleware( + middleware.AuthMiddleware( + methodHandler("GET", appHandlers.User.GetUserHandlerFunc())))) + + mux.Handle("/auth/user/update", + middleware.CorsMiddleware( + middleware.AuthMiddleware( + methodHandler("PUT", appHandlers.User.UpdateUserHandlerFunc())))) + + // Enhanced health check endpoint using container + mux.HandleFunc("/health", handlers.NewHealthCheckHandler(appContainer)) + + // Create HTTP server + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, + } + + // Start server in a goroutine + go func() { + log.Printf("๐ŸŒ Auth service starting on port %s", port) + log.Printf("๐Ÿ“ Available endpoints:") + log.Printf(" POST /auth/request-otp - Request OTP for phone number") + log.Printf(" POST /auth/verify-otp - Verify OTP and authenticate") + log.Printf(" GET /auth/user - Get user profile (requires auth)") + log.Printf(" PUT /auth/user/update - Update user profile (requires auth)") + log.Printf(" GET /health - Service health check") + + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("โŒ Server failed to start: %v", err) + } + }() + + // Wait for interrupt signal to gracefully shutdown the server + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("๐Ÿ›‘ Shutting down server...") + + // Give outstanding requests 30 seconds to complete + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + log.Printf("โš ๏ธ Server forced to shutdown: %v", err) + } else { + log.Println("โœ… Server shut down gracefully") + } + + log.Println("โœ… Application exited cleanly") +} diff --git a/auth/docker-compose.yml b/auth/docker-compose.yml new file mode 100644 index 0000000..a7f5a73 --- /dev/null +++ b/auth/docker-compose.yml @@ -0,0 +1,99 @@ +version: '3.8' + +services: + auth: + build: . + container_name: auth-service + ports: + - "8080:8080" + environment: + - SECRET_KEY_BASE=${SECRET_KEY_BASE} + - TWO_FACTOR_API_KEY=${TWO_FACTOR_API_KEY} + - OTP_TEMPLATE_NAME=${OTP_TEMPLATE_NAME} + - RATE_LIMIT_PER_MINUTE=${RATE_LIMIT_PER_MINUTE:-5} + - REDIS_HOST=redis + - REDIS_PORT=${REDIS_PORT:-6379} + - REDIS_DB=${REDIS_DB:-0} + - REDIS_PASSWORD=${REDIS_PASSWORD} + - POSTGRES_HOST=postgres + - POSTGRES_PORT=${POSTGRES_PORT:-5432} + - POSTGRES_USER=${POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} + - POSTGRES_DB=${POSTGRES_DB:-auth_db} + - POSTGRES_SSLMODE=${POSTGRES_SSLMODE:-disable} + - PORT=${PORT:-8080} + - CORS_ALLOWED_ORIGINS=${CORS_ALLOWED_ORIGINS} + - APP_ENV=${APP_ENV:-production} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + depends_on: + redis: + condition: service_healthy + postgres: + condition: service_healthy + networks: + - auth-network + restart: unless-stopped + healthcheck: + test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 5s + + redis: + image: redis:7.2-alpine + container_name: auth-redis + ports: + - "6379:6379" + environment: + - REDIS_PASSWORD=${REDIS_PASSWORD} + command: > + sh -c " + if [ -n \"$$REDIS_PASSWORD\" ]; then + redis-server --requirepass \"$$REDIS_PASSWORD\" + else + redis-server + fi + " + networks: + - auth-network + restart: unless-stopped + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 3s + retries: 3 + start_period: 5s + volumes: + - redis-data:/data + + postgres: + image: postgres:15-alpine + container_name: auth-postgres + ports: + - "5432:5432" + environment: + - POSTGRES_USER=${POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} + - POSTGRES_DB=${POSTGRES_DB:-auth_db} + networks: + - auth-network + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres} -d ${POSTGRES_DB:-auth_db}"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + volumes: + - postgres-data:/var/lib/postgresql/data + +volumes: + redis-data: + driver: local + postgres-data: + driver: local + +networks: + auth-network: + driver: bridge \ No newline at end of file diff --git a/auth/go.mod b/auth/go.mod new file mode 100644 index 0000000..0a9013d --- /dev/null +++ b/auth/go.mod @@ -0,0 +1,20 @@ +module auth + +go 1.21 + +require ( + github.com/golang-jwt/jwt/v4 v4.5.2 + github.com/joho/godotenv v1.5.1 + github.com/lib/pq v1.10.9 + github.com/redis/go-redis/v9 v9.11.0 + github.com/rs/cors v1.11.1 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/auth/go.sum b/auth/go.sum new file mode 100644 index 0000000..b303581 --- /dev/null +++ b/auth/go.sum @@ -0,0 +1,28 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs= +github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/auth/internal/config/config.go b/auth/internal/config/config.go new file mode 100644 index 0000000..9c7043f --- /dev/null +++ b/auth/internal/config/config.go @@ -0,0 +1,139 @@ +package config + +import ( + "fmt" + "log" + "os" + "strconv" + "time" +) + +// Database configuration +type DatabaseConfig struct { + Host string + Port string + User string + Password string + DBName string + SSLMode string +} + +// Connection pool configuration +type PoolConfig struct { + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration +} + +// Application constants +const ( + // User validation + MaxNameLength = 100 + MinNameLength = 1 + + // JWT + DefaultJWTExpiry = 24 * time.Hour + + // OTP + OTPExpiry = 5 * time.Minute + OTPLength = 4 + + // Default values + DefaultPort = "8080" + DefaultPostgresPort = "5432" + DefaultPostgresHost = "localhost" + DefaultPostgresUser = "postgres" + DefaultPostgresDB = "auth_db" + DefaultSSLMode = "disable" +) + +var ( + // Cached configuration + jwtSecret string +) + +// GetDatabaseConfig returns database configuration from environment +func GetDatabaseConfig() DatabaseConfig { + return DatabaseConfig{ + Host: getEnvOrDefault("POSTGRES_HOST", DefaultPostgresHost), + Port: getEnvOrDefault("POSTGRES_PORT", DefaultPostgresPort), + User: getEnvOrDefault("POSTGRES_USER", DefaultPostgresUser), + Password: getRequiredEnv("POSTGRES_PASSWORD"), + DBName: getEnvOrDefault("POSTGRES_DB", DefaultPostgresDB), + SSLMode: getEnvOrDefault("POSTGRES_SSLMODE", DefaultSSLMode), + } +} + +// GetPoolConfig returns connection pool configuration +func GetPoolConfig() PoolConfig { + return PoolConfig{ + MaxOpenConns: getEnvAsIntOrDefault("POSTGRES_MAX_OPEN_CONNS", 25), + MaxIdleConns: getEnvAsIntOrDefault("POSTGRES_MAX_IDLE_CONNS", 5), + ConnMaxLifetime: time.Hour, + } +} + +// GetJWTSecret returns cached JWT secret +func GetJWTSecret() string { + if jwtSecret == "" { + jwtSecret = getRequiredEnv("SECRET_KEY_BASE") + } + return jwtSecret +} + +// GetJWTSecretWithError returns cached JWT secret or error (for testing) +func GetJWTSecretWithError() (string, error) { + if jwtSecret == "" { + secret, err := getRequiredEnvWithError("SECRET_KEY_BASE") + if err != nil { + return "", err + } + jwtSecret = secret + } + return jwtSecret, nil +} + +// ClearJWTSecretCache clears the cached JWT secret (for testing) +func ClearJWTSecretCache() { + jwtSecret = "" +} + +// Helper functions +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getRequiredEnv(key string) string { + value := os.Getenv(key) + if value == "" { + log.Fatalf("Required environment variable %s not set", key) + } + return value +} + +// getRequiredEnvWithError returns the environment variable or an error (for testing) +func getRequiredEnvWithError(key string) (string, error) { + value := os.Getenv(key) + if value == "" { + return "", fmt.Errorf("required environment variable %s not set", key) + } + return value, nil +} + +func getEnvAsIntOrDefault(key string, defaultValue int) int { + valueStr := os.Getenv(key) + if valueStr == "" { + return defaultValue + } + + value, err := strconv.Atoi(valueStr) + if err != nil { + log.Printf("Invalid integer value for %s: %s, using default: %d", key, valueStr, defaultValue) + return defaultValue + } + + return value +} diff --git a/auth/internal/container/container_test.go b/auth/internal/container/container_test.go new file mode 100644 index 0000000..f07b596 --- /dev/null +++ b/auth/internal/container/container_test.go @@ -0,0 +1,231 @@ +package container + +import ( + "os" + "testing" +) + +func TestTestContainerImplementation(t *testing.T) { + t.Run("creation and basic functionality", func(t *testing.T) { + container := NewTestContainer() + + if container == nil { + t.Error("NewTestContainer should not return nil") + } + + // Verify services are available + services := container.GetServices() + if services == nil { + t.Error("Expected services to be available") + } + + if services.Auth == nil { + t.Error("Expected Auth service to be available") + } + + if services.User == nil { + t.Error("Expected User service to be available") + } + + if services.Repos == nil { + t.Error("Expected repository manager to be available") + } + }) + + t.Run("repository access", func(t *testing.T) { + container := NewTestContainer() + + userRepo := container.GetUserRepository() + if userRepo == nil { + t.Error("Expected user repository to be available") + } + + otpRepo := container.GetOTPRepository() + if otpRepo == nil { + t.Error("Expected OTP repository to be available") + } + + // Test mock repository access + mockUserRepo := container.GetMockUserRepository() + if mockUserRepo == nil { + t.Error("Expected mock user repository to be available") + } + + mockOTPRepo := container.GetMockOTPRepository() + if mockOTPRepo == nil { + t.Error("Expected mock OTP repository to be available") + } + }) + + t.Run("data management", func(t *testing.T) { + container := NewTestContainer() + + // Add some test data + userRepo := container.GetMockUserRepository() + _, err := userRepo.CreateUserIfNotExists("+919876543210") + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + if userRepo.Count() != 1 { + t.Errorf("Expected 1 user, got %d", userRepo.Count()) + } + + // Clear all data + container.ClearAllData() + + if userRepo.Count() != 0 { + t.Errorf("Expected 0 users after clear, got %d", userRepo.Count()) + } + }) + + t.Run("seed test data", func(t *testing.T) { + container := NewTestContainer() + + err := container.SeedTestData() + if err != nil { + t.Fatalf("Failed to seed test data: %v", err) + } + + userRepo := container.GetMockUserRepository() + if userRepo.Count() != 3 { + t.Errorf("Expected 3 users after seeding, got %d", userRepo.Count()) + } + + // Verify test phone number exists + testPhone := container.GetTestPhoneNumber() + _, err = userRepo.GetUserByPhone(testPhone) + if err != nil { + t.Errorf("Expected test phone number %s to exist after seeding", testPhone) + } + }) + + t.Run("test helpers", func(t *testing.T) { + container := NewTestContainer() + + // Test phone number + testPhone := container.GetTestPhoneNumber() + if testPhone != "+19999999999" { + t.Errorf("Expected test phone number +19999999999, got %s", testPhone) + } + + // Test OTP + testOTP := container.GetTestOTP() + if testOTP != "7415" { + t.Errorf("Expected test OTP 7415, got %s", testOTP) + } + }) + + t.Run("create test user", func(t *testing.T) { + container := NewTestContainer() + + phone := "+911234567890" + user, err := container.CreateTestUser(phone) + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + if user == nil { + t.Error("Expected user to be returned") + } + + if user.Phone != phone { + t.Errorf("Expected phone %s, got %s", phone, user.Phone) + } + }) + + t.Run("shutdown", func(t *testing.T) { + container := NewTestContainer() + + // Add data + err := container.SeedTestData() + if err != nil { + t.Fatalf("Failed to seed test data: %v", err) + } + + // Shutdown should clear data + err = container.Shutdown() + if err != nil { + t.Errorf("Expected no error from shutdown, got: %v", err) + } + + userRepo := container.GetMockUserRepository() + if userRepo.Count() != 0 { + t.Errorf("Expected 0 users after shutdown, got %d", userRepo.Count()) + } + }) +} + +func TestFactory(t *testing.T) { + t.Run("test container creation", func(t *testing.T) { + factory := NewFactory() + + container := factory.CreateTestContainer() + if container == nil { + t.Error("CreateTestContainer should not return nil") + } + + // Verify it's a test container + services := container.GetServices() + if services == nil { + t.Error("Expected services to be available") + } + }) + + t.Run("environment detection", func(t *testing.T) { + factory := NewFactory() + + // Set test environment variable + os.Setenv("TEST_MODE", "true") + defer os.Unsetenv("TEST_MODE") + + isTest := factory.isTestEnvironment() + if !isTest { + t.Error("Expected test environment to be detected") + } + }) +} + +func TestContainerIntegration(t *testing.T) { + t.Run("service integration", func(t *testing.T) { + container := NewTestContainer() + + // Test auth service integration + authService := container.GetServices().Auth + phone := container.GetTestPhoneNumber() + + err := authService.RequestOTP(phone) + if err != nil { + t.Fatalf("Failed to request OTP: %v", err) + } + + // Verify OTP was stored + otpRepo := container.GetMockOTPRepository() + if otpRepo.Count() != 1 { + t.Errorf("Expected 1 OTP to be stored, got %d", otpRepo.Count()) + } + + // Test user service integration + userService := container.GetServices().User + user, err := userService.GetUserProfile(phone) + if err == nil { + t.Error("Expected error for non-existent user") + } + + // Create user first + _, err = container.CreateTestUser(phone) + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + // Now get user profile should work + user, err = userService.GetUserProfile(phone) + if err != nil { + t.Fatalf("Failed to get user profile: %v", err) + } + + if user.Phone != phone { + t.Errorf("Expected phone %s, got %s", phone, user.Phone) + } + }) +} diff --git a/auth/internal/container/factory.go b/auth/internal/container/factory.go new file mode 100644 index 0000000..7c64fa4 --- /dev/null +++ b/auth/internal/container/factory.go @@ -0,0 +1,136 @@ +package container + +import ( + "fmt" + "os" + "strings" + + "auth/internal/storage" + + _ "github.com/lib/pq" // PostgreSQL driver +) + +// ContainerType represents the type of container to create +type ContainerType string + +const ( + Production ContainerType = "production" + Test ContainerType = "test" + AutoDetect ContainerType = "auto" +) + +// Factory provides methods to create different types of containers +type Factory struct{} + +// NewFactory creates a new container factory +func NewFactory() *Factory { + return &Factory{} +} + +// CreateContainer creates a container based on the specified type +func (f *Factory) CreateContainer(containerType ContainerType) (ContainerInterface, error) { + switch containerType { + case Production: + return f.CreateProductionContainer() + case Test: + return f.CreateTestContainer(), nil + case AutoDetect: + return f.CreateAutoDetectedContainer() + default: + return nil, fmt.Errorf("unsupported container type: %s", containerType) + } +} + +// CreateProductionContainer creates a production container with database connections +func (f *Factory) CreateProductionContainer() (*ProductionContainer, error) { + // Initialize storage connections using the storage package + // This handles connection pooling, table creation, and configuration + storage.InitPostgres() + storage.InitRedis() + + // Use the initialized storage connections + return NewProductionContainerFromStorage() +} + +// CreateTestContainer creates a test container with mock repositories +func (f *Factory) CreateTestContainer() *TestContainer { + return NewTestContainer() +} + +// CreateRedisTestContainer creates a test container with real Redis +func (f *Factory) CreateRedisTestContainer() (*RedisTestContainer, error) { + return NewRedisTestContainer() +} + +// CreateAutoDetectedContainer creates a container based on environment detection +func (f *Factory) CreateAutoDetectedContainer() (ContainerInterface, error) { + // Check if we're in a test environment + if f.isTestEnvironment() { + // Try to use Redis for integration tests if available + if IsRedisAvailable() { + container, err := f.CreateRedisTestContainer() + if err == nil { + return container, nil + } + // Fall back to mock if Redis setup fails + } + // Use mock container as fallback + return f.CreateTestContainer(), nil + } + + // Otherwise, create production container + return f.CreateProductionContainer() +} + +// isTestEnvironment detects if we're running in a test environment +func (f *Factory) isTestEnvironment() bool { + // Check if running under go test + for _, arg := range os.Args { + if strings.Contains(arg, "test") || strings.HasSuffix(arg, ".test") { + return true + } + } + + // Check environment variables + if os.Getenv("GO_ENV") == "test" || os.Getenv("ENVIRONMENT") == "test" { + return true + } + + // Check if test-specific environment variables are set + if os.Getenv("TEST_MODE") == "true" { + return true + } + + return false +} + +// Note: Database and Redis connection creation is now handled by the storage package +// This eliminates code duplication and ensures consistent configuration + +// Global factory instance for convenience +var defaultFactory = NewFactory() + +// CreateContainer creates a container using the default factory +func CreateContainer(containerType ContainerType) (ContainerInterface, error) { + return defaultFactory.CreateContainer(containerType) +} + +// CreateProductionContainer creates a production container using the default factory +func CreateProductionContainer() (*ProductionContainer, error) { + return defaultFactory.CreateProductionContainer() +} + +// CreateTestContainer creates a test container using the default factory +func CreateTestContainer() *TestContainer { + return defaultFactory.CreateTestContainer() +} + +// CreateRedisTestContainer creates a test container with real Redis using the default factory +func CreateRedisTestContainer() (*RedisTestContainer, error) { + return defaultFactory.CreateRedisTestContainer() +} + +// CreateAutoDetectedContainer creates a container with auto-detection using the default factory +func CreateAutoDetectedContainer() (ContainerInterface, error) { + return defaultFactory.CreateAutoDetectedContainer() +} diff --git a/auth/internal/container/interfaces.go b/auth/internal/container/interfaces.go new file mode 100644 index 0000000..0c58fcd --- /dev/null +++ b/auth/internal/container/interfaces.go @@ -0,0 +1,26 @@ +package container + +import ( + "auth/internal/services" +) + +// Container manages all application dependencies +type Container struct { + Services *services.ServiceManager +} + +// ContainerInterface defines the contract for dependency containers +type ContainerInterface interface { + GetServices() *services.ServiceManager + Shutdown() error // For cleanup during shutdown +} + +// Implement the interface for Container +func (c *Container) GetServices() *services.ServiceManager { + return c.Services +} + +func (c *Container) Shutdown() error { + // Placeholder for cleanup logic (database connections, etc.) + return nil +} diff --git a/auth/internal/container/production_container.go b/auth/internal/container/production_container.go new file mode 100644 index 0000000..fe589ea --- /dev/null +++ b/auth/internal/container/production_container.go @@ -0,0 +1,134 @@ +package container + +import ( + "database/sql" + "fmt" + + "auth/internal/repositories" + "auth/internal/services" + "auth/internal/storage" + + "github.com/redis/go-redis/v9" +) + +// ProductionContainer implements ContainerInterface for production environment +type ProductionContainer struct { + *Container + db *sql.DB + redisClient *redis.Client +} + +// NewProductionContainer creates a container with real database connections +func NewProductionContainer(db *sql.DB, redisClient *redis.Client) (*ProductionContainer, error) { + if db == nil { + return nil, fmt.Errorf("database connection is required for production container") + } + + if redisClient == nil { + return nil, fmt.Errorf("redis connection is required for production container") + } + + // Create repositories with real database connections + userRepo := repositories.NewPostgresUserRepository(db) + otpRepo := repositories.NewRedisOTPRepository(redisClient) + + // Create repository manager + repoManager := &repositories.RepositoryManager{ + Users: userRepo, + OTPs: otpRepo, + } + + // Create services with repository dependencies + authService := services.NewAuthService(userRepo, otpRepo) + userService := services.NewUserService(userRepo) + + // Create service manager + serviceManager := &services.ServiceManager{ + Auth: authService, + User: userService, + Repos: repoManager, + } + + // Create container + container := &Container{ + Services: serviceManager, + } + + return &ProductionContainer{ + Container: container, + db: db, + redisClient: redisClient, + }, nil +} + +// NewProductionContainerFromStorage creates a container using existing storage connections +func NewProductionContainerFromStorage() (*ProductionContainer, error) { + // Use global storage connections + if storage.DB == nil { + return nil, fmt.Errorf("database connection not initialized - call storage.InitPostgres() first") + } + + if storage.RedisClient == nil { + return nil, fmt.Errorf("redis connection not initialized - call storage.InitRedis() first") + } + + return NewProductionContainer(storage.DB, storage.RedisClient) +} + +// Shutdown closes database connections and cleans up resources +func (c *ProductionContainer) Shutdown() error { + var errors []error + + // Close Redis connection + if c.redisClient != nil { + if err := c.redisClient.Close(); err != nil { + errors = append(errors, fmt.Errorf("failed to close Redis connection: %w", err)) + } + } + + // Close database connection + if c.db != nil { + if err := c.db.Close(); err != nil { + errors = append(errors, fmt.Errorf("failed to close database connection: %w", err)) + } + } + + // Return combined errors if any + if len(errors) > 0 { + errMsg := "shutdown errors: " + for i, err := range errors { + if i > 0 { + errMsg += "; " + } + errMsg += err.Error() + } + return fmt.Errorf(errMsg) + } + + return nil +} + +// GetDatabase returns the database connection (useful for migrations, health checks, etc.) +func (c *ProductionContainer) GetDatabase() *sql.DB { + return c.db +} + +// GetRedisClient returns the Redis client (useful for health checks, direct operations, etc.) +func (c *ProductionContainer) GetRedisClient() *redis.Client { + return c.redisClient +} + +// HealthCheck verifies that all dependencies are healthy +func (c *ProductionContainer) HealthCheck() error { + // Check database connection + if err := c.db.Ping(); err != nil { + return fmt.Errorf("database health check failed: %w", err) + } + + // Check Redis connection + if err := c.redisClient.Ping(storage.GetContext()).Err(); err != nil { + return fmt.Errorf("redis health check failed: %w", err) + } + + return nil +} diff --git a/auth/internal/container/redis_test_container.go b/auth/internal/container/redis_test_container.go new file mode 100644 index 0000000..4e51bbf --- /dev/null +++ b/auth/internal/container/redis_test_container.go @@ -0,0 +1,206 @@ +package container + +import ( + "context" + "database/sql" + "fmt" + "os" + + "auth/internal/repositories" + "auth/internal/services" + + _ "github.com/lib/pq" // PostgreSQL driver + "github.com/redis/go-redis/v9" +) + +// RedisTestContainer implements ContainerInterface for integration testing with real Redis +type RedisTestContainer struct { + *Container + redisClient *redis.Client + db *sql.DB + userRepo repositories.UserRepository + otpRepo repositories.OTPRepository +} + +// NewRedisTestContainer creates a container with real Redis and mock PostgreSQL for integration testing +func NewRedisTestContainer() (*RedisTestContainer, error) { + // Set up Redis connection + redisAddr := os.Getenv("TEST_REDIS_ADDR") + if redisAddr == "" { + redisAddr = "localhost:6379" + } + + redisClient := redis.NewClient(&redis.Options{ + Addr: redisAddr, + DB: 15, // Use DB 15 for testing + }) + + // Test Redis connection + ctx := context.Background() + _, err := redisClient.Ping(ctx).Result() + if err != nil { + return nil, fmt.Errorf("failed to connect to Redis at %s: %w", redisAddr, err) + } + + // Clear test database + err = redisClient.FlushDB(ctx).Err() + if err != nil { + return nil, fmt.Errorf("failed to flush Redis test database: %w", err) + } + + // For now, use mock user repository for faster testing + // In the future, this could be changed to use a real test database + userRepo := repositories.NewMockUserRepository() + + // Use real Redis OTP repository + otpRepo := repositories.NewRedisOTPRepository(redisClient) + + // Create repository manager + repoManager := &repositories.RepositoryManager{ + Users: userRepo, + OTPs: otpRepo, + } + + // Create services with real Redis dependency + authService := services.NewAuthService(userRepo, otpRepo) + userService := services.NewUserService(userRepo) + + // Create service manager + serviceManager := &services.ServiceManager{ + Auth: authService, + User: userService, + Repos: repoManager, + } + + // Create container + container := &Container{ + Services: serviceManager, + } + + return &RedisTestContainer{ + Container: container, + redisClient: redisClient, + userRepo: userRepo, + otpRepo: otpRepo, + }, nil +} + +// GetRedisClient returns the Redis client for testing +func (c *RedisTestContainer) GetRedisClient() *redis.Client { + return c.redisClient +} + +// GetUserRepository returns the user repository +func (c *RedisTestContainer) GetUserRepository() repositories.UserRepository { + return c.userRepo +} + +// GetOTPRepository returns the OTP repository +func (c *RedisTestContainer) GetOTPRepository() repositories.OTPRepository { + return c.otpRepo +} + +// GetMockUserRepository returns the mock user repository for test manipulation +func (c *RedisTestContainer) GetMockUserRepository() *repositories.MockUserRepository { + return c.userRepo.(*repositories.MockUserRepository) +} + +// ClearAllData clears all test data from Redis and mock repositories +func (c *RedisTestContainer) ClearAllData() { + // Clear Redis data + ctx := context.Background() + c.redisClient.FlushDB(ctx) + + // Clear mock user data + c.GetMockUserRepository().Clear() +} + +// Shutdown closes Redis connection and cleans up resources +func (c *RedisTestContainer) Shutdown() error { + // Clear test data + c.ClearAllData() + + // Close Redis connection + if c.redisClient != nil { + return c.redisClient.Close() + } + + return nil +} + +// Reset resets the container to initial state +func (c *RedisTestContainer) Reset() { + c.ClearAllData() +} + +// SeedTestData populates the container with test data +func (c *RedisTestContainer) SeedTestData() error { + userRepo := c.GetMockUserRepository() + + // Clear existing data first + c.ClearAllData() + + // Create test users + testUsers := []string{ + "+19999999999", // Test phone number that doesn't send SMS + "+911234567890", + "+919876543210", + } + + for _, phone := range testUsers { + _, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + return err + } + } + + return nil +} + +// GetTestPhoneNumber returns the standard test phone number +func (c *RedisTestContainer) GetTestPhoneNumber() string { + return "+19999999999" +} + +// GetTestOTP returns the fixed OTP for the test phone number +func (c *RedisTestContainer) GetTestOTP() string { + return "7415" +} + +// CreateTestUser creates a user for testing and returns it +func (c *RedisTestContainer) CreateTestUser(phone string) (*repositories.User, error) { + userRepo := c.GetMockUserRepository() + user, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + return nil, err + } + return user, nil +} + +// HealthCheck verifies that Redis connection is healthy +func (c *RedisTestContainer) HealthCheck() error { + ctx := context.Background() + _, err := c.redisClient.Ping(ctx).Result() + if err != nil { + return fmt.Errorf("Redis health check failed: %w", err) + } + return nil +} + +// IsRedisAvailable checks if Redis is available for testing +func IsRedisAvailable() bool { + redisAddr := os.Getenv("TEST_REDIS_ADDR") + if redisAddr == "" { + redisAddr = "localhost:6379" + } + + client := redis.NewClient(&redis.Options{ + Addr: redisAddr, + DB: 15, + }) + defer client.Close() + + ctx := context.Background() + _, err := client.Ping(ctx).Result() + return err == nil +} diff --git a/auth/internal/container/redis_test_container_test.go b/auth/internal/container/redis_test_container_test.go new file mode 100644 index 0000000..79ef919 --- /dev/null +++ b/auth/internal/container/redis_test_container_test.go @@ -0,0 +1,109 @@ +package container + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRedisTestContainer_Creation(t *testing.T) { + if !IsRedisAvailable() { + t.Skip("Redis not available, skipping RedisTestContainer tests") + } + + container, err := NewRedisTestContainer() + require.NoError(t, err) + defer container.Shutdown() + + // Test that container is properly initialized + assert.NotNil(t, container) + assert.NotNil(t, container.GetRedisClient()) + assert.NotNil(t, container.GetUserRepository()) + assert.NotNil(t, container.GetOTPRepository()) +} + +func TestRedisTestContainer_OTPIntegration(t *testing.T) { + if !IsRedisAvailable() { + t.Skip("Redis not available, skipping RedisTestContainer tests") + } + + container, err := NewRedisTestContainer() + require.NoError(t, err) + defer container.Shutdown() + + // Clear any existing data + container.ClearAllData() + + // Test OTP operations with real Redis + otpRepo := container.GetOTPRepository() + phone := "+919876543210" + otp := "1234" + expireAt := time.Now().Add(5 * time.Minute) + + // Store OTP + err = otpRepo.StoreOTP(phone, otp, expireAt) + assert.NoError(t, err) + + // Verify OTP exists + storedOTP, storedExpireAt, exists := otpRepo.GetOTPInfo(phone) + assert.True(t, exists) + assert.Equal(t, otp, storedOTP) + assert.WithinDuration(t, expireAt, storedExpireAt, 2*time.Second) + + // Verify OTP + valid := otpRepo.VerifyOTP(phone, otp) + assert.True(t, valid) + + // OTP should be consumed after verification + _, _, exists = otpRepo.GetOTPInfo(phone) + assert.False(t, exists) +} + +func TestRedisTestContainer_AuthServiceIntegration(t *testing.T) { + if !IsRedisAvailable() { + t.Skip("Redis not available, skipping RedisTestContainer tests") + } + + // Set required environment variables for JWT generation + originalSecretKey := os.Getenv("SECRET_KEY_BASE") + os.Setenv("SECRET_KEY_BASE", "test_secret_key_for_integration_testing") + defer func() { + if originalSecretKey == "" { + os.Unsetenv("SECRET_KEY_BASE") + } else { + os.Setenv("SECRET_KEY_BASE", originalSecretKey) + } + }() + + container, err := NewRedisTestContainer() + require.NoError(t, err) + defer container.Shutdown() + + // Clear any existing data + container.ClearAllData() + + // Test auth service with real Redis + authService := container.Services.Auth + phone := "+19999999999" // Use test phone number + expectedOTP := "7415" // Fixed OTP for test phone + + // Request OTP + err = authService.RequestOTP(phone) + assert.NoError(t, err) + + // Verify OTP with auth service + user, token, err := authService.VerifyOTP(phone, expectedOTP) + assert.NoError(t, err) + assert.NotNil(t, user) + assert.NotEmpty(t, token) + assert.Equal(t, phone, user.Phone) +} + +func TestIsRedisAvailable(t *testing.T) { + available := IsRedisAvailable() + t.Logf("Redis available: %v", available) + // This test just logs the availability, doesn't fail +} diff --git a/auth/internal/container/test_container.go b/auth/internal/container/test_container.go new file mode 100644 index 0000000..62ea5d9 --- /dev/null +++ b/auth/internal/container/test_container.go @@ -0,0 +1,137 @@ +package container + +import ( + "auth/internal/repositories" + "auth/internal/services" +) + +// TestContainer implements ContainerInterface for testing environment +type TestContainer struct { + *Container + userRepo repositories.UserRepository + otpRepo repositories.OTPRepository +} + +// NewTestContainer creates a container with mock repositories for testing +func NewTestContainer() *TestContainer { + // Create mock repositories + userRepo := repositories.NewMockUserRepository() + otpRepo := repositories.NewMockOTPRepository() + + // Create repository manager + repoManager := &repositories.RepositoryManager{ + Users: userRepo, + OTPs: otpRepo, + } + + // Create services with mock repository dependencies + authService := services.NewAuthService(userRepo, otpRepo) + userService := services.NewUserService(userRepo) + + // Create service manager + serviceManager := &services.ServiceManager{ + Auth: authService, + User: userService, + Repos: repoManager, + } + + // Create container + container := &Container{ + Services: serviceManager, + } + + return &TestContainer{ + Container: container, + userRepo: userRepo, + otpRepo: otpRepo, + } +} + +// GetUserRepository returns the mock user repository for test manipulation +func (c *TestContainer) GetUserRepository() repositories.UserRepository { + return c.userRepo +} + +// GetOTPRepository returns the mock OTP repository for test manipulation +func (c *TestContainer) GetOTPRepository() repositories.OTPRepository { + return c.otpRepo +} + +// GetMockUserRepository returns the mock user repository with testing methods +func (c *TestContainer) GetMockUserRepository() *repositories.MockUserRepository { + return c.userRepo.(*repositories.MockUserRepository) +} + +// GetMockOTPRepository returns the mock OTP repository with testing methods +func (c *TestContainer) GetMockOTPRepository() *repositories.MockOTPRepository { + return c.otpRepo.(*repositories.MockOTPRepository) +} + +// ClearAllData clears all mock data - useful for test cleanup +func (c *TestContainer) ClearAllData() { + c.GetMockUserRepository().Clear() + c.GetMockOTPRepository().Clear() +} + +// Shutdown for test container is a no-op since there are no resources to clean up +func (c *TestContainer) Shutdown() error { + // Clear data for clean shutdown + c.ClearAllData() + return nil +} + +// Reset resets the container to initial state (useful for test isolation) +func (c *TestContainer) Reset() { + c.ClearAllData() +} + +// SeedTestData populates the container with common test data +func (c *TestContainer) SeedTestData() error { + userRepo := c.GetMockUserRepository() + otpRepo := c.GetMockOTPRepository() + + // Clear existing data first + userRepo.Clear() + otpRepo.Clear() + + // Create some test users + testUsers := []string{ + "+19999999999", // Test phone number that doesn't send SMS + "+911234567890", + "+919876543210", + } + + for _, phone := range testUsers { + _, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + return err + } + } + + return nil +} + +// GetTestPhoneNumber returns the standard test phone number +func (c *TestContainer) GetTestPhoneNumber() string { + return "+19999999999" +} + +// GetTestOTP returns the fixed OTP for the test phone number +func (c *TestContainer) GetTestOTP() string { + return "7415" +} + +// CreateTestUser creates a user for testing and returns it +func (c *TestContainer) CreateTestUser(phone string) (*repositories.User, error) { + userRepo := c.GetMockUserRepository() + user, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + return nil, err + } + return user, nil +} + +// HealthCheck for test container always returns nil (always healthy) +func (c *TestContainer) HealthCheck() error { + return nil +} diff --git a/auth/internal/handlers/auth.go b/auth/internal/handlers/auth.go new file mode 100644 index 0000000..faf1c9d --- /dev/null +++ b/auth/internal/handlers/auth.go @@ -0,0 +1,199 @@ +package handlers + +import ( + "log" + "net/http" + + "auth/internal/services" + "auth/internal/utils" +) + +// RequestOTPRequest represents a request to request an OTP +type RequestOTPRequest struct { + Phone string `json:"phone"` +} + +// VerifyOTPRequest represents a request to verify an OTP +type VerifyOTPRequest struct { + Phone string `json:"phone"` + OTP string `json:"otp"` +} + +// AuthResponse represents an authentication response +type AuthResponse struct { + Message string `json:"message"` +} + +// AuthHandlers handles authentication using service layer +type AuthHandlers struct { + authService services.AuthService + userService services.UserService +} + +// NewAuthHandlers creates new auth handlers with service dependencies +func NewAuthHandlers(authService services.AuthService, userService services.UserService) *AuthHandlers { + return &AuthHandlers{ + authService: authService, + userService: userService, + } +} + +// RequestOTPHandler handles OTP request using services +func (h *AuthHandlers) RequestOTPHandler(w http.ResponseWriter, r *http.Request) { + var req RequestOTPRequest + if err := utils.DecodeJSONRequest(r, &req); err != nil { + log.Printf("Invalid JSON in OTP request: %v", err) + utils.SendJSONError(w, "Invalid JSON format in request body", http.StatusBadRequest) + return + } + + if req.Phone == "" { + utils.SendJSONError(w, "Phone number is required", http.StatusBadRequest) + return + } + + // Use service layer for OTP request + if err := h.authService.RequestOTP(req.Phone); err != nil { + log.Printf("Failed to request OTP for phone %s: %v", req.Phone, err) + + // Determine appropriate error response based on error message + // Validation errors typically contain specific validation messages + if isValidationError(err) { + utils.SendJSONError(w, err.Error(), http.StatusBadRequest) + } else { + utils.SendJSONError(w, "Failed to send OTP. Please try again later", http.StatusInternalServerError) + } + return + } + + utils.SendJSONResponse(w, AuthResponse{Message: "OTP sent successfully"}) +} + +// VerifyOTPHandler handles OTP verification using services +func (h *AuthHandlers) VerifyOTPHandler(w http.ResponseWriter, r *http.Request) { + var req VerifyOTPRequest + if err := utils.DecodeJSONRequest(r, &req); err != nil { + log.Printf("Invalid JSON in verify OTP request: %v", err) + utils.SendJSONError(w, "Invalid JSON format in request body", http.StatusBadRequest) + return + } + + if req.Phone == "" { + utils.SendJSONError(w, "Phone number is required", http.StatusBadRequest) + return + } + + if req.OTP == "" { + utils.SendJSONError(w, "OTP is required", http.StatusBadRequest) + return + } + + // Use service layer for OTP verification + user, tokenString, err := h.authService.VerifyOTP(req.Phone, req.OTP) + if err != nil { + log.Printf("Failed to verify OTP for phone %s: %v", req.Phone, err) + + // Determine appropriate error response based on error message + if isValidationError(err) { + utils.SendJSONError(w, err.Error(), http.StatusBadRequest) + } else if isAuthenticationError(err) { + utils.SendJSONError(w, "Invalid or expired OTP. Please request a new one", http.StatusUnauthorized) + } else { + utils.SendJSONError(w, "Authentication failed. Please try again later", http.StatusInternalServerError) + } + return + } + + response := map[string]interface{}{ + "message": "Authentication successful", + "token": tokenString, + "user": user, + } + utils.SendJSONResponse(w, response) +} + +// RequestOTPHandlerFunc returns a standard http.HandlerFunc +func (h *AuthHandlers) RequestOTPHandlerFunc() http.HandlerFunc { + return h.RequestOTPHandler +} + +// VerifyOTPHandlerFunc returns a standard http.HandlerFunc +func (h *AuthHandlers) VerifyOTPHandlerFunc() http.HandlerFunc { + return h.VerifyOTPHandler +} + +// Helper functions for error classification + +// isValidationError checks if an error is a validation error +func isValidationError(err error) bool { + if err == nil { + return false + } + + errMsg := err.Error() + // Check for common validation error patterns + validationPatterns := []string{ + "validation error", + "invalid phone", + "invalid format", + "phone number", + "invalid OTP", + "OTP format", + "name validation", + "required field", + } + + for _, pattern := range validationPatterns { + if contains(errMsg, pattern) { + return true + } + } + + return false +} + +// isAuthenticationError checks if an error is an authentication error +func isAuthenticationError(err error) bool { + if err == nil { + return false + } + + errMsg := err.Error() + // Check for common authentication error patterns + authPatterns := []string{ + "invalid OTP", + "expired OTP", + "OTP not found", + "authentication failed", + "unauthorized", + "OTP verification failed", + } + + for _, pattern := range authPatterns { + if contains(errMsg, pattern) { + return true + } + } + + return false +} + +// contains checks if a string contains a substring (case-insensitive) +func contains(s, substr string) bool { + return len(s) >= len(substr) && + (s == substr || + (len(s) > len(substr) && + (s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + indexOf(s, substr) >= 0))) +} + +// indexOf returns the index of the first occurrence of substr in s, or -1 if not found +func indexOf(s, substr string) int { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} diff --git a/auth/internal/handlers/handlers_test.go b/auth/internal/handlers/handlers_test.go new file mode 100644 index 0000000..b39deb0 --- /dev/null +++ b/auth/internal/handlers/handlers_test.go @@ -0,0 +1,179 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "auth/internal/container" +) + +func TestAuthHandlers(t *testing.T) { + // Set required environment variables for JWT generation + originalSecretKey := os.Getenv("SECRET_KEY_BASE") + os.Setenv("SECRET_KEY_BASE", "test_secret_key_for_handlers_testing") + defer func() { + if originalSecretKey == "" { + os.Unsetenv("SECRET_KEY_BASE") + } else { + os.Setenv("SECRET_KEY_BASE", originalSecretKey) + } + }() + + // Create test container with services + testContainer := container.CreateTestContainer() + defer testContainer.Shutdown() + + // Create handlers + handlers := NewHandlersFromContainer(testContainer) + + t.Run("RequestOTP success", func(t *testing.T) { + reqBody := RequestOTPRequest{ + Phone: testContainer.GetTestPhoneNumber(), + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/auth/request-otp", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handlers.Auth.RequestOTPHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var response AuthResponse + err := json.NewDecoder(w.Body).Decode(&response) + if err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if response.Message != "OTP sent successfully" { + t.Errorf("Expected success message, got %s", response.Message) + } + }) + + t.Run("RequestOTP invalid phone", func(t *testing.T) { + reqBody := RequestOTPRequest{ + Phone: "invalid-phone", + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/auth/request-otp", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handlers.Auth.RequestOTPHandler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", w.Code) + } + }) + + t.Run("VerifyOTP success", func(t *testing.T) { + // First request OTP + testPhone := testContainer.GetTestPhoneNumber() + testOTP := testContainer.GetTestOTP() + + // Request OTP first + authService := testContainer.GetServices().Auth + err := authService.RequestOTP(testPhone) + if err != nil { + t.Fatalf("Failed to request OTP: %v", err) + } + + // Now verify OTP + reqBody := VerifyOTPRequest{ + Phone: testPhone, + OTP: testOTP, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/auth/verify-otp", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handlers.Auth.VerifyOTPHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var response map[string]interface{} + err = json.NewDecoder(w.Body).Decode(&response) + if err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if response["message"] != "Authentication successful" { + t.Errorf("Expected authentication success message") + } + + if response["token"] == nil { + t.Error("Expected JWT token in response") + } + + if response["user"] == nil { + t.Error("Expected user in response") + } + }) +} + +// Helper types for testing +type testError struct { + msg string +} + +func (e *testError) Error() string { + return e.msg +} + +func TestErrorClassification(t *testing.T) { + t.Run("isValidationError", func(t *testing.T) { + testCases := []struct { + error string + expected bool + }{ + {"validation error: invalid phone", true}, + {"invalid phone number format", true}, + {"invalid OTP format", true}, + {"phone number is required", true}, + {"database connection failed", false}, + {"internal server error", false}, + } + + for _, tc := range testCases { + err := &testError{msg: tc.error} + result := isValidationError(err) + if result != tc.expected { + t.Errorf("For error '%s', expected %v, got %v", tc.error, tc.expected, result) + } + } + }) + + t.Run("isAuthenticationError", func(t *testing.T) { + testCases := []struct { + error string + expected bool + }{ + {"invalid OTP", true}, + {"expired OTP", true}, + {"OTP verification failed", true}, + {"authentication failed", true}, + {"database connection failed", false}, + {"validation error", false}, + } + + for _, tc := range testCases { + err := &testError{msg: tc.error} + result := isAuthenticationError(err) + if result != tc.expected { + t.Errorf("For error '%s', expected %v, got %v", tc.error, tc.expected, result) + } + } + }) +} diff --git a/auth/internal/handlers/manager.go b/auth/internal/handlers/manager.go new file mode 100644 index 0000000..26238e7 --- /dev/null +++ b/auth/internal/handlers/manager.go @@ -0,0 +1,126 @@ +package handlers + +import ( + "auth/internal/container" + "auth/internal/middleware" + "auth/internal/services" + "encoding/json" + "net/http" +) + +// Handlers combines all handlers +type Handlers struct { + Auth *AuthHandlers + User *UserHandlers +} + +// NewHandlers creates all handlers from a service manager +func NewHandlers(serviceManager *services.ServiceManager) *Handlers { + return &Handlers{ + Auth: NewAuthHandlers(serviceManager.Auth, serviceManager.User), + User: NewUserHandlers(serviceManager.User), + } +} + +// NewHandlersFromContainer creates all handlers from a container +func NewHandlersFromContainer(c container.ContainerInterface) *Handlers { + return NewHandlers(c.GetServices()) +} + +// HandlerManager provides access to all handlers +type HandlerManager struct { + handlers *Handlers +} + +// NewHandlerManager creates a handler manager +func NewHandlerManager(c container.ContainerInterface) *HandlerManager { + return &HandlerManager{ + handlers: NewHandlersFromContainer(c), + } +} + +// GetAuthHandlers returns auth handlers +func (hm *HandlerManager) GetAuthHandlers() *AuthHandlers { + return hm.handlers.Auth +} + +// GetUserHandlers returns user handlers +func (hm *HandlerManager) GetUserHandlers() *UserHandlers { + return hm.handlers.User +} + +// HealthChecker interface for containers that support health checks +type HealthChecker interface { + HealthCheck() error +} + +// NewHealthCheckHandler creates a health check handler that uses the container +func NewHealthCheckHandler(c container.ContainerInterface) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "status": "healthy", + "services": map[string]string{ + "auth": "up", + "user": "up", + }, + } + + // Check container health if it supports health checks + if healthChecker, ok := c.(HealthChecker); ok { + if err := healthChecker.HealthCheck(); err != nil { + response = map[string]interface{}{ + "status": "unhealthy", + "error": err.Error(), + } + w.WriteHeader(500) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + return + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + } +} + +// methodHandler wraps a handler to only accept specific HTTP methods +func methodHandler(method string, handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != method { + w.Header().Set("Allow", method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + handler(w, r) + } +} + +// SetupRoutes helper that sets up all routes with proper middleware +func SetupRoutes(mux *http.ServeMux, c container.ContainerInterface) { + handlers := NewHandlersFromContainer(c) + + // Public auth endpoints (with CORS and rate limiting middleware) + mux.Handle("/auth/request-otp", + middleware.CorsMiddleware( + middleware.RateLimitMiddleware( + handlers.Auth.RequestOTPHandlerFunc()))) + + mux.Handle("/auth/verify-otp", + middleware.CorsMiddleware( + handlers.Auth.VerifyOTPHandlerFunc())) + + // Protected user endpoints (with CORS and auth middleware) + mux.Handle("/auth/user", + middleware.CorsMiddleware( + middleware.AuthMiddleware( + methodHandler("GET", handlers.User.GetUserHandlerFunc())))) + + mux.Handle("/auth/user/update", + middleware.CorsMiddleware( + middleware.AuthMiddleware( + methodHandler("PUT", handlers.User.UpdateUserHandlerFunc())))) + + // Health check route + mux.HandleFunc("/health", NewHealthCheckHandler(c)) +} diff --git a/auth/internal/handlers/user.go b/auth/internal/handlers/user.go new file mode 100644 index 0000000..58d33f9 --- /dev/null +++ b/auth/internal/handlers/user.go @@ -0,0 +1,130 @@ +package handlers + +import ( + "encoding/json" + "log" + "net/http" + + "auth/internal/middleware" + "auth/internal/services" + "auth/internal/utils" +) + +// UpdateUserRequest represents a request to update user information +type UpdateUserRequest struct { + Name string `json:"name"` +} + +// UserHandlers handles user operations using service layer +type UserHandlers struct { + userService services.UserService +} + +// NewUserHandlers creates new user handlers with service dependencies +func NewUserHandlers(userService services.UserService) *UserHandlers { + return &UserHandlers{ + userService: userService, + } +} + +// GetUserHandler returns the current user information using services +func (h *UserHandlers) GetUserHandler(w http.ResponseWriter, r *http.Request) { + phone, ok := middleware.GetUserPhoneFromContext(r.Context()) + if !ok { + utils.SendJSONError(w, "User not found in context", http.StatusUnauthorized) + return + } + + // Use service layer to get user profile + user, err := h.userService.GetUserProfile(phone) + if err != nil { + log.Printf("Failed to get user profile for phone %s: %v", phone, err) + + // Determine appropriate error response based on error message + if isValidationError(err) { + utils.SendJSONError(w, err.Error(), http.StatusBadRequest) + } else if isUserNotFoundError(err) { + utils.SendJSONError(w, "User not found", http.StatusNotFound) + } else { + utils.SendJSONError(w, "Failed to retrieve user information", http.StatusInternalServerError) + } + return + } + + utils.SendJSONResponse(w, user) +} + +// UpdateUserHandler allows users to update their name using services +func (h *UserHandlers) UpdateUserHandler(w http.ResponseWriter, r *http.Request) { + phone, ok := middleware.GetUserPhoneFromContext(r.Context()) + if !ok { + utils.SendJSONError(w, "User not found in context", http.StatusUnauthorized) + return + } + + var req UpdateUserRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Printf("Invalid JSON in update user request: %v", err) + utils.SendJSONError(w, "Invalid JSON format in request body", http.StatusBadRequest) + return + } + + // Use service layer to update user profile + user, err := h.userService.UpdateUserProfile(phone, req.Name) + if err != nil { + log.Printf("Failed to update user profile for phone %s: %v", phone, err) + + // Determine appropriate error response based on error message + if isValidationError(err) { + utils.SendJSONError(w, err.Error(), http.StatusBadRequest) + } else if isUserNotFoundError(err) { + utils.SendJSONError(w, "User not found", http.StatusNotFound) + } else { + utils.SendJSONError(w, "Failed to update user profile", http.StatusInternalServerError) + } + return + } + + response := map[string]interface{}{ + "message": "User profile updated successfully", + "user": user, + } + utils.SendJSONResponse(w, response) +} + +// GetUserHandlerFunc returns a standard http.HandlerFunc +func (h *UserHandlers) GetUserHandlerFunc() http.HandlerFunc { + return h.GetUserHandler +} + +// UpdateUserHandlerFunc returns a standard http.HandlerFunc +func (h *UserHandlers) UpdateUserHandlerFunc() http.HandlerFunc { + return h.UpdateUserHandler +} + +// Additional helper functions for user-specific error classification + +// isUserNotFoundError checks if an error indicates a user was not found +func isUserNotFoundError(err error) bool { + if err == nil { + return false + } + + errMsg := err.Error() + // Check for user not found error patterns + notFoundPatterns := []string{ + "user not found", + "no user found", + "user does not exist", + "no rows", + "not found", + } + + for _, pattern := range notFoundPatterns { + if contains(errMsg, pattern) { + return true + } + } + + return false +} diff --git a/auth/internal/middleware/auth.go b/auth/internal/middleware/auth.go new file mode 100644 index 0000000..1b4304c --- /dev/null +++ b/auth/internal/middleware/auth.go @@ -0,0 +1,64 @@ +package middleware + +import ( + "context" + "net/http" + "strings" + + "auth/internal/config" + "auth/internal/utils" + "github.com/golang-jwt/jwt/v4" +) + +type Claims struct { + Phone string `json:"phone"` + jwt.RegisteredClaims +} + +type contextKey string + +const UserPhoneKey contextKey = "userPhone" + +// AuthMiddleware validates JWT token and adds user phone to context +func AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + utils.SendJSONError(w, "Authorization header is required", http.StatusUnauthorized) + return + } + + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + if tokenString == authHeader { + utils.SendJSONError(w, "Bearer token is required", http.StatusUnauthorized) + return + } + + token, err := parseJWTToken(tokenString) + if err != nil { + utils.SendJSONError(w, "Invalid token", http.StatusUnauthorized) + return + } + + if claims, ok := token.Claims.(*Claims); ok && token.Valid { + ctx := context.WithValue(r.Context(), UserPhoneKey, claims.Phone) + next.ServeHTTP(w, r.WithContext(ctx)) + } else { + utils.SendJSONError(w, "Invalid token claims", http.StatusUnauthorized) + } + }) +} + +// parseJWTToken parses and validates a JWT token +func parseJWTToken(tokenString string) (*jwt.Token, error) { + jwtSecret := config.GetJWTSecret() + return jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(jwtSecret), nil + }) +} + +// GetUserPhoneFromContext extracts user phone from request context +func GetUserPhoneFromContext(ctx context.Context) (string, bool) { + phone, ok := ctx.Value(UserPhoneKey).(string) + return phone, ok +} diff --git a/auth/internal/middleware/auth_test.go b/auth/internal/middleware/auth_test.go new file mode 100644 index 0000000..5154393 --- /dev/null +++ b/auth/internal/middleware/auth_test.go @@ -0,0 +1,127 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "auth/internal/utils" + + "github.com/golang-jwt/jwt/v4" +) + +func init() { + os.Setenv("SECRET_KEY_BASE", "testsecret") +} + +// TestAuthMiddleware tests JWT authentication middleware +func TestAuthMiddleware(t *testing.T) { + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + phone, ok := GetUserPhoneFromContext(r.Context()) + if !ok { + t.Error("Phone not found in context") + } + if phone != "+919876543210" { + t.Errorf("Expected phone +919876543210, got %s", phone) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("authenticated")) + }) + + authHandler := AuthMiddleware(testHandler) + + t.Run("valid JWT token", func(t *testing.T) { + // Generate valid JWT + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "phone": "+919876543210", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenString, _ := token.SignedString([]byte("testsecret")) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+tokenString) + + rec := httptest.NewRecorder() + authHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", rec.Code) + } + }) + + t.Run("missing authorization header", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + + rec := httptest.NewRecorder() + authHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("Expected 401, got %d", rec.Code) + } + + var errorResp utils.ErrorResponse + json.NewDecoder(rec.Body).Decode(&errorResp) + if errorResp.Message != "Authorization header is required" { + t.Errorf("Unexpected error message: %s", errorResp.Message) + } + }) + + t.Run("invalid token format", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "InvalidFormat") + + rec := httptest.NewRecorder() + authHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("Expected 401, got %d", rec.Code) + } + + var errorResp utils.ErrorResponse + json.NewDecoder(rec.Body).Decode(&errorResp) + if errorResp.Message != "Bearer token is required" { + t.Errorf("Unexpected error message: %s", errorResp.Message) + } + }) + + t.Run("invalid JWT token", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer invalidtoken") + + rec := httptest.NewRecorder() + authHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("Expected 401, got %d", rec.Code) + } + + var errorResp utils.ErrorResponse + json.NewDecoder(rec.Body).Decode(&errorResp) + if !strings.Contains(errorResp.Message, "Invalid token") { + t.Errorf("Unexpected error message: %s", errorResp.Message) + } + }) + + t.Run("expired JWT token", func(t *testing.T) { + // Generate expired JWT + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "phone": "+919876543210", + "exp": time.Now().Add(-time.Hour).Unix(), // Expired + }) + tokenString, _ := token.SignedString([]byte("testsecret")) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+tokenString) + + rec := httptest.NewRecorder() + authHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("Expected 401, got %d", rec.Code) + } + }) +} diff --git a/auth/internal/middleware/cors.go b/auth/internal/middleware/cors.go new file mode 100644 index 0000000..4b6f1b0 --- /dev/null +++ b/auth/internal/middleware/cors.go @@ -0,0 +1,67 @@ +package middleware + +import ( + "log" + "net/http" + "os" + "strings" + + "github.com/rs/cors" +) + +// Logger interface for dependency injection +type Logger interface { + Fatal(v ...interface{}) + Println(v ...interface{}) +} + +// StandardLogger wraps the standard log package +type StandardLogger struct{} + +func (l StandardLogger) Fatal(v ...interface{}) { + log.Fatal(v...) +} + +func (l StandardLogger) Println(v ...interface{}) { + log.Println(v...) +} + +// CorsMiddleware creates a CORS middleware using rs/cors library +// but maintains the same environment variable configuration as before +func CorsMiddleware(next http.Handler) http.Handler { + return CorsMiddlewareWithLogger(next, StandardLogger{}) +} + +// CorsMiddlewareWithLogger creates a CORS middleware with a custom logger (for testing) +func CorsMiddlewareWithLogger(next http.Handler, logger Logger) http.Handler { + // Get allowed origins from environment variable (same as scribbl_backend) + allowedOrigins := os.Getenv("CORS_ALLOWED_ORIGINS") + if allowedOrigins == "" { + // Check if we're in production environment + appEnv := os.Getenv("APP_ENV") + if appEnv == "production" { + // In production, we require explicit CORS configuration + logger.Fatal("CORS_ALLOWED_ORIGINS must be set in production environment") + } + + // Default to localhost for development only + logger.Println("Warning: CORS_ALLOWED_ORIGINS not set, defaulting to localhost (development mode)") + allowedOrigins = "http://localhost:3000,http://localhost:3001" + } + + // Parse the allowed origins + origins := strings.Split(allowedOrigins, ",") + for i, origin := range origins { + origins[i] = strings.TrimSpace(origin) + } + + // Create CORS middleware with rs/cors + c := cors.New(cors.Options{ + AllowedOrigins: origins, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + AllowCredentials: true, + }) + + return c.Handler(next) +} diff --git a/auth/internal/middleware/cors_test.go b/auth/internal/middleware/cors_test.go new file mode 100644 index 0000000..ad71ffd --- /dev/null +++ b/auth/internal/middleware/cors_test.go @@ -0,0 +1,120 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +// TestCORSMiddleware tests basic CORS header handling +func TestCORSMiddleware(t *testing.T) { + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }) + + corsHandler := CorsMiddleware(testHandler) + + t.Run("preflight request", func(t *testing.T) { + req := httptest.NewRequest("OPTIONS", "/test", nil) + req.Header.Set("Origin", "http://localhost:3000") + req.Header.Set("Access-Control-Request-Method", "POST") + + rec := httptest.NewRecorder() + corsHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("Expected 204 for preflight, got %d", rec.Code) + } + + // Check that basic CORS headers are present + if rec.Header().Get("Access-Control-Allow-Origin") != "http://localhost:3000" { + t.Errorf("Expected Access-Control-Allow-Origin to be http://localhost:3000, got %s", rec.Header().Get("Access-Control-Allow-Origin")) + } + if rec.Header().Get("Access-Control-Allow-Methods") == "" { + t.Error("Expected Access-Control-Allow-Methods to be set") + } + if rec.Header().Get("Access-Control-Allow-Credentials") != "true" { + t.Error("Expected Access-Control-Allow-Credentials to be true") + } + }) + + t.Run("regular request", func(t *testing.T) { + req := httptest.NewRequest("POST", "/test", nil) + req.Header.Set("Origin", "http://localhost:3000") + + rec := httptest.NewRecorder() + corsHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", rec.Code) + } + + // Check that CORS headers are set for regular requests too + if rec.Header().Get("Access-Control-Allow-Origin") != "http://localhost:3000" { + t.Error("Missing CORS headers on regular request") + } + }) +} + +// TestCORSMiddlewareEnvironments tests environment-specific CORS behavior +func TestCORSMiddlewareEnvironments(t *testing.T) { + // Save original environment values + originalCORS := os.Getenv("CORS_ALLOWED_ORIGINS") + originalAppEnv := os.Getenv("APP_ENV") + + // Clean up after test + defer func() { + os.Setenv("CORS_ALLOWED_ORIGINS", originalCORS) + os.Setenv("APP_ENV", originalAppEnv) + }() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }) + + t.Run("production environment with explicit CORS_ALLOWED_ORIGINS", func(t *testing.T) { + os.Setenv("CORS_ALLOWED_ORIGINS", "https://example.com,https://www.example.com") + os.Setenv("APP_ENV", "production") + + corsHandler := CorsMiddleware(testHandler) + + req := httptest.NewRequest("POST", "/test", nil) + req.Header.Set("Origin", "https://example.com") + + rec := httptest.NewRecorder() + corsHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", rec.Code) + } + + if rec.Header().Get("Access-Control-Allow-Origin") != "https://example.com" { + t.Error("Expected explicit origin to be allowed in production") + } + }) + + t.Run("production environment should reject localhost", func(t *testing.T) { + os.Setenv("CORS_ALLOWED_ORIGINS", "https://example.com") + os.Setenv("APP_ENV", "production") + + corsHandler := CorsMiddleware(testHandler) + + req := httptest.NewRequest("POST", "/test", nil) + req.Header.Set("Origin", "http://localhost:3000") + + rec := httptest.NewRecorder() + corsHandler.ServeHTTP(rec, req) + + // Should still return 200 but without CORS headers for disallowed origin + if rec.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", rec.Code) + } + + if rec.Header().Get("Access-Control-Allow-Origin") != "" { + t.Error("Expected localhost to be rejected in production with explicit origins") + } + }) +} diff --git a/auth/internal/middleware/ratelimit.go b/auth/internal/middleware/ratelimit.go new file mode 100644 index 0000000..a7a3b79 --- /dev/null +++ b/auth/internal/middleware/ratelimit.go @@ -0,0 +1,105 @@ +package middleware + +import ( + "encoding/json" + "io" + "net/http" + "os" + "strconv" + "strings" + "time" + + "auth/internal/storage" +) + +// ErrorResponse represents a standardized error response (shared with handlers) +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message"` + Code int `json:"code"` +} + +// sendJSONError sends a JSON error response +func sendJSONError(w http.ResponseWriter, message string, statusCode int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + var errorType string + switch statusCode { + case http.StatusTooManyRequests: + errorType = "TOO_MANY_REQUESTS" + default: + errorType = "ERROR" + } + + response := ErrorResponse{ + Error: errorType, + Message: message, + Code: statusCode, + } + + json.NewEncoder(w).Encode(response) +} + +func RateLimitMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + next.ServeHTTP(w, r) + return + } + var phone string + var originalBody []byte + + if strings.Contains(r.URL.Path, "/auth/request-otp") { + // Read the entire body first + var err error + originalBody, err = io.ReadAll(r.Body) + if err != nil { + sendJSONError(w, "Failed to read request body", http.StatusBadRequest) + return + } + r.Body.Close() + + // Try to parse the phone number for rate limiting + type req struct { + Phone string `json:"phone"` + } + var body req + if json.Unmarshal(originalBody, &body) == nil { + phone = body.Phone + } + + // Always reconstruct the body for the handler to use + r.Body = io.NopCloser(strings.NewReader(string(originalBody))) + } + + if phone != "" { + // Skip rate limiting if Redis is not initialized (test scenario) + if storage.RedisClient == nil { + next.ServeHTTP(w, r) + return + } + + key := "rl:" + phone + count, _ := storage.RedisClient.Incr(storage.GetContext(), key).Result() + if count == 1 { + storage.RedisClient.Expire(storage.GetContext(), key, time.Minute) + } + + // Get rate limit from environment variable, default to 5 if not set or invalid + rateLimitStr := os.Getenv("RATE_LIMIT_PER_MINUTE") + rateLimit := 5 // Default value + if rateLimitStr != "" { + if parsedLimit, err := strconv.Atoi(rateLimitStr); err == nil && parsedLimit > 0 { + rateLimit = parsedLimit + } + } + + if count > int64(rateLimit) { + sendJSONError(w, "Rate limit exceeded. Please try again later", http.StatusTooManyRequests) + return + } + } + next.ServeHTTP(w, r) + }) +} diff --git a/auth/internal/middleware/ratelimit_test.go b/auth/internal/middleware/ratelimit_test.go new file mode 100644 index 0000000..ebc6274 --- /dev/null +++ b/auth/internal/middleware/ratelimit_test.go @@ -0,0 +1,112 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + + "auth/internal/storage" + "auth/internal/utils" +) + +func init() { + os.Setenv("SECRET_KEY_BASE", "testsecret") +} + +func setupRedisForTests() { + storage.InitRedis() + storage.RedisClient.FlushDB(storage.GetContext()) +} + +// TestRateLimitMiddleware tests the rate limiting functionality +func TestRateLimitMiddleware(t *testing.T) { + setupRedisForTests() + + // Create a test handler + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }) + + // Wrap with rate limit middleware + rateLimitedHandler := RateLimitMiddleware(testHandler) + + // Make 5 requests (should all succeed) - use /auth/request-otp endpoint for rate limiting + phone := "+919876543210" + for i := 0; i < 5; i++ { + reqBody := fmt.Sprintf(`{"phone":"%s"}`, phone) + req := httptest.NewRequest("POST", "/auth/request-otp", bytes.NewReader([]byte(reqBody))) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "127.0.0.1:12345" // Same IP for all requests + + rec := httptest.NewRecorder() + rateLimitedHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("Request %d: expected 200, got %d", i+1, rec.Code) + } + } + + // 6th request should be rate limited + reqBody := fmt.Sprintf(`{"phone":"%s"}`, phone) + req := httptest.NewRequest("POST", "/auth/request-otp", bytes.NewReader([]byte(reqBody))) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "127.0.0.1:12345" + + rec := httptest.NewRecorder() + rateLimitedHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusTooManyRequests { + t.Fatalf("Expected 429 on rate limit, got %d", rec.Code) + } + + // Verify JSON error response + var errorResp utils.ErrorResponse + if err := json.NewDecoder(rec.Body).Decode(&errorResp); err != nil { + t.Fatalf("Failed to decode rate limit response: %v", err) + } + + if errorResp.Error != "TOO_MANY_REQUESTS" { + t.Errorf("Expected 'TOO_MANY_REQUESTS', got '%s'", errorResp.Error) + } +} + +// TestRateLimitBodyReconstruction tests that rate limiting properly handles request body +func TestRateLimitBodyReconstruction(t *testing.T) { + setupRedisForTests() + + // Handler that checks the request body + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data map[string]string + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + utils.SendJSONError(w, "Invalid JSON", http.StatusBadRequest) + return + } + + if data["phone"] != "+919876543210" { + utils.SendJSONError(w, "Body not reconstructed properly", http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }) + + rateLimitedHandler := RateLimitMiddleware(testHandler) + + // Test that body is properly reconstructed after rate limit check + req := httptest.NewRequest("POST", "/auth/request-otp", bytes.NewReader([]byte(`{"phone":"+919876543210"}`))) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "127.0.0.1:54321" + + rec := httptest.NewRecorder() + rateLimitedHandler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d. Body: %s", rec.Code, rec.Body.String()) + } +} diff --git a/auth/internal/repositories/interfaces.go b/auth/internal/repositories/interfaces.go new file mode 100644 index 0000000..297b0e3 --- /dev/null +++ b/auth/internal/repositories/interfaces.go @@ -0,0 +1,37 @@ +package repositories + +import ( + "time" +) + +// User represents a user in the system +type User struct { + ID int64 `json:"id"` + Phone string `json:"phone"` + Name *string `json:"name"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// UserRepository handles all user-related data operations +type UserRepository interface { + CreateUserIfNotExists(phone string) (*User, error) + GetUserByPhone(phone string) (*User, error) + UpdateUserName(phone, name string) (*User, error) + DeleteUser(phone string) error // For future extensibility +} + +// OTPRepository handles OTP storage and verification +type OTPRepository interface { + StoreOTP(phone, otp string, expireAt time.Time) error + VerifyOTP(phone, otp string) bool + InvalidateOTP(phone string) error // For cleanup + GetOTPInfo(phone string) (string, time.Time, bool) // For testing and debugging + GetOTPTTL(phone string) (time.Duration, error) // For monitoring and debugging +} + +// RepositoryManager provides access to all repositories +type RepositoryManager struct { + Users UserRepository + OTPs OTPRepository +} diff --git a/auth/internal/repositories/mock_otp_repo.go b/auth/internal/repositories/mock_otp_repo.go new file mode 100644 index 0000000..3d542cd --- /dev/null +++ b/auth/internal/repositories/mock_otp_repo.go @@ -0,0 +1,168 @@ +package repositories + +import ( + "sync" + "time" +) + +// OTPEntry represents an OTP entry in mock storage +type OTPEntry struct { + OTP string + ExpireAt int64 // Unix timestamp +} + +// MockOTPRepository implements OTPRepository interface for testing +type MockOTPRepository struct { + otps map[string]OTPEntry + mutex sync.RWMutex +} + +// NewMockOTPRepository creates a new mock OTP repository +func NewMockOTPRepository() *MockOTPRepository { + return &MockOTPRepository{ + otps: make(map[string]OTPEntry), + } +} + +// StoreOTP stores an OTP with expiration time +func (r *MockOTPRepository) StoreOTP(phone, otp string, expireAt time.Time) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.otps[phone] = OTPEntry{ + OTP: otp, + ExpireAt: expireAt.Unix(), + } + + return nil +} + +// VerifyOTP verifies an OTP and consumes it if valid +func (r *MockOTPRepository) VerifyOTP(phone, otp string) bool { + r.mutex.Lock() + defer r.mutex.Unlock() + + entry, exists := r.otps[phone] + if !exists { + return false + } + + // Check if OTP matches + if entry.OTP != otp { + return false + } + + // Check if OTP has expired + if time.Now().Unix() > entry.ExpireAt { + // Clean up expired OTP + delete(r.otps, phone) + return false + } + + // OTP is valid, consume it (remove from storage) + delete(r.otps, phone) + return true +} + +// InvalidateOTP invalidates/deletes an OTP +func (r *MockOTPRepository) InvalidateOTP(phone string) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + delete(r.otps, phone) + return nil +} + +// Test helper methods + +// Clear clears all OTPs from the repository +func (r *MockOTPRepository) Clear() { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.otps = make(map[string]OTPEntry) +} + +// Count returns the number of OTPs in the repository +func (r *MockOTPRepository) Count() int { + r.mutex.RLock() + defer r.mutex.RUnlock() + + return len(r.otps) +} + +// GetOTPInfo returns OTP information without consuming it (for testing) +func (r *MockOTPRepository) GetOTPInfo(phone string) (string, time.Time, bool) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + entry, exists := r.otps[phone] + if !exists { + return "", time.Time{}, false + } + + return entry.OTP, time.Unix(entry.ExpireAt, 0), true +} + +// GetAll returns all OTPs (for testing purposes) +func (r *MockOTPRepository) GetAll() map[string]OTPEntry { + r.mutex.RLock() + defer r.mutex.RUnlock() + + result := make(map[string]OTPEntry) + for k, v := range r.otps { + result[k] = v + } + + return result +} + +// IsExpired checks if an OTP has expired without consuming it +func (r *MockOTPRepository) IsExpired(phone string) bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + + entry, exists := r.otps[phone] + if !exists { + return true // Consider non-existent OTPs as expired + } + + return time.Now().Unix() > entry.ExpireAt +} + +// CleanupExpired removes all expired OTPs +func (r *MockOTPRepository) CleanupExpired() int { + r.mutex.Lock() + defer r.mutex.Unlock() + + now := time.Now().Unix() + count := 0 + + for phone, entry := range r.otps { + if now > entry.ExpireAt { + delete(r.otps, phone) + count++ + } + } + + return count +} + +// GetOTPTTL returns the remaining time until expiration for an OTP +func (r *MockOTPRepository) GetOTPTTL(phone string) (time.Duration, error) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + entry, exists := r.otps[phone] + if !exists { + return time.Duration(-2), nil // Mimic Redis behavior for non-existent keys + } + + now := time.Now().Unix() + if now > entry.ExpireAt { + return time.Duration(-1), nil // Mimic Redis behavior for expired keys + } + + remainingSeconds := entry.ExpireAt - now + return time.Duration(remainingSeconds) * time.Second, nil +} diff --git a/auth/internal/repositories/mock_user_repo.go b/auth/internal/repositories/mock_user_repo.go new file mode 100644 index 0000000..1b926a5 --- /dev/null +++ b/auth/internal/repositories/mock_user_repo.go @@ -0,0 +1,126 @@ +package repositories + +import ( + "database/sql" + "sync" + "time" +) + +// MockUserRepository implements UserRepository interface for testing +type MockUserRepository struct { + users map[string]*User + mutex sync.RWMutex + nextID int +} + +// NewMockUserRepository creates a new mock user repository +func NewMockUserRepository() *MockUserRepository { + return &MockUserRepository{ + users: make(map[string]*User), + nextID: 1, + } +} + +// CreateUserIfNotExists creates a new user if one doesn't exist, otherwise returns existing user +func (r *MockUserRepository) CreateUserIfNotExists(phone string) (*User, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if user, exists := r.users[phone]; exists { + return user, nil + } + + user := &User{ + ID: int64(r.nextID), + Phone: phone, + Name: nil, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + r.users[phone] = user + r.nextID++ + + return user, nil +} + +// GetUserByPhone retrieves a user by phone number +func (r *MockUserRepository) GetUserByPhone(phone string) (*User, error) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + if user, exists := r.users[phone]; exists { + return user, nil + } + + return nil, sql.ErrNoRows +} + +// UpdateUserName updates a user's name +func (r *MockUserRepository) UpdateUserName(phone, name string) (*User, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + user, exists := r.users[phone] + if !exists { + return nil, sql.ErrNoRows + } + + user.Name = &name + user.UpdatedAt = time.Now() + + return user, nil +} + +// DeleteUser deletes a user by phone number +func (r *MockUserRepository) DeleteUser(phone string) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + if _, exists := r.users[phone]; !exists { + return sql.ErrNoRows + } + + delete(r.users, phone) + return nil +} + +// Test helper methods + +// Clear clears all users from the repository +func (r *MockUserRepository) Clear() { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.users = make(map[string]*User) + r.nextID = 1 +} + +// Count returns the number of users in the repository +func (r *MockUserRepository) Count() int { + r.mutex.RLock() + defer r.mutex.RUnlock() + + return len(r.users) +} + +// GetAll returns all users (for testing purposes) +func (r *MockUserRepository) GetAll() map[string]*User { + r.mutex.RLock() + defer r.mutex.RUnlock() + + result := make(map[string]*User) + for k, v := range r.users { + result[k] = v + } + + return result +} + +// SetNextID sets the next ID to be used (for testing purposes) +func (r *MockUserRepository) SetNextID(id int) { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.nextID = id +} diff --git a/auth/internal/repositories/postgres_user_repo.go b/auth/internal/repositories/postgres_user_repo.go new file mode 100644 index 0000000..8566ee2 --- /dev/null +++ b/auth/internal/repositories/postgres_user_repo.go @@ -0,0 +1,128 @@ +package repositories + +import ( + "database/sql" + "fmt" +) + +// PostgresUserRepository implements UserRepository interface for PostgreSQL +type PostgresUserRepository struct { + db *sql.DB +} + +// NewPostgresUserRepository creates a new PostgreSQL user repository +func NewPostgresUserRepository(db *sql.DB) UserRepository { + return &PostgresUserRepository{db: db} +} + +// CreateUserIfNotExists creates a new user if one doesn't exist, otherwise returns existing user +func (r *PostgresUserRepository) CreateUserIfNotExists(phone string) (*User, error) { + // Check if user already exists + user, err := r.GetUserByPhone(phone) + if err != nil && err != sql.ErrNoRows { + return nil, fmt.Errorf("error checking if user exists: %w", err) + } + + // If user exists, return the existing user + if user != nil { + return user, nil + } + + // Create new user + query := ` + INSERT INTO users (phone) + VALUES ($1) + RETURNING id, phone, name, created_at, updated_at + ` + + newUser := &User{} + err = r.db.QueryRow(query, phone).Scan( + &newUser.ID, + &newUser.Phone, + &newUser.Name, + &newUser.CreatedAt, + &newUser.UpdatedAt, + ) + + if err != nil { + return nil, fmt.Errorf("error creating user: %w", err) + } + + return newUser, nil +} + +// GetUserByPhone retrieves a user by phone number +func (r *PostgresUserRepository) GetUserByPhone(phone string) (*User, error) { + query := ` + SELECT id, phone, name, created_at, updated_at + FROM users + WHERE phone = $1 + ` + + user := &User{} + err := r.db.QueryRow(query, phone).Scan( + &user.ID, + &user.Phone, + &user.Name, + &user.CreatedAt, + &user.UpdatedAt, + ) + + if err != nil { + if err == sql.ErrNoRows { + return nil, err + } + return nil, fmt.Errorf("error getting user by phone: %w", err) + } + + return user, nil +} + +// UpdateUserName updates a user's name +func (r *PostgresUserRepository) UpdateUserName(phone, name string) (*User, error) { + query := ` + UPDATE users + SET name = $1, updated_at = CURRENT_TIMESTAMP + WHERE phone = $2 + RETURNING id, phone, name, created_at, updated_at + ` + + user := &User{} + err := r.db.QueryRow(query, name, phone).Scan( + &user.ID, + &user.Phone, + &user.Name, + &user.CreatedAt, + &user.UpdatedAt, + ) + + if err != nil { + if err == sql.ErrNoRows { + return nil, err + } + return nil, fmt.Errorf("error updating user name: %w", err) + } + + return user, nil +} + +// DeleteUser deletes a user by phone number +func (r *PostgresUserRepository) DeleteUser(phone string) error { + query := `DELETE FROM users WHERE phone = $1` + + result, err := r.db.Exec(query, phone) + if err != nil { + return fmt.Errorf("error deleting user: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting rows affected: %w", err) + } + + if rowsAffected == 0 { + return sql.ErrNoRows + } + + return nil +} diff --git a/auth/internal/repositories/redis_otp_repo.go b/auth/internal/repositories/redis_otp_repo.go new file mode 100644 index 0000000..442f278 --- /dev/null +++ b/auth/internal/repositories/redis_otp_repo.go @@ -0,0 +1,131 @@ +package repositories + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +// RedisOTPRepository implements OTPRepository interface for Redis +type RedisOTPRepository struct { + client *redis.Client + ctx context.Context +} + +// NewRedisOTPRepository creates a new Redis OTP repository +func NewRedisOTPRepository(client *redis.Client) OTPRepository { + return &RedisOTPRepository{ + client: client, + ctx: context.Background(), + } +} + +// StoreOTP stores an OTP with expiration time +func (r *RedisOTPRepository) StoreOTP(phone, otp string, expireAt time.Time) error { + key := fmt.Sprintf("otp:%s", phone) + value := fmt.Sprintf("%s:%d", otp, expireAt.Unix()) + duration := time.Until(expireAt) + + if duration <= 0 { + return fmt.Errorf("OTP expiration time is in the past") + } + + err := r.client.Set(r.ctx, key, value, duration).Err() + if err != nil { + return fmt.Errorf("error storing OTP in Redis: %w", err) + } + + return nil +} + +// VerifyOTP verifies an OTP and consumes it if valid +func (r *RedisOTPRepository) VerifyOTP(phone, otp string) bool { + key := fmt.Sprintf("otp:%s", phone) + + val, err := r.client.Get(r.ctx, key).Result() + if err != nil { + // OTP not found or Redis error + return false + } + + // Split the value by colon to separate OTP and timestamp + parts := strings.Split(val, ":") + if len(parts) != 2 { + return false + } + + storedOTP := parts[0] + expireAt, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return false + } + + // Check if OTP matches + if storedOTP != otp { + return false + } + + // Check if OTP has expired + if time.Now().Unix() > expireAt { + // Clean up expired OTP + r.client.Del(r.ctx, key) + return false + } + + // OTP is valid, consume it (delete from Redis) + r.client.Del(r.ctx, key) + return true +} + +// InvalidateOTP invalidates/deletes an OTP +func (r *RedisOTPRepository) InvalidateOTP(phone string) error { + key := fmt.Sprintf("otp:%s", phone) + + err := r.client.Del(r.ctx, key).Err() + if err != nil { + return fmt.Errorf("error invalidating OTP in Redis: %w", err) + } + + return nil +} + +// Additional helper methods for testing and monitoring + +// GetOTPInfo returns OTP information without consuming it (for testing) +func (r *RedisOTPRepository) GetOTPInfo(phone string) (string, time.Time, bool) { + key := fmt.Sprintf("otp:%s", phone) + + val, err := r.client.Get(r.ctx, key).Result() + if err != nil { + return "", time.Time{}, false + } + + parts := strings.Split(val, ":") + if len(parts) != 2 { + return "", time.Time{}, false + } + + storedOTP := parts[0] + expireAt, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return "", time.Time{}, false + } + + return storedOTP, time.Unix(expireAt, 0), true +} + +// GetOTPTTL returns the remaining TTL for an OTP +func (r *RedisOTPRepository) GetOTPTTL(phone string) (time.Duration, error) { + key := fmt.Sprintf("otp:%s", phone) + + ttl, err := r.client.TTL(r.ctx, key).Result() + if err != nil { + return 0, fmt.Errorf("error getting OTP TTL: %w", err) + } + + return ttl, nil +} diff --git a/auth/internal/repositories/redis_otp_repo_test.go b/auth/internal/repositories/redis_otp_repo_test.go new file mode 100644 index 0000000..f4a13c9 --- /dev/null +++ b/auth/internal/repositories/redis_otp_repo_test.go @@ -0,0 +1,156 @@ +package repositories + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test configuration +const ( + testRedisAddr = "localhost:6379" + testRedisDB = 15 // Use DB 15 for testing to avoid conflicts +) + +// setupTestRedis creates a Redis client for testing +func setupTestRedis(t *testing.T) *redis.Client { + // Skip if Redis is not available + if testing.Short() { + t.Skip("Skipping Redis integration tests in short mode") + } + + client := redis.NewClient(&redis.Options{ + Addr: testRedisAddr, + DB: testRedisDB, + }) + + // Test connection + ctx := context.Background() + _, err := client.Ping(ctx).Result() + if err != nil { + t.Skipf("Redis not available at %s: %v", testRedisAddr, err) + } + + // Clear test database + client.FlushDB(ctx) + + return client +} + +// teardownTestRedis cleans up Redis test data +func teardownTestRedis(t *testing.T, client *redis.Client) { + if client != nil { + client.FlushDB(context.Background()) + client.Close() + } +} + +func TestRedisOTPRepository_StoreAndVerify(t *testing.T) { + client := setupTestRedis(t) + defer teardownTestRedis(t, client) + + repo := NewRedisOTPRepository(client) + phone := "+919876543210" + otp := "1234" + expireAt := time.Now().Add(5 * time.Minute) + + t.Run("store and verify valid OTP", func(t *testing.T) { + // Store OTP + err := repo.StoreOTP(phone, otp, expireAt) + require.NoError(t, err) + + // Verify OTP + valid := repo.VerifyOTP(phone, otp) + assert.True(t, valid) + + // OTP should be consumed after verification + _, _, exists := repo.GetOTPInfo(phone) + assert.False(t, exists) + }) + + t.Run("verify invalid OTP", func(t *testing.T) { + // Store OTP + err := repo.StoreOTP(phone, otp, expireAt) + require.NoError(t, err) + + // Try wrong OTP + valid := repo.VerifyOTP(phone, "0000") + assert.False(t, valid) + + // OTP should still exist + _, _, exists := repo.GetOTPInfo(phone) + assert.True(t, exists) + }) + + t.Run("verify expired OTP", func(t *testing.T) { + // Store OTP with very short expiration + expiredTime := time.Now().Add(1 * time.Second) + err := repo.StoreOTP(phone, otp, expiredTime) + require.NoError(t, err) + + // Wait for expiration + time.Sleep(2 * time.Second) + + // Try to verify expired OTP + valid := repo.VerifyOTP(phone, otp) + assert.False(t, valid) + }) +} + +func TestRedisOTPRepository_GetOTPInfo(t *testing.T) { + client := setupTestRedis(t) + defer teardownTestRedis(t, client) + + repo := NewRedisOTPRepository(client) + phone := "+919876543210" + otp := "1234" + expireAt := time.Now().Add(5 * time.Minute) + + t.Run("get existing OTP info", func(t *testing.T) { + // Store OTP first + err := repo.StoreOTP(phone, otp, expireAt) + require.NoError(t, err) + + // Get OTP info + storedOTP, storedExpireAt, exists := repo.GetOTPInfo(phone) + assert.True(t, exists) + assert.Equal(t, otp, storedOTP) + assert.WithinDuration(t, expireAt, storedExpireAt, 2*time.Second) + }) + + t.Run("get non-existent OTP info", func(t *testing.T) { + phone := "+919876543211" + + // Get info for non-existent OTP + _, _, exists := repo.GetOTPInfo(phone) + assert.False(t, exists) + }) +} + +func TestRedisOTPRepository_InvalidateOTP(t *testing.T) { + client := setupTestRedis(t) + defer teardownTestRedis(t, client) + + repo := NewRedisOTPRepository(client) + phone := "+919876543210" + otp := "1234" + expireAt := time.Now().Add(5 * time.Minute) + + t.Run("invalidate existing OTP", func(t *testing.T) { + // Store OTP first + err := repo.StoreOTP(phone, otp, expireAt) + require.NoError(t, err) + + // Invalidate OTP + err = repo.InvalidateOTP(phone) + assert.NoError(t, err) + + // Verify OTP was invalidated + _, _, exists := repo.GetOTPInfo(phone) + assert.False(t, exists) + }) +} diff --git a/auth/internal/services/auth_service.go b/auth/internal/services/auth_service.go new file mode 100644 index 0000000..85d3456 --- /dev/null +++ b/auth/internal/services/auth_service.go @@ -0,0 +1,132 @@ +package services + +import ( + "fmt" + "time" + + "auth/internal/config" + "auth/internal/repositories" + "auth/internal/utils" + + "github.com/golang-jwt/jwt/v4" +) + +// AuthServiceImpl implements AuthService interface +type AuthServiceImpl struct { + userRepo repositories.UserRepository + otpRepo repositories.OTPRepository +} + +// NewAuthService creates a new authentication service +func NewAuthService(userRepo repositories.UserRepository, otpRepo repositories.OTPRepository) AuthService { + return &AuthServiceImpl{ + userRepo: userRepo, + otpRepo: otpRepo, + } +} + +// RequestOTP generates and sends an OTP for the given phone number +func (s *AuthServiceImpl) RequestOTP(phone string) error { + // Validate phone number format + if err := utils.ValidatePhoneFormat(phone); err != nil { + return fmt.Errorf("invalid phone format: %w", err) + } + + // Generate OTP + otp := utils.GenerateOTPForPhone(phone) + + // Send OTP via SMS first (fail fast if SMS fails) + if err := utils.SendOTPWith2Factor(phone, otp); err != nil { + return fmt.Errorf("failed to send OTP: %w", err) + } + + // Store OTP with expiration + expireAt := time.Now().Add(config.OTPExpiry) + if err := s.otpRepo.StoreOTP(phone, otp, expireAt); err != nil { + return fmt.Errorf("failed to store OTP: %w", err) + } + + return nil +} + +// VerifyOTP verifies an OTP and returns user and JWT token on success +func (s *AuthServiceImpl) VerifyOTP(phone, otp string) (*repositories.User, string, error) { + // Validate phone number format + if err := utils.ValidatePhoneFormat(phone); err != nil { + return nil, "", fmt.Errorf("invalid phone format: %w", err) + } + + // Validate OTP format + if err := utils.ValidateOTPFormat(otp); err != nil { + return nil, "", fmt.Errorf("invalid OTP format: %w", err) + } + + // Verify OTP (this consumes the OTP if valid) + if !s.otpRepo.VerifyOTP(phone, otp) { + return nil, "", fmt.Errorf("invalid or expired OTP") + } + + // Create or get user + user, err := s.userRepo.CreateUserIfNotExists(phone) + if err != nil { + return nil, "", fmt.Errorf("failed to create/get user: %w", err) + } + + // Generate JWT token + token, err := s.generateJWTToken(phone) + if err != nil { + return nil, "", fmt.Errorf("failed to generate authentication token: %w", err) + } + + return user, token, nil +} + +// generateJWTToken creates a JWT token for the given phone number +func (s *AuthServiceImpl) generateJWTToken(phone string) (string, error) { + jwtSecret, err := config.GetJWTSecretWithError() + if err != nil { + return "", fmt.Errorf("authentication service configuration error: %w", err) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "phone": phone, + "exp": time.Now().Add(config.DefaultJWTExpiry).Unix(), + }) + + tokenString, err := token.SignedString([]byte(jwtSecret)) + if err != nil { + return "", fmt.Errorf("failed to sign JWT token: %w", err) + } + + return tokenString, nil +} + +// Additional helper methods for extended functionality + +// InvalidateOTP invalidates an OTP for a given phone number +func (s *AuthServiceImpl) InvalidateOTP(phone string) error { + if err := utils.ValidatePhoneFormat(phone); err != nil { + return fmt.Errorf("invalid phone format: %w", err) + } + + return s.otpRepo.InvalidateOTP(phone) +} + +// IsOTPValid checks if an OTP is valid without consuming it (for testing) +func (s *AuthServiceImpl) IsOTPValid(phone, otp string) (bool, error) { + if err := utils.ValidatePhoneFormat(phone); err != nil { + return false, fmt.Errorf("invalid phone format: %w", err) + } + + if err := utils.ValidateOTPFormat(otp); err != nil { + return false, fmt.Errorf("invalid OTP format: %w", err) + } + + // Use the interface method GetOTPInfo which is now implemented by all repositories + storedOTP, expireAt, exists := s.otpRepo.GetOTPInfo(phone) + if !exists { + return false, nil + } + + return storedOTP == otp && time.Now().Before(expireAt), nil +} diff --git a/auth/internal/services/auth_service_test.go b/auth/internal/services/auth_service_test.go new file mode 100644 index 0000000..6e6847e --- /dev/null +++ b/auth/internal/services/auth_service_test.go @@ -0,0 +1,224 @@ +package services + +import ( + "os" + "testing" + "time" + + "auth/internal/repositories" +) + +func init() { + // Set up test environment variables + os.Setenv("SECRET_KEY_BASE", "testsecret") + os.Setenv("TWO_FACTOR_API_KEY", "test") + os.Setenv("OTP_TEMPLATE_NAME", "test") +} + +func TestAuthService(t *testing.T) { + t.Run("constructor", func(t *testing.T) { + userRepo := repositories.NewMockUserRepository() + otpRepo := repositories.NewMockOTPRepository() + + authService := NewAuthService(userRepo, otpRepo) + + if authService == nil { + t.Error("NewAuthService should not return nil") + } + + // Verify it's the correct type + if _, ok := authService.(*AuthServiceImpl); !ok { + t.Error("NewAuthService should return *AuthServiceImpl") + } + }) +} + +func TestAuthService_RequestOTP(t *testing.T) { + userRepo := repositories.NewMockUserRepository() + otpRepo := repositories.NewMockOTPRepository() + authService := NewAuthService(userRepo, otpRepo) + + t.Run("valid phone number", func(t *testing.T) { + otpRepo.Clear() + phone := "+19999999999" // Test phone number that doesn't send real SMS + + err := authService.RequestOTP(phone) + if err != nil { + t.Fatalf("Failed to request OTP: %v", err) + } + + // Verify OTP was stored + if otpRepo.Count() != 1 { + t.Errorf("Expected 1 OTP to be stored, got %d", otpRepo.Count()) + } + + // Verify OTP exists for the phone + _, _, exists := otpRepo.GetOTPInfo(phone) + if !exists { + t.Error("Expected OTP to exist for the phone number") + } + }) + + t.Run("invalid phone format", func(t *testing.T) { + otpRepo.Clear() + invalidPhone := "invalid-phone" + + err := authService.RequestOTP(invalidPhone) + if err == nil { + t.Error("Expected error for invalid phone format") + } + + if otpRepo.Count() != 0 { + t.Errorf("Expected 0 OTPs for invalid phone, got %d", otpRepo.Count()) + } + }) +} + +func TestAuthService_VerifyOTP(t *testing.T) { + userRepo := repositories.NewMockUserRepository() + otpRepo := repositories.NewMockOTPRepository() + authService := NewAuthService(userRepo, otpRepo) + + t.Run("valid OTP verification", func(t *testing.T) { + userRepo.Clear() + otpRepo.Clear() + phone := "+19999999999" // Test phone number + otp := "7415" // Fixed OTP for test phone number + + // Store OTP first + expireAt := time.Now().Add(5 * time.Minute) + err := otpRepo.StoreOTP(phone, otp, expireAt) + if err != nil { + t.Fatalf("Failed to store OTP: %v", err) + } + + // Verify OTP + user, token, err := authService.VerifyOTP(phone, otp) + if err != nil { + t.Fatalf("Failed to verify OTP: %v", err) + } + + if user == nil { + t.Error("Expected user to be returned") + } + + if user.Phone != phone { + t.Errorf("Expected user phone %s, got %s", phone, user.Phone) + } + + if token == "" { + t.Error("Expected JWT token to be returned") + } + + // Verify user was created + if userRepo.Count() != 1 { + t.Errorf("Expected 1 user to be created, got %d", userRepo.Count()) + } + + // Verify OTP was consumed + if otpRepo.Count() != 0 { + t.Errorf("Expected 0 OTPs after verification (consumed), got %d", otpRepo.Count()) + } + }) + + t.Run("invalid OTP", func(t *testing.T) { + userRepo.Clear() + otpRepo.Clear() + phone := "+919876543210" + correctOTP := "1234" + wrongOTP := "0000" + + // Store OTP + expireAt := time.Now().Add(5 * time.Minute) + err := otpRepo.StoreOTP(phone, correctOTP, expireAt) + if err != nil { + t.Fatalf("Failed to store OTP: %v", err) + } + + // Try to verify wrong OTP + user, token, err := authService.VerifyOTP(phone, wrongOTP) + if err == nil { + t.Error("Expected error for invalid OTP") + } + + if user != nil { + t.Error("Expected no user for invalid OTP") + } + + if token != "" { + t.Error("Expected no token for invalid OTP") + } + + // Verify OTP still exists after failed verification + if otpRepo.Count() != 1 { + t.Errorf("Expected 1 OTP after failed verification, got %d", otpRepo.Count()) + } + }) + + t.Run("expired OTP", func(t *testing.T) { + userRepo.Clear() + otpRepo.Clear() + phone := "+919876543210" + otp := "1234" + + // Store expired OTP + expiredTime := time.Now().Add(-1 * time.Minute) + err := otpRepo.StoreOTP(phone, otp, expiredTime) + if err != nil { + t.Fatalf("Failed to store expired OTP: %v", err) + } + + // Try to verify expired OTP + user, token, err := authService.VerifyOTP(phone, otp) + if err == nil { + t.Error("Expected error for expired OTP") + } + + if user != nil { + t.Error("Expected no user for expired OTP") + } + + if token != "" { + t.Error("Expected no token for expired OTP") + } + }) + + t.Run("existing user verification", func(t *testing.T) { + userRepo.Clear() + otpRepo.Clear() + phone := "+919876543210" + otp := "1234" + + // Create user first + existingUser, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + t.Fatalf("Failed to create existing user: %v", err) + } + + // Store OTP + expireAt := time.Now().Add(5 * time.Minute) + err = otpRepo.StoreOTP(phone, otp, expireAt) + if err != nil { + t.Fatalf("Failed to store OTP: %v", err) + } + + // Verify OTP + user, token, err := authService.VerifyOTP(phone, otp) + if err != nil { + t.Fatalf("Failed to verify OTP: %v", err) + } + + if user.ID != existingUser.ID { + t.Errorf("Expected same user ID %d, got %d", existingUser.ID, user.ID) + } + + if token == "" { + t.Error("Expected JWT token to be returned") + } + + // Should still be only 1 user + if userRepo.Count() != 1 { + t.Errorf("Expected 1 user total, got %d", userRepo.Count()) + } + }) +} diff --git a/auth/internal/services/interfaces.go b/auth/internal/services/interfaces.go new file mode 100644 index 0000000..1953eab --- /dev/null +++ b/auth/internal/services/interfaces.go @@ -0,0 +1,24 @@ +package services + +import ( + "auth/internal/repositories" +) + +// AuthService handles authentication business logic +type AuthService interface { + RequestOTP(phone string) error + VerifyOTP(phone, otp string) (*repositories.User, string, error) // user, jwt, error +} + +// UserService handles user management business logic +type UserService interface { + GetUserProfile(phone string) (*repositories.User, error) + UpdateUserProfile(phone, name string) (*repositories.User, error) +} + +// ServiceManager provides access to all services +type ServiceManager struct { + Auth AuthService + User UserService + Repos *repositories.RepositoryManager +} diff --git a/auth/internal/services/user_service.go b/auth/internal/services/user_service.go new file mode 100644 index 0000000..cacd178 --- /dev/null +++ b/auth/internal/services/user_service.go @@ -0,0 +1,182 @@ +package services + +import ( + "database/sql" + "fmt" + "strings" + + "auth/internal/repositories" + "auth/internal/utils" +) + +// UserServiceImpl implements UserService interface +type UserServiceImpl struct { + userRepo repositories.UserRepository +} + +// NewUserService creates a new user service +func NewUserService(userRepo repositories.UserRepository) UserService { + return &UserServiceImpl{ + userRepo: userRepo, + } +} + +// GetUserProfile retrieves a user's profile by phone number +func (s *UserServiceImpl) GetUserProfile(phone string) (*repositories.User, error) { + // Validate phone number format + if err := utils.ValidatePhoneFormat(phone); err != nil { + return nil, fmt.Errorf("invalid phone format: %w", err) + } + + // Get user from repository + user, err := s.userRepo.GetUserByPhone(phone) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("user not found") + } + return nil, fmt.Errorf("failed to retrieve user: %w", err) + } + + return user, nil +} + +// UpdateUserProfile updates a user's profile information +func (s *UserServiceImpl) UpdateUserProfile(phone, name string) (*repositories.User, error) { + // Validate phone number format + if err := utils.ValidatePhoneFormat(phone); err != nil { + return nil, fmt.Errorf("invalid phone format: %w", err) + } + + // Validate and sanitize name + sanitizedName, err := s.validateAndSanitizeName(name) + if err != nil { + return nil, fmt.Errorf("name validation failed: %w", err) + } + + // Update user in repository + user, err := s.userRepo.UpdateUserName(phone, sanitizedName) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("user not found") + } + return nil, fmt.Errorf("failed to update user profile: %w", err) + } + + return user, nil +} + +// validateAndSanitizeName validates and sanitizes the user name +func (s *UserServiceImpl) validateAndSanitizeName(name string) (string, error) { + // Basic validation + if strings.TrimSpace(name) == "" { + return "", fmt.Errorf("name is required") + } + + // Length validation + if len(name) > 100 { + return "", fmt.Errorf("name cannot exceed 100 characters") + } + + // Validate name using existing utils function + if err := utils.ValidateName(name); err != nil { + return "", err + } + + // Sanitize: trim whitespace and normalize internal spaces + sanitized := utils.SanitizeName(name) + + return sanitized, nil +} + +// Additional helper methods for extended functionality + +// CreateUser creates a new user (useful for admin operations) +func (s *UserServiceImpl) CreateUser(phone string) (*repositories.User, error) { + // Validate phone number format + if err := utils.ValidatePhoneFormat(phone); err != nil { + return nil, fmt.Errorf("invalid phone format: %w", err) + } + + // Create user + user, err := s.userRepo.CreateUserIfNotExists(phone) + if err != nil { + return nil, fmt.Errorf("failed to create user: %w", err) + } + + return user, nil +} + +// DeleteUser deletes a user by phone number (useful for admin operations) +func (s *UserServiceImpl) DeleteUser(phone string) error { + // Validate phone number format + if err := utils.ValidatePhoneFormat(phone); err != nil { + return fmt.Errorf("invalid phone format: %w", err) + } + + // Delete user from repository + err := s.userRepo.DeleteUser(phone) + if err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("user not found") + } + return fmt.Errorf("failed to delete user: %w", err) + } + + return nil +} + +// UserExists checks if a user exists by phone number +func (s *UserServiceImpl) UserExists(phone string) (bool, error) { + // Validate phone number format + if err := utils.ValidatePhoneFormat(phone); err != nil { + return false, fmt.Errorf("invalid phone format: %w", err) + } + + // Try to get user + _, err := s.userRepo.GetUserByPhone(phone) + if err != nil { + if err == sql.ErrNoRows { + return false, nil // User doesn't exist, but no error + } + return false, fmt.Errorf("failed to check user existence: %w", err) + } + + return true, nil +} + +// UpdateUserProfilePartial allows partial updates to user profile +func (s *UserServiceImpl) UpdateUserProfilePartial(phone string, updates map[string]interface{}) (*repositories.User, error) { + // Validate phone number format + if err := utils.ValidatePhoneFormat(phone); err != nil { + return nil, fmt.Errorf("invalid phone format: %w", err) + } + + // Get current user + user, err := s.GetUserProfile(phone) + if err != nil { + return nil, err // Error already formatted + } + + // Process updates + if name, exists := updates["name"]; exists { + if nameStr, ok := name.(string); ok { + sanitizedName, err := s.validateAndSanitizeName(nameStr) + if err != nil { + return nil, fmt.Errorf("invalid name: %w", err) + } + + // Update name + user, err = s.userRepo.UpdateUserName(phone, sanitizedName) + if err != nil { + return nil, fmt.Errorf("failed to update user name: %w", err) + } + } else { + return nil, fmt.Errorf("name must be a string") + } + } + + // Future: Add support for other profile fields here + // if email, exists := updates["email"]; exists { ... } + + return user, nil +} diff --git a/auth/internal/services/user_service_test.go b/auth/internal/services/user_service_test.go new file mode 100644 index 0000000..19f0d62 --- /dev/null +++ b/auth/internal/services/user_service_test.go @@ -0,0 +1,189 @@ +package services + +import ( + "testing" + + "auth/internal/repositories" +) + +func TestUserService(t *testing.T) { + t.Run("constructor", func(t *testing.T) { + userRepo := repositories.NewMockUserRepository() + userService := NewUserService(userRepo) + + if userService == nil { + t.Error("NewUserService should not return nil") + } + + // Verify it's the correct type + if _, ok := userService.(*UserServiceImpl); !ok { + t.Error("NewUserService should return *UserServiceImpl") + } + }) +} + +func TestUserService_GetUserProfile(t *testing.T) { + userRepo := repositories.NewMockUserRepository() + userService := NewUserService(userRepo) + + t.Run("existing user", func(t *testing.T) { + userRepo.Clear() + phone := "+919876543210" + + // Create user first + createdUser, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + // Get user profile + user, err := userService.GetUserProfile(phone) + if err != nil { + t.Fatalf("Failed to get user profile: %v", err) + } + + if user == nil { + t.Error("Expected user to be returned") + } + + if user.ID != createdUser.ID { + t.Errorf("Expected user ID %d, got %d", createdUser.ID, user.ID) + } + + if user.Phone != phone { + t.Errorf("Expected phone %s, got %s", phone, user.Phone) + } + }) + + t.Run("non-existent user", func(t *testing.T) { + userRepo.Clear() + phone := "+919876543210" + + // Try to get profile for non-existent user + user, err := userService.GetUserProfile(phone) + if err == nil { + t.Error("Expected error for non-existent user") + } + + if user != nil { + t.Error("Expected no user for non-existent user") + } + }) + + t.Run("invalid phone format", func(t *testing.T) { + invalidPhone := "invalid-phone" + + user, err := userService.GetUserProfile(invalidPhone) + if err == nil { + t.Error("Expected error for invalid phone format") + } + + if user != nil { + t.Error("Expected no user for invalid phone") + } + }) +} + +func TestUserService_UpdateUserProfile(t *testing.T) { + userRepo := repositories.NewMockUserRepository() + userService := NewUserService(userRepo) + + t.Run("valid name update", func(t *testing.T) { + userRepo.Clear() + phone := "+919876543210" + newName := "John Doe" + + // Create user first + _, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + // Update user profile + user, err := userService.UpdateUserProfile(phone, newName) + if err != nil { + t.Fatalf("Failed to update user profile: %v", err) + } + + if user == nil { + t.Error("Expected user to be returned") + } + + if user.Name == nil || *user.Name != newName { + actualName := "" + if user.Name != nil { + actualName = *user.Name + } + t.Errorf("Expected name %s, got %s", newName, actualName) + } + + if user.Phone != phone { + t.Errorf("Expected phone %s, got %s", phone, user.Phone) + } + }) + + t.Run("name sanitization", func(t *testing.T) { + userRepo.Clear() + phone := "+919876543210" + dirtyName := " John Doe " + expectedName := "John Doe" + + // Create user first + _, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + // Update with dirty name + user, err := userService.UpdateUserProfile(phone, dirtyName) + if err != nil { + t.Fatalf("Failed to update user profile: %v", err) + } + + if user.Name == nil || *user.Name != expectedName { + actualName := "" + if user.Name != nil { + actualName = *user.Name + } + t.Errorf("Expected sanitized name %s, got %s", expectedName, actualName) + } + }) + + t.Run("invalid name", func(t *testing.T) { + userRepo.Clear() + phone := "+919876543210" + invalidName := "" // Empty name after sanitization + + // Create user first + _, err := userRepo.CreateUserIfNotExists(phone) + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + // Try to update with invalid name + user, err := userService.UpdateUserProfile(phone, invalidName) + if err == nil { + t.Error("Expected error for invalid name") + } + + if user != nil { + t.Error("Expected no user for invalid name update") + } + }) + + t.Run("non-existent user", func(t *testing.T) { + userRepo.Clear() + phone := "+919876543210" + name := "John Doe" + + // Try to update non-existent user + user, err := userService.UpdateUserProfile(phone, name) + if err == nil { + t.Error("Expected error for non-existent user") + } + + if user != nil { + t.Error("Expected no user for non-existent user update") + } + }) +} diff --git a/auth/internal/storage/postgres.go b/auth/internal/storage/postgres.go new file mode 100644 index 0000000..d438b68 --- /dev/null +++ b/auth/internal/storage/postgres.go @@ -0,0 +1,68 @@ +package storage + +import ( + "database/sql" + "fmt" + "log" + "sync" + + "auth/internal/config" + _ "github.com/lib/pq" +) + +var ( + DB *sql.DB + dbOnce sync.Once +) + +func InitPostgres() { + dbOnce.Do(func() { + dbConfig := config.GetDatabaseConfig() + poolConfig := config.GetPoolConfig() + + psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", + dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.Password, dbConfig.DBName, dbConfig.SSLMode) + + var err error + DB, err = sql.Open("postgres", psqlInfo) + if err != nil { + log.Fatalf("Failed to open database connection: %v", err) + } + + // Test connection + err = DB.Ping() + if err != nil { + log.Fatalf("Failed to ping database: %v", err) + } + + // Set connection pool settings + DB.SetMaxOpenConns(poolConfig.MaxOpenConns) + DB.SetMaxIdleConns(poolConfig.MaxIdleConns) + DB.SetConnMaxLifetime(poolConfig.ConnMaxLifetime) + + log.Printf("Successfully connected to PostgreSQL at %s:%s", dbConfig.Host, dbConfig.Port) + + // Create tables if they don't exist + createTables() + }) +} + +func createTables() { + // Create users table + createUsersTable := ` + CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + phone VARCHAR(20) UNIQUE NOT NULL, + name VARCHAR(100), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + ` + + _, err := DB.Exec(createUsersTable) + if err != nil { + log.Fatalf("Failed to create users table: %v", err) + } + + log.Println("Database tables created/verified successfully") +} diff --git a/auth/internal/storage/redis.go b/auth/internal/storage/redis.go new file mode 100644 index 0000000..369a046 --- /dev/null +++ b/auth/internal/storage/redis.go @@ -0,0 +1,71 @@ +package storage + +import ( + "context" + "log" + "os" + "strconv" + "sync" + "time" + + "github.com/redis/go-redis/v9" +) + +var ( + RedisClient *redis.Client + once sync.Once +) + +func InitRedis() { + once.Do(func() { + host := os.Getenv("REDIS_HOST") + if host == "" { + host = "localhost" + } + + port := os.Getenv("REDIS_PORT") + if port == "" { + port = "6379" + } + + addr := host + ":" + port + password := os.Getenv("REDIS_PASSWORD") + + // Parse Redis DB (defaults to 0) + dbStr := os.Getenv("REDIS_DB") + if dbStr == "" { + dbStr = "0" + } + db, err := strconv.Atoi(dbStr) + if err != nil { + log.Printf("Invalid REDIS_DB value '%s', using default 0", dbStr) + db = 0 + } + + RedisClient = redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + DB: db, + DialTimeout: 10 * time.Second, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + PoolSize: 10, + MinIdleConns: 5, + MaxRetries: 3, + }) + + // Test Redis connection during startup with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + _, pingErr := RedisClient.Ping(ctx).Result() + if pingErr != nil { + log.Fatalf("Failed to connect to Redis at %s: %v", addr, pingErr) + } + log.Printf("Successfully connected to Redis at %s", addr) + }) +} + +func GetContext() context.Context { + return context.Background() +} diff --git a/auth/internal/utils/errors.go b/auth/internal/utils/errors.go new file mode 100644 index 0000000..442da54 --- /dev/null +++ b/auth/internal/utils/errors.go @@ -0,0 +1,70 @@ +package utils + +import ( + "encoding/json" + "net/http" +) + +// ErrorResponse represents a standardized error response +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message"` + Code int `json:"code"` +} + +// Common error types +const ( + ErrorTypeBadRequest = "BAD_REQUEST" + ErrorTypeUnauthorized = "UNAUTHORIZED" + ErrorTypeNotFound = "NOT_FOUND" + ErrorTypeInternalServerError = "INTERNAL_SERVER_ERROR" + ErrorTypeMethodNotAllowed = "METHOD_NOT_ALLOWED" + ErrorTypeConflict = "CONFLICT" +) + +// SendJSONError sends a standardized JSON error response +func SendJSONError(w http.ResponseWriter, message string, statusCode int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + errorType := getErrorType(statusCode) + + response := ErrorResponse{ + Error: errorType, + Message: message, + Code: statusCode, + } + + json.NewEncoder(w).Encode(response) +} + +// SendJSONResponse sends a standardized JSON success response +func SendJSONResponse(w http.ResponseWriter, data interface{}) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(data) +} + +// DecodeJSONRequest decodes JSON request body into the provided interface +func DecodeJSONRequest(r *http.Request, v interface{}) error { + return json.NewDecoder(r.Body).Decode(v) +} + +// getErrorType maps HTTP status codes to error types +func getErrorType(statusCode int) string { + switch statusCode { + case http.StatusBadRequest: + return ErrorTypeBadRequest + case http.StatusUnauthorized: + return ErrorTypeUnauthorized + case http.StatusNotFound: + return ErrorTypeNotFound + case http.StatusMethodNotAllowed: + return ErrorTypeMethodNotAllowed + case http.StatusConflict: + return ErrorTypeConflict + case http.StatusInternalServerError: + return ErrorTypeInternalServerError + default: + return "ERROR" + } +} diff --git a/auth/internal/utils/otp.go b/auth/internal/utils/otp.go new file mode 100644 index 0000000..57a3bf4 --- /dev/null +++ b/auth/internal/utils/otp.go @@ -0,0 +1,73 @@ +package utils + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "strings" + "time" +) + +type sendOTPResponse struct { + Status string `json:"Status"` + Details string `json:"Details"` + Message string `json:"Message"` +} + +func GenerateOTP() string { + return fmt.Sprintf("%04d", time.Now().UnixNano()%10000) +} + +func GenerateOTPForPhone(phone string) string { + // Test phone number with fixed OTP (using US country code 1) + if phone == "+19999999999" { + return "7415" + } + return GenerateOTP() +} + +var SendOTPWith2Factor = func(phone, otp string) error { + // Handle test phone number (using US country code 1) + if phone == "+19999999999" { + return nil + } + + apiKey := os.Getenv("TWO_FACTOR_API_KEY") + templateName := os.Getenv("OTP_TEMPLATE_NAME") + if apiKey == "" || templateName == "" { + err := fmt.Errorf("2factor config missing: TWO_FACTOR_API_KEY or OTP_TEMPLATE_NAME not set") + log.Printf("OTP send failed for phone %s: %v", phone, err) + return err + } + + // Remove + prefix if present for API call (matching Elixir implementation) + phoneForAPI := strings.TrimPrefix(phone, "+") + + // Use custom OTP endpoint with template name from env + url := fmt.Sprintf("https://2factor.in/API/V1/%s/SMS/%s/%s/%s", apiKey, phoneForAPI, otp, templateName) + + // Make GET request (matching Elixir implementation) + resp, err := http.Get(url) + if err != nil { + log.Printf("OTP send failed for phone %s: HTTP request error: %v", phone, err) + return err + } + defer resp.Body.Close() + + var result sendOTPResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + log.Printf("OTP send failed for phone %s: JSON decode error: %v", phone, err) + return err + } + + if resp.StatusCode != 200 || result.Status != "Success" { + err := fmt.Errorf("2factor send failed: %s", result.Message) + log.Printf("OTP send failed for phone %s: 2Factor API error - Status Code: %d, Status: %s, Message: %s, Details: %s", + phone, resp.StatusCode, result.Status, result.Message, result.Details) + return err + } + + return nil +} diff --git a/auth/internal/utils/otp_test.go b/auth/internal/utils/otp_test.go new file mode 100644 index 0000000..3180b19 --- /dev/null +++ b/auth/internal/utils/otp_test.go @@ -0,0 +1,56 @@ +package utils + +import ( + "regexp" + "testing" +) + +// TestGenerateOTPForPhone tests OTP generation +func TestGenerateOTPForPhone(t *testing.T) { + phones := []string{ + "+919876543210", + "+15551234567", + "+447700900123", + } + + for _, phone := range phones { + t.Run("OTP generation for "+phone, func(t *testing.T) { + otp := GenerateOTPForPhone(phone) + + // Test OTP format (4 digits) + if !regexp.MustCompile(`^\d{4}$`).MatchString(otp) { + t.Errorf("Generated OTP '%s' does not match 4-digit format", otp) + } + + // Test OTP length + if len(otp) != 4 { + t.Errorf("Expected OTP length 4, got %d", len(otp)) + } + }) + } +} + +// TestOTPRange tests that generated OTPs are within valid range +func TestOTPRange(t *testing.T) { + phones := []string{ + "+919876543210", + "+15551234567", + "+447700900123", + } + + for _, phone := range phones { + otp := GenerateOTPForPhone(phone) + + // Check that all characters are digits + for _, digit := range otp { + if digit < '0' || digit > '9' { + t.Errorf("Invalid digit in OTP %s: %c", otp, digit) + } + } + + // Check it's in valid 4-digit range (0000-9999) + if !regexp.MustCompile(`^\d{4}$`).MatchString(otp) { + t.Errorf("OTP %s is not a valid 4-digit number", otp) + } + } +} diff --git a/auth/internal/utils/validation.go b/auth/internal/utils/validation.go new file mode 100644 index 0000000..0cf0dc5 --- /dev/null +++ b/auth/internal/utils/validation.go @@ -0,0 +1,57 @@ +package utils + +import ( + "fmt" + "regexp" + "strings" + + "auth/internal/config" +) + +// Phone validation +func ValidatePhoneFormat(phone string) error { + // E.164 format validation: + followed by country code (1-3 digits) and subscriber number + // Minimum realistic length is 8 digits total (e.g., +1234567), maximum is 15 + matched, _ := regexp.MatchString(`^\+[1-9]\d{6,14}$`, phone) + if !matched { + return fmt.Errorf("phone number must be in E.164 format (e.g., +919876543210)") + } + return nil +} + +// OTP validation +func ValidateOTPFormat(otp string) error { + pattern := fmt.Sprintf(`^\d{%d}$`, config.OTPLength) + matched, _ := regexp.MatchString(pattern, otp) + if !matched { + return fmt.Errorf("OTP must be exactly %d digits", config.OTPLength) + } + return nil +} + +// Name validation +func ValidateName(name string) error { + name = strings.TrimSpace(name) + if name == "" { + return fmt.Errorf("name cannot be empty") + } + + if len(name) < config.MinNameLength { + return fmt.Errorf("name must be at least %d character", config.MinNameLength) + } + + if len(name) > config.MaxNameLength { + return fmt.Errorf("name cannot be longer than %d characters", config.MaxNameLength) + } + + return nil +} + +// SanitizeName trims whitespace and collapses multiple spaces +func SanitizeName(name string) string { + // Trim leading and trailing whitespace + name = strings.TrimSpace(name) + // Collapse multiple internal spaces into single spaces + re := regexp.MustCompile(`\s+`) + return re.ReplaceAllString(name, " ") +} diff --git a/auth/internal/utils/validation_test.go b/auth/internal/utils/validation_test.go new file mode 100644 index 0000000..251246c --- /dev/null +++ b/auth/internal/utils/validation_test.go @@ -0,0 +1,193 @@ +package utils + +import ( + "testing" +) + +// TestValidatePhoneFormat tests phone number validation +func TestValidatePhoneFormat(t *testing.T) { + tests := []struct { + name string + phone string + wantErr bool + }{ + { + name: "valid E.164 format - India", + phone: "+919876543210", + wantErr: false, + }, + { + name: "valid E.164 format - US", + phone: "+15551234567", + wantErr: false, + }, + { + name: "missing plus sign", + phone: "919876543210", + wantErr: true, + }, + { + name: "too short", + phone: "+123456", + wantErr: true, + }, + { + name: "contains letters", + phone: "+91abc7543210", + wantErr: true, + }, + { + name: "empty string", + phone: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidatePhoneFormat(tt.phone) + if (err != nil) != tt.wantErr { + t.Errorf("ValidatePhoneFormat(%s) error = %v, wantErr %v", tt.phone, err, tt.wantErr) + } + }) + } +} + +// TestValidateOTPFormat tests OTP validation +func TestValidateOTPFormat(t *testing.T) { + tests := []struct { + name string + otp string + wantErr bool + }{ + { + name: "valid 4-digit OTP", + otp: "1234", + wantErr: false, + }, + { + name: "valid OTP starting with zero", + otp: "0123", + wantErr: false, + }, + { + name: "3-digit OTP", + otp: "123", + wantErr: true, + }, + { + name: "5-digit OTP", + otp: "12345", + wantErr: true, + }, + { + name: "OTP with letters", + otp: "12a4", + wantErr: true, + }, + { + name: "empty OTP", + otp: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateOTPFormat(tt.otp) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateOTPFormat(%s) error = %v, wantErr %v", tt.otp, err, tt.wantErr) + } + }) + } +} + +// TestValidateName tests name validation +func TestValidateName(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + { + name: "valid name", + input: "John Doe", + wantErr: false, + }, + { + name: "single character name", + input: "A", + wantErr: false, + }, + { + name: "empty name", + input: "", + wantErr: true, + }, + { + name: "only spaces", + input: " ", + wantErr: true, + }, + { + name: "name too long", + input: string(make([]byte, 101)), // 101 characters + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Fill the byte slices with 'A' for length tests + if len(tt.input) > 50 { + for i := range []byte(tt.input) { + []byte(tt.input)[i] = 'A' + } + } + + err := ValidateName(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateName(%s) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +// TestSanitizeName tests name sanitization +func TestSanitizeName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "already clean name", + input: "John Doe", + expected: "John Doe", + }, + { + name: "name with leading and trailing spaces", + input: " John Doe ", + expected: "John Doe", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "only whitespace", + input: " \t\n ", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeName(tt.input) + if result != tt.expected { + t.Errorf("SanitizeName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/auth/sample.env b/auth/sample.env new file mode 100644 index 0000000..2fc0bb7 --- /dev/null +++ b/auth/sample.env @@ -0,0 +1,43 @@ +# Auth Service Environment Configuration +# Copy this file to .env and update with your actual values +# NOTE: Uses same variable names as scribbl_backend for consistency + +# Secret key for JWT signing (same as scribbl_backend) +SECRET_KEY_BASE=change-this-to-a-strong-random-secret-at-least-32-chars-long + +# Redis Configuration (same structure as scribbl_backend) +# Redis is used for temporary OTP storage +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_DB=0 +REDIS_PASSWORD= + +# PostgreSQL Configuration +# PostgreSQL is used for persistent user data storage +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_USER=postgres +POSTGRES_PASSWORD=your_postgres_password +POSTGRES_DB=auth_db +POSTGRES_SSLMODE=disable + +# 2Factor.in SMS API Configuration (shared with scribbl_backend) +TWO_FACTOR_API_KEY=your_2factor_api_key_here +OTP_TEMPLATE_NAME=your_registered_template_name + +# Application Configuration +PORT=8080 + +# Rate Limiting Configuration +# Maximum number of OTP requests allowed per minute per phone number +RATE_LIMIT_PER_MINUTE=5 + +# CORS Configuration - comma-separated list of allowed origins (same as scribbl_backend) +# For production, set to your actual frontend domains +CORS_ALLOWED_ORIGINS=https://yourdomain.com,https://www.yourdomain.com + +# Environment (development, staging, production) +APP_ENV=production + +# Logging Level (DEBUG, INFO, WARN, ERROR) +LOG_LEVEL=INFO \ No newline at end of file diff --git a/auth/test/integration/auth_integration_test.go b/auth/test/integration/auth_integration_test.go new file mode 100644 index 0000000..c6878a7 --- /dev/null +++ b/auth/test/integration/auth_integration_test.go @@ -0,0 +1,231 @@ +// auth_integration_test.go - Integration tests for the complete authentication API +// +// These tests verify the entire authentication flow from HTTP request to response, +// testing the service as a black box without internal package dependencies. +// +// To run: go test -v ./test/integration (Redis and proper environment must be available) + +package integration + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "auth/internal/container" + "auth/internal/handlers" + + "github.com/joho/godotenv" +) + +// TestMain initializes the test environment +func TestMain(m *testing.M) { + // Load test environment + if err := godotenv.Load("../../sample.env"); err != nil { + // It's okay if sample.env doesn't exist + } + + // Set required test environment variables if not already set + if os.Getenv("SECRET_KEY_BASE") == "" { + os.Setenv("SECRET_KEY_BASE", "test_secret_key_for_integration_testing") + } + if os.Getenv("TWO_FACTOR_API_KEY") == "" { + os.Setenv("TWO_FACTOR_API_KEY", "test_api_key") + } + if os.Getenv("OTP_TEMPLATE_NAME") == "" { + os.Setenv("OTP_TEMPLATE_NAME", "test_template") + } + + // Run tests + code := m.Run() + os.Exit(code) +} + +// setupTestServer creates a test server with the new container architecture +func setupTestServer() *httptest.Server { + // Create container with auto-detection (will use Redis if available, fallback to mocks) + appContainer, err := container.CreateAutoDetectedContainer() + if err != nil { + panic("Failed to create test container: " + err.Error()) + } + + // Create test server with container-based handlers + mux := http.NewServeMux() + + // Use the new SetupRoutes helper + handlers.SetupRoutes(mux, appContainer) + + server := httptest.NewServer(mux) + + // Store container reference for cleanup + server.Config.Handler = &testHandlerWithCleanup{ + Handler: mux, + Container: appContainer, + } + + return server +} + +// testHandlerWithCleanup wraps the handler and container for proper cleanup +type testHandlerWithCleanup struct { + http.Handler + Container container.ContainerInterface +} + +func (t *testHandlerWithCleanup) ServeHTTP(w http.ResponseWriter, r *http.Request) { + t.Handler.ServeHTTP(w, r) +} + +func clearTestData(c container.ContainerInterface) { + // Try to clear test data if this is a test container + if closer, ok := c.(interface{ ClearAllData() }); ok { + closer.ClearAllData() + } +} + +// TestCompleteAuthFlow tests the complete authentication workflow +func TestCompleteAuthFlow(t *testing.T) { + server := setupTestServer() + defer server.Close() + client := server.Client() + + // Clear test data + if wrapper, ok := server.Config.Handler.(*testHandlerWithCleanup); ok { + clearTestData(wrapper.Container) + defer wrapper.Container.Shutdown() + } + + phone := "+19999999999" // Use test phone number + + t.Run("request OTP", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"phone": phone}) + resp, err := client.Post(server.URL+"/auth/request-otp", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to request OTP: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("verify OTP and get JWT", func(t *testing.T) { + // Use the known test OTP for the test phone number + testOTP := "7415" // Fixed OTP for test phone +19999999999 + + // Verify with the known test OTP + body, _ := json.Marshal(map[string]string{"phone": phone, "otp": testOTP}) + resp, err := client.Post(server.URL+"/auth/verify-otp", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to verify OTP: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected 200, got %d", resp.StatusCode) + } + + var response map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if response["token"] == nil { + t.Fatalf("Expected JWT token in response") + } + }) +} + +// TestUserManagementFlow tests the user management endpoints +func TestUserManagementFlow(t *testing.T) { + server := setupTestServer() + defer server.Close() + client := server.Client() + + // Clear test data + if wrapper, ok := server.Config.Handler.(*testHandlerWithCleanup); ok { + clearTestData(wrapper.Container) + defer wrapper.Container.Shutdown() + } + + phone := "+19999999999" // Use test phone number + testOTP := "7415" // Fixed OTP for test phone + var tokenString string // JWT token for authenticated requests + + // First, go through the complete auth flow to create the user and get a real JWT token + t.Run("setup user through auth flow", func(t *testing.T) { + // Request OTP + body, _ := json.Marshal(map[string]string{"phone": phone}) + resp, err := client.Post(server.URL+"/auth/request-otp", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to request OTP: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected 200 for OTP request, got %d", resp.StatusCode) + } + + // Verify OTP to create user and get token + body, _ = json.Marshal(map[string]string{"phone": phone, "otp": testOTP}) + resp, err = client.Post(server.URL+"/auth/verify-otp", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to verify OTP: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected 200 for OTP verification, got %d", resp.StatusCode) + } + + var response map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if response["token"] == nil { + t.Fatalf("Expected JWT token in response") + } + + // Store the token for the next tests + tokenString = response["token"].(string) + }) + + t.Run("get user profile", func(t *testing.T) { + req, _ := http.NewRequest("GET", server.URL+"/auth/user", nil) + req.Header.Set("Authorization", "Bearer "+tokenString) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to get user: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("update user name", func(t *testing.T) { + updateData := map[string]string{"name": "John Doe"} + jsonData, _ := json.Marshal(updateData) + + req, _ := http.NewRequest("PUT", server.URL+"/auth/user/update", bytes.NewReader(jsonData)) + req.Header.Set("Authorization", "Bearer "+tokenString) + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to update user: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected 200, got %d", resp.StatusCode) + } + }) +} diff --git a/deployment/Caddyfile b/deployment/Caddyfile index 5a65ab0..108435c 100644 --- a/deployment/Caddyfile +++ b/deployment/Caddyfile @@ -1,3 +1,7 @@ api.{$DOMAIN_NAME} { reverse_proxy scribbl_backend_1:4000 scribbl_backend_2:4000 scribbl_backend_3:4000 scribbl_backend_4:4000 +} + +auth.{$DOMAIN_NAME} { + reverse_proxy auth:8080 } \ No newline at end of file diff --git a/deployment/docker-compose.yaml b/deployment/docker-compose.yaml index 4b1c988..c61456f 100644 --- a/deployment/docker-compose.yaml +++ b/deployment/docker-compose.yaml @@ -8,58 +8,90 @@ services: restart: always volumes: - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 30s + + auth: + image: ghcr.io/singhalkarun/scribbl-auth:${AUTH_IMAGE_TAG:-main} + environment: + SECRET_KEY_BASE: ${SECRET_KEY_BASE} + REDIS_HOST: ${REDIS_HOST} + REDIS_PORT: ${REDIS_PORT} + REDIS_DB: ${AUTH_REDIS_DB} + TWO_FACTOR_API_KEY: ${TWO_FACTOR_API_KEY} + OTP_TEMPLATE_NAME: ${OTP_TEMPLATE_NAME} + CORS_ALLOWED_ORIGINS: ${CORS_ALLOWED_ORIGINS} + PORT: 8080 + APP_ENV: production + LOG_LEVEL: INFO + RATE_LIMIT_PER_MINUTE: ${RATE_LIMIT_PER_MINUTE:-5} + restart: always + depends_on: + redis: + condition: service_healthy + healthcheck: + test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s scribbl_backend_1: - image: ghcr.io/singhalkarun/scribbl-backend:main + image: ghcr.io/singhalkarun/scribbl-backend:${BACKEND_IMAGE_TAG:-main} environment: DATABASE_URL: postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:5432/${POSTGRES_DB} SECRET_KEY_BASE: ${SECRET_KEY_BASE} NODE_NAME: "1" REDIS_HOST: ${REDIS_HOST} REDIS_PORT: ${REDIS_PORT} - REDIS_DB: ${REDIS_DB} + REDIS_DB: ${SCRIBBL_REDIS_DB} CORS_ALLOWED_ORIGINS: ${CORS_ALLOWED_ORIGINS} restart: always depends_on: - postgres + scribbl_backend_2: - image: ghcr.io/singhalkarun/scribbl-backend:main + image: ghcr.io/singhalkarun/scribbl-backend:${BACKEND_IMAGE_TAG:-main} environment: DATABASE_URL: postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:5432/${POSTGRES_DB} SECRET_KEY_BASE: ${SECRET_KEY_BASE} NODE_NAME: "2" REDIS_HOST: ${REDIS_HOST} REDIS_PORT: ${REDIS_PORT} - REDIS_DB: ${REDIS_DB} + REDIS_DB: ${SCRIBBL_REDIS_DB} CORS_ALLOWED_ORIGINS: ${CORS_ALLOWED_ORIGINS} restart: always depends_on: - postgres scribbl_backend_3: - image: ghcr.io/singhalkarun/scribbl-backend:main + image: ghcr.io/singhalkarun/scribbl-backend:${BACKEND_IMAGE_TAG:-main} environment: DATABASE_URL: postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:5432/${POSTGRES_DB} SECRET_KEY_BASE: ${SECRET_KEY_BASE} NODE_NAME: "3" REDIS_HOST: ${REDIS_HOST} REDIS_PORT: ${REDIS_PORT} - REDIS_DB: ${REDIS_DB} + REDIS_DB: ${SCRIBBL_REDIS_DB} CORS_ALLOWED_ORIGINS: ${CORS_ALLOWED_ORIGINS} restart: always depends_on: - postgres scribbl_backend_4: - image: ghcr.io/singhalkarun/scribbl-backend:main + image: ghcr.io/singhalkarun/scribbl-backend:${BACKEND_IMAGE_TAG:-main} environment: DATABASE_URL: postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:5432/${POSTGRES_DB} SECRET_KEY_BASE: ${SECRET_KEY_BASE} NODE_NAME: "4" REDIS_HOST: ${REDIS_HOST} REDIS_PORT: ${REDIS_PORT} - REDIS_DB: ${REDIS_DB} + REDIS_DB: ${SCRIBBL_REDIS_DB} CORS_ALLOWED_ORIGINS: ${CORS_ALLOWED_ORIGINS} restart: always depends_on: @@ -81,6 +113,7 @@ services: - scribbl_backend_2 - scribbl_backend_3 - scribbl_backend_4 + - auth redis: image: redis:7.4.3-alpine @@ -88,6 +121,12 @@ services: command: redis-server --notify-keyspace-events Ex volumes: - redis_data:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 3s + retries: 3 + start_period: 5s volumes: pgdata: diff --git a/deployment/sample.env b/deployment/sample.env index 26b62e2..65d0468 100644 --- a/deployment/sample.env +++ b/deployment/sample.env @@ -1,13 +1,31 @@ -POSTGRES_USER= -POSTGRES_PASSWORD= -POSTGRES_DB= -POSTGRES_HOST= +# Database Configuration +POSTGRES_USER=scribbl_user +POSTGRES_PASSWORD=your_secure_postgres_password +POSTGRES_DB=scribbl_db +POSTGRES_HOST=postgres -SECRET_KEY_BASE= +# Application Security +SECRET_KEY_BASE=change-this-to-a-strong-random-secret-at-least-64-chars-long-for-production -DOMAIN_NAME= +# Domain Configuration +DOMAIN_NAME=scribbl. -REDIS_URL= +# Image Tags Configuration +# Specify which image tags to use for deployment +# Leave unset to use 'main' as default +AUTH_IMAGE_TAG=main +BACKEND_IMAGE_TAG=main + +# Redis Configuration +REDIS_HOST=redis +REDIS_PORT=6379 +AUTH_REDIS_DB=0 +SCRIBBL_REDIS_DB=1 + +# Auth Service Configuration +TWO_FACTOR_API_KEY=your_2factor_api_key_here +OTP_TEMPLATE_NAME=your_registered_template_name +RATE_LIMIT_PER_MINUTE=5 # CORS Configuration - comma-separated list of allowed origins CORS_ALLOWED_ORIGINS=https://yourdomain.com,https://www.yourdomain.com \ No newline at end of file