From db1c1b71f59236cb5987df9a88df548fbb4a0758 Mon Sep 17 00:00:00 2001 From: Jarod Date: Tue, 5 Aug 2025 23:31:21 +0800 Subject: [PATCH 1/5] feat: add a SPNEGO Kerberos Authentication Middleware for Fiber v2 and v3 --- spnego/README.md | 197 ++++++++++++++++++++++++++++++++++++++ spnego/README.zh-CN.md | 198 +++++++++++++++++++++++++++++++++++++++ spnego/config/config.go | 61 ++++++++++++ spnego/doc.go | 51 ++++++++++ spnego/example.go | 52 ++++++++++ spnego/go.mod | 41 ++++++++ spnego/go.sum | 116 +++++++++++++++++++++++ spnego/v2/spnego.go | 105 +++++++++++++++++++++ spnego/v2/spnego_test.go | 162 ++++++++++++++++++++++++++++++++ spnego/v3/spnego.go | 104 ++++++++++++++++++++ spnego/v3/spnego_test.go | 162 ++++++++++++++++++++++++++++++++ 11 files changed, 1249 insertions(+) create mode 100644 spnego/README.md create mode 100644 spnego/README.zh-CN.md create mode 100644 spnego/config/config.go create mode 100644 spnego/doc.go create mode 100644 spnego/example.go create mode 100644 spnego/go.mod create mode 100644 spnego/go.sum create mode 100644 spnego/v2/spnego.go create mode 100644 spnego/v2/spnego_test.go create mode 100644 spnego/v3/spnego.go create mode 100644 spnego/v3/spnego_test.go diff --git a/spnego/README.md b/spnego/README.md new file mode 100644 index 000000000..bd600874a --- /dev/null +++ b/spnego/README.md @@ -0,0 +1,197 @@ +# SPNEGO Kerberos Authentication Middleware for Fiber + +[中文版本](README.zh-CN.md) + +This middleware provides SPNEGO (Simple and Protected GSSAPI Negotiation Mechanism) authentication for Fiber applications, enabling Kerberos authentication for HTTP requests. + +## Features + +- Kerberos authentication via SPNEGO mechanism +- Flexible keytab lookup system +- Support for dynamic keytab retrieval from various sources +- Integration with Fiber context for authenticated identity storage +- Configurable logging + +## Version Compatibility + +This middleware is available in two versions to support different Fiber releases: + +- **v2**: Compatible with Fiber v2 +- **v3**: Compatible with Fiber v3 + +## Installation + +```bash +# For Fiber v3 +go get github.com/gofiber/contrib/spnego/v3 + +# For Fiber v2 +go get github.com/gofiber/contrib/spnego/v2 +``` + +## Usage + +### For Fiber v3 + +```go +package main + +import ( + flog "github.com/gofiber/fiber/v3/log" + "fmt" + + "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/gofiber/fiber/v3" + "github.com/gofiber/contrib/spnego/v3" +) + +func main() { + app := fiber.New() + + // Create a configuration with a keytab lookup function + cfg := &spnego.Config{ + // Use a function to look up keytab from files + KeytabLookup: func() (*keytab.Keytab, error) { + // Implement your keytab lookup logic here + // This could be from files, database, or other sources + kt, err := spnego.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") + if err != nil { + return nil, err + } + return kt() + }, + // Optional: Set a custom logger + Log: flog.DefaultLogger().Logger().(*log.Logger), + } + + // Create the middleware + authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(cfg) + if err != nil { + flog.Fatalf("Failed to create middleware: %v", err) + } + + // Apply the middleware to protected routes + app.Use("/protected", authMiddleware) + + // Access authenticated identity + app.Get("/protected/resource", func(c fiber.Ctx) error { + identity, ok := v3.GetAuthenticatedIdentityFromContext(c) + if !ok { + return c.Status(fiber.StatusUnauthorized).SendString("Unauthorized") + } + return c.SendString(fmt.Sprintf("Hello, %s!", identity.UserName())) + }) + + app.Listen(":3000") +} +``` + +### For Fiber v2 + +```go +package main + +import ( + "fmt" + "log" + "os" + + "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/contrib/spnego/v2" +) + +func main() { + app := fiber.New() + + // Create a configuration with a keytab lookup function + cfg := &spnego.Config{ + // Use a function to look up keytab from files + KeytabLookup: func() (*keytab.Keytab, error) { + // Implement your keytab lookup logic here + // This could be from files, database, or other sources + kt, err := spnego.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") + if err != nil { + return nil, err + } + return kt() + }, + // Optional: Set a custom logger + Log: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), + } + + // Create the middleware + authMiddleware, err := v2.NewSpnegoKrb5AuthenticateMiddleware(cfg) + if err != nil { + log.Fatalf("Failed to create middleware: %v", err) + } + + // Apply the middleware to protected routes + app.Use("/protected", authMiddleware) + + // Access authenticated identity + app.Get("/protected/resource", func(c *fiber.Ctx) error { + identity, ok := v2.GetAuthenticatedIdentityFromContext(c) + if !ok { + return c.Status(fiber.StatusUnauthorized).SendString("Unauthorized") + } + return c.SendString(fmt.Sprintf("Hello, %s!", identity.UserName())) + }) + + app.Listen(":3000") +} +``` + +## Dynamic Keytab Lookup + +The middleware is designed with extensibility in mind, allowing keytab retrieval from various sources beyond static files: + +```go +// Example: Retrieve keytab from a database +func dbKeytabLookup() (*keytab.Keytab, error) { + // Your database lookup logic here + // ... + return keytabFromDatabase, nil +} + +// Example: Retrieve keytab from a remote service +func remoteKeytabLookup() (*keytab.Keytab, error) { + // Your remote service call logic here + // ... + return keytabFromRemote, nil +} +``` + +## API Reference + +### `NewSpnegoKrb5AuthenticateMiddleware(cfg *Config) (fiber.Handler, error)` + +Creates a new SPNEGO authentication middleware. + +### `GetAuthenticatedIdentityFromContext(ctx fiber.Ctx) (goidentity.Identity, bool)` + +Retrieves the authenticated identity from the Fiber context. + +### `NewKeytabFileLookupFunc(keytabFiles ...string) (KeytabLookupFunc, error)` + +Creates a new KeytabLookupFunc that loads keytab files. + +## Configuration + +The `Config` struct supports the following fields: + +- `KeytabLookup`: A function that retrieves the keytab (required) +- `Log`: The logger used for middleware logging (optional, defaults to Fiber's default logger) + +## Requirements + +- Go 1.21 or higher +- For v3: Fiber v3 +- For v2: Fiber v2 +- Kerberos infrastructure + +## Notes + +- Ensure your Kerberos infrastructure is properly configured +- The middleware handles the SPNEGO negotiation process +- Authenticated identities are stored in the Fiber context using `config.ContextKeyOfIdentity` diff --git a/spnego/README.zh-CN.md b/spnego/README.zh-CN.md new file mode 100644 index 000000000..0693ccae4 --- /dev/null +++ b/spnego/README.zh-CN.md @@ -0,0 +1,198 @@ +# SPNEGO Kerberos 认证中间件 for Fiber + +[English Version](README.md) + +该中间件为Fiber应用提供SPNEGO(简单受保护GSSAPI协商机制)认证,使HTTP请求能够使用Kerberos认证。 + +## 功能特点 + +- 通过SPNEGO机制实现Kerberos认证 +- 灵活的keytab查找系统 +- 支持从各种来源动态检索keytab +- 与Fiber上下文集成用于存储认证身份 +- 可配置日志 + +## 版本兼容性 + +该中间件提供两个版本以支持不同的Fiber版本: + +- **v2**:兼容Fiber v2 +- **v3**:兼容Fiber v3 + +## 安装 + +```bash +# 对于Fiber v3 + go get github.com/gofiber/contrib/spnego/v3 + +# 对于Fiber v2 + go get github.com/gofiber/contrib/spnego/v2 +``` + +## 使用方法 + +### 对于Fiber v3 + +```go +package main + +import ( + flog "github.com/gofiber/fiber/v3/log" + "fmt" + "log" + + "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/gofiber/fiber/v3" + "github.com/gofiber/contrib/spnego/v3" +) + +func main() { + app := fiber.New() + + // 创建带有keytab查找函数的配置 + cfg := &v3.Config{ + // 使用函数从文件查找keytab + KeytabLookup: func() (*keytab.Keytab, error) { + // 在此实现您的keytab查找逻辑 + // 可以从文件、数据库或其他来源获取 + kt, err := v3.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") + if err != nil { + return nil, err + } + return kt() + }, + // 可选:设置自定义日志器 + Log: flog.DefaultLogger().Logger().(*log.Logger), + } + + // 创建中间件 + authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(cfg) + if err != nil { + flog.Fatalf("创建中间件失败: %v", err) + } + + // 将中间件应用于受保护的路由 + app.Use("/protected", authMiddleware) + + // 访问认证身份 + app.Get("/protected/resource", func(c fiber.Ctx) error { + identity, ok := v3.GetAuthenticatedIdentityFromContext(c) + if !ok { + return c.Status(fiber.StatusUnauthorized).SendString("未授权") + } + return c.SendString(fmt.Sprintf("你好, %s!", identity.UserName())) + }) + + app.Listen(":3000") +} +``` + +### 对于Fiber v2 + +```go +package main + +import ( + "fmt" + "log" + "os" + + "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/contrib/spnego/v2" +) + +func main() { + app := fiber.New() + + // 创建带有keytab查找函数的配置 + cfg := &v2.Config{ + // 使用函数从文件查找keytab + KeytabLookup: func() (*keytab.Keytab, error) { + // 在此实现您的keytab查找逻辑 + // 可以从文件、数据库或其他来源获取 + kt, err := v2.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") + if err != nil { + return nil, err + } + return kt() + }, + // 可选:设置自定义日志器 + Log: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), + } + + // 创建中间件 + authMiddleware, err := v2.NewSpnegoKrb5AuthenticateMiddleware(cfg) + if err != nil { + log.Fatalf("创建中间件失败: %v", err) + } + + // 将中间件应用于受保护的路由 + app.Use("/protected", authMiddleware) + + // 访问认证身份 + app.Get("/protected/resource", func(c *fiber.Ctx) error { + identity, ok := v2.GetAuthenticatedIdentityFromContext(c) + if !ok { + return c.Status(fiber.StatusUnauthorized).SendString("未授权") + } + return c.SendString(fmt.Sprintf("你好, %s!", identity.UserName())) + }) + + app.Listen(":3000") +} +``` + +## 动态Keytab查找 + +该中间件设计具有可扩展性,允许从静态文件以外的各种来源检索keytab: + +```go +// 示例:从数据库检索keytab +func dbKeytabLookup() (*keytab.Keytab, error) { + // 此处实现数据库查找逻辑 + // ... + return keytabFromDatabase, nil +} + +// 示例:从远程服务检索keytab +func remoteKeytabLookup() (*keytab.Keytab, error) { + // 此处实现远程服务调用逻辑 + // ... + return keytabFromRemote, nil +} +``` + +## API 参考 + +### `NewSpnegoKrb5AuthenticateMiddleware(cfg *Config) (fiber.Handler, error)` + +创建一个新的SPNEGO认证中间件。 + +### `GetAuthenticatedIdentityFromContext(ctx fiber.Ctx) (goidentity.Identity, bool)` + +从Fiber上下文中检索已认证的身份。 + +### `NewKeytabFileLookupFunc(keytabFiles ...string) (KeytabLookupFunc, error)` + +创建一个加载keytab文件的KeytabLookupFunc。 + +## 配置 + +`Config`结构体支持以下字段: + +- `KeytabLookup`: 检索keytab的函数(必需) +- `Log`: 用于中间件日志记录的日志器(可选,默认为Fiber的默认日志器) + +## 要求 + +- Go 1.21或更高版本 +- 对于v3:Fiber v3 +- 对于v2:Fiber v2 +- Kerberos基础设施 + +## 注意事项 + +- 确保您的Kerberos基础设施已正确配置 +- 中间件处理SPNEGO协商过程 +- 已认证的身份使用`contextKeyOfIdentity`存储在Fiber上下文中 diff --git a/spnego/config/config.go b/spnego/config/config.go new file mode 100644 index 000000000..cb213de76 --- /dev/null +++ b/spnego/config/config.go @@ -0,0 +1,61 @@ +package config + +import ( + "errors" + "fmt" + "log" + + "github.com/jcmturner/gokrb5/v8/keytab" +) + +// ErrConfigInvalidOfKeytabLookupFunctionRequired is returned when the KeytabLookup function is not set in Config +var ErrConfigInvalidOfKeytabLookupFunctionRequired = errors.New("config invalid: keytab lookup function is required") + +// ErrLookupKeytabFailed is returned when the keytab lookup fails +var ErrLookupKeytabFailed = errors.New("keytab lookup failed") + +// ErrConvertRequestFailed is returned when the request conversion to HTTP request fails +var ErrConvertRequestFailed = errors.New("convert request failed") + +// ErrConfigInvalidOfAtLeastOneKeytabFileRequired is returned when no keytab files are provided +var ( + ErrConfigInvalidOfAtLeastOneKeytabFileRequired = errors.New("config invalid: at least one keytab file required") + ErrLoadKeytabFileFailed = errors.New("load keytab failed") +) + +// ContextKeyOfIdentity is the key used to store the authenticated identity in the Fiber context +const ContextKeyOfIdentity = "middleware.spnego.Identity" + +// KeytabLookupFunc is a function type that returns a keytab or an error +// It's used to look up the keytab dynamically when needed +// This design allows for extensibility, enabling keytab retrieval from various sources +// such as databases, remote services, or other custom implementations beyond static files +type KeytabLookupFunc func() (*keytab.Keytab, error) + +// Config holds the configuration for the SPNEGO middleware +// It includes the keytab lookup function and a logger +type Config struct { + // KeytabLookup is a function that retrieves the keytab + KeytabLookup KeytabLookupFunc + // Log is the logger used for middleware logging + Log *log.Logger +} + +// NewKeytabFileLookupFunc creates a new KeytabLookupFunc that loads keytab files +// It accepts one or more keytab file paths and returns a function that loads them +func NewKeytabFileLookupFunc(keytabFiles ...string) (KeytabLookupFunc, error) { + if len(keytabFiles) == 0 { + return nil, ErrConfigInvalidOfAtLeastOneKeytabFileRequired + } + var mergeKeytab keytab.Keytab + for _, keytabFile := range keytabFiles { + kt, err := keytab.Load(keytabFile) + if err != nil { + return nil, fmt.Errorf("%w: file %s load failed: %w", ErrLoadKeytabFileFailed, keytabFile, err) + } + mergeKeytab.Entries = append(mergeKeytab.Entries, kt.Entries...) + } + return func() (*keytab.Keytab, error) { + return &mergeKeytab, nil + }, nil +} diff --git a/spnego/doc.go b/spnego/doc.go new file mode 100644 index 000000000..4e3fa55d3 --- /dev/null +++ b/spnego/doc.go @@ -0,0 +1,51 @@ +// Package spnego provides SPNEGO (Simple and Protected GSSAPI Negotiation Mechanism) +// Package spnego provides SPNEGO (Simple and Protected GSSAPI Negotiation Mechanism) +// authentication middleware for Fiber applications. It enables Kerberos authentication +// for HTTP requests, allowing seamless integration with Active Directory and other +// Kerberos-based authentication systems. +// +// Version Compatibility: +// - v2 package: Compatible with Fiber v2 +// - v3 package: Compatible with Fiber v3 +// +// Example Usage: +// +// import ( +// "fmt" +// "github.com/gofiber/contrib/spnego/config" +// v3 "github.com/gofiber/contrib/spnego/v3" +// "github.com/gofiber/fiber/v3" +// ) +// +// func main() { +// app := fiber.New() +// +// // Create keytab lookup function +// keytabLookup, err := config.NewKeytabFileLookupFunc("/path/to/keytab.keytab") +// if err != nil { +// panic(err) +// } +// +// // Create SPNEGO middleware +// authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ +// KeytabLookup: keytabLookup, +// }) +// if err != nil { +// panic(err) +// } +// +// // Apply middleware to protected routes +// app.Use("/protected", authMiddleware) +// +// // Access authenticated identity +// app.Get("/protected/resource", func(c fiber.Ctx) error { +// identity, ok := v3.GetAuthenticatedIdentityFromContext(c) +// if !ok { +// return c.Status(fiber.StatusUnauthorized).SendString("Unauthorized") +// } +// return c.SendString(fmt.Sprintf("Hello, %s!", identity.UserName())) +// }) +// +// app.Listen(":3000") +// } +package spnego diff --git a/spnego/example.go b/spnego/example.go new file mode 100644 index 000000000..0d262f1c0 --- /dev/null +++ b/spnego/example.go @@ -0,0 +1,52 @@ +package spnego + +import ( + "fmt" + "time" + + "github.com/gofiber/contrib/spnego/config" + v3 "github.com/gofiber/contrib/spnego/v3" + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" +) + +func ExampleNewSpnegoKrb5AuthenticateMiddleware() { + app := fiber.New() + keytabLookup, err := config.NewKeytabFileLookupFunc("/keytabFile/one.keytab", "/keytabFile/two.keyta") + if err != nil { + panic(fmt.Errorf("create keytab lookup function failed: %w", err)) + } + authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + KeytabLookup: keytabLookup, + }) + if err != nil { + panic(fmt.Errorf("create spnego middleware failed: %w", err)) + } + // Apply the middleware to protected routes + app.Use("/protected", authMiddleware) + + // Access authenticated identity + app.Get("/protected/resource", func(c fiber.Ctx) error { + identity, ok := v3.GetAuthenticatedIdentityFromContext(c) + if !ok { + return c.Status(fiber.StatusUnauthorized).SendString("Unauthorized") + } + return c.SendString(fmt.Sprintf("Hello, %s!", identity.UserName())) + }) + log.Info("Server is running on :3000") + go func() { + <-time.After(time.Second * 1) + fmt.Println("use curl -kv --negotiate http://sso.example.local:3000/protected/resource") + fmt.Println("if response is 401, execute `klist` to check use kerberos session") + <-time.After(time.Second * 2) + fmt.Println("close server") + if err = app.Shutdown(); err != nil { + panic(fmt.Errorf("shutdown server failed: %w", err)) + } + }() + if err := app.Listen("sso.example.local:3000"); err != nil { + panic(fmt.Errorf("start server failed: %w", err)) + } + + // Output: Server is running on :3000 +} diff --git a/spnego/go.mod b/spnego/go.mod new file mode 100644 index 000000000..3ef0c9086 --- /dev/null +++ b/spnego/go.mod @@ -0,0 +1,41 @@ +module github.com/gofiber/contrib/spnego + +go 1.24.0 + +toolchain go1.24.2 + +require ( + github.com/gofiber/fiber/v2 v2.52.9 + github.com/gofiber/fiber/v3 v3.0.0-beta.5 + github.com/jcmturner/goidentity/v6 v6.0.1 + github.com/jcmturner/gokrb5/v8 v8.4.4 + github.com/stretchr/testify v1.10.0 + github.com/valyala/fasthttp v1.64.0 +) + +require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gofiber/schema v1.6.0 // indirect + github.com/gofiber/utils/v2 v2.0.0-beta.13 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/hashicorp/go-uuid v1.0.3 // indirect + github.com/jcmturner/aescts/v2 v2.0.0 // indirect + github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect + github.com/jcmturner/gofork v1.7.6 // indirect + github.com/jcmturner/rpc/v2 v2.0.3 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/philhofer/fwd v1.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/tinylib/msgp v1.3.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/crypto v0.40.0 // indirect + golang.org/x/net v0.42.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/text v0.27.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/spnego/go.sum b/spnego/go.sum new file mode 100644 index 000000000..96e060f85 --- /dev/null +++ b/spnego/go.sum @@ -0,0 +1,116 @@ +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= +github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= +github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= +github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= +github.com/gofiber/fiber/v3 v3.0.0-beta.5 h1:MSGbiQZEYiYOqti2Ip2zMRkN4VvZw7Vo7dwZBa1Qjk8= +github.com/gofiber/fiber/v3 v3.0.0-beta.5/go.mod h1:XmI2Agulde26YcQrA2n8X499I1p98/zfCNbNObVUeP8= +github.com/gofiber/schema v1.6.0 h1:rAgVDFwhndtC+hgV7Vu5ItQCn7eC2mBA4Eu1/ZTiEYY= +github.com/gofiber/schema v1.6.0/go.mod h1:WNZWpQx8LlPSK7ZaX0OqOh+nQo/eW2OevsXs1VZfs/s= +github.com/gofiber/utils/v2 v2.0.0-beta.13 h1:dlpbGFLveQ9OduL2UHw4dtu4lXE+Gb3bHMc+8Yxp/dk= +github.com/gofiber/utils/v2 v2.0.0-beta.13/go.mod h1:qEZ175nSOkl5xciHmqxwNDsWzwiB39gB8RgU1d3U4mQ= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= +github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/shamaton/msgpack/v2 v2.2.3 h1:uDOHmxQySlvlUYfQwdjxyybAOzjlQsD1Vjy+4jmO9NM= +github.com/shamaton/msgpack/v2 v2.2.3/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= +github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og= +github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/spnego/v2/spnego.go b/spnego/v2/spnego.go new file mode 100644 index 000000000..22c30ae08 --- /dev/null +++ b/spnego/v2/spnego.go @@ -0,0 +1,105 @@ +// Package v2 provides SPNEGO authentication middleware for Fiber v2. +// This middleware enables Kerberos authentication for incoming requests +// using the SPNEGO protocol, allowing seamless integration with Active Directory +// and other Kerberos-based authentication systems. +package v2 + +import ( + "fmt" + "log" + "net/http" + "os" + + "github.com/gofiber/contrib/spnego/config" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" + "github.com/jcmturner/goidentity/v6" + "github.com/jcmturner/gokrb5/v8/service" + "github.com/jcmturner/gokrb5/v8/spnego" +) + +// NewSpnegoKrb5AuthenticateMiddleware creates a new SPNEGO authentication middleware. +// It takes a Config struct and returns a Fiber handler or an error. +// The middleware handles Kerberos authentication for incoming requests using the +// SPNEGO protocol, verifying client credentials against the configured keytab. +func NewSpnegoKrb5AuthenticateMiddleware(cfg *config.Config) (fiber.Handler, error) { + // Validate configuration + if cfg == nil { + cfg = &config.Config{} + } + if cfg.KeytabLookup == nil { + return nil, config.ErrConfigInvalidOfKeytabLookupFunctionRequired + } + // Set default logger if not provided + if cfg.Log == nil { + // Due to differences between Fiber v2 and v3 versions, internal log.Log cannot be obtained, so a new one is created in the same way + cfg.Log = log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds) + } + // Return the middleware handler + return func(ctx *fiber.Ctx) error { + // Look up the keytab + kt, err := cfg.KeytabLookup() + if err != nil { + return fmt.Errorf("%w: %w", config.ErrLookupKeytabFailed, err) + } + // Create the SPNEGO handler using the keytab + var handleErr error + handler := spnego.SPNEGOKRB5Authenticate(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + // Set the authenticated identity in the Fiber context + setAuthenticatedIdentityToContext(ctx, goidentity.FromHTTPRequestContext(r)) + // Call the next handler in the chain + handleErr = ctx.Next() + }), kt, service.Logger(cfg.Log)) + // Convert Fiber context to HTTP request + rawReq, err := adaptor.ConvertRequest(ctx, true) + if err != nil { + return fmt.Errorf("%w: %w", config.ErrConvertRequestFailed, err) + } + // Serve the request using the SPNEGO handler + handler.ServeHTTP(wrapCtx{ctx}, rawReq) + return handleErr + }, nil +} + +// setAuthenticatedIdentityToContext stores the authenticated identity in the Fiber context. +// It takes a Fiber context pointer and an identity, and sets it using the ContextKeyOfIdentity key +// for later retrieval by other handlers in the request chain. +func setAuthenticatedIdentityToContext(ctx *fiber.Ctx, identity goidentity.Identity) { + ctx.Locals(config.ContextKeyOfIdentity, identity) +} + +// GetAuthenticatedIdentityFromContext retrieves the authenticated identity from the Fiber context. +// It returns the identity and a boolean indicating if it was found. +// This function should be used by subsequent handlers to access the authenticated user's information. +// +// Example: +// +// user, ok := GetAuthenticatedIdentityFromContext(ctx) +// if ok { +// fmt.Printf("Authenticated user: %s\n", user.UserName()) +// } +func GetAuthenticatedIdentityFromContext(ctx *fiber.Ctx) (goidentity.Identity, bool) { + id, ok := ctx.Locals(config.ContextKeyOfIdentity).(goidentity.Identity) + return id, ok +} + +// wrapCtx wraps a Fiber context pointer to implement the http.ResponseWriter interface. +// This adapter allows the Fiber context to be used with standard HTTP handlers +// that expect an http.ResponseWriter, bridging the gap between Fiber's context +// model and the standard library's HTTP interfaces. + +type wrapCtx struct { + *fiber.Ctx +} + +// Header returns the request headers from the wrapped Fiber context. +// This method implements the http.ResponseWriter interface. +func (w wrapCtx) Header() http.Header { + return w.Ctx.GetReqHeaders() +} + +// WriteHeader sets the HTTP status code on the wrapped Fiber context. +// This method implements the http.ResponseWriter interface. +func (w wrapCtx) WriteHeader(statusCode int) { + w.Ctx.Status(statusCode) +} diff --git a/spnego/v2/spnego_test.go b/spnego/v2/spnego_test.go new file mode 100644 index 000000000..cc49a1681 --- /dev/null +++ b/spnego/v2/spnego_test.go @@ -0,0 +1,162 @@ +package v2 + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/gofiber/contrib/spnego/config" + "github.com/gofiber/fiber/v2" + "github.com/jcmturner/goidentity/v6" + "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { + t.Run("test for keytab lookup function not set", func(t *testing.T) { + _, err := NewSpnegoKrb5AuthenticateMiddleware(nil) + require.ErrorIs(t, err, config.ErrConfigInvalidOfKeytabLookupFunctionRequired) + }) + t.Run("test for keytab lookup failed", func(t *testing.T) { + middleware, err := NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + KeytabLookup: func() (*keytab.Keytab, error) { + return nil, errors.New("mock keytab lookup error") + }, + }) + require.NoError(t, err) + app := fiber.New() + app.Get("/authenticate", middleware, func(c *fiber.Ctx) error { + return c.SendString("authenticated") + }) + handler := app.Handler() + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/authenticate") + handler(ctx) + require.Equal(t, http.StatusInternalServerError, ctx.Response.StatusCode()) + require.Equal(t, fmt.Sprintf("%s: mock keytab lookup error", config.ErrLookupKeytabFailed), string(ctx.Response.Body())) + }) + t.Run("test for keytab lookup function is set", func(t *testing.T) { + var keytabFiles []string + for i := 0; i < 5; i++ { + kt, clean, err := newKeytabTempFile(fmt.Sprintf("HTTP/sso%d.example.com", i), "KRB5.TEST", 18, 19) + require.NoError(t, err) + t.Cleanup(clean) + keytabFiles = append(keytabFiles, kt) + } + lookupFunc, err := config.NewKeytabFileLookupFunc(keytabFiles...) + require.NoError(t, err) + middleware, err := NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + KeytabLookup: lookupFunc, + }) + require.NoError(t, err) + app := fiber.New() + app.Get("/authenticate", middleware, func(c *fiber.Ctx) error { + user, ok := GetAuthenticatedIdentityFromContext(c) + if ok { + t.Logf("username: %s\ndomain: %s\n", user.UserName(), user.Domain()) + } + return c.SendString("authenticated") + }) + handler := app.Handler() + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/authenticate") + handler(ctx) + require.Equal(t, fasthttp.StatusUnauthorized, ctx.Response.StatusCode()) + }) +} + +func TestNewKeytabFileLookupFunc(t *testing.T) { + t.Run("test for empty keytab files", func(t *testing.T) { + _, err := config.NewKeytabFileLookupFunc() + require.ErrorIs(t, err, config.ErrConfigInvalidOfAtLeastOneKeytabFileRequired) + }) + t.Run("test for has invalid keytab file", func(t *testing.T) { + kt1, clean, err := newKeytabTempFile("HTTP/sso.example.com", "KRB5.TEST", 18, 19) + require.NoError(t, err) + t.Cleanup(clean) + kt2, clean, err := newBadKeytabTempFile("HTTP/sso1.example.com", "KRB5.TEST", 18, 19) + require.NoError(t, err) + t.Cleanup(clean) + _, err = config.NewKeytabFileLookupFunc(kt1, kt2) + require.ErrorIs(t, err, config.ErrLoadKeytabFileFailed) + }) + t.Run("test for some keytab files", func(t *testing.T) { + var keytabFiles []string + for i := 0; i < 5; i++ { + kt, clean, err := newKeytabTempFile(fmt.Sprintf("HTTP/sso%d.example.com", i), "KRB5.TEST", 18, 19) + require.NoError(t, err) + t.Cleanup(clean) + keytabFiles = append(keytabFiles, kt) + } + lookupFunc, err := config.NewKeytabFileLookupFunc(keytabFiles...) + require.NoError(t, err) + _, err = lookupFunc() + require.NoError(t, err) + }) +} + +func newBadKeytabTempFile(principal string, realm string, et ...int32) (filename string, clean func(), err error) { + filename = fmt.Sprintf("./tmp_%d.keytab", time.Now().Unix()) + clean = func() { + os.Remove(filename) + } + var kt keytab.Keytab + for _, e := range et { + kt.AddEntry(principal, realm, "abcdefg", time.Now(), 2, e) + } + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) + if err != nil { + return filename, clean, fmt.Errorf("open file failed: %w", err) + } + if _, err = kt.Write(file); err != nil { + return filename, clean, fmt.Errorf("write file failed: %w", err) + } + file.Close() + return filename, clean, nil +} + +func newKeytabTempFile(principal string, realm string, et ...int32) (filename string, clean func(), err error) { + filename = fmt.Sprintf("./tmp_%d.keytab", time.Now().Unix()) + clean = func() { + os.Remove(filename) + } + kt := keytab.New() + for _, e := range et { + kt.AddEntry(principal, realm, "abcdefg", time.Now(), 2, e) + } + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) + if err != nil { + return filename, clean, fmt.Errorf("open file failed: %w", err) + } + if _, err = kt.Write(file); err != nil { + return filename, clean, fmt.Errorf("write file failed: %w", err) + } + file.Close() + return filename, clean, nil +} + +func TestGetAuthenticatedIdentityFromContext(t *testing.T) { + app := fiber.New() + app.Use("/testContext", func(ctx *fiber.Ctx) error { + user := goidentity.NewUser("test-user") + user.SetDomain("example.com") + _, ok := GetAuthenticatedIdentityFromContext(ctx) + require.False(t, ok) + setAuthenticatedIdentityToContext(ctx, &user) + user1, ok := GetAuthenticatedIdentityFromContext(ctx) + require.True(t, ok) + require.Equal(t, user.UserName(), user1.UserName()) + require.Equal(t, user.Domain(), user1.Domain()) + return ctx.SendStatus(200) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/testContext", nil)) + require.NoError(t, err) + require.Equal(t, resp.StatusCode, 200) +} diff --git a/spnego/v3/spnego.go b/spnego/v3/spnego.go new file mode 100644 index 000000000..aa1dca0f8 --- /dev/null +++ b/spnego/v3/spnego.go @@ -0,0 +1,104 @@ +// Package v3 provides SPNEGO authentication middleware for Fiber v3. +// This middleware enables Kerberos authentication for incoming requests +// using the SPNEGO protocol, allowing seamless integration with Active Directory +// and other Kerberos-based authentication systems. +package v3 + +import ( + "fmt" + "log" + "net/http" + + "github.com/gofiber/contrib/spnego/config" + "github.com/gofiber/fiber/v3" + flog "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/fiber/v3/middleware/adaptor" + "github.com/jcmturner/goidentity/v6" + "github.com/jcmturner/gokrb5/v8/service" + "github.com/jcmturner/gokrb5/v8/spnego" +) + +// NewSpnegoKrb5AuthenticateMiddleware creates a new SPNEGO authentication middleware. +// It takes a Config struct and returns a Fiber handler or an error. +// The middleware handles Kerberos authentication for incoming requests using the +// SPNEGO protocol, verifying client credentials against the configured keytab. +func NewSpnegoKrb5AuthenticateMiddleware(cfg *config.Config) (fiber.Handler, error) { + // Validate configuration + if cfg == nil { + cfg = &config.Config{} + } + if cfg.KeytabLookup == nil { + return nil, config.ErrConfigInvalidOfKeytabLookupFunctionRequired + } + // Set default logger if not provided + if cfg.Log == nil { + cfg.Log = flog.DefaultLogger().Logger().(*log.Logger) + } + // Return the middleware handler + return func(ctx fiber.Ctx) error { + // Look up the keytab + kt, err := cfg.KeytabLookup() + if err != nil { + return fmt.Errorf("%w: %w", config.ErrLookupKeytabFailed, err) + } + // Create the SPNEGO handler using the keytab + var handleErr error + handler := spnego.SPNEGOKRB5Authenticate(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + // Set the authenticated identity in the Fiber context + setAuthenticatedIdentityToContext(ctx, goidentity.FromHTTPRequestContext(r)) + // Call the next handler in the chain + handleErr = ctx.Next() + }), kt, service.Logger(cfg.Log)) + // Convert Fiber context to HTTP request + rawReq, err := adaptor.ConvertRequest(ctx, true) + if err != nil { + return fmt.Errorf("%w: %w", config.ErrConvertRequestFailed, err) + } + // Serve the request using the SPNEGO handler + handler.ServeHTTP(wrapCtx{ctx}, rawReq) + return handleErr + }, nil +} + +// setAuthenticatedIdentityToContext stores the authenticated identity in the Fiber context. +// It takes a Fiber context and an identity, and sets it using the ContextKeyOfIdentity key +// for later retrieval by other handlers in the request chain. +func setAuthenticatedIdentityToContext(ctx fiber.Ctx, identity goidentity.Identity) { + ctx.Locals(config.ContextKeyOfIdentity, identity) +} + +// GetAuthenticatedIdentityFromContext retrieves the authenticated identity from the Fiber context. +// It returns the identity and a boolean indicating if it was found. +// This function should be used by subsequent handlers to access the authenticated user's information. +// +// Example: +// +// user, ok := GetAuthenticatedIdentityFromContext(ctx) +// if ok { +// fmt.Printf("Authenticated user: %s\n", user.UserName()) +// } +func GetAuthenticatedIdentityFromContext(ctx fiber.Ctx) (goidentity.Identity, bool) { + id, ok := ctx.Locals(config.ContextKeyOfIdentity).(goidentity.Identity) + return id, ok +} + +// wrapCtx wraps a Fiber context to implement the http.ResponseWriter interface. +// This adapter allows the Fiber context to be used with standard HTTP handlers +// that expect an http.ResponseWriter, bridging the gap between Fiber's context +// model and the standard library's HTTP interfaces. + +type wrapCtx struct { + fiber.Ctx +} + +// Header returns the request headers from the wrapped Fiber context. +// This method implements the http.ResponseWriter interface. +func (w wrapCtx) Header() http.Header { + return w.Ctx.GetReqHeaders() +} + +// WriteHeader sets the HTTP status code on the wrapped Fiber context. +// This method implements the http.ResponseWriter interface. +func (w wrapCtx) WriteHeader(statusCode int) { + w.Ctx.Status(statusCode) +} diff --git a/spnego/v3/spnego_test.go b/spnego/v3/spnego_test.go new file mode 100644 index 000000000..0986950df --- /dev/null +++ b/spnego/v3/spnego_test.go @@ -0,0 +1,162 @@ +package v3 + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/gofiber/contrib/spnego/config" + "github.com/gofiber/fiber/v3" + "github.com/jcmturner/goidentity/v6" + "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { + t.Run("test for keytab lookup function not set", func(t *testing.T) { + _, err := NewSpnegoKrb5AuthenticateMiddleware(nil) + require.ErrorIs(t, err, config.ErrConfigInvalidOfKeytabLookupFunctionRequired) + }) + t.Run("test for keytab lookup failed", func(t *testing.T) { + middleware, err := NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + KeytabLookup: func() (*keytab.Keytab, error) { + return nil, errors.New("mock keytab lookup error") + }, + }) + require.NoError(t, err) + app := fiber.New() + app.Get("/authenticate", middleware, func(c fiber.Ctx) error { + return c.SendString("authenticated") + }) + handler := app.Handler() + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/authenticate") + handler(ctx) + require.Equal(t, http.StatusInternalServerError, ctx.Response.StatusCode()) + require.Equal(t, fmt.Sprintf("%s: mock keytab lookup error", config.ErrLookupKeytabFailed), string(ctx.Response.Body())) + }) + t.Run("test for keytab lookup function is set", func(t *testing.T) { + var keytabFiles []string + for i := 0; i < 5; i++ { + kt, clean, err := newKeytabTempFile(fmt.Sprintf("HTTP/sso%d.example.com", i), "KRB5.TEST", 18, 19) + require.NoError(t, err) + t.Cleanup(clean) + keytabFiles = append(keytabFiles, kt) + } + lookupFunc, err := config.NewKeytabFileLookupFunc(keytabFiles...) + require.NoError(t, err) + middleware, err := NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + KeytabLookup: lookupFunc, + }) + require.NoError(t, err) + app := fiber.New() + app.Get("/authenticate", middleware, func(c fiber.Ctx) error { + user, ok := GetAuthenticatedIdentityFromContext(c) + if ok { + t.Logf("username: %s\ndomain: %s\n", user.UserName(), user.Domain()) + } + return c.SendString("authenticated") + }) + handler := app.Handler() + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/authenticate") + handler(ctx) + require.Equal(t, fasthttp.StatusUnauthorized, ctx.Response.StatusCode()) + }) +} + +func TestNewKeytabFileLookupFunc(t *testing.T) { + t.Run("test for empty keytab files", func(t *testing.T) { + _, err := config.NewKeytabFileLookupFunc() + require.ErrorIs(t, err, config.ErrConfigInvalidOfAtLeastOneKeytabFileRequired) + }) + t.Run("test for has invalid keytab file", func(t *testing.T) { + kt1, clean, err := newKeytabTempFile("HTTP/sso.example.com", "KRB5.TEST", 18, 19) + require.NoError(t, err) + t.Cleanup(clean) + kt2, clean, err := newBadKeytabTempFile("HTTP/sso1.example.com", "KRB5.TEST", 18, 19) + require.NoError(t, err) + t.Cleanup(clean) + _, err = config.NewKeytabFileLookupFunc(kt1, kt2) + require.ErrorIs(t, err, config.ErrLoadKeytabFileFailed) + }) + t.Run("test for some keytab files", func(t *testing.T) { + var keytabFiles []string + for i := 0; i < 5; i++ { + kt, clean, err := newKeytabTempFile(fmt.Sprintf("HTTP/sso%d.example.com", i), "KRB5.TEST", 18, 19) + require.NoError(t, err) + t.Cleanup(clean) + keytabFiles = append(keytabFiles, kt) + } + lookupFunc, err := config.NewKeytabFileLookupFunc(keytabFiles...) + require.NoError(t, err) + _, err = lookupFunc() + require.NoError(t, err) + }) +} + +func newBadKeytabTempFile(principal string, realm string, et ...int32) (filename string, clean func(), err error) { + filename = fmt.Sprintf("./tmp_%d.keytab", time.Now().Unix()) + clean = func() { + os.Remove(filename) + } + var kt keytab.Keytab + for _, e := range et { + kt.AddEntry(principal, realm, "abcdefg", time.Now(), 2, e) + } + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) + if err != nil { + return filename, clean, fmt.Errorf("open file failed: %w", err) + } + if _, err = kt.Write(file); err != nil { + return filename, clean, fmt.Errorf("write file failed: %w", err) + } + file.Close() + return filename, clean, nil +} + +func newKeytabTempFile(principal string, realm string, et ...int32) (filename string, clean func(), err error) { + filename = fmt.Sprintf("./tmp_%d.keytab", time.Now().Unix()) + clean = func() { + os.Remove(filename) + } + kt := keytab.New() + for _, e := range et { + kt.AddEntry(principal, realm, "abcdefg", time.Now(), 2, e) + } + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) + if err != nil { + return filename, clean, fmt.Errorf("open file failed: %w", err) + } + if _, err = kt.Write(file); err != nil { + return filename, clean, fmt.Errorf("write file failed: %w", err) + } + file.Close() + return filename, clean, nil +} + +func TestGetAuthenticatedIdentityFromContext(t *testing.T) { + app := fiber.New() + app.Use("/testContext", func(ctx fiber.Ctx) error { + user := goidentity.NewUser("test-user") + user.SetDomain("example.com") + _, ok := GetAuthenticatedIdentityFromContext(ctx) + require.False(t, ok) + setAuthenticatedIdentityToContext(ctx, &user) + user1, ok := GetAuthenticatedIdentityFromContext(ctx) + require.True(t, ok) + require.Equal(t, user.UserName(), user1.UserName()) + require.Equal(t, user.Domain(), user1.Domain()) + return ctx.SendStatus(200) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/testContext", nil)) + require.NoError(t, err) + require.Equal(t, resp.StatusCode, 200) +} From 8a1aa70e2796eb80d1e63e249d451151a9a9ed5a Mon Sep 17 00:00:00 2001 From: Jarod Date: Tue, 26 Aug 2025 01:03:41 +0800 Subject: [PATCH 2/5] refactor(spnego): Restructure SPNEGO authentication middleware codebase Move configuration and utility functions from config package to main package, optimize code organization Add new error types and utility functions, improve test coverage Simplify documentation content, remove redundant information --- spnego/config.go | 45 ++++++++ spnego/config_test.go | 132 ++++++++++++++++++++++ spnego/doc.go | 51 +-------- spnego/error.go | 18 +++ spnego/{ => example}/example.go | 27 ++++- spnego/identity.go | 28 +++++ spnego/identity_test.go | 57 ++++++++++ spnego/utils/adapter.go | 57 ++++++++++ spnego/utils/adapter_test.go | 61 ++++++++++ spnego/utils/keytab.go | 77 +++++++++++++ spnego/utils/keytab_test.go | 46 ++++++++ spnego/utils/mock_keytab.go | 150 +++++++++++++++++++++++++ spnego/utils/mock_keytab_test.go | 185 +++++++++++++++++++++++++++++++ spnego/v2/spnego.go | 61 ++-------- spnego/v2/spnego_test.go | 143 ++++++------------------ spnego/v3/spnego.go | 61 ++-------- spnego/v3/spnego_test.go | 141 ++++++----------------- 17 files changed, 966 insertions(+), 374 deletions(-) create mode 100644 spnego/config.go create mode 100644 spnego/config_test.go create mode 100644 spnego/error.go rename spnego/{ => example}/example.go (64%) create mode 100644 spnego/identity.go create mode 100644 spnego/identity_test.go create mode 100644 spnego/utils/adapter.go create mode 100644 spnego/utils/adapter_test.go create mode 100644 spnego/utils/keytab.go create mode 100644 spnego/utils/keytab_test.go create mode 100644 spnego/utils/mock_keytab.go create mode 100644 spnego/utils/mock_keytab_test.go diff --git a/spnego/config.go b/spnego/config.go new file mode 100644 index 000000000..20ded9856 --- /dev/null +++ b/spnego/config.go @@ -0,0 +1,45 @@ +package spnego + +import ( + "fmt" + "log" + + "github.com/jcmturner/gokrb5/v8/keytab" +) + +// contextKeyOfIdentity is the key used to store the authenticated identity in the Fiber context +const contextKeyOfIdentity = "middleware.spnego.Identity" + +// KeytabLookupFunc is a function type that returns a keytab or an error +// It's used to look up the keytab dynamically when needed +// This design allows for extensibility, enabling keytab retrieval from various sources +// such as databases, remote services, or other custom implementations beyond static files +type KeytabLookupFunc func() (*keytab.Keytab, error) + +// Config holds the configuration for the SPNEGO middleware +// It includes the keytab lookup function and a logger +type Config struct { + // KeytabLookup is a function that retrieves the keytab + KeytabLookup KeytabLookupFunc + // Log is the logger used for middleware logging + Log *log.Logger +} + +// NewKeytabFileLookupFunc creates a new KeytabLookupFunc that loads keytab files +// It accepts one or more keytab file paths and returns a function that loads them +func NewKeytabFileLookupFunc(keytabFiles ...string) (KeytabLookupFunc, error) { + if len(keytabFiles) == 0 { + return nil, ErrConfigInvalidOfAtLeastOneKeytabFileRequired + } + return func() (*keytab.Keytab, error) { + var mergeKeytab keytab.Keytab + for _, keytabFile := range keytabFiles { + kt, err := keytab.Load(keytabFile) + if err != nil { + return nil, fmt.Errorf("%w: file %s load failed: %w", ErrLoadKeytabFileFailed, keytabFile, err) + } + mergeKeytab.Entries = append(mergeKeytab.Entries, kt.Entries...) + } + return &mergeKeytab, nil + }, nil +} diff --git a/spnego/config_test.go b/spnego/config_test.go new file mode 100644 index 000000000..75c6cb330 --- /dev/null +++ b/spnego/config_test.go @@ -0,0 +1,132 @@ +package spnego + +import ( + "os" + "testing" + "time" + + "github.com/gofiber/contrib/spnego/utils" + "github.com/stretchr/testify/require" +) + +func TestNewKeytabFileLookupFunc(t *testing.T) { + t.Run("test didn't give any keytab files", func(t *testing.T) { + _, err := NewKeytabFileLookupFunc() + require.ErrorIs(t, err, ErrConfigInvalidOfAtLeastOneKeytabFileRequired) + }) + t.Run("test not found keytab file", func(t *testing.T) { + err := os.WriteFile("./invalid.keytab", []byte("12345"), 0600) + require.NoError(t, err) + t.Cleanup(func() { + os.Remove("./invalid.keytab") + }) + fn, err := NewKeytabFileLookupFunc("./invalid.keytab") + require.NoError(t, err) + _, err = fn() + require.ErrorIs(t, err, ErrLoadKeytabFileFailed) + }) + t.Run("test one keytab file", func(t *testing.T) { + tm := time.Now() + _, clean, err := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso.example.com"), + utils.WithRealm("TEST.LOCAL"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm, + }), + utils.WithFilename("./temp.keytab"), + ) + require.NoError(t, err) + t.Cleanup(clean) + fn, err := NewKeytabFileLookupFunc("./temp.keytab") + require.NoError(t, err) + kt1, err := fn() + require.NoError(t, err) + info := utils.GetKeytabInfo(kt1) + require.Len(t, info, 1) + require.Equal(t, info[0].PrincipalName, "HTTP/sso.example.com@TEST.LOCAL") + require.Equal(t, info[0].Realm, "TEST.LOCAL") + require.Len(t, info[0].Pairs, 1) + require.Equal(t, info[0].Pairs[0].Version, uint8(2)) + require.Equal(t, info[0].Pairs[0].EncryptType, int32(18)) + // Note: The creation time of keytab is only accurate to the second. + require.Equal(t, info[0].Pairs[0].CreateTime.Unix(), tm.Unix()) + }) + t.Run("test multiple keytab file but has invalid keytab", func(t *testing.T) { + tm := time.Now() + _, clean, err := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso.example.com"), + utils.WithRealm("TEST.LOCAL"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm, + }), + utils.WithFilename("./temp.keytab"), + ) + require.NoError(t, err) + t.Cleanup(clean) + err = os.WriteFile("./invalid1.keytab", []byte("12345"), 0600) + require.NoError(t, err) + t.Cleanup(func() { + os.Remove("./invalid1.keytab") + }) + fn, err := NewKeytabFileLookupFunc("./temp.keytab", "./invalid1.keytab") + require.NoError(t, err) + _, err = fn() + require.ErrorIs(t, err, ErrLoadKeytabFileFailed) + }) + t.Run("test multiple keytab file", func(t *testing.T) { + tm := time.Now() + _, clean1, err1 := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso.example1.com"), + utils.WithRealm("TEST.LOCAL"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm, + }), + utils.WithFilename("./temp1.keytab"), + ) + require.NoError(t, err1) + t.Cleanup(clean1) + _, clean2, err2 := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso.example2.com"), + utils.WithRealm("TEST.LOCAL"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 17, + CreateTime: tm, + }, utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm, + }), + utils.WithFilename("./temp2.keytab"), + ) + require.NoError(t, err2) + t.Cleanup(clean2) + fn, err := NewKeytabFileLookupFunc("./temp1.keytab", "./temp2.keytab") + require.NoError(t, err) + kt2, err := fn() + require.NoError(t, err) + info := utils.GetKeytabInfo(kt2) + require.Len(t, info, 2) + require.Equal(t, info[0].PrincipalName, "HTTP/sso.example1.com@TEST.LOCAL") + require.Equal(t, info[0].Realm, "TEST.LOCAL") + require.Len(t, info[0].Pairs, 1) + require.Equal(t, info[0].Pairs[0].Version, uint8(2)) + require.Equal(t, info[0].Pairs[0].EncryptType, int32(18)) + require.Equal(t, info[0].Pairs[0].CreateTime.Unix(), tm.Unix()) + require.Equal(t, info[1].PrincipalName, "HTTP/sso.example2.com@TEST.LOCAL") + require.Equal(t, info[1].Realm, "TEST.LOCAL") + require.Len(t, info[1].Pairs, 2) + require.Equal(t, info[1].Pairs[0].Version, uint8(2)) + require.Equal(t, info[1].Pairs[0].EncryptType, int32(17)) + require.Equal(t, info[1].Pairs[0].CreateTime.Unix(), tm.Unix()) + require.Equal(t, info[1].Pairs[1].Version, uint8(2)) + require.Equal(t, info[1].Pairs[1].EncryptType, int32(18)) + require.Equal(t, info[1].Pairs[1].CreateTime.Unix(), tm.Unix()) + }) +} diff --git a/spnego/doc.go b/spnego/doc.go index 4e3fa55d3..ea28c342c 100644 --- a/spnego/doc.go +++ b/spnego/doc.go @@ -1,51 +1,4 @@ // Package spnego provides SPNEGO (Simple and Protected GSSAPI Negotiation Mechanism) -// Package spnego provides SPNEGO (Simple and Protected GSSAPI Negotiation Mechanism) -// authentication middleware for Fiber applications. It enables Kerberos authentication -// for HTTP requests, allowing seamless integration with Active Directory and other -// Kerberos-based authentication systems. -// -// Version Compatibility: -// - v2 package: Compatible with Fiber v2 -// - v3 package: Compatible with Fiber v3 -// -// Example Usage: -// -// import ( -// "fmt" -// "github.com/gofiber/contrib/spnego/config" -// v3 "github.com/gofiber/contrib/spnego/v3" -// "github.com/gofiber/fiber/v3" -// ) -// -// func main() { -// app := fiber.New() -// -// // Create keytab lookup function -// keytabLookup, err := config.NewKeytabFileLookupFunc("/path/to/keytab.keytab") -// if err != nil { -// panic(err) -// } -// -// // Create SPNEGO middleware -// authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ -// KeytabLookup: keytabLookup, -// }) -// if err != nil { -// panic(err) -// } -// -// // Apply middleware to protected routes -// app.Use("/protected", authMiddleware) -// -// // Access authenticated identity -// app.Get("/protected/resource", func(c fiber.Ctx) error { -// identity, ok := v3.GetAuthenticatedIdentityFromContext(c) -// if !ok { -// return c.Status(fiber.StatusUnauthorized).SendString("Unauthorized") -// } -// return c.SendString(fmt.Sprintf("Hello, %s!", identity.UserName())) -// }) -// -// app.Listen(":3000") -// } +// authentication middleware for Fiber applications. +// It enables Kerberos authentication for HTTP requests. package spnego diff --git a/spnego/error.go b/spnego/error.go new file mode 100644 index 000000000..4d05705bf --- /dev/null +++ b/spnego/error.go @@ -0,0 +1,18 @@ +package spnego + +import "errors" + +// ErrConfigInvalidOfKeytabLookupFunctionRequired is returned when the KeytabLookup function is not set in Config +var ErrConfigInvalidOfKeytabLookupFunctionRequired = errors.New("config invalid: keytab lookup function is required") + +// ErrLookupKeytabFailed is returned when the keytab lookup fails +var ErrLookupKeytabFailed = errors.New("keytab lookup failed") + +// ErrConvertRequestFailed is returned when the request conversion to HTTP request fails +var ErrConvertRequestFailed = errors.New("convert request failed") + +// ErrConfigInvalidOfAtLeastOneKeytabFileRequired is returned when no keytab files are provided +var ErrConfigInvalidOfAtLeastOneKeytabFileRequired = errors.New("config invalid: at least one keytab file required") + +// ErrLoadKeytabFileFailed is returned when load keytab files failed +var ErrLoadKeytabFileFailed = errors.New("load keytab failed") diff --git a/spnego/example.go b/spnego/example/example.go similarity index 64% rename from spnego/example.go rename to spnego/example/example.go index 0d262f1c0..7c85d1a50 100644 --- a/spnego/example.go +++ b/spnego/example/example.go @@ -1,10 +1,11 @@ -package spnego +package example import ( "fmt" "time" - "github.com/gofiber/contrib/spnego/config" + "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" v3 "github.com/gofiber/contrib/spnego/v3" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" @@ -12,11 +13,27 @@ import ( func ExampleNewSpnegoKrb5AuthenticateMiddleware() { app := fiber.New() - keytabLookup, err := config.NewKeytabFileLookupFunc("/keytabFile/one.keytab", "/keytabFile/two.keyta") + // create mock keytab file + // you must use a real keytab file + _, clean, err := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso1.example.com"), + utils.WithRealm("EXAMPLE.LOCAL"), + utils.WithFilename("./temp-sso1.keytab"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: time.Now(), + }), + ) + if err != nil { + log.Fatalf("create mock keytab error: %v", err) + } + defer clean() + keytabLookup, err := spnego.NewKeytabFileLookupFunc("./temp-sso1.keytab") if err != nil { panic(fmt.Errorf("create keytab lookup function failed: %w", err)) } - authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ KeytabLookup: keytabLookup, }) if err != nil { @@ -27,7 +44,7 @@ func ExampleNewSpnegoKrb5AuthenticateMiddleware() { // Access authenticated identity app.Get("/protected/resource", func(c fiber.Ctx) error { - identity, ok := v3.GetAuthenticatedIdentityFromContext(c) + identity, ok := spnego.GetAuthenticatedIdentityFromContext(c) if !ok { return c.Status(fiber.StatusUnauthorized).SendString("Unauthorized") } diff --git a/spnego/identity.go b/spnego/identity.go new file mode 100644 index 000000000..f3c6c18b0 --- /dev/null +++ b/spnego/identity.go @@ -0,0 +1,28 @@ +package spnego + +import "github.com/jcmturner/goidentity/v6" + +type FiberContext interface { + Locals(key any, value ...any) any +} + +// SetAuthenticatedIdentityToContext stores the authenticated identity in the Fiber context. +// It takes a Fiber context and an identity, and sets it using the contextKeyOfIdentity key +// for later retrieval by other handlers in the request chain. +func SetAuthenticatedIdentityToContext[T FiberContext](ctx T, identity goidentity.Identity) { + ctx.Locals(contextKeyOfIdentity, identity) +} + +// GetAuthenticatedIdentityFromContext retrieves the authenticated identity from the Fiber context. +// It returns the identity and a boolean indicating if it was found. +// This function should be used by subsequent handlers to access the authenticated user's information. +// Example: +// +// user, ok := GetAuthenticatedIdentityFromContext(ctx) +// if ok { +// fmt.Printf("Authenticated user: %s\n", user.UserName()) +// } +func GetAuthenticatedIdentityFromContext[T FiberContext](ctx T) (goidentity.Identity, bool) { + id, ok := ctx.Locals(contextKeyOfIdentity).(goidentity.Identity) + return id, ok +} diff --git a/spnego/identity_test.go b/spnego/identity_test.go new file mode 100644 index 000000000..9a8eeb36c --- /dev/null +++ b/spnego/identity_test.go @@ -0,0 +1,57 @@ +package spnego + +import ( + "net/http/httptest" + "testing" + + fiberV2 "github.com/gofiber/fiber/v2" + fiberV3 "github.com/gofiber/fiber/v3" + "github.com/jcmturner/goidentity/v6" + "github.com/stretchr/testify/require" +) + +func TestGetAndSetAuthenticatedIdentityFromContextForFiberV2(t *testing.T) { + app := fiberV2.New() + id := goidentity.NewUser("test@TEST.LOCAL") + app.Use("/identity", func(ctx *fiberV2.Ctx) error { + SetAuthenticatedIdentityToContext(ctx, &id) + return ctx.Next() + }) + app.Get("/test", func(ctx *fiberV2.Ctx) error { + _, ok := GetAuthenticatedIdentityFromContext(ctx) + require.False(t, ok) + return ctx.SendStatus(fiberV2.StatusOK) + }) + app.Get("/identity/test", func(ctx *fiberV2.Ctx) error { + user, ok := GetAuthenticatedIdentityFromContext(ctx) + require.True(t, ok) + require.Equal(t, id.UserName(), user.UserName()) + require.Equal(t, id.Domain(), user.Domain()) + return ctx.SendStatus(fiberV2.StatusOK) + }) + app.Test(httptest.NewRequest("GET", "/test", nil)) + app.Test(httptest.NewRequest("GET", "/identity/test", nil)) +} + +func TestGetAndSetAuthenticatedIdentityFromContextForFiberV3(t *testing.T) { + app := fiberV3.New() + id := goidentity.NewUser("test@TEST.LOCAL") + app.Use("/identity", func(ctx fiberV3.Ctx) error { + SetAuthenticatedIdentityToContext(ctx, &id) + return ctx.Next() + }) + app.Get("/test", func(ctx fiberV3.Ctx) error { + _, ok := GetAuthenticatedIdentityFromContext(ctx) + require.False(t, ok) + return ctx.SendStatus(fiberV3.StatusOK) + }) + app.Get("/identity/test", func(ctx fiberV3.Ctx) error { + user, ok := GetAuthenticatedIdentityFromContext(ctx) + require.True(t, ok) + require.Equal(t, id.UserName(), user.UserName()) + require.Equal(t, id.Domain(), user.Domain()) + return ctx.SendStatus(fiberV3.StatusOK) + }) + app.Test(httptest.NewRequest("GET", "/test", nil)) + app.Test(httptest.NewRequest("GET", "/identity/test", nil)) +} diff --git a/spnego/utils/adapter.go b/spnego/utils/adapter.go new file mode 100644 index 000000000..674f4239c --- /dev/null +++ b/spnego/utils/adapter.go @@ -0,0 +1,57 @@ +package utils + +import ( + "net/http" + + "github.com/valyala/fasthttp" +) + +// FiberContext defines the minimal interface required from a Fiber context +// for the adapter to function properly. +// T represents the type of the Fiber context (v2 or v3 compatible). +type FiberContext[T any] interface { + // Response returns the underlying fasthttp.Response + Response() *fasthttp.Response + // Write writes bytes to the response body + Write(bytes []byte) (int, error) + // Status sets the HTTP status code and returns the context + Status(status int) T +} + +// WrapFiberContextAdaptHttpResponseWriter adapts a Fiber context to the http.ResponseWriter interface +// This allows Fiber to work with libraries that expect the standard http.ResponseWriter +// T represents the type of the Fiber context (v2 or v3 compatible). +type WrapFiberContextAdaptHttpResponseWriter[T FiberContext[T]] struct { + ctx T +} + +// Header returns the response headers from the Fiber context +// in the standard http.Header format +// note: write header must using fiber context +func (f *WrapFiberContextAdaptHttpResponseWriter[T]) Header() http.Header { + headers := http.Header{} + for k, v := range f.ctx.Response().Header.All() { + headers.Set(string(k), string(v)) + } + return headers +} + +// Write writes bytes to the response body using the Fiber context's Write method +func (f *WrapFiberContextAdaptHttpResponseWriter[T]) Write(bytes []byte) (int, error) { + return f.ctx.Write(bytes) +} + +// WriteHeader sets the HTTP status code using the Fiber context's Status method +func (f *WrapFiberContextAdaptHttpResponseWriter[T]) WriteHeader(statusCode int) { + f.ctx.Status(statusCode) +} + +// NewWrapFiberContext creates a new adapter instance that wraps a Fiber context +// to implement the http.ResponseWriter interface +// ctx: The Fiber context to wrap +// Returns: A new adapter instance +func NewWrapFiberContext[T FiberContext[T]](ctx T) *WrapFiberContextAdaptHttpResponseWriter[T] { + return &WrapFiberContextAdaptHttpResponseWriter[T]{ + ctx: ctx, + } +} diff --git a/spnego/utils/adapter_test.go b/spnego/utils/adapter_test.go new file mode 100644 index 000000000..4044b9e63 --- /dev/null +++ b/spnego/utils/adapter_test.go @@ -0,0 +1,61 @@ +package utils + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + v2Fiber "github.com/gofiber/fiber/v2" + v2Adaptor "github.com/gofiber/fiber/v2/middleware/adaptor" + v3Fiber "github.com/gofiber/fiber/v3" + v3Adaptor "github.com/gofiber/fiber/v3/middleware/adaptor" + "github.com/stretchr/testify/require" +) + +func TestNewWrapFiberContextOfFiverV2(t *testing.T) { + app := v2Fiber.New() + reqId := strconv.FormatInt(time.Now().UnixNano(), 16) + app.Get("/test", func(ctx *v2Fiber.Ctx) error { + ctx.Response().Header.Set("X-Request-Id", reqId) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + require.Equal(t, reqId, w.Header().Get("X-Request-Id")) + w.Write([]byte(fmt.Sprintf("reqId: %s", reqId))) + }) + rawReq, err := v2Adaptor.ConvertRequest(ctx, true) + require.NoError(t, err) + handler.ServeHTTP(NewWrapFiberContext(ctx), rawReq) + return nil + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/test", nil)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("reqId: %s", reqId), string(body)) +} + +func TestNewWrapFiberContextOfFiverV3(t *testing.T) { + app := v3Fiber.New() + reqId := strconv.FormatInt(time.Now().UnixNano(), 16) + app.Get("/test", func(ctx v3Fiber.Ctx) error { + ctx.Response().Header.Set("X-Request-Id", reqId) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + require.Equal(t, reqId, w.Header().Get("X-Request-Id")) + w.Write([]byte(fmt.Sprintf("reqId: %s", reqId))) + }) + rawReq, err := v3Adaptor.ConvertRequest(ctx, true) + require.NoError(t, err) + handler.ServeHTTP(NewWrapFiberContext(ctx), rawReq) + return nil + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/test", nil)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("reqId: %s", reqId), string(body)) +} diff --git a/spnego/utils/keytab.go b/spnego/utils/keytab.go new file mode 100644 index 000000000..45ce82605 --- /dev/null +++ b/spnego/utils/keytab.go @@ -0,0 +1,77 @@ +package utils + +import ( + "time" + + "maps" + "sort" + + "github.com/jcmturner/gokrb5/v8/keytab" +) + +// KeytabInfo represents information about a principal in a Kerberos keytab +// It contains the principal name, realm, and associated encryption type pairs +type KeytabInfo struct { + PrincipalName string // The Kerberos principal name (e.g., HTTP/service.example.com) + Realm string // The Kerberos realm (e.g., EXAMPLE.COM) + Pairs []EncryptTypePair // List of encryption type pairs for this principal +} + +// EncryptTypePair represents an encryption type entry in a Kerberos keytab +// It contains the version, encryption type, and creation timestamp +type EncryptTypePair struct { + Version uint8 // The key version number + EncryptType int32 // The encryption type (e.g., 18 for AES-256-CTS-HMAC-SHA1-96) + CreateTime time.Time // The timestamp when this key was created +} + +// MultiKeytabInfo is a slice of KeytabInfo structures +// Used to represent multiple principal entries from a keytab +type MultiKeytabInfo []KeytabInfo + +// GetKeytabInfo extracts information from a Kerberos keytab and returns it in a structured format +// It organizes keytab entries by principal name and sorts them alphabetically +// +// Parameters: +// kt - A pointer to a keytab.Keytab instance (can be nil) +// +// Returns: +// MultiKeytabInfo - A sorted slice of KeytabInfo structures containing principal information +// +// Example usage: +// kt, _ := keytab.Load("/path/to/keytab") +// info := GetKeytabInfo(kt) +// for _, principal := range info { +// fmt.Printf("Principal: %s@%s\n", principal.PrincipalName, principal.Realm) +// for _, pair := range principal.Pairs { +// fmt.Printf(" EncryptType: %d, Version: %d, Created: %v\n", pair.EncryptType, pair.Version, pair.CreateTime) +// } +// } +func GetKeytabInfo(kt *keytab.Keytab) MultiKeytabInfo { + keytabMap := make(map[string]KeytabInfo) + if kt != nil { + for _, entry := range kt.Entries { + item, ok := keytabMap[entry.Principal.String()] + if !ok { + item = KeytabInfo{ + PrincipalName: entry.Principal.String(), + Realm: entry.Principal.Realm, + } + } + item.Pairs = append(item.Pairs, EncryptTypePair{ + Version: entry.KVNO8, + EncryptType: entry.Key.KeyType, + CreateTime: entry.Timestamp, + }) + keytabMap[entry.Principal.String()] = item + } + } + var mk = make(MultiKeytabInfo, 0, len(keytabMap)) + for item := range maps.Values(keytabMap) { + mk = append(mk, item) + } + sort.Slice(mk, func(i, j int) bool { + return mk[i].PrincipalName < mk[j].PrincipalName + }) + return mk +} diff --git a/spnego/utils/keytab_test.go b/spnego/utils/keytab_test.go new file mode 100644 index 000000000..7411c5c70 --- /dev/null +++ b/spnego/utils/keytab_test.go @@ -0,0 +1,46 @@ +package utils + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestGetKeytabInfo(t *testing.T) { + tm := time.Now() + kt, _, err := NewMockKeytab( + WithRealm("EXAMPLE.LOCAL"), + WithPrincipal("HTTP/sso-test.example.com"), + WithPassword("abcd1234"), + WithPairs(EncryptTypePair{ + Version: 3, + EncryptType: 17, + CreateTime: tm, + }, EncryptTypePair{ + Version: 3, + EncryptType: 18, + CreateTime: tm, + }), + ) + require.NoError(t, err) + err = kt.AddEntry("HTTP/sso-test2.example.com", "EXAMPLE.LOCAL", "qwer1234", tm.Add(-time.Minute), 2, 18) + require.NoError(t, err) + info := GetKeytabInfo(kt) + require.Len(t, info, 2) + require.Equal(t, info[0].PrincipalName, "HTTP/sso-test.example.com@EXAMPLE.LOCAL") + require.Equal(t, info[0].Realm, "EXAMPLE.LOCAL") + require.Len(t, info[0].Pairs, 2) + require.Equal(t, info[0].Pairs[0].Version, uint8(3)) + require.Equal(t, info[0].Pairs[0].EncryptType, int32(17)) + require.Equal(t, info[0].Pairs[0].CreateTime.Unix(), tm.Unix()) + require.Equal(t, info[0].Pairs[1].Version, uint8(3)) + require.Equal(t, info[0].Pairs[1].EncryptType, int32(18)) + require.Equal(t, info[0].Pairs[1].CreateTime.Unix(), tm.Unix()) + require.Equal(t, info[1].PrincipalName, "HTTP/sso-test2.example.com@EXAMPLE.LOCAL") + require.Equal(t, info[1].Realm, "EXAMPLE.LOCAL") + require.Len(t, info[1].Pairs, 1) + require.Equal(t, info[1].Pairs[0].Version, uint8(2)) + require.Equal(t, info[1].Pairs[0].EncryptType, int32(18)) + require.Equal(t, info[1].Pairs[0].CreateTime.Unix(), tm.Add(-time.Minute).Unix()) +} diff --git a/spnego/utils/mock_keytab.go b/spnego/utils/mock_keytab.go new file mode 100644 index 000000000..4c576d389 --- /dev/null +++ b/spnego/utils/mock_keytab.go @@ -0,0 +1,150 @@ +package utils + +import ( + "fmt" + "os" + + "github.com/jcmturner/gokrb5/v8/keytab" +) + +// mockOptions contains configuration parameters for creating a mock keytab +// It allows customization of principal name, realm, password, filename, and encryption type pairs +// used for testing SPNEGO authentication middleware +type mockOptions struct { + PrincipalName string // Kerberos principal name + Realm string // Kerberos realm + Password string // Password for generating encryption keys + Filename string // Optional filename to write the mock keytab + Pairs []EncryptTypePair // Encryption type pairs to add to the keytab +} + +// apply applies the given options to the mockOptions +// This method iterates over all provided options and applies them in sequence +// allowing for flexible configuration of the mock keytab +func (m *mockOptions) apply(opts ...MockOption) { + for _, opt := range opts { + opt(m) + } +} + +// WithPrincipal sets the Kerberos principal name for the mock keytab +// Example: WithPrincipal("HTTP/service.example.com") +func WithPrincipal(principalName string) MockOption { + return func(options *mockOptions) { + options.PrincipalName = principalName + } +} + +// WithRealm sets the Kerberos realm for the mock keytab +// Example: WithRealm("EXAMPLE.COM") +func WithRealm(realm string) MockOption { + return func(options *mockOptions) { + options.Realm = realm + } +} + +// WithFilename specifies the filename to write the mock keytab to +// If provided, the keytab will be written to this file +// Example: WithFilename("test.keytab") +func WithFilename(filename string) MockOption { + return func(options *mockOptions) { + options.Filename = filename + } +} + +// WithPairs adds encryption type pairs to the mock keytab +// Each pair specifies an encryption type and associated parameters +// Example: WithPairs(EncryptTypePair{EncryptType: 18, CreateTime: time.Now(), Version: 1}) +func WithPairs(pairs ...EncryptTypePair) MockOption { + return func(options *mockOptions) { + options.Pairs = append(options.Pairs, pairs...) + } +} + +// WithPassword sets the password used to generate encryption keys +// This password is used with the principal name and realm to create keys +// Example: WithPassword("securePassword123") +func WithPassword(password string) MockOption { + return func(options *mockOptions) { + options.Password = password + } +} + +// MockOption defines a function type for configuring mockOptions +// Used to implement the option pattern for flexible configuration +type MockOption func(*mockOptions) + +// newDefaultMockOptions creates mockOptions with default values +// Default realm is "TEST.LOCAL" and default password is "abcdef" +// These defaults can be overridden using WithXXX option functions +func newDefaultMockOptions() *mockOptions { + return &mockOptions{ + Realm: "TEST.LOCAL", + Password: "abcdef", + } +} + +type fileOperator interface { + OpenFile(filename string, flag int, perm os.FileMode) (*os.File, error) + Remove(filename string) error +} + +type myFileOperator struct{} + +func (m myFileOperator) OpenFile(filename string, flag int, perm os.FileMode) (*os.File, error) { + return os.OpenFile(filename, flag, perm) +} + +func (m myFileOperator) Remove(filename string) error { + return os.Remove(filename) +} + +var defaultFileOperator fileOperator = myFileOperator{} + +// NewMockKeytab creates a mock keytab for testing purposes +// It allows customization through option functions and returns: +// - A keytab.Keytab instance populated with the specified entries +// - A cleanup function that removes any created files +// - An error if the keytab creation fails +// +// Example usage: +// +// kt, cleanup, err := NewMockKeytab( +// WithPrincipal("HTTP/service.example.com"), +// WithRealm("EXAMPLE.COM"), +// WithPassword("secret"), +// WithFilename("test.keytab"), +// WithPairs(EncryptTypePair{EncryptType: 18}) +// ) +// defer cleanup() +// if err != nil { +// // handle error +// } +func NewMockKeytab(opts ...MockOption) (*keytab.Keytab, func(), error) { + opt := newDefaultMockOptions() + opt.apply(opts...) + kt := keytab.New() + var err error + for _, pair := range opt.Pairs { + if err = kt.AddEntry(opt.PrincipalName, opt.Realm, opt.Password, pair.CreateTime, pair.Version, pair.EncryptType); err != nil { + return nil, nil, fmt.Errorf("error adding entry: %v", err) + } + } + var clean = func() {} + if len(opt.Filename) > 0 { + file, err := defaultFileOperator.OpenFile(opt.Filename, os.O_RDWR|os.O_CREATE, 0o666) + if err != nil { + return nil, nil, fmt.Errorf("error opening file: %w", err) + } + clean = func() { + _ = defaultFileOperator.Remove(opt.Filename) + } + if _, err = kt.Write(file); err != nil { + clean() + return nil, nil, fmt.Errorf("error writing to file: %w", err) + } + _ = file.Close() + return kt, clean, nil + } + return kt, clean, nil +} diff --git a/spnego/utils/mock_keytab_test.go b/spnego/utils/mock_keytab_test.go new file mode 100644 index 000000000..c2114adfa --- /dev/null +++ b/spnego/utils/mock_keytab_test.go @@ -0,0 +1,185 @@ +package utils + +import ( + "os" + "testing" + "time" + + "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/jcmturner/gokrb5/v8/types" + "github.com/stretchr/testify/require" +) + +type mockFileOperator struct { + flag int +} + +func (m mockFileOperator) OpenFile(filename string, flag int, perm os.FileMode) (*os.File, error) { + if m.flag&0x01 != 0 { + return nil, os.ErrPermission + } + file, err := os.OpenFile(filename, flag, perm) + if err != nil { + return nil, err + } + if m.flag&0x02 != 0 { + file.Close() + } + return file, nil +} + +func (m mockFileOperator) Remove(filename string) error { + return os.Remove(filename) +} + +func TestNewMockKeytab(t *testing.T) { + t.Run("test add keytab entry failed", func(t *testing.T) { + _, _, err := NewMockKeytab( + WithPrincipal("HTTP/sso.example.com"), + WithRealm("TEST.LOCAL"), + WithPairs(EncryptTypePair{ + Version: 3, + EncryptType: 18, + CreateTime: time.Now(), + }, EncryptTypePair{ + Version: 3, + EncryptType: 0xffff, + CreateTime: time.Now(), + }), + ) + require.Error(t, err) + }) + t.Run("test none file created", func(t *testing.T) { + tm := time.Now() + kt, clean, err := NewMockKeytab( + WithPrincipal("HTTP/sso.example.com"), + WithRealm("TEST.LOCAL"), + WithPairs(EncryptTypePair{ + Version: 3, + EncryptType: 18, + CreateTime: tm, + }), + ) + require.NoError(t, err) + t.Cleanup(clean) + _, kv, err := kt.GetEncryptionKey(types.NewPrincipalName(1, "HTTP/sso.example.com"), "TEST.LOCAL", 3, 18) + require.NoError(t, err) + require.Equal(t, 3, kv) + }) + t.Run("test file open failed", func(t *testing.T) { + defaultFileOperator = mockFileOperator{flag: 0x01} + _, _, err := NewMockKeytab( + WithPrincipal("HTTP/sso.example.com"), + WithRealm("TEST.LOCAL"), + WithPairs(EncryptTypePair{ + Version: 3, + EncryptType: 18, + CreateTime: time.Now(), + }), + WithFilename("./temp.keytab"), + ) + require.ErrorIs(t, err, os.ErrPermission) + require.NoFileExists(t, "./temp.keytab") + }) + t.Run("test file write failed", func(t *testing.T) { + defaultFileOperator = mockFileOperator{flag: 0x02} + _, _, err := NewMockKeytab( + WithPrincipal("HTTP/sso.example.com"), + WithRealm("TEST.LOCAL"), + WithPairs(EncryptTypePair{ + Version: 3, + EncryptType: 18, + CreateTime: time.Now(), + }), + WithFilename("./temp.keytab"), + ) + require.ErrorIs(t, err, os.ErrClosed) + require.NoFileExists(t, "./temp.keytab") + }) + t.Run("test file created", func(t *testing.T) { + defaultFileOperator = myFileOperator{} + tm := time.Now() + _, clean, err := NewMockKeytab( + WithPrincipal("HTTP/sso.example.com"), + WithRealm("TEST.LOCAL"), + WithPairs(EncryptTypePair{ + Version: 3, + EncryptType: 18, + CreateTime: tm, + }), + WithFilename("./temp.keytab"), + ) + require.NoError(t, err) + t.Cleanup(clean) + require.FileExists(t, "./temp.keytab") + kt, err := keytab.Load("./temp.keytab") + require.NoError(t, err) + _, kv, err := kt.GetEncryptionKey(types.NewPrincipalName(1, "HTTP/sso.example.com"), "TEST.LOCAL", 3, 18) + require.NoError(t, err) + require.Equal(t, 3, kv) + }) +} + +func TestWithFilename(t *testing.T) { + opts := mockOptions{} + require.Empty(t, opts.Filename) + WithFilename("/tmp/test.keytab")(&opts) + require.Equal(t, "/tmp/test.keytab", opts.Filename) +} + +func TestWithPairs(t *testing.T) { + opts := mockOptions{} + tm := time.Now() + require.Len(t, opts.Pairs, 0) + WithPairs(EncryptTypePair{ + Version: 2, + EncryptType: 17, + CreateTime: tm.Add(-time.Minute), + }, EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm.Add(-time.Minute), + })(&opts) + require.Len(t, opts.Pairs, 2) + WithPairs(EncryptTypePair{ + Version: 3, + EncryptType: 18, + CreateTime: tm, + })(&opts) + require.Len(t, opts.Pairs, 3) + require.Equal(t, opts.Pairs, []EncryptTypePair{ + {Version: 2, EncryptType: 17, CreateTime: tm.Add(-time.Minute)}, + {Version: 2, EncryptType: 18, CreateTime: tm.Add(-time.Minute)}, + {Version: 3, EncryptType: 18, CreateTime: tm}, + }) +} + +func TestWithPassword(t *testing.T) { + opts := mockOptions{} + require.Empty(t, opts.Password) + WithPassword("abcd1234")(&opts) + require.Equal(t, "abcd1234", opts.Password) +} + +func TestWithPrincipal(t *testing.T) { + opts := mockOptions{} + require.Empty(t, opts.PrincipalName) + WithPrincipal("HTTP/sso.example.local")(&opts) + require.Equal(t, "HTTP/sso.example.local", opts.PrincipalName) +} + +func TestWithRealm(t *testing.T) { + opts := mockOptions{} + require.Empty(t, opts.Realm) + WithRealm("EXAMPLE.LOCAL")(&opts) + require.Equal(t, "EXAMPLE.LOCAL", opts.Realm) +} + +func Test_mockOptions_apply(t *testing.T) { + opts := mockOptions{} + require.Empty(t, opts.Filename) + require.Empty(t, opts.Realm) + opts.apply(WithFilename("/tmp/test.keytab"), WithRealm("TEST.LOCAL")) + require.Equal(t, "/tmp/test.keytab", opts.Filename) + require.Equal(t, "TEST.LOCAL", opts.Realm) +} diff --git a/spnego/v2/spnego.go b/spnego/v2/spnego.go index 22c30ae08..1c45cb6cf 100644 --- a/spnego/v2/spnego.go +++ b/spnego/v2/spnego.go @@ -10,7 +10,8 @@ import ( "net/http" "os" - "github.com/gofiber/contrib/spnego/config" + spnego2 "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/adaptor" "github.com/jcmturner/goidentity/v6" @@ -22,13 +23,10 @@ import ( // It takes a Config struct and returns a Fiber handler or an error. // The middleware handles Kerberos authentication for incoming requests using the // SPNEGO protocol, verifying client credentials against the configured keytab. -func NewSpnegoKrb5AuthenticateMiddleware(cfg *config.Config) (fiber.Handler, error) { +func NewSpnegoKrb5AuthenticateMiddleware(cfg spnego2.Config) (fiber.Handler, error) { // Validate configuration - if cfg == nil { - cfg = &config.Config{} - } if cfg.KeytabLookup == nil { - return nil, config.ErrConfigInvalidOfKeytabLookupFunctionRequired + return nil, spnego2.ErrConfigInvalidOfKeytabLookupFunctionRequired } // Set default logger if not provided if cfg.Log == nil { @@ -40,66 +38,23 @@ func NewSpnegoKrb5AuthenticateMiddleware(cfg *config.Config) (fiber.Handler, err // Look up the keytab kt, err := cfg.KeytabLookup() if err != nil { - return fmt.Errorf("%w: %w", config.ErrLookupKeytabFailed, err) + return fmt.Errorf("%w: %w", spnego2.ErrLookupKeytabFailed, err) } // Create the SPNEGO handler using the keytab var handleErr error handler := spnego.SPNEGOKRB5Authenticate(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { // Set the authenticated identity in the Fiber context - setAuthenticatedIdentityToContext(ctx, goidentity.FromHTTPRequestContext(r)) + spnego2.SetAuthenticatedIdentityToContext(ctx, goidentity.FromHTTPRequestContext(r)) // Call the next handler in the chain handleErr = ctx.Next() }), kt, service.Logger(cfg.Log)) // Convert Fiber context to HTTP request rawReq, err := adaptor.ConvertRequest(ctx, true) if err != nil { - return fmt.Errorf("%w: %w", config.ErrConvertRequestFailed, err) + return fmt.Errorf("%w: %w", spnego2.ErrConvertRequestFailed, err) } // Serve the request using the SPNEGO handler - handler.ServeHTTP(wrapCtx{ctx}, rawReq) + handler.ServeHTTP(utils.NewWrapFiberContext(ctx), rawReq) return handleErr }, nil } - -// setAuthenticatedIdentityToContext stores the authenticated identity in the Fiber context. -// It takes a Fiber context pointer and an identity, and sets it using the ContextKeyOfIdentity key -// for later retrieval by other handlers in the request chain. -func setAuthenticatedIdentityToContext(ctx *fiber.Ctx, identity goidentity.Identity) { - ctx.Locals(config.ContextKeyOfIdentity, identity) -} - -// GetAuthenticatedIdentityFromContext retrieves the authenticated identity from the Fiber context. -// It returns the identity and a boolean indicating if it was found. -// This function should be used by subsequent handlers to access the authenticated user's information. -// -// Example: -// -// user, ok := GetAuthenticatedIdentityFromContext(ctx) -// if ok { -// fmt.Printf("Authenticated user: %s\n", user.UserName()) -// } -func GetAuthenticatedIdentityFromContext(ctx *fiber.Ctx) (goidentity.Identity, bool) { - id, ok := ctx.Locals(config.ContextKeyOfIdentity).(goidentity.Identity) - return id, ok -} - -// wrapCtx wraps a Fiber context pointer to implement the http.ResponseWriter interface. -// This adapter allows the Fiber context to be used with standard HTTP handlers -// that expect an http.ResponseWriter, bridging the gap between Fiber's context -// model and the standard library's HTTP interfaces. - -type wrapCtx struct { - *fiber.Ctx -} - -// Header returns the request headers from the wrapped Fiber context. -// This method implements the http.ResponseWriter interface. -func (w wrapCtx) Header() http.Header { - return w.Ctx.GetReqHeaders() -} - -// WriteHeader sets the HTTP status code on the wrapped Fiber context. -// This method implements the http.ResponseWriter interface. -func (w wrapCtx) WriteHeader(statusCode int) { - w.Ctx.Status(statusCode) -} diff --git a/spnego/v2/spnego_test.go b/spnego/v2/spnego_test.go index cc49a1681..e9a1c2205 100644 --- a/spnego/v2/spnego_test.go +++ b/spnego/v2/spnego_test.go @@ -4,14 +4,12 @@ import ( "errors" "fmt" "net/http" - "net/http/httptest" - "os" "testing" "time" - "github.com/gofiber/contrib/spnego/config" + "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" "github.com/gofiber/fiber/v2" - "github.com/jcmturner/goidentity/v6" "github.com/jcmturner/gokrb5/v8/keytab" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" @@ -19,11 +17,11 @@ import ( func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { t.Run("test for keytab lookup function not set", func(t *testing.T) { - _, err := NewSpnegoKrb5AuthenticateMiddleware(nil) - require.ErrorIs(t, err, config.ErrConfigInvalidOfKeytabLookupFunctionRequired) + _, err := NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{}) + require.ErrorIs(t, err, spnego.ErrConfigInvalidOfKeytabLookupFunctionRequired) }) t.Run("test for keytab lookup failed", func(t *testing.T) { - middleware, err := NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + middleware, err := NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ KeytabLookup: func() (*keytab.Keytab, error) { return nil, errors.New("mock keytab lookup error") }, @@ -39,30 +37,50 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { ctx.Request.SetRequestURI("/authenticate") handler(ctx) require.Equal(t, http.StatusInternalServerError, ctx.Response.StatusCode()) - require.Equal(t, fmt.Sprintf("%s: mock keytab lookup error", config.ErrLookupKeytabFailed), string(ctx.Response.Body())) + require.Equal(t, fmt.Sprintf("%s: mock keytab lookup error", spnego.ErrLookupKeytabFailed), string(ctx.Response.Body())) }) t.Run("test for keytab lookup function is set", func(t *testing.T) { - var keytabFiles []string - for i := 0; i < 5; i++ { - kt, clean, err := newKeytabTempFile(fmt.Sprintf("HTTP/sso%d.example.com", i), "KRB5.TEST", 18, 19) - require.NoError(t, err) - t.Cleanup(clean) - keytabFiles = append(keytabFiles, kt) - } - lookupFunc, err := config.NewKeytabFileLookupFunc(keytabFiles...) + tm := time.Now() + _, clean1, err1 := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso1.example.com"), + utils.WithRealm("EXAMPLE.LOCAL"), + utils.WithFilename("./temp-sso1.keytab"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm, + }), + ) + require.NoError(t, err1) + t.Cleanup(clean1) + _, clean2, err2 := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso2.example.com"), + utils.WithRealm("EXAMPLE.LOCAL"), + utils.WithFilename("./temp-sso2.keytab"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm, + }), + ) + require.NoError(t, err2) + t.Cleanup(clean2) + lookupFunc, err := spnego.NewKeytabFileLookupFunc("./temp-sso1.keytab", "./temp-sso2.keytab") require.NoError(t, err) - middleware, err := NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + middleware, err := NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ KeytabLookup: lookupFunc, }) require.NoError(t, err) app := fiber.New() app.Get("/authenticate", middleware, func(c *fiber.Ctx) error { - user, ok := GetAuthenticatedIdentityFromContext(c) + user, ok := spnego.GetAuthenticatedIdentityFromContext(c) + require.True(t, ok) if ok { t.Logf("username: %s\ndomain: %s\n", user.UserName(), user.Domain()) } return c.SendString("authenticated") }) + handler := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) @@ -71,92 +89,3 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { require.Equal(t, fasthttp.StatusUnauthorized, ctx.Response.StatusCode()) }) } - -func TestNewKeytabFileLookupFunc(t *testing.T) { - t.Run("test for empty keytab files", func(t *testing.T) { - _, err := config.NewKeytabFileLookupFunc() - require.ErrorIs(t, err, config.ErrConfigInvalidOfAtLeastOneKeytabFileRequired) - }) - t.Run("test for has invalid keytab file", func(t *testing.T) { - kt1, clean, err := newKeytabTempFile("HTTP/sso.example.com", "KRB5.TEST", 18, 19) - require.NoError(t, err) - t.Cleanup(clean) - kt2, clean, err := newBadKeytabTempFile("HTTP/sso1.example.com", "KRB5.TEST", 18, 19) - require.NoError(t, err) - t.Cleanup(clean) - _, err = config.NewKeytabFileLookupFunc(kt1, kt2) - require.ErrorIs(t, err, config.ErrLoadKeytabFileFailed) - }) - t.Run("test for some keytab files", func(t *testing.T) { - var keytabFiles []string - for i := 0; i < 5; i++ { - kt, clean, err := newKeytabTempFile(fmt.Sprintf("HTTP/sso%d.example.com", i), "KRB5.TEST", 18, 19) - require.NoError(t, err) - t.Cleanup(clean) - keytabFiles = append(keytabFiles, kt) - } - lookupFunc, err := config.NewKeytabFileLookupFunc(keytabFiles...) - require.NoError(t, err) - _, err = lookupFunc() - require.NoError(t, err) - }) -} - -func newBadKeytabTempFile(principal string, realm string, et ...int32) (filename string, clean func(), err error) { - filename = fmt.Sprintf("./tmp_%d.keytab", time.Now().Unix()) - clean = func() { - os.Remove(filename) - } - var kt keytab.Keytab - for _, e := range et { - kt.AddEntry(principal, realm, "abcdefg", time.Now(), 2, e) - } - file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) - if err != nil { - return filename, clean, fmt.Errorf("open file failed: %w", err) - } - if _, err = kt.Write(file); err != nil { - return filename, clean, fmt.Errorf("write file failed: %w", err) - } - file.Close() - return filename, clean, nil -} - -func newKeytabTempFile(principal string, realm string, et ...int32) (filename string, clean func(), err error) { - filename = fmt.Sprintf("./tmp_%d.keytab", time.Now().Unix()) - clean = func() { - os.Remove(filename) - } - kt := keytab.New() - for _, e := range et { - kt.AddEntry(principal, realm, "abcdefg", time.Now(), 2, e) - } - file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) - if err != nil { - return filename, clean, fmt.Errorf("open file failed: %w", err) - } - if _, err = kt.Write(file); err != nil { - return filename, clean, fmt.Errorf("write file failed: %w", err) - } - file.Close() - return filename, clean, nil -} - -func TestGetAuthenticatedIdentityFromContext(t *testing.T) { - app := fiber.New() - app.Use("/testContext", func(ctx *fiber.Ctx) error { - user := goidentity.NewUser("test-user") - user.SetDomain("example.com") - _, ok := GetAuthenticatedIdentityFromContext(ctx) - require.False(t, ok) - setAuthenticatedIdentityToContext(ctx, &user) - user1, ok := GetAuthenticatedIdentityFromContext(ctx) - require.True(t, ok) - require.Equal(t, user.UserName(), user1.UserName()) - require.Equal(t, user.Domain(), user1.Domain()) - return ctx.SendStatus(200) - }) - resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/testContext", nil)) - require.NoError(t, err) - require.Equal(t, resp.StatusCode, 200) -} diff --git a/spnego/v3/spnego.go b/spnego/v3/spnego.go index aa1dca0f8..7e0173c03 100644 --- a/spnego/v3/spnego.go +++ b/spnego/v3/spnego.go @@ -9,7 +9,8 @@ import ( "log" "net/http" - "github.com/gofiber/contrib/spnego/config" + spnego2 "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" "github.com/gofiber/fiber/v3" flog "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/adaptor" @@ -22,13 +23,10 @@ import ( // It takes a Config struct and returns a Fiber handler or an error. // The middleware handles Kerberos authentication for incoming requests using the // SPNEGO protocol, verifying client credentials against the configured keytab. -func NewSpnegoKrb5AuthenticateMiddleware(cfg *config.Config) (fiber.Handler, error) { +func NewSpnegoKrb5AuthenticateMiddleware(cfg spnego2.Config) (fiber.Handler, error) { // Validate configuration - if cfg == nil { - cfg = &config.Config{} - } if cfg.KeytabLookup == nil { - return nil, config.ErrConfigInvalidOfKeytabLookupFunctionRequired + return nil, spnego2.ErrConfigInvalidOfKeytabLookupFunctionRequired } // Set default logger if not provided if cfg.Log == nil { @@ -39,66 +37,23 @@ func NewSpnegoKrb5AuthenticateMiddleware(cfg *config.Config) (fiber.Handler, err // Look up the keytab kt, err := cfg.KeytabLookup() if err != nil { - return fmt.Errorf("%w: %w", config.ErrLookupKeytabFailed, err) + return fmt.Errorf("%w: %w", spnego2.ErrLookupKeytabFailed, err) } // Create the SPNEGO handler using the keytab var handleErr error handler := spnego.SPNEGOKRB5Authenticate(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { // Set the authenticated identity in the Fiber context - setAuthenticatedIdentityToContext(ctx, goidentity.FromHTTPRequestContext(r)) + spnego2.SetAuthenticatedIdentityToContext(ctx, goidentity.FromHTTPRequestContext(r)) // Call the next handler in the chain handleErr = ctx.Next() }), kt, service.Logger(cfg.Log)) // Convert Fiber context to HTTP request rawReq, err := adaptor.ConvertRequest(ctx, true) if err != nil { - return fmt.Errorf("%w: %w", config.ErrConvertRequestFailed, err) + return fmt.Errorf("%w: %w", spnego2.ErrConvertRequestFailed, err) } // Serve the request using the SPNEGO handler - handler.ServeHTTP(wrapCtx{ctx}, rawReq) + handler.ServeHTTP(utils.NewWrapFiberContext(ctx), rawReq) return handleErr }, nil } - -// setAuthenticatedIdentityToContext stores the authenticated identity in the Fiber context. -// It takes a Fiber context and an identity, and sets it using the ContextKeyOfIdentity key -// for later retrieval by other handlers in the request chain. -func setAuthenticatedIdentityToContext(ctx fiber.Ctx, identity goidentity.Identity) { - ctx.Locals(config.ContextKeyOfIdentity, identity) -} - -// GetAuthenticatedIdentityFromContext retrieves the authenticated identity from the Fiber context. -// It returns the identity and a boolean indicating if it was found. -// This function should be used by subsequent handlers to access the authenticated user's information. -// -// Example: -// -// user, ok := GetAuthenticatedIdentityFromContext(ctx) -// if ok { -// fmt.Printf("Authenticated user: %s\n", user.UserName()) -// } -func GetAuthenticatedIdentityFromContext(ctx fiber.Ctx) (goidentity.Identity, bool) { - id, ok := ctx.Locals(config.ContextKeyOfIdentity).(goidentity.Identity) - return id, ok -} - -// wrapCtx wraps a Fiber context to implement the http.ResponseWriter interface. -// This adapter allows the Fiber context to be used with standard HTTP handlers -// that expect an http.ResponseWriter, bridging the gap between Fiber's context -// model and the standard library's HTTP interfaces. - -type wrapCtx struct { - fiber.Ctx -} - -// Header returns the request headers from the wrapped Fiber context. -// This method implements the http.ResponseWriter interface. -func (w wrapCtx) Header() http.Header { - return w.Ctx.GetReqHeaders() -} - -// WriteHeader sets the HTTP status code on the wrapped Fiber context. -// This method implements the http.ResponseWriter interface. -func (w wrapCtx) WriteHeader(statusCode int) { - w.Ctx.Status(statusCode) -} diff --git a/spnego/v3/spnego_test.go b/spnego/v3/spnego_test.go index 0986950df..523c6f7f4 100644 --- a/spnego/v3/spnego_test.go +++ b/spnego/v3/spnego_test.go @@ -4,14 +4,12 @@ import ( "errors" "fmt" "net/http" - "net/http/httptest" - "os" "testing" "time" - "github.com/gofiber/contrib/spnego/config" + "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" "github.com/gofiber/fiber/v3" - "github.com/jcmturner/goidentity/v6" "github.com/jcmturner/gokrb5/v8/keytab" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" @@ -19,11 +17,11 @@ import ( func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { t.Run("test for keytab lookup function not set", func(t *testing.T) { - _, err := NewSpnegoKrb5AuthenticateMiddleware(nil) - require.ErrorIs(t, err, config.ErrConfigInvalidOfKeytabLookupFunctionRequired) + _, err := NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{}) + require.ErrorIs(t, err, spnego.ErrConfigInvalidOfKeytabLookupFunctionRequired) }) t.Run("test for keytab lookup failed", func(t *testing.T) { - middleware, err := NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + middleware, err := NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ KeytabLookup: func() (*keytab.Keytab, error) { return nil, errors.New("mock keytab lookup error") }, @@ -39,25 +37,43 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { ctx.Request.SetRequestURI("/authenticate") handler(ctx) require.Equal(t, http.StatusInternalServerError, ctx.Response.StatusCode()) - require.Equal(t, fmt.Sprintf("%s: mock keytab lookup error", config.ErrLookupKeytabFailed), string(ctx.Response.Body())) + require.Equal(t, fmt.Sprintf("%s: mock keytab lookup error", spnego.ErrLookupKeytabFailed), string(ctx.Response.Body())) }) t.Run("test for keytab lookup function is set", func(t *testing.T) { - var keytabFiles []string - for i := 0; i < 5; i++ { - kt, clean, err := newKeytabTempFile(fmt.Sprintf("HTTP/sso%d.example.com", i), "KRB5.TEST", 18, 19) - require.NoError(t, err) - t.Cleanup(clean) - keytabFiles = append(keytabFiles, kt) - } - lookupFunc, err := config.NewKeytabFileLookupFunc(keytabFiles...) + tm := time.Now() + _, clean1, err1 := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso1.example.com"), + utils.WithRealm("EXAMPLE.LOCAL"), + utils.WithFilename("./temp-sso1.keytab"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm, + }), + ) + require.NoError(t, err1) + t.Cleanup(clean1) + _, clean2, err2 := utils.NewMockKeytab( + utils.WithPrincipal("HTTP/sso2.example.com"), + utils.WithRealm("EXAMPLE.LOCAL"), + utils.WithFilename("./temp-sso2.keytab"), + utils.WithPairs(utils.EncryptTypePair{ + Version: 2, + EncryptType: 18, + CreateTime: tm, + }), + ) + require.NoError(t, err2) + t.Cleanup(clean2) + lookupFunc, err := spnego.NewKeytabFileLookupFunc("./temp-sso1.keytab", "./temp-sso2.keytab") require.NoError(t, err) - middleware, err := NewSpnegoKrb5AuthenticateMiddleware(&config.Config{ + middleware, err := NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ KeytabLookup: lookupFunc, }) require.NoError(t, err) app := fiber.New() app.Get("/authenticate", middleware, func(c fiber.Ctx) error { - user, ok := GetAuthenticatedIdentityFromContext(c) + user, ok := spnego.GetAuthenticatedIdentityFromContext(c) if ok { t.Logf("username: %s\ndomain: %s\n", user.UserName(), user.Domain()) } @@ -71,92 +87,3 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { require.Equal(t, fasthttp.StatusUnauthorized, ctx.Response.StatusCode()) }) } - -func TestNewKeytabFileLookupFunc(t *testing.T) { - t.Run("test for empty keytab files", func(t *testing.T) { - _, err := config.NewKeytabFileLookupFunc() - require.ErrorIs(t, err, config.ErrConfigInvalidOfAtLeastOneKeytabFileRequired) - }) - t.Run("test for has invalid keytab file", func(t *testing.T) { - kt1, clean, err := newKeytabTempFile("HTTP/sso.example.com", "KRB5.TEST", 18, 19) - require.NoError(t, err) - t.Cleanup(clean) - kt2, clean, err := newBadKeytabTempFile("HTTP/sso1.example.com", "KRB5.TEST", 18, 19) - require.NoError(t, err) - t.Cleanup(clean) - _, err = config.NewKeytabFileLookupFunc(kt1, kt2) - require.ErrorIs(t, err, config.ErrLoadKeytabFileFailed) - }) - t.Run("test for some keytab files", func(t *testing.T) { - var keytabFiles []string - for i := 0; i < 5; i++ { - kt, clean, err := newKeytabTempFile(fmt.Sprintf("HTTP/sso%d.example.com", i), "KRB5.TEST", 18, 19) - require.NoError(t, err) - t.Cleanup(clean) - keytabFiles = append(keytabFiles, kt) - } - lookupFunc, err := config.NewKeytabFileLookupFunc(keytabFiles...) - require.NoError(t, err) - _, err = lookupFunc() - require.NoError(t, err) - }) -} - -func newBadKeytabTempFile(principal string, realm string, et ...int32) (filename string, clean func(), err error) { - filename = fmt.Sprintf("./tmp_%d.keytab", time.Now().Unix()) - clean = func() { - os.Remove(filename) - } - var kt keytab.Keytab - for _, e := range et { - kt.AddEntry(principal, realm, "abcdefg", time.Now(), 2, e) - } - file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) - if err != nil { - return filename, clean, fmt.Errorf("open file failed: %w", err) - } - if _, err = kt.Write(file); err != nil { - return filename, clean, fmt.Errorf("write file failed: %w", err) - } - file.Close() - return filename, clean, nil -} - -func newKeytabTempFile(principal string, realm string, et ...int32) (filename string, clean func(), err error) { - filename = fmt.Sprintf("./tmp_%d.keytab", time.Now().Unix()) - clean = func() { - os.Remove(filename) - } - kt := keytab.New() - for _, e := range et { - kt.AddEntry(principal, realm, "abcdefg", time.Now(), 2, e) - } - file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) - if err != nil { - return filename, clean, fmt.Errorf("open file failed: %w", err) - } - if _, err = kt.Write(file); err != nil { - return filename, clean, fmt.Errorf("write file failed: %w", err) - } - file.Close() - return filename, clean, nil -} - -func TestGetAuthenticatedIdentityFromContext(t *testing.T) { - app := fiber.New() - app.Use("/testContext", func(ctx fiber.Ctx) error { - user := goidentity.NewUser("test-user") - user.SetDomain("example.com") - _, ok := GetAuthenticatedIdentityFromContext(ctx) - require.False(t, ok) - setAuthenticatedIdentityToContext(ctx, &user) - user1, ok := GetAuthenticatedIdentityFromContext(ctx) - require.True(t, ok) - require.Equal(t, user.UserName(), user1.UserName()) - require.Equal(t, user.Domain(), user1.Domain()) - return ctx.SendStatus(200) - }) - resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/testContext", nil)) - require.NoError(t, err) - require.Equal(t, resp.StatusCode, 200) -} From b79516107fb8dbe8865a958040e1ac35e6da7336 Mon Sep 17 00:00:00 2001 From: Jarod Date: Tue, 26 Aug 2025 01:28:54 +0800 Subject: [PATCH 3/5] docs(spnego): Update README documentation - Update English README document structure, add badges and document ID - Synchronize update of Chinese translation version - Unify keytab lookup implementation in example code --- spnego/README.md | 89 +++++++++++++++++++-------------------- spnego/README.zh-CN.md | 94 +++++++++++++++++++----------------------- 2 files changed, 88 insertions(+), 95 deletions(-) diff --git a/spnego/README.md b/spnego/README.md index bd600874a..8b53b1bcc 100644 --- a/spnego/README.md +++ b/spnego/README.md @@ -1,8 +1,16 @@ +--- +id: spnego +--- + # SPNEGO Kerberos Authentication Middleware for Fiber -[中文版本](README.zh-CN.md) +![Release](https://img.shields.io/github/v/tag/gofiber/contrib?filter=spnego*) +[![Discord](https://img.shields.io/discord/704680098577514527?style=flat&label=%F0%9F%92%AC%20discord&color=00ACD7)](https://gofiber.io/discord) +![Test](https://github.com/gofiber/contrib/workflows/Test%20spnego/badge.svg) -This middleware provides SPNEGO (Simple and Protected GSSAPI Negotiation Mechanism) authentication for Fiber applications, enabling Kerberos authentication for HTTP requests. +This middleware provides SPNEGO (Simple and Protected GSSAPI Negotiation Mechanism) authentication for [Fiber](https://github.com/gofiber/fiber) applications, enabling Kerberos authentication for HTTP requests and inspired by [gokrb5](https://github.com/jcmturner/gokrb5) + +[中文版本](README.zh-CN.md) ## Features @@ -37,37 +45,33 @@ go get github.com/gofiber/contrib/spnego/v2 package main import ( - flog "github.com/gofiber/fiber/v3/log" "fmt" + "time" - "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" + v3 "github.com/gofiber/contrib/spnego/v3" "github.com/gofiber/fiber/v3" - "github.com/gofiber/contrib/spnego/v3" + "github.com/gofiber/fiber/v3/log" ) func main() { app := fiber.New() - + // Create a configuration with a keytab lookup function - cfg := &spnego.Config{ - // Use a function to look up keytab from files - KeytabLookup: func() (*keytab.Keytab, error) { - // Implement your keytab lookup logic here - // This could be from files, database, or other sources - kt, err := spnego.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") - if err != nil { - return nil, err - } - return kt() - }, - // Optional: Set a custom logger - Log: flog.DefaultLogger().Logger().(*log.Logger), + // For testing, you can create a mock keytab file using utils.NewMockKeytab + // In production, use a real keytab file + keytabLookup, err := spnego.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") + if err != nil { + log.Fatalf("Failed to create keytab lookup function: %v", err) } - + // Create the middleware - authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(cfg) + authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ + KeytabLookup: keytabLookup, + }) if err != nil { - flog.Fatalf("Failed to create middleware: %v", err) + log.Fatalf("Failed to create middleware: %v", err) } // Apply the middleware to protected routes @@ -75,7 +79,7 @@ func main() { // Access authenticated identity app.Get("/protected/resource", func(c fiber.Ctx) error { - identity, ok := v3.GetAuthenticatedIdentityFromContext(c) + identity, ok := spnego.GetAuthenticatedIdentityFromContext(c) if !ok { return c.Status(fiber.StatusUnauthorized).SendString("Unauthorized") } @@ -96,32 +100,29 @@ import ( "log" "os" - "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" + v2 "github.com/gofiber/contrib/spnego/v2" "github.com/gofiber/fiber/v2" - "github.com/gofiber/contrib/spnego/v2" ) func main() { app := fiber.New() - + // Create a configuration with a keytab lookup function - cfg := &spnego.Config{ - // Use a function to look up keytab from files - KeytabLookup: func() (*keytab.Keytab, error) { - // Implement your keytab lookup logic here - // This could be from files, database, or other sources - kt, err := spnego.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") - if err != nil { - return nil, err - } - return kt() - }, - // Optional: Set a custom logger - Log: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), + // For testing, you can create a mock keytab file using utils.NewMockKeytab + // In production, use a real keytab file + keytabLookup, err := spnego.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") + if err != nil { + log.Fatalf("Failed to create keytab lookup function: %v", err) } - + // Create the middleware - authMiddleware, err := v2.NewSpnegoKrb5AuthenticateMiddleware(cfg) + authMiddleware, err := v2.NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ + KeytabLookup: keytabLookup, + // Optional: Set a custom logger + Log: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), + }) if err != nil { log.Fatalf("Failed to create middleware: %v", err) } @@ -131,7 +132,7 @@ func main() { // Access authenticated identity app.Get("/protected/resource", func(c *fiber.Ctx) error { - identity, ok := v2.GetAuthenticatedIdentityFromContext(c) + identity, ok := spnego.GetAuthenticatedIdentityFromContext(c) if !ok { return c.Status(fiber.StatusUnauthorized).SendString("Unauthorized") } @@ -164,7 +165,7 @@ func remoteKeytabLookup() (*keytab.Keytab, error) { ## API Reference -### `NewSpnegoKrb5AuthenticateMiddleware(cfg *Config) (fiber.Handler, error)` +### `NewSpnegoKrb5AuthenticateMiddleware(cfg spnego.Config) (fiber.Handler, error)` Creates a new SPNEGO authentication middleware. @@ -194,4 +195,4 @@ The `Config` struct supports the following fields: - Ensure your Kerberos infrastructure is properly configured - The middleware handles the SPNEGO negotiation process -- Authenticated identities are stored in the Fiber context using `config.ContextKeyOfIdentity` +- Authenticated identities are stored in the Fiber context using `spnego.contextKeyOfIdentity` diff --git a/spnego/README.zh-CN.md b/spnego/README.zh-CN.md index 0693ccae4..b0a617359 100644 --- a/spnego/README.zh-CN.md +++ b/spnego/README.zh-CN.md @@ -37,38 +37,33 @@ package main import ( - flog "github.com/gofiber/fiber/v3/log" "fmt" - "log" + "time" - "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" + v3 "github.com/gofiber/contrib/spnego/v3" "github.com/gofiber/fiber/v3" - "github.com/gofiber/contrib/spnego/v3" + "github.com/gofiber/fiber/v3/log" ) func main() { app := fiber.New() - + // 创建带有keytab查找函数的配置 - cfg := &v3.Config{ - // 使用函数从文件查找keytab - KeytabLookup: func() (*keytab.Keytab, error) { - // 在此实现您的keytab查找逻辑 - // 可以从文件、数据库或其他来源获取 - kt, err := v3.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") - if err != nil { - return nil, err - } - return kt() - }, - // 可选:设置自定义日志器 - Log: flog.DefaultLogger().Logger().(*log.Logger), + // 测试环境下,您可以使用utils.NewMockKeytab创建模拟keytab文件 + // 生产环境下,请使用真实的keytab文件 + keytabLookup, err := spnego.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") + if err != nil { + log.Fatalf("创建keytab查找函数失败: %v", err) } - + // 创建中间件 - authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(cfg) + authMiddleware, err := v3.NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ + KeytabLookup: keytabLookup, + }) if err != nil { - flog.Fatalf("创建中间件失败: %v", err) + log.Fatalf("创建中间件失败: %v", err) } // 将中间件应用于受保护的路由 @@ -76,7 +71,7 @@ func main() { // 访问认证身份 app.Get("/protected/resource", func(c fiber.Ctx) error { - identity, ok := v3.GetAuthenticatedIdentityFromContext(c) + identity, ok := spnego.GetAuthenticatedIdentityFromContext(c) if !ok { return c.Status(fiber.StatusUnauthorized).SendString("未授权") } @@ -97,32 +92,29 @@ import ( "log" "os" - "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/gofiber/contrib/spnego" + "github.com/gofiber/contrib/spnego/utils" + v2 "github.com/gofiber/contrib/spnego/v2" "github.com/gofiber/fiber/v2" - "github.com/gofiber/contrib/spnego/v2" ) func main() { app := fiber.New() - + // 创建带有keytab查找函数的配置 - cfg := &v2.Config{ - // 使用函数从文件查找keytab - KeytabLookup: func() (*keytab.Keytab, error) { - // 在此实现您的keytab查找逻辑 - // 可以从文件、数据库或其他来源获取 - kt, err := v2.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") - if err != nil { - return nil, err - } - return kt() - }, - // 可选:设置自定义日志器 - Log: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), + // 测试环境下,您可以使用utils.NewMockKeytab创建模拟keytab文件 + // 生产环境下,请使用真实的keytab文件 + keytabLookup, err := spnego.NewKeytabFileLookupFunc("/path/to/keytab/file.keytab") + if err != nil { + log.Fatalf("创建keytab查找函数失败: %v", err) } - + // 创建中间件 - authMiddleware, err := v2.NewSpnegoKrb5AuthenticateMiddleware(cfg) + authMiddleware, err := v2.NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ + KeytabLookup: keytabLookup, + // 可选:设置自定义日志器 + Log: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), + }) if err != nil { log.Fatalf("创建中间件失败: %v", err) } @@ -132,7 +124,7 @@ func main() { // 访问认证身份 app.Get("/protected/resource", func(c *fiber.Ctx) error { - identity, ok := v2.GetAuthenticatedIdentityFromContext(c) + identity, ok := spnego.GetAuthenticatedIdentityFromContext(c) if !ok { return c.Status(fiber.StatusUnauthorized).SendString("未授权") } @@ -141,31 +133,30 @@ func main() { app.Listen(":3000") } -``` ## 动态Keytab查找 -该中间件设计具有可扩展性,允许从静态文件以外的各种来源检索keytab: +中间件的设计考虑了扩展性,允许从静态文件以外的各种来源检索keytab: ```go // 示例:从数据库检索keytab func dbKeytabLookup() (*keytab.Keytab, error) { - // 此处实现数据库查找逻辑 + // 您的数据库查找逻辑 // ... return keytabFromDatabase, nil } // 示例:从远程服务检索keytab func remoteKeytabLookup() (*keytab.Keytab, error) { - // 此处实现远程服务调用逻辑 + // 您的远程服务调用逻辑 // ... return keytabFromRemote, nil } ``` -## API 参考 +## API参考 -### `NewSpnegoKrb5AuthenticateMiddleware(cfg *Config) (fiber.Handler, error)` +### `NewSpnegoKrb5AuthenticateMiddleware(cfg spnego.Config) (fiber.Handler, error)` 创建一个新的SPNEGO认证中间件。 @@ -175,14 +166,14 @@ func remoteKeytabLookup() (*keytab.Keytab, error) { ### `NewKeytabFileLookupFunc(keytabFiles ...string) (KeytabLookupFunc, error)` -创建一个加载keytab文件的KeytabLookupFunc。 +创建一个加载keytab文件的新KeytabLookupFunc。 ## 配置 `Config`结构体支持以下字段: -- `KeytabLookup`: 检索keytab的函数(必需) -- `Log`: 用于中间件日志记录的日志器(可选,默认为Fiber的默认日志器) +- `KeytabLookup`:检索keytab的函数(必需) +- `Log`:用于中间件日志记录的日志器(可选,默认为Fiber的默认日志器) ## 要求 @@ -195,4 +186,5 @@ func remoteKeytabLookup() (*keytab.Keytab, error) { - 确保您的Kerberos基础设施已正确配置 - 中间件处理SPNEGO协商过程 -- 已认证的身份使用`contextKeyOfIdentity`存储在Fiber上下文中 +- 已认证的身份使用`spnego.contextKeyOfIdentity`存储在Fiber上下文中 +``` From 5197bf6964a36fcbcca9e3e7d30a765d4581d7ec Mon Sep 17 00:00:00 2001 From: Jarod Date: Tue, 26 Aug 2025 09:56:25 +0800 Subject: [PATCH 4/5] fix(utils): Fix adding TRUNC flag when creating mock keytab files test: Update test cases to use temporary directory instead of current directory refactor: Remove unused config package --- spnego/config/config.go | 61 -------------------------------- spnego/utils/mock_keytab.go | 2 +- spnego/utils/mock_keytab_test.go | 17 ++++++--- spnego/v2/spnego_test.go | 9 +++-- spnego/v3/spnego_test.go | 9 +++-- 5 files changed, 26 insertions(+), 72 deletions(-) delete mode 100644 spnego/config/config.go diff --git a/spnego/config/config.go b/spnego/config/config.go deleted file mode 100644 index cb213de76..000000000 --- a/spnego/config/config.go +++ /dev/null @@ -1,61 +0,0 @@ -package config - -import ( - "errors" - "fmt" - "log" - - "github.com/jcmturner/gokrb5/v8/keytab" -) - -// ErrConfigInvalidOfKeytabLookupFunctionRequired is returned when the KeytabLookup function is not set in Config -var ErrConfigInvalidOfKeytabLookupFunctionRequired = errors.New("config invalid: keytab lookup function is required") - -// ErrLookupKeytabFailed is returned when the keytab lookup fails -var ErrLookupKeytabFailed = errors.New("keytab lookup failed") - -// ErrConvertRequestFailed is returned when the request conversion to HTTP request fails -var ErrConvertRequestFailed = errors.New("convert request failed") - -// ErrConfigInvalidOfAtLeastOneKeytabFileRequired is returned when no keytab files are provided -var ( - ErrConfigInvalidOfAtLeastOneKeytabFileRequired = errors.New("config invalid: at least one keytab file required") - ErrLoadKeytabFileFailed = errors.New("load keytab failed") -) - -// ContextKeyOfIdentity is the key used to store the authenticated identity in the Fiber context -const ContextKeyOfIdentity = "middleware.spnego.Identity" - -// KeytabLookupFunc is a function type that returns a keytab or an error -// It's used to look up the keytab dynamically when needed -// This design allows for extensibility, enabling keytab retrieval from various sources -// such as databases, remote services, or other custom implementations beyond static files -type KeytabLookupFunc func() (*keytab.Keytab, error) - -// Config holds the configuration for the SPNEGO middleware -// It includes the keytab lookup function and a logger -type Config struct { - // KeytabLookup is a function that retrieves the keytab - KeytabLookup KeytabLookupFunc - // Log is the logger used for middleware logging - Log *log.Logger -} - -// NewKeytabFileLookupFunc creates a new KeytabLookupFunc that loads keytab files -// It accepts one or more keytab file paths and returns a function that loads them -func NewKeytabFileLookupFunc(keytabFiles ...string) (KeytabLookupFunc, error) { - if len(keytabFiles) == 0 { - return nil, ErrConfigInvalidOfAtLeastOneKeytabFileRequired - } - var mergeKeytab keytab.Keytab - for _, keytabFile := range keytabFiles { - kt, err := keytab.Load(keytabFile) - if err != nil { - return nil, fmt.Errorf("%w: file %s load failed: %w", ErrLoadKeytabFileFailed, keytabFile, err) - } - mergeKeytab.Entries = append(mergeKeytab.Entries, kt.Entries...) - } - return func() (*keytab.Keytab, error) { - return &mergeKeytab, nil - }, nil -} diff --git a/spnego/utils/mock_keytab.go b/spnego/utils/mock_keytab.go index 4c576d389..a326eff6b 100644 --- a/spnego/utils/mock_keytab.go +++ b/spnego/utils/mock_keytab.go @@ -132,7 +132,7 @@ func NewMockKeytab(opts ...MockOption) (*keytab.Keytab, func(), error) { } var clean = func() {} if len(opt.Filename) > 0 { - file, err := defaultFileOperator.OpenFile(opt.Filename, os.O_RDWR|os.O_CREATE, 0o666) + file, err := defaultFileOperator.OpenFile(opt.Filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o666) if err != nil { return nil, nil, fmt.Errorf("error opening file: %w", err) } diff --git a/spnego/utils/mock_keytab_test.go b/spnego/utils/mock_keytab_test.go index c2114adfa..2a7b00d15 100644 --- a/spnego/utils/mock_keytab_test.go +++ b/spnego/utils/mock_keytab_test.go @@ -2,6 +2,7 @@ package utils import ( "os" + "path" "testing" "time" @@ -67,7 +68,11 @@ func TestNewMockKeytab(t *testing.T) { require.Equal(t, 3, kv) }) t.Run("test file open failed", func(t *testing.T) { + prevFileOperator := defaultFileOperator defaultFileOperator = mockFileOperator{flag: 0x01} + t.Cleanup(func() { + defaultFileOperator = prevFileOperator + }) _, _, err := NewMockKeytab( WithPrincipal("HTTP/sso.example.com"), WithRealm("TEST.LOCAL"), @@ -82,7 +87,11 @@ func TestNewMockKeytab(t *testing.T) { require.NoFileExists(t, "./temp.keytab") }) t.Run("test file write failed", func(t *testing.T) { + prevFileOperator := defaultFileOperator defaultFileOperator = mockFileOperator{flag: 0x02} + t.Cleanup(func() { + defaultFileOperator = prevFileOperator + }) _, _, err := NewMockKeytab( WithPrincipal("HTTP/sso.example.com"), WithRealm("TEST.LOCAL"), @@ -97,7 +106,7 @@ func TestNewMockKeytab(t *testing.T) { require.NoFileExists(t, "./temp.keytab") }) t.Run("test file created", func(t *testing.T) { - defaultFileOperator = myFileOperator{} + filename := path.Join(t.TempDir(), "temp.keytab") tm := time.Now() _, clean, err := NewMockKeytab( WithPrincipal("HTTP/sso.example.com"), @@ -107,12 +116,12 @@ func TestNewMockKeytab(t *testing.T) { EncryptType: 18, CreateTime: tm, }), - WithFilename("./temp.keytab"), + WithFilename(filename), ) require.NoError(t, err) t.Cleanup(clean) - require.FileExists(t, "./temp.keytab") - kt, err := keytab.Load("./temp.keytab") + require.FileExists(t, filename) + kt, err := keytab.Load(filename) require.NoError(t, err) _, kv, err := kt.GetEncryptionKey(types.NewPrincipalName(1, "HTTP/sso.example.com"), "TEST.LOCAL", 3, 18) require.NoError(t, err) diff --git a/spnego/v2/spnego_test.go b/spnego/v2/spnego_test.go index e9a1c2205..9cc8dbd39 100644 --- a/spnego/v2/spnego_test.go +++ b/spnego/v2/spnego_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "path" "testing" "time" @@ -41,10 +42,12 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { }) t.Run("test for keytab lookup function is set", func(t *testing.T) { tm := time.Now() + filename1 := path.Join(t.TempDir(), "temp-sso1.keytab") + filename2 := path.Join(t.TempDir(), "temp-sso2.keytab") _, clean1, err1 := utils.NewMockKeytab( utils.WithPrincipal("HTTP/sso1.example.com"), utils.WithRealm("EXAMPLE.LOCAL"), - utils.WithFilename("./temp-sso1.keytab"), + utils.WithFilename(filename1), utils.WithPairs(utils.EncryptTypePair{ Version: 2, EncryptType: 18, @@ -56,7 +59,7 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { _, clean2, err2 := utils.NewMockKeytab( utils.WithPrincipal("HTTP/sso2.example.com"), utils.WithRealm("EXAMPLE.LOCAL"), - utils.WithFilename("./temp-sso2.keytab"), + utils.WithFilename(filename2), utils.WithPairs(utils.EncryptTypePair{ Version: 2, EncryptType: 18, @@ -65,7 +68,7 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { ) require.NoError(t, err2) t.Cleanup(clean2) - lookupFunc, err := spnego.NewKeytabFileLookupFunc("./temp-sso1.keytab", "./temp-sso2.keytab") + lookupFunc, err := spnego.NewKeytabFileLookupFunc(filename1, filename2) require.NoError(t, err) middleware, err := NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ KeytabLookup: lookupFunc, diff --git a/spnego/v3/spnego_test.go b/spnego/v3/spnego_test.go index 523c6f7f4..a7c400212 100644 --- a/spnego/v3/spnego_test.go +++ b/spnego/v3/spnego_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "path" "testing" "time" @@ -41,10 +42,12 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { }) t.Run("test for keytab lookup function is set", func(t *testing.T) { tm := time.Now() + filename1 := path.Join(t.TempDir(), "temp-sso1.keytab") + filename2 := path.Join(t.TempDir(), "temp-sso2.keytab") _, clean1, err1 := utils.NewMockKeytab( utils.WithPrincipal("HTTP/sso1.example.com"), utils.WithRealm("EXAMPLE.LOCAL"), - utils.WithFilename("./temp-sso1.keytab"), + utils.WithFilename(filename1), utils.WithPairs(utils.EncryptTypePair{ Version: 2, EncryptType: 18, @@ -56,7 +59,7 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { _, clean2, err2 := utils.NewMockKeytab( utils.WithPrincipal("HTTP/sso2.example.com"), utils.WithRealm("EXAMPLE.LOCAL"), - utils.WithFilename("./temp-sso2.keytab"), + utils.WithFilename(filename2), utils.WithPairs(utils.EncryptTypePair{ Version: 2, EncryptType: 18, @@ -65,7 +68,7 @@ func TestNewSpnegoKrb5AuthenticateMiddleware(t *testing.T) { ) require.NoError(t, err2) t.Cleanup(clean2) - lookupFunc, err := spnego.NewKeytabFileLookupFunc("./temp-sso1.keytab", "./temp-sso2.keytab") + lookupFunc, err := spnego.NewKeytabFileLookupFunc(filename1, filename2) require.NoError(t, err) middleware, err := NewSpnegoKrb5AuthenticateMiddleware(spnego.Config{ KeytabLookup: lookupFunc, From 952aaf6010b7359a229639a80f8257c6bba34bd1 Mon Sep 17 00:00:00 2001 From: Jarod Date: Tue, 26 Aug 2025 10:02:15 +0800 Subject: [PATCH 5/5] fix: fix when write mock file fail, open file is not close before remove --- spnego/utils/mock_keytab.go | 1 + 1 file changed, 1 insertion(+) diff --git a/spnego/utils/mock_keytab.go b/spnego/utils/mock_keytab.go index a326eff6b..3daf52dc8 100644 --- a/spnego/utils/mock_keytab.go +++ b/spnego/utils/mock_keytab.go @@ -140,6 +140,7 @@ func NewMockKeytab(opts ...MockOption) (*keytab.Keytab, func(), error) { _ = defaultFileOperator.Remove(opt.Filename) } if _, err = kt.Write(file); err != nil { + file.Close() clean() return nil, nil, fmt.Errorf("error writing to file: %w", err) }