|
1 | 1 | package rpcperms |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "context" |
4 | 5 | "encoding/json" |
| 6 | + "errors" |
5 | 7 | "testing" |
6 | 8 |
|
7 | 9 | "github.com/lightningnetwork/lnd/lnrpc" |
8 | 10 | "github.com/stretchr/testify/require" |
| 11 | + "google.golang.org/grpc/codes" |
| 12 | + "google.golang.org/grpc/status" |
| 13 | + "google.golang.org/protobuf/proto" |
9 | 14 | ) |
10 | 15 |
|
11 | 16 | // TestReplaceProtoMsg makes sure the proto message replacement works as |
@@ -88,3 +93,133 @@ func jsonEqual(t *testing.T, expected, actual interface{}) { |
88 | 93 |
|
89 | 94 | require.JSONEq(t, string(expectedJSON), string(actualJSON)) |
90 | 95 | } |
| 96 | + |
| 97 | +// TestParseErrorReplacement tests that parseErrorReplacement correctly parses |
| 98 | +// both plain error strings and rich gRPC status errors. |
| 99 | +func TestParseErrorReplacement(t *testing.T) { |
| 100 | + testCases := []struct { |
| 101 | + name string |
| 102 | + typeName string |
| 103 | + serialized []byte |
| 104 | + expectedErrMsg string |
| 105 | + expectParseErr bool |
| 106 | + }{{ |
| 107 | + name: "plain error string", |
| 108 | + typeName: StatusTypeNameError, |
| 109 | + serialized: []byte("this is a plain error"), |
| 110 | + expectedErrMsg: "this is a plain error", |
| 111 | + }, { |
| 112 | + name: "empty error string", |
| 113 | + typeName: StatusTypeNameError, |
| 114 | + serialized: []byte(""), |
| 115 | + expectedErrMsg: "", |
| 116 | + }, { |
| 117 | + name: "invalid status proto", |
| 118 | + typeName: StatusTypeNameStatus, |
| 119 | + serialized: []byte("not a valid proto"), |
| 120 | + expectParseErr: true, |
| 121 | + }} |
| 122 | + |
| 123 | + for _, tc := range testCases { |
| 124 | + t.Run(tc.name, func(tt *testing.T) { |
| 125 | + resultErr, parseErr := parseErrorReplacement( |
| 126 | + tc.typeName, tc.serialized, |
| 127 | + ) |
| 128 | + |
| 129 | + if tc.expectParseErr { |
| 130 | + require.Error(tt, parseErr) |
| 131 | + return |
| 132 | + } |
| 133 | + |
| 134 | + require.NoError(tt, parseErr) |
| 135 | + require.Equal(tt, tc.expectedErrMsg, resultErr.Error()) |
| 136 | + }) |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +// TestParseErrorReplacementWithStatus tests that parseErrorReplacement |
| 141 | +// correctly handles gRPC status errors with proper error codes. |
| 142 | +func TestParseErrorReplacementWithStatus(t *testing.T) { |
| 143 | + // Create a gRPC status error with a specific code and message. |
| 144 | + st := status.New(codes.NotFound, "resource not found") |
| 145 | + statusProto := st.Proto() |
| 146 | + |
| 147 | + // Serialize the status proto. |
| 148 | + serialized, err := proto.Marshal(statusProto) |
| 149 | + require.NoError(t, err) |
| 150 | + |
| 151 | + // Parse it back. |
| 152 | + resultErr, parseErr := parseErrorReplacement( |
| 153 | + StatusTypeNameStatus, serialized, |
| 154 | + ) |
| 155 | + require.NoError(t, parseErr) |
| 156 | + require.Error(t, resultErr) |
| 157 | + |
| 158 | + // Verify we can extract the status back. |
| 159 | + resultStatus, ok := status.FromError(resultErr) |
| 160 | + require.True(t, ok) |
| 161 | + require.Equal(t, codes.NotFound, resultStatus.Code()) |
| 162 | + require.Equal(t, "resource not found", resultStatus.Message()) |
| 163 | +} |
| 164 | + |
| 165 | +// TestNewMessageInterceptionRequestWithStatusError tests that |
| 166 | +// NewMessageInterceptionRequest correctly serializes gRPC status errors |
| 167 | +// as google.rpc.Status protos instead of plain error strings. |
| 168 | +func TestNewMessageInterceptionRequestWithStatusError(t *testing.T) { |
| 169 | + testCases := []struct { |
| 170 | + name string |
| 171 | + err error |
| 172 | + expectedTypeName string |
| 173 | + isStatusError bool |
| 174 | + }{{ |
| 175 | + name: "plain error", |
| 176 | + err: errors.New("this is a plain error"), |
| 177 | + expectedTypeName: StatusTypeNameError, |
| 178 | + isStatusError: false, |
| 179 | + }, { |
| 180 | + name: "gRPC status error", |
| 181 | + err: status.Error(codes.NotFound, "resource not found"), |
| 182 | + expectedTypeName: StatusTypeNameStatus, |
| 183 | + isStatusError: true, |
| 184 | + }, { |
| 185 | + name: "gRPC status error with different code", |
| 186 | + err: status.Error(codes.PermissionDenied, "access denied"), |
| 187 | + expectedTypeName: StatusTypeNameStatus, |
| 188 | + isStatusError: true, |
| 189 | + }} |
| 190 | + |
| 191 | + for _, tc := range testCases { |
| 192 | + t.Run(tc.name, func(tt *testing.T) { |
| 193 | + ctx := context.Background() |
| 194 | + req, err := NewMessageInterceptionRequest( |
| 195 | + ctx, TypeResponse, false, "/test/Method", |
| 196 | + tc.err, |
| 197 | + ) |
| 198 | + require.NoError(tt, err) |
| 199 | + require.True(tt, req.IsError) |
| 200 | + require.Equal(tt, tc.expectedTypeName, req.ProtoTypeName) |
| 201 | + |
| 202 | + if tc.isStatusError { |
| 203 | + // Verify we can parse the status back. |
| 204 | + resultErr, parseErr := parseErrorReplacement( |
| 205 | + req.ProtoTypeName, req.ProtoSerialized, |
| 206 | + ) |
| 207 | + require.NoError(tt, parseErr) |
| 208 | + |
| 209 | + // Verify the error code is preserved. |
| 210 | + resultStatus, ok := status.FromError(resultErr) |
| 211 | + require.True(tt, ok) |
| 212 | + |
| 213 | + originalStatus, _ := status.FromError(tc.err) |
| 214 | + require.Equal( |
| 215 | + tt, originalStatus.Code(), |
| 216 | + resultStatus.Code(), |
| 217 | + ) |
| 218 | + require.Equal( |
| 219 | + tt, originalStatus.Message(), |
| 220 | + resultStatus.Message(), |
| 221 | + ) |
| 222 | + } |
| 223 | + }) |
| 224 | + } |
| 225 | +} |
0 commit comments