Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions service/internal/auth/protovalidate_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package auth

import (
"context"
"fmt"

"buf.build/go/protovalidate"
"connectrpc.com/connect"
"github.com/lestrrat-go/jwx/v2/jwt"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

// ProtoAttrMapper extracts selected proto fields and converts them to
// casbin-request attributes. Enforces whitelist-only access to ensure
// ONLY configured fields are available to authorization policies.
type ProtoAttrMapper struct {
// Allowed fields to extract and expose to Casbin (whitelist-only)
Allowed []string
// RequiredFields that must exist on the request (subset of Allowed)
RequiredFields []string
// Validate controls whether to run protovalidate on the incoming message
Validate bool
}

// Interceptor returns a ConnectRPC unary interceptor that validates the
// request protobuf using protovalidate and attaches a map[string]string of
// attributes to the context for downstream enforcement.
func (p *ProtoAttrMapper) Interceptor(e *Enforcer) connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
// Validate proto message using protovalidate if available
if any := req.Any(); any != nil {

Check failure on line 33 in service/internal/auth/protovalidate_interceptor.go

View workflow job for this annotation

GitHub Actions / go (service)

redefines-builtin-id: redefinition of the built-in type any (revive)
if m, ok := any.(proto.Message); ok {
if p.Validate {
v, err := protovalidate.New()
if err == nil {
if err := v.Validate(m); err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("protovalidate failed: %w", err))
}
}
}

// Build attributes map - whitelist-only extraction
attrs := map[string]string{}
mr := m.ProtoReflect()
for _, allow := range p.Allowed {
if val, ok := lookupProtoFieldString(mr, allow); ok {

Check failure on line 48 in service/internal/auth/protovalidate_interceptor.go

View workflow job for this annotation

GitHub Actions / go (service)

shadow: declaration of "ok" shadows declaration at line 34 (govet)
attrs[allow] = val
}
}

// Validate required fields are present
for _, required := range p.RequiredFields {
if _, exists := attrs[required]; !exists {
return nil, connect.NewError(
connect.CodeInvalidArgument,
fmt.Errorf("required field %q is missing or invalid", required),
)
}
}

// Attach attrs to context for downstream (store under key "casbin_attrs")
// SECURITY: Only whitelisted fields are in this map - no other request
// fields are accessible to Casbin policy evaluation
ctx = context.WithValue(ctx, casbinContextKey("casbin_attrs"), attrs)

// Optionally perform synchronous enforcement: derive resource/action
if e != nil {
if tk, ok := ctx.Value(tokenContextKey{}).(jwt.Token); ok {

Check failure on line 70 in service/internal/auth/protovalidate_interceptor.go

View workflow job for this annotation

GitHub Actions / go (service)

shadow: declaration of "ok" shadows declaration at line 34 (govet)
res := req.Spec().Procedure
act := req.Spec().Procedure
_, _ = e.Enforce(tk, res, act)
}
}
}
}
return next(ctx, req)
})
}
return connect.UnaryInterceptorFunc(interceptor)
}

// helper to lookup a dot-separated path on a protoreflect.Message and
// return its string value if present.
func lookupProtoFieldString(m protoreflect.Message, path string) (string, bool) {
// Only support single-level fields for now to keep simple
fld := m.Descriptor().Fields().ByName(protoreflect.Name(path))
if fld == nil {
return "", false
}
v := m.Get(fld)
if !v.IsValid() {
return "", false
}
// Convert scalar to string if possible
switch fld.Kind() {

Check failure on line 97 in service/internal/auth/protovalidate_interceptor.go

View workflow job for this annotation

GitHub Actions / go (service)

missing cases in switch of type protoreflect.Kind: protoreflect.EnumKind, protoreflect.Sint32Kind, protoreflect.Sint64Kind, protoreflect.Sfixed32Kind, protoreflect.Fixed32Kind, protoreflect.FloatKind, protoreflect.Sfixed64Kind, protoreflect.Fixed64Kind, protoreflect.DoubleKind, protoreflect.BytesKind, protoreflect.MessageKind, protoreflect.GroupKind (exhaustive)
case protoreflect.StringKind:
s := v.String()
// Treat empty strings as missing for required field validation
if s == "" {
return "", false
}
return s, true
case protoreflect.Int32Kind, protoreflect.Int64Kind:
return fmt.Sprintf("%d", v.Int()), true

Check failure on line 106 in service/internal/auth/protovalidate_interceptor.go

View workflow job for this annotation

GitHub Actions / go (service)

integer-format: fmt.Sprintf can be replaced with faster strconv.FormatInt (perfsprint)
case protoreflect.Uint32Kind, protoreflect.Uint64Kind:
return fmt.Sprintf("%d", v.Uint()), true

Check failure on line 108 in service/internal/auth/protovalidate_interceptor.go

View workflow job for this annotation

GitHub Actions / go (service)

integer-format: fmt.Sprintf can be replaced with faster strconv.FormatUint (perfsprint)
case protoreflect.BoolKind:
return fmt.Sprintf("%t", v.Bool()), true

Check failure on line 110 in service/internal/auth/protovalidate_interceptor.go

View workflow job for this annotation

GitHub Actions / go (service)

bool-format: fmt.Sprintf can be replaced with faster strconv.FormatBool (perfsprint)
default:
return "", false
}
}

// context keys
type casbinContextKey string

Check failure on line 117 in service/internal/auth/protovalidate_interceptor.go

View workflow job for this annotation

GitHub Actions / go (service)

File is not properly formatted (gofumpt)
type tokenContextKey struct{}

// GetCasbinAttrsFromContext retrieves the extracted proto attributes from the context.
// Returns the attributes map and true if present, or nil and false if not found.
func GetCasbinAttrsFromContext(ctx context.Context) (map[string]string, bool) {
v := ctx.Value(casbinContextKey("casbin_attrs"))
if v == nil {
return nil, false
}
attrs, ok := v.(map[string]string)
return attrs, ok
}
229 changes: 229 additions & 0 deletions service/internal/auth/protovalidate_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package auth

import (
"context"
"testing"

"connectrpc.com/connect"
"github.com/opentdf/platform/protocol/go/common"
"github.com/opentdf/platform/service/logger"
"github.com/stretchr/testify/require"
)

func Test_ProtoAttrMapper_Interceptor(t *testing.T) {
mapper := &ProtoAttrMapper{Allowed: []string{"name", "id"}, Validate: false}

// create a simple proto message from policy namespace that has string fields
msg := &common.IdNameIdentifier{
Id: "abc",
Name: "example",
}

// create a no-op next handler that checks context for attrs
next := func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {

Check failure on line 23 in service/internal/auth/protovalidate_interceptor_test.go

View workflow job for this annotation

GitHub Actions / go (service)

unused-parameter: parameter 'req' seems to be unused, consider removing or renaming it as _ (revive)
v := ctx.Value(casbinContextKey("casbin_attrs"))
require.NotNil(t, v)
m, ok := v.(map[string]string)
require.True(t, ok)
require.Equal(t, "example", m["name"])
require.Equal(t, "abc", m["id"])
return connect.NewResponse[any](nil), nil
}

interceptor := mapper.Interceptor(nil)
wrapped := interceptor(next)

// Build a connect request wrapper
req := connect.NewRequest(msg)
_, err := wrapped(context.Background(), req)
require.NoError(t, err)
}

func Test_ProtoAttrMapper_RequiredFields(t *testing.T) {
t.Run("missing required field should fail", func(t *testing.T) {
mapper := &ProtoAttrMapper{
Allowed: []string{"name", "id"},
RequiredFields: []string{"name", "id"},
Validate: false,
}

// Message missing 'name' field (empty string)
msg := &common.IdNameIdentifier{
Id: "abc",
Name: "", // empty/missing
}

next := func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {

Check failure on line 56 in service/internal/auth/protovalidate_interceptor_test.go

View workflow job for this annotation

GitHub Actions / go (service)

unused-parameter: parameter 'ctx' seems to be unused, consider removing or renaming it as _ (revive)
t.Fatal("should not reach next handler")
return connect.NewResponse[any](nil), nil
}

interceptor := mapper.Interceptor(nil)
wrapped := interceptor(next)

req := connect.NewRequest(msg)
_, err := wrapped(context.Background(), req)
require.Error(t, err)
require.Contains(t, err.Error(), "required field")
require.Contains(t, err.Error(), "name")
})

t.Run("all required fields present should succeed", func(t *testing.T) {
mapper := &ProtoAttrMapper{
Allowed: []string{"name", "id"},
RequiredFields: []string{"name"},
Validate: false,
}

msg := &common.IdNameIdentifier{
Id: "abc",
Name: "example",
}

next := func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
v := ctx.Value(casbinContextKey("casbin_attrs"))
require.NotNil(t, v)
return connect.NewResponse[any](nil), nil
}

interceptor := mapper.Interceptor(nil)
wrapped := interceptor(next)

req := connect.NewRequest(msg)
_, err := wrapped(context.Background(), req)
require.NoError(t, err)
})
}

func Test_ProtoAttrMapper_WhitelistOnly(t *testing.T) {
t.Run("only whitelisted fields should be in attrs", func(t *testing.T) {
// Only allow 'name', not 'id'
mapper := &ProtoAttrMapper{
Allowed: []string{"name"},
Validate: false,
}

msg := &common.IdNameIdentifier{
Id: "secret-id-should-not-be-exposed",
Name: "example",
}

next := func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
v := ctx.Value(casbinContextKey("casbin_attrs"))
require.NotNil(t, v)
m, ok := v.(map[string]string)
require.True(t, ok)

// SECURITY TEST: only 'name' should be present
require.Equal(t, "example", m["name"])
require.NotContains(t, m, "id", "id should NOT be in attrs - security violation")
require.Len(t, m, 1, "only whitelisted fields should be present")
return connect.NewResponse[any](nil), nil
}

interceptor := mapper.Interceptor(nil)
wrapped := interceptor(next)

req := connect.NewRequest(msg)
_, err := wrapped(context.Background(), req)
require.NoError(t, err)
})
}

func Test_ProtoAttrMapper_EnforcementIntegration(t *testing.T) {
t.Run("enforcement with attribute-based policy", func(t *testing.T) {
// Create a Casbin enforcer with an attribute-aware model
modelConf := `
[request_definition]
r = sub, res, act, owner

[policy_definition]
p = sub, res, act, owner, eft

[role_definition]
g = _, _

[policy_effect]
e = some(where (p.eft == allow)) && !some(where (p.eft == deny))

[matchers]
m = g(r.sub, p.sub) && keyMatch(r.res, p.res) && keyMatch(r.act, p.act) && (p.owner == "*" || r.owner == p.owner)
`

policyCSV := `
p, role:admin, /policy/*, read, *, allow
p, role:user, /policy/*, read, user123, allow
g, admin-user, role:admin
g, regular-user, role:user
`
loggerInstance, err := logger.NewLogger(logger.Config{
Level: "error",
Output: "stdout",
Type: "json",
})
require.NoError(t, err)
require.NotNil(t, loggerInstance)

casbinCfg := CasbinConfig{
PolicyConfig: PolicyConfig{
Model: modelConf,
Csv: policyCSV,
},
}
enforcer, err := NewCasbinEnforcer(casbinCfg, loggerInstance)
require.NoError(t, err)
require.NotNil(t, enforcer)

// Test 1: Admin can access any resource
allowed, err := enforcer.Enforcer.Enforce("role:admin", "/policy/attributes", "read", "*")
require.NoError(t, err)
require.True(t, allowed, "admin should have access")

// Test 2: User can only access their own resources
allowed, err = enforcer.Enforcer.Enforce("role:user", "/policy/attributes", "read", "user123")
require.NoError(t, err)
require.True(t, allowed, "user should have access to their own resource")

// Test 3: User cannot access other user's resources
allowed, err = enforcer.Enforcer.Enforce("role:user", "/policy/attributes", "read", "user456")
require.NoError(t, err)
require.False(t, allowed, "user should NOT have access to other user's resource")

t.Log("Attribute-based enforcement working correctly")
})

t.Run("interceptor extracts attrs for enforcement", func(t *testing.T) {
mapper := &ProtoAttrMapper{
Allowed: []string{"name", "id"},
RequiredFields: []string{"id"},
Validate: false,
}

msg := &common.IdNameIdentifier{
Id: "user123",
Name: "test-resource",
}

next := func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
v := ctx.Value(casbinContextKey("casbin_attrs"))
require.NotNil(t, v)
attrs, ok := v.(map[string]string)
require.True(t, ok)

// Verify extracted attributes are ready for enforcement
require.Equal(t, "user123", attrs["id"])
require.Equal(t, "test-resource", attrs["name"])

// These attrs can now be passed to Casbin Enforce with extended signature
// e.g., enforcer.Enforce(subject, resource, action, attrs["id"])
return connect.NewResponse[any](nil), nil
}

interceptor := mapper.Interceptor(nil)
wrapped := interceptor(next)

req := connect.NewRequest(msg)
_, err := wrapped(context.Background(), req)
require.NoError(t, err)
})
}
Loading