Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/release-notes/release-notes-0.21.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@

## RPC Updates

* [Enabled rich gRPC status error support in the middleware
API](https://github.com/lightningnetwork/lnd/pull/10458), allowing middleware
to inspect and modify full gRPC error details including error codes, not just
plain error strings.

## lncli Updates

## Breaking Changes
Expand Down
75 changes: 71 additions & 4 deletions rpcperms/middleware_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,26 @@ import (
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/macaroons"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"gopkg.in/macaroon.v2"
)

const (
// StatusTypeNameError is the type name used for plain error strings
// that are not gRPC status errors.
StatusTypeNameError = "error"

// StatusTypeNameStatus is the fully qualified name of the
// google.rpc.Status proto message that is used to represent rich gRPC
// errors with proper error codes and details.
StatusTypeNameStatus = "google.rpc.Status"
)

var (
// ErrShuttingDown is the error that's returned when the server is
// shutting down and a request cannot be served anymore.
Expand Down Expand Up @@ -276,9 +289,20 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
// proto message?
response.replace = true
if requestInfo.request.IsError {
response.replacement = errors.New(
string(t.ReplacementSerialized),
// Check if the original error was a
// rich gRPC status error. If so, we
// need to parse the replacement as a
// Status proto and reconstruct the
// error with the proper gRPC code.
replacement, err := parseErrorReplacement(
requestInfo.request.ProtoTypeName,
t.ReplacementSerialized,
)
if err != nil {
response.err = err
break
}
response.replacement = replacement

break
}
Expand Down Expand Up @@ -432,10 +456,27 @@ func NewMessageInterceptionRequest(ctx context.Context,
req.ProtoTypeName = string(proto.MessageName(t))

case error:
req.ProtoSerialized = []byte(t.Error())
req.ProtoTypeName = "error"
req.IsError = true

// Check if the error is a gRPC status error. If so, we
// serialize the underlying Status proto to allow middleware to
// inspect and modify the full error details including the gRPC
// error code.
st, ok := status.FromError(t)
if ok {
req.ProtoSerialized, err = proto.Marshal(st.Proto())
if err != nil {
return nil, fmt.Errorf("cannot marshal "+
"status proto: %w", err)
}
req.ProtoTypeName = StatusTypeNameStatus
} else {
// Not a gRPC status error, fall back to plain error
// string serialization.
req.ProtoSerialized = []byte(t.Error())
req.ProtoTypeName = StatusTypeNameError
}

default:
return nil, fmt.Errorf("unsupported type for interception "+
"request: %v", m)
Expand Down Expand Up @@ -582,6 +623,32 @@ func parseProto(typeName string, serialized []byte) (proto.Message, error) {
return msg.Interface(), nil
}

// parseErrorReplacement parses a replacement error from its serialized form.
// If the original error was a rich gRPC status error (indicated by the
// StatusTypeNameStatus type name), it will parse the replacement as a
// google.rpc.Status proto and reconstruct a proper gRPC status error.
// Otherwise, it treats the replacement as a plain error string.
func parseErrorReplacement(typeName string, serialized []byte) (error, error) {
// If the original error was a rich gRPC status, parse the replacement
// as a Status proto and reconstruct the error with the proper gRPC
// code and details.
if typeName == StatusTypeNameStatus {
// Unmarshal directly into the google.rpc.Status proto type.
statusProto := &spb.Status{}
if err := proto.Unmarshal(serialized, statusProto); err != nil {
return nil, fmt.Errorf("cannot parse status proto: %w",
err)
}

// Convert the proto back to a gRPC status error.
st := status.FromProto(statusProto)
return st.Err(), nil
}

// For plain error strings, just create a new error from the bytes.
return errors.New(string(serialized)), nil
}

// replaceProtoMsg replaces the given target message with the content of the
// replacement message.
func replaceProtoMsg(target interface{}, replacement interface{}) error {
Expand Down
135 changes: 135 additions & 0 deletions rpcperms/middleware_handler_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package rpcperms

import (
"context"
"encoding/json"
"errors"
"testing"

"github.com/lightningnetwork/lnd/lnrpc"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

// TestReplaceProtoMsg makes sure the proto message replacement works as
Expand Down Expand Up @@ -88,3 +93,133 @@ func jsonEqual(t *testing.T, expected, actual interface{}) {

require.JSONEq(t, string(expectedJSON), string(actualJSON))
}

// TestParseErrorReplacement tests that parseErrorReplacement correctly parses
// both plain error strings and rich gRPC status errors.
func TestParseErrorReplacement(t *testing.T) {
testCases := []struct {
name string
typeName string
serialized []byte
expectedErrMsg string
expectParseErr bool
}{{
name: "plain error string",
typeName: StatusTypeNameError,
serialized: []byte("this is a plain error"),
expectedErrMsg: "this is a plain error",
}, {
name: "empty error string",
typeName: StatusTypeNameError,
serialized: []byte(""),
expectedErrMsg: "",
}, {
name: "invalid status proto",
typeName: StatusTypeNameStatus,
serialized: []byte("not a valid proto"),
expectParseErr: true,
}}

for _, tc := range testCases {
t.Run(tc.name, func(tt *testing.T) {
resultErr, parseErr := parseErrorReplacement(
tc.typeName, tc.serialized,
)

if tc.expectParseErr {
require.Error(tt, parseErr)
return
}

require.NoError(tt, parseErr)
require.Equal(tt, tc.expectedErrMsg, resultErr.Error())
})
}
}

// TestParseErrorReplacementWithStatus tests that parseErrorReplacement
// correctly handles gRPC status errors with proper error codes.
func TestParseErrorReplacementWithStatus(t *testing.T) {
// Create a gRPC status error with a specific code and message.
st := status.New(codes.NotFound, "resource not found")
statusProto := st.Proto()

// Serialize the status proto.
serialized, err := proto.Marshal(statusProto)
require.NoError(t, err)

// Parse it back.
resultErr, parseErr := parseErrorReplacement(
StatusTypeNameStatus, serialized,
)
require.NoError(t, parseErr)
require.Error(t, resultErr)

// Verify we can extract the status back.
resultStatus, ok := status.FromError(resultErr)
require.True(t, ok)
require.Equal(t, codes.NotFound, resultStatus.Code())
require.Equal(t, "resource not found", resultStatus.Message())
}

// TestNewMessageInterceptionRequestWithStatusError tests that
// NewMessageInterceptionRequest correctly serializes gRPC status errors
// as google.rpc.Status protos instead of plain error strings.
func TestNewMessageInterceptionRequestWithStatusError(t *testing.T) {
testCases := []struct {
name string
err error
expectedTypeName string
isStatusError bool
}{{
name: "plain error",
err: errors.New("this is a plain error"),
expectedTypeName: StatusTypeNameError,
isStatusError: false,
}, {
name: "gRPC status error",
err: status.Error(codes.NotFound, "resource not found"),
expectedTypeName: StatusTypeNameStatus,
isStatusError: true,
}, {
name: "gRPC status error with different code",
err: status.Error(codes.PermissionDenied, "access denied"),
expectedTypeName: StatusTypeNameStatus,
isStatusError: true,
}}

for _, tc := range testCases {
t.Run(tc.name, func(tt *testing.T) {
ctx := context.Background()
req, err := NewMessageInterceptionRequest(
ctx, TypeResponse, false, "/test/Method",
tc.err,
)
require.NoError(tt, err)
require.True(tt, req.IsError)
require.Equal(tt, tc.expectedTypeName, req.ProtoTypeName)

if tc.isStatusError {
// Verify we can parse the status back.
resultErr, parseErr := parseErrorReplacement(
req.ProtoTypeName, req.ProtoSerialized,
)
require.NoError(tt, parseErr)

// Verify the error code is preserved.
resultStatus, ok := status.FromError(resultErr)
require.True(tt, ok)

originalStatus, _ := status.FromError(tc.err)
require.Equal(
tt, originalStatus.Code(),
resultStatus.Code(),
)
require.Equal(
tt, originalStatus.Message(),
resultStatus.Message(),
)
}
})
}
}