Skip to content

Commit 43b78df

Browse files
committed
rpcperms: enable rich gRPC status errors in middleware API
1 parent 91423ee commit 43b78df

2 files changed

Lines changed: 206 additions & 4 deletions

File tree

rpcperms/middleware_handler.go

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,26 @@ import (
1212
"github.com/btcsuite/btcd/chaincfg"
1313
"github.com/lightningnetwork/lnd/lnrpc"
1414
"github.com/lightningnetwork/lnd/macaroons"
15+
spb "google.golang.org/genproto/googleapis/rpc/status"
1516
"google.golang.org/grpc/metadata"
17+
"google.golang.org/grpc/status"
1618
"google.golang.org/protobuf/proto"
1719
"google.golang.org/protobuf/reflect/protoreflect"
1820
"google.golang.org/protobuf/reflect/protoregistry"
1921
"gopkg.in/macaroon.v2"
2022
)
2123

24+
const (
25+
// StatusTypeNameError is the type name used for plain error strings
26+
// that are not gRPC status errors.
27+
StatusTypeNameError = "error"
28+
29+
// StatusTypeNameStatus is the fully qualified name of the
30+
// google.rpc.Status proto message that is used to represent rich gRPC
31+
// errors with proper error codes and details.
32+
StatusTypeNameStatus = "google.rpc.Status"
33+
)
34+
2235
var (
2336
// ErrShuttingDown is the error that's returned when the server is
2437
// shutting down and a request cannot be served anymore.
@@ -276,9 +289,20 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
276289
// proto message?
277290
response.replace = true
278291
if requestInfo.request.IsError {
279-
response.replacement = errors.New(
280-
string(t.ReplacementSerialized),
292+
// Check if the original error was a
293+
// rich gRPC status error. If so, we
294+
// need to parse the replacement as a
295+
// Status proto and reconstruct the
296+
// error with the proper gRPC code.
297+
replacement, err := parseErrorReplacement(
298+
requestInfo.request.ProtoTypeName,
299+
t.ReplacementSerialized,
281300
)
301+
if err != nil {
302+
response.err = err
303+
break
304+
}
305+
response.replacement = replacement
282306

283307
break
284308
}
@@ -432,10 +456,27 @@ func NewMessageInterceptionRequest(ctx context.Context,
432456
req.ProtoTypeName = string(proto.MessageName(t))
433457

434458
case error:
435-
req.ProtoSerialized = []byte(t.Error())
436-
req.ProtoTypeName = "error"
437459
req.IsError = true
438460

461+
// Check if the error is a gRPC status error. If so, we
462+
// serialize the underlying Status proto to allow middleware to
463+
// inspect and modify the full error details including the gRPC
464+
// error code.
465+
st, ok := status.FromError(t)
466+
if ok {
467+
req.ProtoSerialized, err = proto.Marshal(st.Proto())
468+
if err != nil {
469+
return nil, fmt.Errorf("cannot marshal "+
470+
"status proto: %w", err)
471+
}
472+
req.ProtoTypeName = StatusTypeNameStatus
473+
} else {
474+
// Not a gRPC status error, fall back to plain error
475+
// string serialization.
476+
req.ProtoSerialized = []byte(t.Error())
477+
req.ProtoTypeName = StatusTypeNameError
478+
}
479+
439480
default:
440481
return nil, fmt.Errorf("unsupported type for interception "+
441482
"request: %v", m)
@@ -582,6 +623,32 @@ func parseProto(typeName string, serialized []byte) (proto.Message, error) {
582623
return msg.Interface(), nil
583624
}
584625

626+
// parseErrorReplacement parses a replacement error from its serialized form.
627+
// If the original error was a rich gRPC status error (indicated by the
628+
// StatusTypeNameStatus type name), it will parse the replacement as a
629+
// google.rpc.Status proto and reconstruct a proper gRPC status error.
630+
// Otherwise, it treats the replacement as a plain error string.
631+
func parseErrorReplacement(typeName string, serialized []byte) (error, error) {
632+
// If the original error was a rich gRPC status, parse the replacement
633+
// as a Status proto and reconstruct the error with the proper gRPC
634+
// code and details.
635+
if typeName == StatusTypeNameStatus {
636+
// Unmarshal directly into the google.rpc.Status proto type.
637+
statusProto := &spb.Status{}
638+
if err := proto.Unmarshal(serialized, statusProto); err != nil {
639+
return nil, fmt.Errorf("cannot parse status proto: %w",
640+
err)
641+
}
642+
643+
// Convert the proto back to a gRPC status error.
644+
st := status.FromProto(statusProto)
645+
return st.Err(), nil
646+
}
647+
648+
// For plain error strings, just create a new error from the bytes.
649+
return errors.New(string(serialized)), nil
650+
}
651+
585652
// replaceProtoMsg replaces the given target message with the content of the
586653
// replacement message.
587654
func replaceProtoMsg(target interface{}, replacement interface{}) error {

rpcperms/middleware_handler_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package rpcperms
22

33
import (
4+
"context"
45
"encoding/json"
6+
"errors"
57
"testing"
68

79
"github.com/lightningnetwork/lnd/lnrpc"
810
"github.com/stretchr/testify/require"
11+
"google.golang.org/grpc/codes"
12+
"google.golang.org/grpc/status"
13+
"google.golang.org/protobuf/proto"
914
)
1015

1116
// TestReplaceProtoMsg makes sure the proto message replacement works as
@@ -88,3 +93,133 @@ func jsonEqual(t *testing.T, expected, actual interface{}) {
8893

8994
require.JSONEq(t, string(expectedJSON), string(actualJSON))
9095
}
96+
97+
// TestParseErrorReplacement tests that parseErrorReplacement correctly parses
98+
// both plain error strings and rich gRPC status errors.
99+
func TestParseErrorReplacement(t *testing.T) {
100+
testCases := []struct {
101+
name string
102+
typeName string
103+
serialized []byte
104+
expectedErrMsg string
105+
expectParseErr bool
106+
}{{
107+
name: "plain error string",
108+
typeName: StatusTypeNameError,
109+
serialized: []byte("this is a plain error"),
110+
expectedErrMsg: "this is a plain error",
111+
}, {
112+
name: "empty error string",
113+
typeName: StatusTypeNameError,
114+
serialized: []byte(""),
115+
expectedErrMsg: "",
116+
}, {
117+
name: "invalid status proto",
118+
typeName: StatusTypeNameStatus,
119+
serialized: []byte("not a valid proto"),
120+
expectParseErr: true,
121+
}}
122+
123+
for _, tc := range testCases {
124+
t.Run(tc.name, func(tt *testing.T) {
125+
resultErr, parseErr := parseErrorReplacement(
126+
tc.typeName, tc.serialized,
127+
)
128+
129+
if tc.expectParseErr {
130+
require.Error(tt, parseErr)
131+
return
132+
}
133+
134+
require.NoError(tt, parseErr)
135+
require.Equal(tt, tc.expectedErrMsg, resultErr.Error())
136+
})
137+
}
138+
}
139+
140+
// TestParseErrorReplacementWithStatus tests that parseErrorReplacement
141+
// correctly handles gRPC status errors with proper error codes.
142+
func TestParseErrorReplacementWithStatus(t *testing.T) {
143+
// Create a gRPC status error with a specific code and message.
144+
st := status.New(codes.NotFound, "resource not found")
145+
statusProto := st.Proto()
146+
147+
// Serialize the status proto.
148+
serialized, err := proto.Marshal(statusProto)
149+
require.NoError(t, err)
150+
151+
// Parse it back.
152+
resultErr, parseErr := parseErrorReplacement(
153+
StatusTypeNameStatus, serialized,
154+
)
155+
require.NoError(t, parseErr)
156+
require.Error(t, resultErr)
157+
158+
// Verify we can extract the status back.
159+
resultStatus, ok := status.FromError(resultErr)
160+
require.True(t, ok)
161+
require.Equal(t, codes.NotFound, resultStatus.Code())
162+
require.Equal(t, "resource not found", resultStatus.Message())
163+
}
164+
165+
// TestNewMessageInterceptionRequestWithStatusError tests that
166+
// NewMessageInterceptionRequest correctly serializes gRPC status errors
167+
// as google.rpc.Status protos instead of plain error strings.
168+
func TestNewMessageInterceptionRequestWithStatusError(t *testing.T) {
169+
testCases := []struct {
170+
name string
171+
err error
172+
expectedTypeName string
173+
isStatusError bool
174+
}{{
175+
name: "plain error",
176+
err: errors.New("this is a plain error"),
177+
expectedTypeName: StatusTypeNameError,
178+
isStatusError: false,
179+
}, {
180+
name: "gRPC status error",
181+
err: status.Error(codes.NotFound, "resource not found"),
182+
expectedTypeName: StatusTypeNameStatus,
183+
isStatusError: true,
184+
}, {
185+
name: "gRPC status error with different code",
186+
err: status.Error(codes.PermissionDenied, "access denied"),
187+
expectedTypeName: StatusTypeNameStatus,
188+
isStatusError: true,
189+
}}
190+
191+
for _, tc := range testCases {
192+
t.Run(tc.name, func(tt *testing.T) {
193+
ctx := context.Background()
194+
req, err := NewMessageInterceptionRequest(
195+
ctx, TypeResponse, false, "/test/Method",
196+
tc.err,
197+
)
198+
require.NoError(tt, err)
199+
require.True(tt, req.IsError)
200+
require.Equal(tt, tc.expectedTypeName, req.ProtoTypeName)
201+
202+
if tc.isStatusError {
203+
// Verify we can parse the status back.
204+
resultErr, parseErr := parseErrorReplacement(
205+
req.ProtoTypeName, req.ProtoSerialized,
206+
)
207+
require.NoError(tt, parseErr)
208+
209+
// Verify the error code is preserved.
210+
resultStatus, ok := status.FromError(resultErr)
211+
require.True(tt, ok)
212+
213+
originalStatus, _ := status.FromError(tc.err)
214+
require.Equal(
215+
tt, originalStatus.Code(),
216+
resultStatus.Code(),
217+
)
218+
require.Equal(
219+
tt, originalStatus.Message(),
220+
resultStatus.Message(),
221+
)
222+
}
223+
})
224+
}
225+
}

0 commit comments

Comments
 (0)