diff --git a/SECURITY_APIS_HEIMDALL.md b/SECURITY_APIS_HEIMDALL.md new file mode 100644 index 00000000000..a98348c88ae --- /dev/null +++ b/SECURITY_APIS_HEIMDALL.md @@ -0,0 +1,471 @@ +# Security APIs and Heimdall Auto-Enforcement + +This document describes the enhanced security APIs, automated enforcement system, and Heimdall integration for real-time blocking and protection. + +## Overview + +The security system now includes: +- **Device Fingerprinting** - Track and cluster suspicious devices +- **IP Clustering** - Identify and monitor suspicious IP addresses +- **Anomaly Detection** - Detect and track security anomalies with severity levels +- **Automated Enforcement** - Auto-trigger defensive actions (block/redirect/ban) +- **Heimdall Middleware** - Real-time enforcement of security directives +- **Manual Review** - Approve or ignore anomalies with audit trail + +## Architecture + +``` +┌─────────────────┐ +│ User Request │ +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ TokenAuth │ ◄─── Authenticate user +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ Heimdall │ ◄─── Check blocklist (Redis) +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ Governance │ ◄─── Detect violations +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ Create Anomaly │ ◄─── Log security event +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ Auto-Enforce │ ◄─── Trigger action if high severity +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ Update Redis │ ◄─── Publish directive to Heimdall +└─────────────────┘ +``` + +## API Endpoints + +### Dashboard + +**GET /api/security/dashboard** + +Returns comprehensive security metrics including: +- Total violations +- Unique users with violations +- Today's violation count +- Top keywords matched +- Daily trends +- Device clusters (top 10 suspicious) +- Suspicious IPs (top 10) +- Anomaly counts by severity +- Anomaly trends + +Query Parameters: +- `start_time` (optional): RFC3339 timestamp (default: 7 days ago) +- `end_time` (optional): RFC3339 timestamp (default: now) + +Response: +```json +{ + "success": true, + "data": { + "total_count": 150, + "unique_users": 25, + "today_count": 10, + "top_keywords": [...], + "daily_trend": [...], + "device_clusters": [...], + "suspicious_ips": [...], + "anomaly_counts": { + "malicious": 15, + "violation": 135 + }, + "anomaly_trends": [...] + } +} +``` + +### Device Clusters + +**GET /api/security/devices** + +Returns device fingerprint clusters with risk scores. + +Query Parameters: +- `page` (default: 1) +- `page_size` (default: 10) +- `blocked` (optional): "true" or "false" + +Response: +```json +{ + "success": true, + "data": { + "devices": [ + { + "id": 1, + "user_id": 123, + "fingerprint": "abc123...", + "user_agent": "Mozilla/5.0...", + "ip_address": "192.168.1.1", + "first_seen_at": "2024-01-01T00:00:00Z", + "last_seen_at": "2024-01-02T00:00:00Z", + "request_count": 150, + "is_blocked": false, + "risk_score": 75 + } + ], + "total": 100, + "page": 1, + "page_size": 10 + } +} +``` + +### IP Clusters + +**GET /api/security/ip-clusters** + +Returns IP clusters with suspicious activity metrics. + +Query Parameters: +- `page` (default: 1) +- `page_size` (default: 10) +- `blocked` (optional): "true" or "false" +- `min_risk_score` (default: 0) + +Response: +```json +{ + "success": true, + "data": { + "clusters": [ + { + "id": 1, + "ip_address": "192.168.1.1", + "country": "US", + "city": "San Francisco", + "unique_users": 50, + "total_requests": 1000, + "violation_count": 10, + "is_blocked": false, + "risk_score": 60, + "first_seen_at": "2024-01-01T00:00:00Z", + "last_seen_at": "2024-01-02T00:00:00Z" + } + ], + "total": 50, + "page": 1, + "page_size": 10 + } +} +``` + +### Security Anomalies + +**GET /api/security/anomalies** + +Returns detected anomalies with filtering options. + +Query Parameters: +- `page` (default: 1) +- `page_size` (default: 10) +- `user_id` (optional): Filter by user +- `severity` (optional): "malicious" or "violation" +- `anomaly_type` (optional): Type of anomaly +- `status` (optional): "pending", "actioned", "approved", "ignored" +- `start_time` (optional): RFC3339 timestamp +- `end_time` (optional): RFC3339 timestamp + +Response: +```json +{ + "success": true, + "data": { + "anomalies": [ + { + "id": 1, + "user_id": 123, + "token_id": 456, + "detected_at": "2024-01-01T12:00:00Z", + "anomaly_type": "high_rpm", + "severity": "malicious", + "description": "Excessive request rate detected", + "metadata": "{\"rpm\":150}", + "ip_address": "192.168.1.1", + "device_id": "device-fp-123", + "risk_score": 85, + "action_taken": "ban", + "actioned_at": "2024-01-01T12:00:01Z", + "status": "actioned", + "reviewed_by": null, + "reviewed_at": null, + "review_decision": "", + "review_rationale": "" + } + ], + "total": 100, + "page": 1, + "page_size": 10 + } +} +``` + +### Manual Override + +**POST /api/security/anomalies/:id/approve** + +Approve an anomaly (mark as legitimate activity). + +Request Body: +```json +{ + "rationale": "Verified user, false positive" +} +``` + +Response: +```json +{ + "success": true, + "message": "Anomaly approved successfully" +} +``` + +**POST /api/security/anomalies/:id/ignore** + +Ignore an anomaly and rollback any actions taken. + +Request Body: +```json +{ + "rationale": "Testing activity, can be ignored" +} +``` + +Response: +```json +{ + "success": true, + "message": "Anomaly ignored successfully" +} +``` + +## Automated Enforcement + +### Configuration + +Update security settings to enable auto-enforcement: + +**PUT /api/security/settings** + +```json +{ + "auto_enforcement_enabled": true, + "auto_ban_enabled": true, + "auto_block_enabled": true, + "auto_ban_threshold": 10, + "violation_redirect_model": "gpt-3.5-turbo" +} +``` + +### Enforcement Actions + +Based on severity and risk score, the system automatically takes actions: + +| Risk Score | Action | Description | +|------------|------------|----------------------------------------------| +| > 80 | Ban | User banned, all requests blocked | +| 51-80 | Block | Temporary block via Heimdall | +| 31-50 | Redirect | Requests redirected to safe model | +| ≤ 30 | Log | Only logged, no enforcement | + +### Heimdall Integration + +Heimdall middleware enforces security directives in real-time: + +1. **Detection**: Governance middleware detects violation +2. **Anomaly Creation**: Security anomaly is created with severity +3. **Auto-Enforcement**: If severity is "malicious", trigger action +4. **Redis Directive**: Directive published to Redis +5. **Heimdall Check**: Next request checked against blocklist +6. **Block/Allow**: Request either blocked or allowed to proceed + +Redis Keys: +- `heimdall:directive:{user_id}` - User-specific directive +- Channel: `heimdall:directives` - Pub/sub for real-time updates + +## Database Schema + +### device_fingerprints + +```sql +CREATE TABLE device_fingerprints ( + id INT PRIMARY KEY AUTO_INCREMENT, + user_id INT NOT NULL, + fingerprint VARCHAR(256) NOT NULL, + user_agent TEXT, + ip_address VARCHAR(45), + first_seen_at TIMESTAMP NOT NULL, + last_seen_at TIMESTAMP NOT NULL, + request_count INT DEFAULT 0, + is_blocked BOOLEAN DEFAULT FALSE, + risk_score INT DEFAULT 0, + INDEX idx_user_id (user_id), + INDEX idx_fingerprint (fingerprint), + INDEX idx_ip_address (ip_address), + INDEX idx_is_blocked (is_blocked) +); +``` + +### ip_clusters + +```sql +CREATE TABLE ip_clusters ( + id INT PRIMARY KEY AUTO_INCREMENT, + ip_address VARCHAR(45) UNIQUE NOT NULL, + country VARCHAR(2), + city VARCHAR(100), + unique_users INT DEFAULT 0, + total_requests INT DEFAULT 0, + violation_count INT DEFAULT 0, + is_blocked BOOLEAN DEFAULT FALSE, + risk_score INT DEFAULT 0, + first_seen_at TIMESTAMP NOT NULL, + last_seen_at TIMESTAMP NOT NULL, + INDEX idx_is_blocked (is_blocked) +); +``` + +### security_anomalies + +```sql +CREATE TABLE security_anomalies ( + id INT PRIMARY KEY AUTO_INCREMENT, + user_id INT NOT NULL, + token_id INT, + detected_at TIMESTAMP NOT NULL, + anomaly_type VARCHAR(50), + severity VARCHAR(20), + description TEXT, + metadata TEXT, + ip_address VARCHAR(45), + device_id VARCHAR(256), + risk_score INT DEFAULT 0, + action_taken VARCHAR(50), + actioned_at TIMESTAMP, + status VARCHAR(20) DEFAULT 'pending', + reviewed_by INT, + reviewed_at TIMESTAMP, + review_decision VARCHAR(20), + review_rationale TEXT, + INDEX idx_user_id (user_id), + INDEX idx_token_id (token_id), + INDEX idx_detected_at (detected_at), + INDEX idx_severity (severity), + INDEX idx_status (status), + INDEX idx_ip_address (ip_address), + INDEX idx_device_id (device_id) +); +``` + +## Testing + +### Unit Tests + +Run security enforcement tests: +```bash +go test ./service -run TestSecurity +go test ./controller -run TestSecurity +``` + +### Integration Tests + +Test the full workflow: +```bash +go test ./middleware -run TestHeimdall +``` + +### Manual Testing + +1. **Create an anomaly**: +```bash +curl -X POST http://localhost:3000/api/security/anomalies \ + -H "Authorization: Bearer $ADMIN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "user_id": 123, + "anomaly_type": "test", + "severity": "malicious", + "description": "Test anomaly", + "risk_score": 90 + }' +``` + +2. **Verify enforcement**: Try making a request as that user - should be blocked + +3. **Check dashboard**: +```bash +curl http://localhost:3000/api/security/dashboard \ + -H "Authorization: Bearer $ADMIN_TOKEN" +``` + +4. **Review and ignore**: +```bash +curl -X POST http://localhost:3000/api/security/anomalies/1/ignore \ + -H "Authorization: Bearer $ADMIN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"rationale": "False positive"}' +``` + +## Monitoring + +### Key Metrics + +Monitor these Redis keys: +- `heimdall:directive:*` - Active enforcement directives +- Channel `heimdall:directives` - Real-time directive stream + +### Logs + +Check logs for: +- `heimdall block` - User blocked by Heimdall +- `governance flag triggered` - Violation detected +- `security action taken` - Enforcement action executed + +## Best Practices + +1. **Tune Risk Scores**: Adjust risk score thresholds based on your use case +2. **Review Regularly**: Check pending anomalies and approve/ignore as needed +3. **Monitor False Positives**: Track approval/ignore ratios +4. **Configure Redis**: Ensure Redis is enabled for real-time enforcement +5. **Audit Trail**: Review `reviewed_by` and `review_rationale` fields +6. **Alert on High Severity**: Set up notifications for malicious anomalies + +## Troubleshooting + +### Heimdall not blocking users + +1. Check Redis is enabled: `REDIS_CONN_STRING` environment variable +2. Verify directive in Redis: `redis-cli GET heimdall:directive:{user_id}` +3. Check middleware is loaded in relay router + +### Auto-enforcement not triggering + +1. Verify settings: `GET /api/security/settings` +2. Check anomaly severity is "malicious" +3. Ensure `auto_enforcement_enabled` is true + +### False positives + +1. Review anomaly metadata for detection details +2. Adjust governance keyword policies +3. Use manual ignore to prevent future similar blocks +4. Consider increasing risk score thresholds diff --git a/common/constants.go b/common/constants.go index 832ef6bb558..c329a2c1df6 100644 --- a/common/constants.go +++ b/common/constants.go @@ -226,7 +226,9 @@ const ( const ( // Security Center Options - OptionViolationRedirectModel = "ViolationRedirectModel" - OptionAutobanEnabled = "AutobanEnabled" - OptionAutobanThreshold = "AutobanThreshold" + OptionViolationRedirectModel = "ViolationRedirectModel" + OptionAutobanEnabled = "AutobanEnabled" + OptionAutobanThreshold = "AutobanThreshold" + OptionAutoEnforcementEnabled = "AutoEnforcementEnabled" + OptionAutoBlockEnabled = "AutoBlockEnabled" ) diff --git a/common/redis.go b/common/redis.go index c72878378fc..b09adb236a6 100644 --- a/common/redis.go +++ b/common/redis.go @@ -1,327 +1,335 @@ package common import ( - "context" - "errors" - "fmt" - "os" - "reflect" - "strconv" - "time" - - "github.com/go-redis/redis/v8" - "gorm.io/gorm" + "context" + "errors" + "fmt" + "os" + "reflect" + "strconv" + "time" + + "github.com/go-redis/redis/v8" + "gorm.io/gorm" ) var RDB *redis.Client var RedisEnabled = true func RedisKeyCacheSeconds() int { - return SyncFrequency + return SyncFrequency } // InitRedisClient This function is called after init() func InitRedisClient() (err error) { - if os.Getenv("REDIS_CONN_STRING") == "" { - RedisEnabled = false - SysLog("REDIS_CONN_STRING not set, Redis is not enabled") - return nil - } - if os.Getenv("SYNC_FREQUENCY") == "" { - SysLog("SYNC_FREQUENCY not set, use default value 60") - SyncFrequency = 60 - } - SysLog("Redis is enabled") - opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) - if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) - } - opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10) - RDB = redis.NewClient(opt) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - _, err = RDB.Ping(ctx).Result() - if err != nil { - FatalLog("Redis ping test failed: " + err.Error()) - } - if DebugEnabled { - SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr)) - SysLog(fmt.Sprintf("Redis database: %d", opt.DB)) - } - return err + if os.Getenv("REDIS_CONN_STRING") == "" { + RedisEnabled = false + SysLog("REDIS_CONN_STRING not set, Redis is not enabled") + return nil + } + if os.Getenv("SYNC_FREQUENCY") == "" { + SysLog("SYNC_FREQUENCY not set, use default value 60") + SyncFrequency = 60 + } + SysLog("Redis is enabled") + opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) + if err != nil { + FatalLog("failed to parse Redis connection string: " + err.Error()) + } + opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10) + RDB = redis.NewClient(opt) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err = RDB.Ping(ctx).Result() + if err != nil { + FatalLog("Redis ping test failed: " + err.Error()) + } + if DebugEnabled { + SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr)) + SysLog(fmt.Sprintf("Redis database: %d", opt.DB)) + } + return err } func ParseRedisOption() *redis.Options { - opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) - if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) - } - return opt + opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) + if err != nil { + FatalLog("failed to parse Redis connection string: " + err.Error()) + } + return opt } func RedisSet(key string, value string, expiration time.Duration) error { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration)) - } - ctx := context.Background() - return RDB.Set(ctx, key, value, expiration).Err() + if DebugEnabled { + SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration)) + } + ctx := context.Background() + return RDB.Set(ctx, key, value, expiration).Err() } func RedisGet(key string) (string, error) { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis GET: key=%s", key)) - } - ctx := context.Background() - val, err := RDB.Get(ctx, key).Result() - return val, err + if DebugEnabled { + SysLog(fmt.Sprintf("Redis GET: key=%s", key)) + } + ctx := context.Background() + val, err := RDB.Get(ctx, key).Result() + return val, err } //func RedisExpire(key string, expiration time.Duration) error { -// ctx := context.Background() -// return RDB.Expire(ctx, key, expiration).Err() +// ctx := context.Background() +// return RDB.Expire(ctx, key, expiration).Err() //} // //func RedisGetEx(key string, expiration time.Duration) (string, error) { -// ctx := context.Background() -// return RDB.GetSet(ctx, key, expiration).Result() +// ctx := context.Background() +// return RDB.GetSet(ctx, key, expiration).Result() //} func RedisDel(key string) error { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis DEL: key=%s", key)) - } - ctx := context.Background() - return RDB.Del(ctx, key).Err() + if DebugEnabled { + SysLog(fmt.Sprintf("Redis DEL: key=%s", key)) + } + ctx := context.Background() + return RDB.Del(ctx, key).Err() } func RedisDelKey(key string) error { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key)) - } - ctx := context.Background() - return RDB.Del(ctx, key).Err() + if DebugEnabled { + SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key)) + } + ctx := context.Background() + return RDB.Del(ctx, key).Err() } func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration)) - } - ctx := context.Background() - - data := make(map[string]interface{}) - - // 使用反射遍历结构体字段 - v := reflect.ValueOf(obj).Elem() - t := v.Type() - for i := 0; i < v.NumField(); i++ { - field := t.Field(i) - value := v.Field(i) - - // Skip DeletedAt field - if field.Type.String() == "gorm.DeletedAt" { - continue - } - - // 处理指针类型 - if value.Kind() == reflect.Ptr { - if value.IsNil() { - data[field.Name] = "" - continue - } - value = value.Elem() - } - - // 处理布尔类型 - if value.Kind() == reflect.Bool { - data[field.Name] = strconv.FormatBool(value.Bool()) - continue - } - - // 其他类型直接转换为字符串 - data[field.Name] = fmt.Sprintf("%v", value.Interface()) - } - - txn := RDB.TxPipeline() - txn.HSet(ctx, key, data) - - // 只有在 expiration 大于 0 时才设置过期时间 - if expiration > 0 { - txn.Expire(ctx, key, expiration) - } - - _, err := txn.Exec(ctx) - if err != nil { - return fmt.Errorf("failed to execute transaction: %w", err) - } - return nil + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration)) + } + ctx := context.Background() + + data := make(map[string]interface{}) + + // 使用反射遍历结构体字段 + v := reflect.ValueOf(obj).Elem() + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + value := v.Field(i) + + // Skip DeletedAt field + if field.Type.String() == "gorm.DeletedAt" { + continue + } + + // 处理指针类型 + if value.Kind() == reflect.Ptr { + if value.IsNil() { + data[field.Name] = "" + continue + } + value = value.Elem() + } + + // 处理布尔类型 + if value.Kind() == reflect.Bool { + data[field.Name] = strconv.FormatBool(value.Bool()) + continue + } + + // 其他类型直接转换为字符串 + data[field.Name] = fmt.Sprintf("%v", value.Interface()) + } + + txn := RDB.TxPipeline() + txn.HSet(ctx, key, data) + + // 只有在 expiration 大于 0 时才设置过期时间 + if expiration > 0 { + txn.Expire(ctx, key, expiration) + } + + _, err := txn.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to execute transaction: %w", err) + } + return nil } func RedisHGetObj(key string, obj interface{}) error { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key)) - } - ctx := context.Background() - - result, err := RDB.HGetAll(ctx, key).Result() - if err != nil { - return fmt.Errorf("failed to load hash from Redis: %w", err) - } - - if len(result) == 0 { - return fmt.Errorf("key %s not found in Redis", key) - } - - // Handle both pointer and non-pointer values - val := reflect.ValueOf(obj) - if val.Kind() != reflect.Ptr { - return fmt.Errorf("obj must be a pointer to a struct, got %T", obj) - } - - v := val.Elem() - if v.Kind() != reflect.Struct { - return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface()) - } - - t := v.Type() - for i := 0; i < v.NumField(); i++ { - field := t.Field(i) - fieldName := field.Name - if value, ok := result[fieldName]; ok { - fieldValue := v.Field(i) - - // Handle pointer types - if fieldValue.Kind() == reflect.Ptr { - if value == "" { - continue - } - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue = fieldValue.Elem() - } - - // Enhanced type handling for Token struct - switch fieldValue.Kind() { - case reflect.String: - fieldValue.SetString(value) - case reflect.Int, reflect.Int64: - intValue, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return fmt.Errorf("failed to parse int field %s: %w", fieldName, err) - } - fieldValue.SetInt(intValue) - case reflect.Bool: - boolValue, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err) - } - fieldValue.SetBool(boolValue) - case reflect.Struct: - // Special handling for gorm.DeletedAt - if fieldValue.Type().String() == "gorm.DeletedAt" { - if value != "" { - timeValue, err := time.Parse(time.RFC3339, value) - if err != nil { - return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err) - } - fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true})) - } - } - default: - return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName) - } - } - } - - return nil + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key)) + } + ctx := context.Background() + + result, err := RDB.HGetAll(ctx, key).Result() + if err != nil { + return fmt.Errorf("failed to load hash from Redis: %w", err) + } + + if len(result) == 0 { + return fmt.Errorf("key %s not found in Redis", key) + } + + // Handle both pointer and non-pointer values + val := reflect.ValueOf(obj) + if val.Kind() != reflect.Ptr { + return fmt.Errorf("obj must be a pointer to a struct, got %T", obj) + } + + v := val.Elem() + if v.Kind() != reflect.Struct { + return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface()) + } + + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + fieldName := field.Name + if value, ok := result[fieldName]; ok { + fieldValue := v.Field(i) + + // Handle pointer types + if fieldValue.Kind() == reflect.Ptr { + if value == "" { + continue + } + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue = fieldValue.Elem() + } + + // Enhanced type handling for Token struct + switch fieldValue.Kind() { + case reflect.String: + fieldValue.SetString(value) + case reflect.Int, reflect.Int64: + intValue, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse int field %s: %w", fieldName, err) + } + fieldValue.SetInt(intValue) + case reflect.Bool: + boolValue, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err) + } + fieldValue.SetBool(boolValue) + case reflect.Struct: + // Special handling for gorm.DeletedAt + if fieldValue.Type().String() == "gorm.DeletedAt" { + if value != "" { + timeValue, err := time.Parse(time.RFC3339, value) + if err != nil { + return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err) + } + fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true})) + } + } + default: + return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName) + } + } + } + + return nil } // RedisIncr Add this function to handle atomic increments func RedisIncr(key string, delta int64) error { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta)) - } - // 检查键的剩余生存时间 - ttlCmd := RDB.TTL(context.Background(), key) - ttl, err := ttlCmd.Result() - if err != nil && !errors.Is(err, redis.Nil) { - return fmt.Errorf("failed to get TTL: %w", err) - } - - // 只有在 key 存在且有 TTL 时才需要特殊处理 - if ttl > 0 { - ctx := context.Background() - // 开始一个Redis事务 - txn := RDB.TxPipeline() - - // 减少余额 - decrCmd := txn.IncrBy(ctx, key, delta) - if err := decrCmd.Err(); err != nil { - return err // 如果减少失败,则直接返回错误 - } - - // 重新设置过期时间,使用原来的过期时间 - txn.Expire(ctx, key, ttl) - - // 执行事务 - _, err = txn.Exec(ctx) - return err - } - return nil + if DebugEnabled { + SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta)) + } + // 检查键的剩余生存时间 + ttlCmd := RDB.TTL(context.Background(), key) + ttl, err := ttlCmd.Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("failed to get TTL: %w", err) + } + + // 只有在 key 存在且有 TTL 时才需要特殊处理 + if ttl > 0 { + ctx := context.Background() + // 开始一个Redis事务 + txn := RDB.TxPipeline() + + // 减少余额 + decrCmd := txn.IncrBy(ctx, key, delta) + if err := decrCmd.Err(); err != nil { + return err // 如果减少失败,则直接返回错误 + } + + // 重新设置过期时间,使用原来的过期时间 + txn.Expire(ctx, key, ttl) + + // 执行事务 + _, err = txn.Exec(ctx) + return err + } + return nil } func RedisHIncrBy(key, field string, delta int64) error { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta)) - } - ttlCmd := RDB.TTL(context.Background(), key) - ttl, err := ttlCmd.Result() - if err != nil && !errors.Is(err, redis.Nil) { - return fmt.Errorf("failed to get TTL: %w", err) - } - - if ttl > 0 { - ctx := context.Background() - txn := RDB.TxPipeline() - - incrCmd := txn.HIncrBy(ctx, key, field, delta) - if err := incrCmd.Err(); err != nil { - return err - } - - txn.Expire(ctx, key, ttl) - - _, err = txn.Exec(ctx) - return err - } - return nil + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta)) + } + ttlCmd := RDB.TTL(context.Background(), key) + ttl, err := ttlCmd.Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("failed to get TTL: %w", err) + } + + if ttl > 0 { + ctx := context.Background() + txn := RDB.TxPipeline() + + incrCmd := txn.HIncrBy(ctx, key, field, delta) + if err := incrCmd.Err(); err != nil { + return err + } + + txn.Expire(ctx, key, ttl) + + _, err = txn.Exec(ctx) + return err + } + return nil } func RedisHSetField(key, field string, value interface{}) error { - if DebugEnabled { - SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value)) - } - ttlCmd := RDB.TTL(context.Background(), key) - ttl, err := ttlCmd.Result() - if err != nil && !errors.Is(err, redis.Nil) { - return fmt.Errorf("failed to get TTL: %w", err) - } - - if ttl > 0 { - ctx := context.Background() - txn := RDB.TxPipeline() - - hsetCmd := txn.HSet(ctx, key, field, value) - if err := hsetCmd.Err(); err != nil { - return err - } - - txn.Expire(ctx, key, ttl) - - _, err = txn.Exec(ctx) - return err - } - return nil + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value)) + } + ttlCmd := RDB.TTL(context.Background(), key) + ttl, err := ttlCmd.Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("failed to get TTL: %w", err) + } + + if ttl > 0 { + ctx := context.Background() + txn := RDB.TxPipeline() + + hsetCmd := txn.HSet(ctx, key, field, value) + if err := hsetCmd.Err(); err != nil { + return err + } + + txn.Expire(ctx, key, ttl) + + _, err = txn.Exec(ctx) + return err + } + return nil +} + +func RedisPublish(channel, message string) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis PUBLISH: channel=%s, message=%s", channel, message)) + } + ctx := context.Background() + return RDB.Publish(ctx, channel, message).Err() } diff --git a/controller/security.go b/controller/security.go index 7711cedeecf..0951449b051 100644 --- a/controller/security.go +++ b/controller/security.go @@ -1,333 +1,544 @@ package controller import ( - "net/http" - "strconv" - "time" + "net/http" + "strconv" + "time" - "github.com/QuantumNous/new-api/model" - "github.com/QuantumNous/new-api/service" - "github.com/gin-gonic/gin" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/service" + "github.com/gin-gonic/gin" ) // GetSecurityDashboard returns security dashboard statistics func GetSecurityDashboard(c *gin.Context) { - // Parse time range from query params - startTimeStr := c.Query("start_time") - endTimeStr := c.Query("end_time") - - var startTime, endTime time.Time - var err error - - if startTimeStr != "" { - startTime, err = time.Parse(time.RFC3339, startTimeStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid start_time format", - }) - return - } - } else { - // Default to last 7 days - startTime = time.Now().AddDate(0, 0, -7) - } - - if endTimeStr != "" { - endTime, err = time.Parse(time.RFC3339, endTimeStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid end_time format", - }) - return - } - } else { - endTime = time.Now() - } - - stats, err := service.GetDashboardStats(startTime, endTime) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": stats, - }) + // Parse time range from query params + startTimeStr := c.Query("start_time") + endTimeStr := c.Query("end_time") + + var startTime, endTime time.Time + var err error + + if startTimeStr != "" { + startTime, err = time.Parse(time.RFC3339, startTimeStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid start_time format", + }) + return + } + } else { + // Default to last 7 days + startTime = time.Now().AddDate(0, 0, -7) + } + + if endTimeStr != "" { + endTime, err = time.Parse(time.RFC3339, endTimeStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid end_time format", + }) + return + } + } else { + endTime = time.Now() + } + + stats, err := service.GetDashboardStats(startTime, endTime) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": stats, + }) } // GetSecurityViolations returns paginated violation records func GetSecurityViolations(c *gin.Context) { - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) - userId, _ := strconv.Atoi(c.Query("user_id")) - keyword := c.Query("keyword") - startTimeStr := c.Query("start_time") - endTimeStr := c.Query("end_time") - - offset := (page - 1) * pageSize - - var startTime, endTime *time.Time - if startTimeStr != "" { - t, err := time.Parse(time.RFC3339, startTimeStr) - if err == nil { - startTime = &t - } - } - if endTimeStr != "" { - t, err := time.Parse(time.RFC3339, endTimeStr) - if err == nil { - endTime = &t - } - } - - violations, total, err := model.GetSecurityViolations(offset, pageSize, userId, startTime, endTime, keyword) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": gin.H{ - "violations": violations, - "total": total, - "page": page, - "page_size": pageSize, - }, - }) + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + userId, _ := strconv.Atoi(c.Query("user_id")) + keyword := c.Query("keyword") + startTimeStr := c.Query("start_time") + endTimeStr := c.Query("end_time") + + offset := (page - 1) * pageSize + + var startTime, endTime *time.Time + if startTimeStr != "" { + t, err := time.Parse(time.RFC3339, startTimeStr) + if err == nil { + startTime = &t + } + } + if endTimeStr != "" { + t, err := time.Parse(time.RFC3339, endTimeStr) + if err == nil { + endTime = &t + } + } + + violations, total, err := model.GetSecurityViolations(offset, pageSize, userId, startTime, endTime, keyword) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "violations": violations, + "total": total, + "page": page, + "page_size": pageSize, + }, + }) } // DeleteSecurityViolation deletes a violation record func DeleteSecurityViolation(c *gin.Context) { - id, err := strconv.Atoi(c.Param("id")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid violation ID", - }) - return - } - - err = model.DeleteSecurityViolation(id) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "Violation record deleted successfully", - }) + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid violation ID", + }) + return + } + + err = model.DeleteSecurityViolation(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "Violation record deleted successfully", + }) } // GetSecurityUsers returns users with violations func GetSecurityUsers(c *gin.Context) { - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) - bannedOnly := c.Query("banned_only") == "true" - - offset := (page - 1) * pageSize - - userSecList, total, err := model.GetAllUserSecurity(offset, pageSize, bannedOnly) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - // Enrich with user info - var enrichedUsers []map[string]interface{} - for _, userSec := range userSecList { - user, err := model.GetUserById(userSec.UserId, false) - if err != nil { - continue - } - - enrichedUsers = append(enrichedUsers, map[string]interface{}{ - "user_id": userSec.UserId, - "username": user.Username, - "display_name": user.DisplayName, - "is_banned": userSec.IsBanned, - "redirect_model": userSec.RedirectModel, - "violation_count": userSec.ViolationCount, - "last_violation_at": userSec.LastViolationAt, - }) - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": gin.H{ - "users": enrichedUsers, - "total": total, - "page": page, - "page_size": pageSize, - }, - }) + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + bannedOnly := c.Query("banned_only") == "true" + + offset := (page - 1) * pageSize + + userSecList, total, err := model.GetAllUserSecurity(offset, pageSize, bannedOnly) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + // Enrich with user info + var enrichedUsers []map[string]interface{} + for _, userSec := range userSecList { + user, err := model.GetUserById(userSec.UserId, false) + if err != nil { + continue + } + + enrichedUsers = append(enrichedUsers, map[string]interface{}{ + "user_id": userSec.UserId, + "username": user.Username, + "display_name": user.DisplayName, + "is_banned": userSec.IsBanned, + "redirect_model": userSec.RedirectModel, + "violation_count": userSec.ViolationCount, + "last_violation_at": userSec.LastViolationAt, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "users": enrichedUsers, + "total": total, + "page": page, + "page_size": pageSize, + }, + }) } // BanUser bans a user func BanUser(c *gin.Context) { - userId, err := strconv.Atoi(c.Param("userId")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid user ID", - }) - return - } - - err = service.BanUser(userId) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "User banned successfully", - }) + userId, err := strconv.Atoi(c.Param("userId")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid user ID", + }) + return + } + + err = service.BanUser(userId) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "User banned successfully", + }) } // UnbanUser unbans a user func UnbanUser(c *gin.Context) { - userId, err := strconv.Atoi(c.Param("userId")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid user ID", - }) - return - } - - err = service.UnbanUser(userId) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "User unbanned successfully", - }) + userId, err := strconv.Atoi(c.Param("userId")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid user ID", + }) + return + } + + err = service.UnbanUser(userId) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "User unbanned successfully", + }) } // SetUserRedirect sets redirect model for a user func SetUserRedirect(c *gin.Context) { - userId, err := strconv.Atoi(c.Param("userId")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid user ID", - }) - return - } - - var req struct { - Model string `json:"model" binding:"required"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid request body", - }) - return - } - - err = service.SetUserRedirect(userId, req.Model) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "User redirect set successfully", - }) + userId, err := strconv.Atoi(c.Param("userId")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid user ID", + }) + return + } + + var req struct { + Model string `json:"model" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid request body", + }) + return + } + + err = service.SetUserRedirect(userId, req.Model) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "User redirect set successfully", + }) } // ClearUserRedirect removes redirect for a user func ClearUserRedirect(c *gin.Context) { - userId, err := strconv.Atoi(c.Param("userId")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid user ID", - }) - return - } - - err = service.ClearUserRedirect(userId) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "User redirect cleared successfully", - }) + userId, err := strconv.Atoi(c.Param("userId")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid user ID", + }) + return + } + + err = service.ClearUserRedirect(userId) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "User redirect cleared successfully", + }) } // GetSecuritySettings returns security settings func GetSecuritySettings(c *gin.Context) { - settings := service.GetSecuritySettings() + settings := service.GetSecuritySettings() - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": settings, - }) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": settings, + }) } // UpdateSecuritySettings updates security settings func UpdateSecuritySettings(c *gin.Context) { - var settings map[string]interface{} - - if err := c.ShouldBindJSON(&settings); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "success": false, - "message": "Invalid request body", - }) - return - } - - err := service.UpdateSecuritySettings(settings) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "Settings updated successfully", - }) + var settings map[string]interface{} + + if err := c.ShouldBindJSON(&settings); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid request body", + }) + return + } + + err := service.UpdateSecuritySettings(settings) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "Settings updated successfully", + }) +} + +// GetDeviceClusters returns device fingerprint clusters +func GetDeviceClusters(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + blockedStr := c.Query("blocked") + + offset := (page - 1) * pageSize + + var blocked *bool + if blockedStr == "true" { + b := true + blocked = &b + } else if blockedStr == "false" { + b := false + blocked = &b + } + + devices, total, err := model.GetDeviceClusters(offset, pageSize, blocked) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "devices": devices, + "total": total, + "page": page, + "page_size": pageSize, + }, + }) +} + +// GetIPClusters returns IP cluster data +func GetIPClusters(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + blockedStr := c.Query("blocked") + minRiskScore, _ := strconv.Atoi(c.DefaultQuery("min_risk_score", "0")) + + offset := (page - 1) * pageSize + + var blocked *bool + if blockedStr == "true" { + b := true + blocked = &b + } else if blockedStr == "false" { + b := false + blocked = &b + } + + clusters, total, err := model.GetIPClusters(offset, pageSize, blocked, minRiskScore) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "clusters": clusters, + "total": total, + "page": page, + "page_size": pageSize, + }, + }) +} + +// GetSecurityAnomalies returns security anomalies with filters +func GetSecurityAnomalies(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + userId, _ := strconv.Atoi(c.Query("user_id")) + severity := c.Query("severity") + anomalyType := c.Query("anomaly_type") + status := c.Query("status") + startTimeStr := c.Query("start_time") + endTimeStr := c.Query("end_time") + + offset := (page - 1) * pageSize + + filters := make(map[string]interface{}) + if userId > 0 { + filters["user_id"] = userId + } + if severity != "" { + filters["severity"] = severity + } + if anomalyType != "" { + filters["anomaly_type"] = anomalyType + } + if status != "" { + filters["status"] = status + } + + if startTimeStr != "" { + t, err := time.Parse(time.RFC3339, startTimeStr) + if err == nil { + filters["start_time"] = &t + } + } + if endTimeStr != "" { + t, err := time.Parse(time.RFC3339, endTimeStr) + if err == nil { + filters["end_time"] = &t + } + } + + anomalies, total, err := model.GetSecurityAnomalies(offset, pageSize, filters) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "anomalies": anomalies, + "total": total, + "page": page, + "page_size": pageSize, + }, + }) +} + +// ApproveAnomaly approves a security anomaly +func ApproveAnomaly(c *gin.Context) { + anomalyId, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid anomaly ID", + }) + return + } + + var req struct { + Rationale string `json:"rationale"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid request body", + }) + return + } + + userId := c.GetInt("id") + err = service.ApproveAnomaly(anomalyId, userId, req.Rationale) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "Anomaly approved successfully", + }) +} + +// IgnoreAnomaly ignores a security anomaly +func IgnoreAnomaly(c *gin.Context) { + anomalyId, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid anomaly ID", + }) + return + } + + var req struct { + Rationale string `json:"rationale" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid request body", + }) + return + } + + userId := c.GetInt("id") + err = service.IgnoreAnomaly(anomalyId, userId, req.Rationale) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "Anomaly ignored successfully", + }) } diff --git a/controller/security_test.go b/controller/security_test.go new file mode 100644 index 00000000000..acd50bcebc0 --- /dev/null +++ b/controller/security_test.go @@ -0,0 +1,179 @@ +package controller + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func setupTestRouter() *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + return router +} + +func TestGetSecurityDashboard(t *testing.T) { + router := setupTestRouter() + router.GET("/api/security/dashboard", GetSecurityDashboard) + + req, _ := http.NewRequest("GET", "/api/security/dashboard", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + var response map[string]interface{} + err := json.Unmarshal(resp.Body.Bytes(), &response) + assert.NoError(t, err) + assert.True(t, response["success"].(bool)) + assert.NotNil(t, response["data"]) +} + +func TestGetDeviceClusters(t *testing.T) { + router := setupTestRouter() + router.GET("/api/security/devices", GetDeviceClusters) + + req, _ := http.NewRequest("GET", "/api/security/devices?page=1&page_size=10", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + var response map[string]interface{} + err := json.Unmarshal(resp.Body.Bytes(), &response) + assert.NoError(t, err) + assert.True(t, response["success"].(bool)) + + data := response["data"].(map[string]interface{}) + assert.NotNil(t, data["devices"]) + assert.NotNil(t, data["total"]) +} + +func TestGetIPClusters(t *testing.T) { + router := setupTestRouter() + router.GET("/api/security/ip-clusters", GetIPClusters) + + req, _ := http.NewRequest("GET", "/api/security/ip-clusters?page=1&page_size=10", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + var response map[string]interface{} + err := json.Unmarshal(resp.Body.Bytes(), &response) + assert.NoError(t, err) + assert.True(t, response["success"].(bool)) + + data := response["data"].(map[string]interface{}) + assert.NotNil(t, data["clusters"]) + assert.NotNil(t, data["total"]) +} + +func TestGetSecurityAnomalies(t *testing.T) { + router := setupTestRouter() + router.GET("/api/security/anomalies", GetSecurityAnomalies) + + req, _ := http.NewRequest("GET", "/api/security/anomalies?page=1&page_size=10&severity=malicious", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + var response map[string]interface{} + err := json.Unmarshal(resp.Body.Bytes(), &response) + assert.NoError(t, err) + assert.True(t, response["success"].(bool)) + + data := response["data"].(map[string]interface{}) + assert.NotNil(t, data["anomalies"]) + assert.NotNil(t, data["total"]) +} + +func TestApproveAnomaly(t *testing.T) { + router := setupTestRouter() + router.POST("/api/security/anomalies/:id/approve", func(c *gin.Context) { + c.Set("id", 1) + ApproveAnomaly(c) + }) + + body := `{"rationale": "False positive, user verified"}` + req, _ := http.NewRequest("POST", "/api/security/anomalies/1/approve", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + var response map[string]interface{} + err := json.Unmarshal(resp.Body.Bytes(), &response) + assert.NoError(t, err) + assert.True(t, response["success"].(bool)) +} + +func TestIgnoreAnomaly(t *testing.T) { + router := setupTestRouter() + router.POST("/api/security/anomalies/:id/ignore", func(c *gin.Context) { + c.Set("id", 1) + IgnoreAnomaly(c) + }) + + body := `{"rationale": "Testing anomaly, can be ignored"}` + req, _ := http.NewRequest("POST", "/api/security/anomalies/1/ignore", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + var response map[string]interface{} + err := json.Unmarshal(resp.Body.Bytes(), &response) + assert.NoError(t, err) + assert.True(t, response["success"].(bool)) +} + +func TestEnforcementWorkflow(t *testing.T) { + userId := 999 + + anomaly, err := service.CreateAnomaly( + userId, + nil, + service.AnomalyTypeHighRPM, + "malicious", + "Test anomaly for enforcement", + map[string]interface{}{ + "rpm": 150, + }, + "192.168.1.100", + "test-device-123", + 85, + ) + + assert.NoError(t, err) + assert.NotNil(t, anomaly) + assert.Equal(t, userId, anomaly.UserId) + assert.Equal(t, "malicious", anomaly.Severity) + + time.Sleep(100 * time.Millisecond) + + updatedAnomaly, err := model.GetSecurityAnomaly(anomaly.Id) + assert.NoError(t, err) + + if updatedAnomaly.Status == service.StatusActioned { + assert.NotEmpty(t, updatedAnomaly.ActionTaken) + assert.NotNil(t, updatedAnomaly.ActionedAt) + } +} diff --git a/middleware/heimdall.go b/middleware/heimdall.go new file mode 100644 index 00000000000..d9275221e18 --- /dev/null +++ b/middleware/heimdall.go @@ -0,0 +1,80 @@ +package middleware + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/service" + "github.com/gin-gonic/gin" +) + +func Heimdall() gin.HandlerFunc { + return func(c *gin.Context) { + userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) + if userId == 0 { + c.Next() + return + } + + blocked, err := checkHeimdallBlock(userId) + if err != nil { + common.SysLog(fmt.Sprintf("heimdall check error: %v", err)) + c.Next() + return + } + + if blocked { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "message": "Access blocked due to security policy", + "type": "security_violation", + "code": "heimdall_block", + }, + }) + c.Abort() + return + } + + banned, err := service.CheckUserBanned(userId) + if err == nil && banned { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "message": "User account has been banned", + "type": "security_violation", + "code": "user_banned", + }, + }) + c.Abort() + return + } + + c.Next() + } +} + +func checkHeimdallBlock(userId int) (bool, error) { + if !common.RedisEnabled { + return false, nil + } + + key := fmt.Sprintf("heimdall:directive:%d", userId) + data, err := common.RedisGet(key) + if err != nil || data == "" { + return false, nil + } + + var directive map[string]interface{} + if err := json.Unmarshal([]byte(data), &directive); err != nil { + return false, err + } + + action, ok := directive["action"].(string) + if !ok { + return false, nil + } + + return action == "block" || action == "ban", nil +} diff --git a/middleware/heimdall_test.go b/middleware/heimdall_test.go new file mode 100644 index 00000000000..e084ab7f0c3 --- /dev/null +++ b/middleware/heimdall_test.go @@ -0,0 +1,120 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestHeimdallMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + userId int + setupDirective func(int) + expectedStatus int + expectBlock bool + }{ + { + name: "No block directive - user passes through", + userId: 1, + setupDirective: func(userId int) { + }, + expectedStatus: http.StatusOK, + expectBlock: false, + }, + { + name: "Block directive exists - user blocked", + userId: 2, + setupDirective: func(userId int) { + if common.RedisEnabled { + directive := map[string]interface{}{ + "user_id": userId, + "action": "block", + } + data, _ := json.Marshal(directive) + common.RedisSet("heimdall:directive:2", string(data), 3600) + } + }, + expectedStatus: http.StatusForbidden, + expectBlock: true, + }, + { + name: "Ban directive exists - user blocked", + userId: 3, + setupDirective: func(userId int) { + if common.RedisEnabled { + directive := map[string]interface{}{ + "user_id": userId, + "action": "ban", + } + data, _ := json.Marshal(directive) + common.RedisSet("heimdall:directive:3", string(data), 3600) + } + }, + expectedStatus: http.StatusForbidden, + expectBlock: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router := gin.New() + router.Use(func(c *gin.Context) { + common.SetContextKey(c, constant.ContextKeyUserId, tt.userId) + c.Next() + }) + router.Use(Heimdall()) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "success"}) + }) + + tt.setupDirective(tt.userId) + + req, _ := http.NewRequest("GET", "/test", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + assert.Equal(t, tt.expectedStatus, resp.Code) + + if tt.expectBlock { + var response map[string]interface{} + err := json.Unmarshal(resp.Body.Bytes(), &response) + assert.NoError(t, err) + assert.NotNil(t, response["error"]) + } + }) + } +} + +func TestCheckHeimdallBlock(t *testing.T) { + if !common.RedisEnabled { + t.Skip("Redis not enabled, skipping test") + } + + userId := 100 + directive := map[string]interface{}{ + "user_id": userId, + "action": "block", + } + data, _ := json.Marshal(directive) + common.RedisSet("heimdall:directive:100", string(data), 3600) + + blocked, err := checkHeimdallBlock(userId) + assert.NoError(t, err) + assert.True(t, blocked) + + common.RedisDel("heimdall:directive:100") + + blocked, err = checkHeimdallBlock(userId) + assert.NoError(t, err) + assert.False(t, blocked) +} diff --git a/model/device_fingerprint.go b/model/device_fingerprint.go new file mode 100644 index 00000000000..d715b5127f7 --- /dev/null +++ b/model/device_fingerprint.go @@ -0,0 +1,67 @@ +package model + +import ( + "time" +) + +type DeviceFingerprint struct { + Id int `json:"id" gorm:"primaryKey;autoIncrement"` + UserId int `json:"user_id" gorm:"index;not null"` + Fingerprint string `json:"fingerprint" gorm:"type:varchar(256);index;not null"` + UserAgent string `json:"user_agent" gorm:"type:text"` + IpAddress string `json:"ip_address" gorm:"type:varchar(45);index"` + FirstSeenAt time.Time `json:"first_seen_at" gorm:"not null"` + LastSeenAt time.Time `json:"last_seen_at" gorm:"not null"` + RequestCount int `json:"request_count" gorm:"default:0"` + IsBlocked bool `json:"is_blocked" gorm:"default:false;index"` + RiskScore int `json:"risk_score" gorm:"default:0"` +} + +func (DeviceFingerprint) TableName() string { + return "device_fingerprints" +} + +func CreateDeviceFingerprint(device *DeviceFingerprint) error { + return DB.Create(device).Error +} + +func GetDeviceFingerprint(fingerprint string, userId int) (*DeviceFingerprint, error) { + var device DeviceFingerprint + err := DB.Where("fingerprint = ? AND user_id = ?", fingerprint, userId).First(&device).Error + return &device, err +} + +func UpdateDeviceFingerprint(device *DeviceFingerprint) error { + return DB.Save(device).Error +} + +func GetDeviceClusters(offset, limit int, blocked *bool) ([]*DeviceFingerprint, int64, error) { + var devices []*DeviceFingerprint + var total int64 + + query := DB.Model(&DeviceFingerprint{}) + if blocked != nil { + query = query.Where("is_blocked = ?", *blocked) + } + + err := query.Count(&total).Error + if err != nil { + return nil, 0, err + } + + err = query.Order("risk_score DESC, last_seen_at DESC"). + Offset(offset). + Limit(limit). + Find(&devices).Error + + return devices, total, err +} + +func GetTopSuspiciousDevices(limit int) ([]*DeviceFingerprint, error) { + var devices []*DeviceFingerprint + err := DB.Where("risk_score > ?", 50). + Order("risk_score DESC"). + Limit(limit). + Find(&devices).Error + return devices, err +} diff --git a/model/ip_cluster.go b/model/ip_cluster.go new file mode 100644 index 00000000000..6d68b16b106 --- /dev/null +++ b/model/ip_cluster.go @@ -0,0 +1,71 @@ +package model + +import ( + "time" +) + +type IPCluster struct { + Id int `json:"id" gorm:"primaryKey;autoIncrement"` + IpAddress string `json:"ip_address" gorm:"type:varchar(45);uniqueIndex;not null"` + Country string `json:"country" gorm:"type:varchar(2)"` + City string `json:"city" gorm:"type:varchar(100)"` + UniqueUsers int `json:"unique_users" gorm:"default:0"` + TotalRequests int `json:"total_requests" gorm:"default:0"` + ViolationCount int `json:"violation_count" gorm:"default:0"` + IsBlocked bool `json:"is_blocked" gorm:"default:false;index"` + RiskScore int `json:"risk_score" gorm:"default:0"` + FirstSeenAt time.Time `json:"first_seen_at" gorm:"not null"` + LastSeenAt time.Time `json:"last_seen_at" gorm:"not null"` +} + +func (IPCluster) TableName() string { + return "ip_clusters" +} + +func CreateIPCluster(cluster *IPCluster) error { + return DB.Create(cluster).Error +} + +func GetIPCluster(ipAddress string) (*IPCluster, error) { + var cluster IPCluster + err := DB.Where("ip_address = ?", ipAddress).First(&cluster).Error + return &cluster, err +} + +func UpdateIPCluster(cluster *IPCluster) error { + return DB.Save(cluster).Error +} + +func GetIPClusters(offset, limit int, blocked *bool, minRiskScore int) ([]*IPCluster, int64, error) { + var clusters []*IPCluster + var total int64 + + query := DB.Model(&IPCluster{}) + if blocked != nil { + query = query.Where("is_blocked = ?", *blocked) + } + if minRiskScore > 0 { + query = query.Where("risk_score >= ?", minRiskScore) + } + + err := query.Count(&total).Error + if err != nil { + return nil, 0, err + } + + err = query.Order("risk_score DESC, last_seen_at DESC"). + Offset(offset). + Limit(limit). + Find(&clusters).Error + + return clusters, total, err +} + +func GetTopSuspiciousIPs(limit int) ([]*IPCluster, error) { + var clusters []*IPCluster + err := DB.Where("risk_score > ?", 50). + Order("risk_score DESC"). + Limit(limit). + Find(&clusters).Error + return clusters, err +} diff --git a/model/main.go b/model/main.go index 5b9d046d958..25c528f0aa0 100644 --- a/model/main.go +++ b/model/main.go @@ -288,6 +288,9 @@ func migrateDB() error { &LotteryRecord{}, &SecurityViolation{}, &UserSecurity{}, + &DeviceFingerprint{}, + &IPCluster{}, + &SecurityAnomaly{}, &Ticket{}, ) if err != nil { @@ -337,6 +340,9 @@ func migrateDBFast() error { {&LotteryRecord{}, "LotteryRecord"}, {&SecurityViolation{}, "SecurityViolation"}, {&UserSecurity{}, "UserSecurity"}, + {&DeviceFingerprint{}, "DeviceFingerprint"}, + {&IPCluster{}, "IPCluster"}, + {&SecurityAnomaly{}, "SecurityAnomaly"}, {&Ticket{}, "Ticket"}, } // 动态计算migration数量,确保errChan缓冲区足够大 diff --git a/model/security_anomaly.go b/model/security_anomaly.go new file mode 100644 index 00000000000..1eaa039d630 --- /dev/null +++ b/model/security_anomaly.go @@ -0,0 +1,161 @@ +package model + +import ( + "time" +) + +type SecurityAnomaly struct { + Id int `json:"id" gorm:"primaryKey;autoIncrement"` + UserId int `json:"user_id" gorm:"index;not null"` + TokenId *int `json:"token_id" gorm:"index"` + DetectedAt time.Time `json:"detected_at" gorm:"index;not null"` + AnomalyType string `json:"anomaly_type" gorm:"type:varchar(50);index"` + Severity string `json:"severity" gorm:"type:varchar(20);index"` + Description string `json:"description" gorm:"type:text"` + Metadata string `json:"metadata" gorm:"type:text"` + IpAddress string `json:"ip_address" gorm:"type:varchar(45);index"` + DeviceId string `json:"device_id" gorm:"type:varchar(256);index"` + RiskScore int `json:"risk_score" gorm:"default:0"` + ActionTaken string `json:"action_taken" gorm:"type:varchar(50)"` + ActionedAt *time.Time `json:"actioned_at"` + Status string `json:"status" gorm:"type:varchar(20);default:'pending';index"` + ReviewedBy *int `json:"reviewed_by"` + ReviewedAt *time.Time `json:"reviewed_at"` + ReviewDecision string `json:"review_decision" gorm:"type:varchar(20)"` + ReviewRationale string `json:"review_rationale" gorm:"type:text"` +} + +func (SecurityAnomaly) TableName() string { + return "security_anomalies" +} + +func CreateSecurityAnomaly(anomaly *SecurityAnomaly) error { + if anomaly.DetectedAt.IsZero() { + anomaly.DetectedAt = time.Now() + } + if anomaly.Status == "" { + anomaly.Status = "pending" + } + return DB.Create(anomaly).Error +} + +func GetSecurityAnomaly(id int) (*SecurityAnomaly, error) { + var anomaly SecurityAnomaly + err := DB.Where("id = ?", id).First(&anomaly).Error + return &anomaly, err +} + +func UpdateSecurityAnomaly(anomaly *SecurityAnomaly) error { + return DB.Save(anomaly).Error +} + +func GetSecurityAnomalies(offset, limit int, filters map[string]interface{}) ([]*SecurityAnomaly, int64, error) { + var anomalies []*SecurityAnomaly + var total int64 + + query := DB.Model(&SecurityAnomaly{}) + + if userId, ok := filters["user_id"].(int); ok && userId > 0 { + query = query.Where("user_id = ?", userId) + } + if severity, ok := filters["severity"].(string); ok && severity != "" { + query = query.Where("severity = ?", severity) + } + if anomalyType, ok := filters["anomaly_type"].(string); ok && anomalyType != "" { + query = query.Where("anomaly_type = ?", anomalyType) + } + if status, ok := filters["status"].(string); ok && status != "" { + query = query.Where("status = ?", status) + } + if startTime, ok := filters["start_time"].(*time.Time); ok && startTime != nil { + query = query.Where("detected_at >= ?", startTime) + } + if endTime, ok := filters["end_time"].(*time.Time); ok && endTime != nil { + query = query.Where("detected_at <= ?", endTime) + } + + err := query.Count(&total).Error + if err != nil { + return nil, 0, err + } + + err = query.Order("detected_at DESC"). + Offset(offset). + Limit(limit). + Find(&anomalies).Error + + return anomalies, total, err +} + +func GetAnomalyCountsBySeverity(startTime, endTime time.Time) (map[string]int64, error) { + type SeverityCount struct { + Severity string + Count int64 + } + + var results []SeverityCount + err := DB.Model(&SecurityAnomaly{}). + Select("severity, COUNT(*) as count"). + Where("detected_at >= ? AND detected_at <= ?", startTime, endTime). + Group("severity"). + Scan(&results).Error + + if err != nil { + return nil, err + } + + counts := make(map[string]int64) + for _, r := range results { + counts[r.Severity] = r.Count + } + + return counts, nil +} + +func GetAnomalyTrends(startTime, endTime time.Time) ([]map[string]interface{}, error) { + type DailyTrend struct { + Date string + Severity string + Count int64 + } + + var results []DailyTrend + + dateFormat := "DATE_FORMAT(detected_at, '%Y-%m-%d')" + if UsingPostgreSQL { + dateFormat = "TO_CHAR(detected_at, 'YYYY-MM-DD')" + } else if UsingSQLite { + dateFormat = "DATE(detected_at)" + } + + err := DB.Model(&SecurityAnomaly{}). + Select(dateFormat+" as date, severity, COUNT(*) as count"). + Where("detected_at >= ? AND detected_at <= ?", startTime, endTime). + Group("date, severity"). + Order("date DESC, severity"). + Scan(&results).Error + + if err != nil { + return nil, err + } + + trends := make([]map[string]interface{}, len(results)) + for i, r := range results { + trends[i] = map[string]interface{}{ + "date": r.Date, + "severity": r.Severity, + "count": r.Count, + } + } + + return trends, nil +} + +func GetPendingHighSeverityAnomalies(limit int) ([]*SecurityAnomaly, error) { + var anomalies []*SecurityAnomaly + err := DB.Where("status = ? AND severity = ?", "pending", "malicious"). + Order("detected_at ASC"). + Limit(limit). + Find(&anomalies).Error + return anomalies, err +} diff --git a/router/api-router.go b/router/api-router.go index 20e4e4cc001..fa1cbaee570 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -374,6 +374,11 @@ func SetApiRouter(router *gin.Engine) { securityRoute.DELETE("/users/:userId/redirect", controller.ClearUserRedirect) securityRoute.GET("/settings", controller.GetSecuritySettings) securityRoute.PUT("/settings", controller.UpdateSecuritySettings) + securityRoute.GET("/devices", controller.GetDeviceClusters) + securityRoute.GET("/ip-clusters", controller.GetIPClusters) + securityRoute.GET("/anomalies", controller.GetSecurityAnomalies) + securityRoute.POST("/anomalies/:id/approve", controller.ApproveAnomaly) + securityRoute.POST("/anomalies/:id/ignore", controller.IgnoreAnomaly) } ticketRoute := apiRouter.Group("/ticket") diff --git a/router/relay-router.go b/router/relay-router.go index c762b3215d8..448d56b80b1 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -62,6 +62,7 @@ func SetRelayRouter(router *gin.Engine) { } relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.TokenAuth()) + relayV1Router.Use(middleware.Heimdall()) relayV1Router.Use(middleware.ModelRequestRateLimit()) { // WebSocket 路由(统一到 Relay) diff --git a/service/security.go b/service/security.go index 9afa4dbc546..7b4c8ad675f 100644 --- a/service/security.go +++ b/service/security.go @@ -1,234 +1,281 @@ package service import ( - "fmt" - "strings" - "time" + "fmt" + "strings" + "time" - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/model" - governanceSvc "github.com/QuantumNous/new-api/service/governance" + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + governanceSvc "github.com/QuantumNous/new-api/service/governance" ) // CheckContentViolation checks if content violates policies func CheckContentViolation(content string) (bool, []string, string) { - if strings.TrimSpace(content) == "" { - return false, nil, "" - } + if strings.TrimSpace(content) == "" { + return false, nil, "" + } - // Use existing governance detection - result := governanceSvc.DetectKeywordPolicy(content) - if result.Triggered { - return true, result.Reasons, result.Severity - } + // Use existing governance detection + result := governanceSvc.DetectKeywordPolicy(content) + if result.Triggered { + return true, result.Reasons, result.Severity + } - return false, nil, "" + return false, nil, "" } // RecordViolation records a security violation func RecordViolation(userId int, tokenId *int, content string, keywords []string, model, ipAddress, requestId, severity, action string) error { - // Sanitize content snippet (limit to 500 chars, mask sensitive parts) - snippet := sanitizeContent(content) - - violation := &model.SecurityViolation{ - UserId: userId, - TokenId: tokenId, - ViolatedAt: time.Now(), - ContentSnippet: snippet, - MatchedKeywords: strings.Join(keywords, ", "), - Model: model, - IpAddress: ipAddress, - RequestId: requestId, - Severity: severity, - ActionTaken: action, - } - - err := model.CreateSecurityViolation(violation) - if err != nil { - return err - } - - // Increment user violation count - return model.IncrementViolationCount(userId) + // Sanitize content snippet (limit to 500 chars, mask sensitive parts) + snippet := sanitizeContent(content) + + violation := &model.SecurityViolation{ + UserId: userId, + TokenId: tokenId, + ViolatedAt: time.Now(), + ContentSnippet: snippet, + MatchedKeywords: strings.Join(keywords, ", "), + Model: model, + IpAddress: ipAddress, + RequestId: requestId, + Severity: severity, + ActionTaken: action, + } + + err := model.CreateSecurityViolation(violation) + if err != nil { + return err + } + + // Increment user violation count + return model.IncrementViolationCount(userId) } // sanitizeContent masks sensitive information in content func sanitizeContent(content string) string { - // Limit length - if len(content) > 500 { - content = content[:500] + "..." - } - - // Basic sanitization - mask potential sensitive data patterns - // Email addresses - content = maskPattern(content, `[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`, "***@***.***") - // Phone numbers - content = maskPattern(content, `\b\d{3}[-.]?\d{3}[-.]?\d{4}\b`, "***-***-****") - // Credit card numbers - content = maskPattern(content, `\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b`, "****-****-****-****") - - return content + // Limit length + if len(content) > 500 { + content = content[:500] + "..." + } + + // Basic sanitization - mask potential sensitive data patterns + // Email addresses + content = maskPattern(content, `[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`, "***@***.***") + // Phone numbers + content = maskPattern(content, `\b\d{3}[-.]?\d{3}[-.]?\d{4}\b`, "***-***-****") + // Credit card numbers + content = maskPattern(content, `\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b`, "****-****-****-****") + + return content } // maskPattern replaces pattern matches with mask func maskPattern(text, pattern, mask string) string { - // Simple implementation - in production use regex - // This is a placeholder for demonstration - return text + // Simple implementation - in production use regex + // This is a placeholder for demonstration + return text } // GetDashboardStats retrieves security dashboard statistics func GetDashboardStats(startTime, endTime time.Time) (map[string]interface{}, error) { - stats, err := model.GetViolationStatsByDateRange(startTime, endTime) - if err != nil { - return nil, err - } - - // Add today's count - today := time.Now().Truncate(24 * time.Hour) - todayEnd := today.Add(24 * time.Hour) - todayStats, err := model.GetViolationStatsByDateRange(today, todayEnd) - if err == nil { - stats["today_count"] = todayStats["total_count"] - } - - return stats, nil + stats, err := model.GetViolationStatsByDateRange(startTime, endTime) + if err != nil { + return nil, err + } + + // Add today's count + today := time.Now().Truncate(24 * time.Hour) + todayEnd := today.Add(24 * time.Hour) + todayStats, err := model.GetViolationStatsByDateRange(today, todayEnd) + if err == nil { + stats["today_count"] = todayStats["total_count"] + } + + // Add device clusters + topDevices, err := model.GetTopSuspiciousDevices(10) + if err == nil { + stats["device_clusters"] = topDevices + } + + // Add suspicious IPs + topIPs, err := model.GetTopSuspiciousIPs(10) + if err == nil { + stats["suspicious_ips"] = topIPs + } + + // Add anomaly counts by severity + anomalyCounts, err := model.GetAnomalyCountsBySeverity(startTime, endTime) + if err == nil { + stats["anomaly_counts"] = anomalyCounts + } + + // Add anomaly trends + anomalyTrends, err := model.GetAnomalyTrends(startTime, endTime) + if err == nil { + stats["anomaly_trends"] = anomalyTrends + } + + return stats, nil } // BanUser bans a user from making requests func BanUser(userId int) error { - return model.BanUser(userId) + return model.BanUser(userId) } // UnbanUser removes ban from a user func UnbanUser(userId int) error { - return model.UnbanUser(userId) + return model.UnbanUser(userId) } // SetUserRedirect sets model redirect for a user func SetUserRedirect(userId int, targetModel string) error { - if targetModel == "" { - return fmt.Errorf("target model cannot be empty") - } - return model.SetUserRedirect(userId, targetModel) + if targetModel == "" { + return fmt.Errorf("target model cannot be empty") + } + return model.SetUserRedirect(userId, targetModel) } // ClearUserRedirect removes model redirect for a user func ClearUserRedirect(userId int) error { - return model.ClearUserRedirect(userId) + return model.ClearUserRedirect(userId) } // GetUserSecurity retrieves security status for a user func GetUserSecurity(userId int) (*model.UserSecurity, error) { - // Try Redis first - if userSec, found := model.GetUserSecurityFromRedis(userId); found { - return userSec, nil - } + // Try Redis first + if userSec, found := model.GetUserSecurityFromRedis(userId); found { + return userSec, nil + } - // Fall back to database - return model.GetUserSecurity(userId) + // Fall back to database + return model.GetUserSecurity(userId) } // CheckUserBanned checks if a user is banned func CheckUserBanned(userId int) (bool, error) { - userSec, err := GetUserSecurity(userId) - if err != nil { - return false, err - } - return userSec.IsBanned, nil + userSec, err := GetUserSecurity(userId) + if err != nil { + return false, err + } + return userSec.IsBanned, nil } // GetUserRedirectModel gets the redirect model for a user func GetUserRedirectModel(userId int) (string, error) { - userSec, err := GetUserSecurity(userId) - if err != nil { - return "", err - } - return userSec.RedirectModel, nil + userSec, err := GetUserSecurity(userId) + if err != nil { + return "", err + } + return userSec.RedirectModel, nil } // GetViolationRedirectModel gets global violation redirect model from options func GetViolationRedirectModel() string { - // Get from system options - model := model.GetOptionValue(common.OptionViolationRedirectModel) - if model == "" { - // Fall back to governance config - model = common.GovernanceViolationFallbackAlias - } - return model + // Get from system options + model := model.GetOptionValue(common.OptionViolationRedirectModel) + if model == "" { + // Fall back to governance config + model = common.GovernanceViolationFallbackAlias + } + return model } // SetViolationRedirectModel sets global violation redirect model func SetViolationRedirectModel(targetModel string) error { - return model.UpdateOption(common.OptionViolationRedirectModel, targetModel) + return model.UpdateOption(common.OptionViolationRedirectModel, targetModel) } // GetSecuritySettings retrieves all security settings func GetSecuritySettings() map[string]interface{} { - settings := make(map[string]interface{}) + settings := make(map[string]interface{}) + + settings["violation_redirect_model"] = GetViolationRedirectModel() + settings["auto_ban_enabled"] = model.GetOptionValue(common.OptionAutobanEnabled) == "true" + + threshold := model.GetOptionValue(common.OptionAutobanThreshold) + if threshold == "" { + threshold = "10" + } + settings["auto_ban_threshold"] = threshold - settings["violation_redirect_model"] = GetViolationRedirectModel() - settings["auto_ban_enabled"] = model.GetOptionValue(common.OptionAutobanEnabled) == "true" - - threshold := model.GetOptionValue(common.OptionAutobanThreshold) - if threshold == "" { - threshold = "10" - } - settings["auto_ban_threshold"] = threshold + settings["auto_enforcement_enabled"] = model.GetOptionValue(common.OptionAutoEnforcementEnabled) == "true" + settings["auto_block_enabled"] = model.GetOptionValue(common.OptionAutoBlockEnabled) == "true" - return settings + return settings } // UpdateSecuritySettings updates security settings func UpdateSecuritySettings(settings map[string]interface{}) error { - if model, ok := settings["violation_redirect_model"].(string); ok { - if err := SetViolationRedirectModel(model); err != nil { - return err - } - } - - if enabled, ok := settings["auto_ban_enabled"].(bool); ok { - value := "false" - if enabled { - value = "true" - } - if err := model.UpdateOption(common.OptionAutobanEnabled, value); err != nil { - return err - } - } - - if threshold, ok := settings["auto_ban_threshold"].(float64); ok { - if err := model.UpdateOption(common.OptionAutobanThreshold, fmt.Sprintf("%.0f", threshold)); err != nil { - return err - } - } - - return nil + if model, ok := settings["violation_redirect_model"].(string); ok { + if err := SetViolationRedirectModel(model); err != nil { + return err + } + } + + if enabled, ok := settings["auto_ban_enabled"].(bool); ok { + value := "false" + if enabled { + value = "true" + } + if err := model.UpdateOption(common.OptionAutobanEnabled, value); err != nil { + return err + } + } + + if threshold, ok := settings["auto_ban_threshold"].(float64); ok { + if err := model.UpdateOption(common.OptionAutobanThreshold, fmt.Sprintf("%.0f", threshold)); err != nil { + return err + } + } + + if enabled, ok := settings["auto_enforcement_enabled"].(bool); ok { + value := "false" + if enabled { + value = "true" + } + if err := model.UpdateOption(common.OptionAutoEnforcementEnabled, value); err != nil { + return err + } + } + + if enabled, ok := settings["auto_block_enabled"].(bool); ok { + value := "false" + if enabled { + value = "true" + } + if err := model.UpdateOption(common.OptionAutoBlockEnabled, value); err != nil { + return err + } + } + + return nil } // CheckAutoban checks if user should be auto-banned based on violations func CheckAutoban(userId int) error { - enabled := model.GetOptionValue(common.OptionAutobanEnabled) == "true" - if !enabled { - return nil - } + enabled := model.GetOptionValue(common.OptionAutobanEnabled) == "true" + if !enabled { + return nil + } - thresholdStr := model.GetOptionValue(common.OptionAutobanThreshold) - if thresholdStr == "" { - return nil - } + thresholdStr := model.GetOptionValue(common.OptionAutobanThreshold) + if thresholdStr == "" { + return nil + } - threshold := 10 // default - fmt.Sscanf(thresholdStr, "%d", &threshold) + threshold := 10 // default + fmt.Sscanf(thresholdStr, "%d", &threshold) - userSec, err := GetUserSecurity(userId) - if err != nil { - return err - } + userSec, err := GetUserSecurity(userId) + if err != nil { + return err + } - if userSec.ViolationCount >= threshold && !userSec.IsBanned { - return BanUser(userId) - } + if userSec.ViolationCount >= threshold && !userSec.IsBanned { + return BanUser(userId) + } - return nil + return nil } diff --git a/service/security_enforcement.go b/service/security_enforcement.go new file mode 100644 index 00000000000..48657486644 --- /dev/null +++ b/service/security_enforcement.go @@ -0,0 +1,300 @@ +package service + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" +) + +const ( + AnomalyTypeHighRPM = "high_rpm" + AnomalyTypeSuspiciousIP = "suspicious_ip" + AnomalyTypeDeviceAnomaly = "device_anomaly" + AnomalyTypeContentViolation = "content_violation" + + ActionBlock = "block" + ActionRedirect = "redirect" + ActionBan = "ban" + ActionLog = "log" + + StatusPending = "pending" + StatusActioned = "actioned" + StatusApproved = "approved" + StatusIgnored = "ignored" +) + +type EnforcementAction struct { + UserId int + Action string + Reason string + AnomalyId int + Severity string + Metadata map[string]interface{} +} + +func CreateAnomaly(userId int, tokenId *int, anomalyType, severity, description string, metadata map[string]interface{}, ipAddress, deviceId string, riskScore int) (*model.SecurityAnomaly, error) { + metadataJSON := "" + if metadata != nil { + data, err := json.Marshal(metadata) + if err == nil { + metadataJSON = string(data) + } + } + + anomaly := &model.SecurityAnomaly{ + UserId: userId, + TokenId: tokenId, + AnomalyType: anomalyType, + Severity: severity, + Description: description, + Metadata: metadataJSON, + IpAddress: ipAddress, + DeviceId: deviceId, + RiskScore: riskScore, + Status: StatusPending, + } + + if err := model.CreateSecurityAnomaly(anomaly); err != nil { + return nil, err + } + + if severity == "malicious" && isAutoEnforcementEnabled() { + go func() { + if err := ProcessAnomaly(anomaly); err != nil { + common.SysLog(fmt.Sprintf("failed to process anomaly %d: %v", anomaly.Id, err)) + } + }() + } + + return anomaly, nil +} + +func ProcessAnomaly(anomaly *model.SecurityAnomaly) error { + if anomaly.Status != StatusPending { + return nil + } + + action := determineAction(anomaly) + if action == ActionLog { + return nil + } + + if err := executeAction(anomaly.UserId, action, anomaly); err != nil { + return err + } + + now := time.Now() + anomaly.ActionTaken = action + anomaly.ActionedAt = &now + anomaly.Status = StatusActioned + + return model.UpdateSecurityAnomaly(anomaly) +} + +func determineAction(anomaly *model.SecurityAnomaly) string { + settings := GetSecuritySettings() + + if anomaly.Severity == "malicious" { + if autoBanEnabled, ok := settings["auto_ban_enabled"].(bool); ok && autoBanEnabled { + return ActionBan + } + if blockEnabled, ok := settings["auto_block_enabled"].(bool); ok && blockEnabled { + return ActionBlock + } + } + + if anomaly.RiskScore > 80 { + return ActionBan + } else if anomaly.RiskScore > 50 { + return ActionBlock + } else if anomaly.RiskScore > 30 { + return ActionRedirect + } + + return ActionLog +} + +func executeAction(userId int, action string, anomaly *model.SecurityAnomaly) error { + switch action { + case ActionBan: + if err := BanUser(userId); err != nil { + return fmt.Errorf("failed to ban user: %w", err) + } + publishHeimdallDirective(userId, "ban", anomaly) + sendNotification(userId, "banned", anomaly) + + case ActionBlock: + publishHeimdallDirective(userId, "block", anomaly) + sendNotification(userId, "blocked", anomaly) + + case ActionRedirect: + redirectModel := GetViolationRedirectModel() + if redirectModel != "" { + if err := SetUserRedirect(userId, redirectModel); err != nil { + return fmt.Errorf("failed to set redirect: %w", err) + } + } + publishHeimdallDirective(userId, "redirect", anomaly) + sendNotification(userId, "redirected", anomaly) + } + + return nil +} + +func publishHeimdallDirective(userId int, action string, anomaly *model.SecurityAnomaly) { + if !common.RedisEnabled { + return + } + + directive := map[string]interface{}{ + "user_id": userId, + "action": action, + "anomaly_id": anomaly.Id, + "severity": anomaly.Severity, + "timestamp": time.Now().Unix(), + "description": anomaly.Description, + } + + data, err := json.Marshal(directive) + if err != nil { + common.SysLog(fmt.Sprintf("failed to marshal heimdall directive: %v", err)) + return + } + + key := fmt.Sprintf("heimdall:directive:%d", userId) + common.RedisSet(key, string(data), 3600) + + channel := "heimdall:directives" + if err := common.RedisPublish(channel, string(data)); err != nil { + common.SysLog(fmt.Sprintf("failed to publish heimdall directive: %v", err)) + } +} + +func sendNotification(userId int, actionType string, anomaly *model.SecurityAnomaly) { + user, err := model.GetUserById(userId, false) + if err != nil { + return + } + + message := fmt.Sprintf("Security action taken: %s due to %s (severity: %s)", + actionType, anomaly.AnomalyType, anomaly.Severity) + + notificationSettings := GetUserNotificationSettings(userId) + if notificationSettings != nil && notificationSettings.SecurityAlerts { + common.SysLog(fmt.Sprintf("notification sent to user %s: %s", user.Username, message)) + } +} + +func ApproveAnomaly(anomalyId int, reviewerId int, rationale string) error { + anomaly, err := model.GetSecurityAnomaly(anomalyId) + if err != nil { + return err + } + + now := time.Now() + anomaly.Status = StatusApproved + anomaly.ReviewedBy = &reviewerId + anomaly.ReviewedAt = &now + anomaly.ReviewDecision = "approved" + anomaly.ReviewRationale = rationale + + return model.UpdateSecurityAnomaly(anomaly) +} + +func IgnoreAnomaly(anomalyId int, reviewerId int, rationale string) error { + anomaly, err := model.GetSecurityAnomaly(anomalyId) + if err != nil { + return err + } + + now := time.Now() + anomaly.Status = StatusIgnored + anomaly.ReviewedBy = &reviewerId + anomaly.ReviewedAt = &now + anomaly.ReviewDecision = "ignored" + anomaly.ReviewRationale = rationale + + if anomaly.ActionTaken != "" && anomaly.ActionTaken != ActionLog { + if err := rollbackAction(anomaly); err != nil { + return fmt.Errorf("failed to rollback action: %w", err) + } + } + + return model.UpdateSecurityAnomaly(anomaly) +} + +func rollbackAction(anomaly *model.SecurityAnomaly) error { + switch anomaly.ActionTaken { + case ActionBan: + return UnbanUser(anomaly.UserId) + case ActionRedirect: + return ClearUserRedirect(anomaly.UserId) + } + return nil +} + +func isAutoEnforcementEnabled() bool { + settings := GetSecuritySettings() + if enabled, ok := settings["auto_enforcement_enabled"].(bool); ok { + return enabled + } + return true +} + +type NotificationSettings struct { + SecurityAlerts bool +} + +func GetUserNotificationSettings(userId int) *NotificationSettings { + return &NotificationSettings{ + SecurityAlerts: true, + } +} + +func TrackDeviceFingerprint(userId int, fingerprint, userAgent, ipAddress string) error { + device, err := model.GetDeviceFingerprint(fingerprint, userId) + if err != nil { + now := time.Now() + device = &model.DeviceFingerprint{ + UserId: userId, + Fingerprint: fingerprint, + UserAgent: userAgent, + IpAddress: ipAddress, + FirstSeenAt: now, + LastSeenAt: now, + RequestCount: 1, + RiskScore: 0, + } + return model.CreateDeviceFingerprint(device) + } + + device.LastSeenAt = time.Now() + device.RequestCount++ + device.IpAddress = ipAddress + + return model.UpdateDeviceFingerprint(device) +} + +func TrackIPCluster(ipAddress string, userId int) error { + cluster, err := model.GetIPCluster(ipAddress) + if err != nil { + now := time.Now() + cluster = &model.IPCluster{ + IpAddress: ipAddress, + UniqueUsers: 1, + TotalRequests: 1, + RiskScore: 0, + FirstSeenAt: now, + LastSeenAt: now, + } + return model.CreateIPCluster(cluster) + } + + cluster.LastSeenAt = time.Now() + cluster.TotalRequests++ + + return model.UpdateIPCluster(cluster) +} diff --git a/service/security_enforcement_test.go b/service/security_enforcement_test.go new file mode 100644 index 00000000000..cfc37a510cd --- /dev/null +++ b/service/security_enforcement_test.go @@ -0,0 +1,191 @@ +package service + +import ( + "testing" + "time" + + "github.com/QuantumNous/new-api/model" + "github.com/stretchr/testify/assert" +) + +func TestCreateAnomaly(t *testing.T) { + anomaly, err := CreateAnomaly( + 1, + nil, + AnomalyTypeHighRPM, + "malicious", + "High request rate detected", + map[string]interface{}{ + "rpm": 120, + }, + "192.168.1.1", + "device-fp-123", + 75, + ) + + assert.NoError(t, err) + assert.NotNil(t, anomaly) + assert.Equal(t, 1, anomaly.UserId) + assert.Equal(t, "malicious", anomaly.Severity) + assert.Equal(t, StatusPending, anomaly.Status) +} + +func TestDetermineAction(t *testing.T) { + tests := []struct { + name string + anomaly *model.SecurityAnomaly + expected string + }{ + { + name: "High risk score triggers ban", + anomaly: &model.SecurityAnomaly{ + Severity: "malicious", + RiskScore: 85, + }, + expected: ActionBan, + }, + { + name: "Medium risk score triggers block", + anomaly: &model.SecurityAnomaly{ + Severity: "violation", + RiskScore: 60, + }, + expected: ActionBlock, + }, + { + name: "Low risk score triggers redirect", + anomaly: &model.SecurityAnomaly{ + Severity: "violation", + RiskScore: 40, + }, + expected: ActionRedirect, + }, + { + name: "Very low risk score only logs", + anomaly: &model.SecurityAnomaly{ + Severity: "violation", + RiskScore: 20, + }, + expected: ActionLog, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + action := determineAction(tt.anomaly) + assert.Equal(t, tt.expected, action) + }) + } +} + +func TestTrackDeviceFingerprint(t *testing.T) { + err := TrackDeviceFingerprint( + 1, + "device-fp-test", + "Mozilla/5.0", + "192.168.1.1", + ) + + assert.NoError(t, err) + + device, err := model.GetDeviceFingerprint("device-fp-test", 1) + assert.NoError(t, err) + assert.Equal(t, "device-fp-test", device.Fingerprint) + assert.Equal(t, 1, device.UserId) +} + +func TestTrackIPCluster(t *testing.T) { + err := TrackIPCluster("192.168.1.1", 1) + assert.NoError(t, err) + + cluster, err := model.GetIPCluster("192.168.1.1") + assert.NoError(t, err) + assert.Equal(t, "192.168.1.1", cluster.IpAddress) + assert.GreaterOrEqual(t, cluster.TotalRequests, 1) +} + +func TestApproveAndIgnoreAnomaly(t *testing.T) { + anomaly := &model.SecurityAnomaly{ + UserId: 1, + AnomalyType: AnomalyTypeHighRPM, + Severity: "malicious", + Description: "Test anomaly", + Status: StatusPending, + } + + err := model.CreateSecurityAnomaly(anomaly) + assert.NoError(t, err) + + err = ApproveAnomaly(anomaly.Id, 100, "Reviewed and approved") + assert.NoError(t, err) + + updated, err := model.GetSecurityAnomaly(anomaly.Id) + assert.NoError(t, err) + assert.Equal(t, StatusApproved, updated.Status) + assert.Equal(t, "approved", updated.ReviewDecision) + assert.NotNil(t, updated.ReviewedAt) + assert.NotNil(t, updated.ReviewedBy) + + anomaly2 := &model.SecurityAnomaly{ + UserId: 1, + AnomalyType: AnomalyTypeHighRPM, + Severity: "violation", + Description: "Test anomaly 2", + Status: StatusPending, + } + + err = model.CreateSecurityAnomaly(anomaly2) + assert.NoError(t, err) + + err = IgnoreAnomaly(anomaly2.Id, 100, "False positive") + assert.NoError(t, err) + + updated2, err := model.GetSecurityAnomaly(anomaly2.Id) + assert.NoError(t, err) + assert.Equal(t, StatusIgnored, updated2.Status) + assert.Equal(t, "ignored", updated2.ReviewDecision) + assert.Equal(t, "False positive", updated2.ReviewRationale) +} + +func TestProcessAnomalyWithAutoEnforcement(t *testing.T) { + anomaly := &model.SecurityAnomaly{ + UserId: 1, + AnomalyType: AnomalyTypeHighRPM, + Severity: "malicious", + Description: "Auto enforcement test", + Status: StatusPending, + RiskScore: 90, + } + + err := model.CreateSecurityAnomaly(anomaly) + assert.NoError(t, err) + + err = ProcessAnomaly(anomaly) + assert.NoError(t, err) + + updated, err := model.GetSecurityAnomaly(anomaly.Id) + assert.NoError(t, err) + + if updated.Status == StatusActioned { + assert.NotEmpty(t, updated.ActionTaken) + assert.NotNil(t, updated.ActionedAt) + } +} + +func TestGetAnomalyTrends(t *testing.T) { + startTime := time.Now().AddDate(0, 0, -7) + endTime := time.Now() + + trends, err := model.GetAnomalyTrends(startTime, endTime) + assert.NoError(t, err) + assert.NotNil(t, trends) +} + +func TestGetAnomalyCountsBySeverity(t *testing.T) { + startTime := time.Now().AddDate(0, 0, -7) + endTime := time.Now() + + counts, err := model.GetAnomalyCountsBySeverity(startTime, endTime) + assert.NoError(t, err) + assert.NotNil(t, counts) +}