Skip to content

Commit 0f7d182

Browse files
committed
Support for session cookies and PKCS8 private keys
If set in response headers, allow setting a cookie named "session" to enable session persistence in load balancer routing. Also: added support for PKCS8 format private keys. Also fixed a case where multiple goroutines running parallel requests to an older server might result in unsupported protocol errors.
1 parent 72e9f6b commit 0f7d182

File tree

5 files changed

+95
-37
lines changed

5 files changed

+95
-37
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@ All notable changes to this project will be documented in this file.
33

44
The format is based on [Keep a Changelog](http://keepachangelog.com/).
55

6+
## Unreleased
7+
8+
### Added
9+
- Support for session persistence. If a Set-Cookie HTTP header is present the SDK will now set a Cookie header using the requested session value.
10+
- Support for PKCS8 format private keys.
11+
612
## 1.3.0 - 2022-02-24
713

814
### Added

nosqldb/auth/iam/helpers.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ func PrivateKeyFromBytesWithPassword(pemData, password []byte) (key *rsa.Private
4343
}
4444

4545
key, e = x509.ParsePKCS1PrivateKey(decrypted)
46+
if e != nil {
47+
e = nil
48+
parseResult, e := x509.ParsePKCS8PrivateKey(decrypted)
49+
if e == nil {
50+
key = parseResult.(*rsa.PrivateKey)
51+
}
52+
}
4653

4754
} else {
4855
e = fmt.Errorf("PEM data was not found in buffer")

nosqldb/bad_protocol_test.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ func seekPos(lengths []int, fieldOff int) (off int) {
162162
}
163163

164164
func (suite *BadProtocolTestSuite) doBadProtoTest(req nosqldb.Request, data []byte, desc string, expectErrCode nosqlerr.ErrorCode) {
165-
_, err := suite.bpTestClient.DoExecute(context.Background(), req, data)
165+
serialVerUsed := suite.bpTestClient.GetSerialVersion()
166+
_, err := suite.bpTestClient.DoExecute(context.Background(), req, data, serialVerUsed)
166167
switch expectErrCode {
167168
case nosqlerr.NoError:
168169
suite.NoErrorf(err, "%q should have succeeded, got error %v", desc, err)
@@ -173,7 +174,8 @@ func (suite *BadProtocolTestSuite) doBadProtoTest(req nosqldb.Request, data []by
173174
}
174175

175176
func (suite *BadProtocolTestSuite) doBadProtoTest2(req nosqldb.Request, data []byte, desc string, expectErrCode1 nosqlerr.ErrorCode, expectErrCode2 nosqlerr.ErrorCode) {
176-
_, err := suite.bpTestClient.DoExecute(context.Background(), req, data)
177+
serialVerUsed := suite.bpTestClient.GetSerialVersion()
178+
_, err := suite.bpTestClient.DoExecute(context.Background(), req, data, serialVerUsed)
177179
suite.Truef((nosqlerr.Is(err, expectErrCode1) || nosqlerr.Is(err, expectErrCode2)),
178180
"%q failed, got error %v, want error %s or %s", desc, err, expectErrCode1, expectErrCode2)
179181
}
@@ -185,7 +187,7 @@ func (suite *BadProtocolTestSuite) TestBadGetRequest() {
185187
Key: suite.key,
186188
}
187189

188-
data, err := suite.bpTestClient.ProcessRequest(req)
190+
data, _, err := suite.bpTestClient.ProcessRequest(req)
189191
suite.Require().NoError(err)
190192
origData := make([]byte, len(data))
191193
copy(origData, data)
@@ -259,7 +261,7 @@ func (suite *BadProtocolTestSuite) TestBadGetIndexesRequest() {
259261
IndexName: suite.index,
260262
}
261263

262-
data, err := suite.bpTestClient.ProcessRequest(req)
264+
data, _, err := suite.bpTestClient.ProcessRequest(req)
263265
suite.Require().NoError(err)
264266
origData := make([]byte, len(data))
265267
copy(origData, data)
@@ -300,7 +302,7 @@ func (suite *BadProtocolTestSuite) TestBadGetTableRequest() {
300302
TableName: suite.table,
301303
}
302304

303-
data, err := suite.bpTestClient.ProcessRequest(req)
305+
data, _, err := suite.bpTestClient.ProcessRequest(req)
304306
suite.Require().NoError(err)
305307
origData := make([]byte, len(data))
306308
copy(origData, data)
@@ -336,7 +338,7 @@ func (suite *BadProtocolTestSuite) TestBadListTablesRequest() {
336338
Namespace: ns,
337339
}
338340

339-
data, err := suite.bpTestClient.ProcessRequest(req)
341+
data, _, err := suite.bpTestClient.ProcessRequest(req)
340342
suite.Require().NoError(err)
341343
origData := make([]byte, len(data))
342344
copy(origData, data)
@@ -387,7 +389,7 @@ func (suite *BadProtocolTestSuite) TestBadListTablesRequest() {
387389
func (suite *BadProtocolTestSuite) TestBadPrepareRequest() {
388390
stmt := "select * from " + suite.table
389391
req := &nosqldb.PrepareRequest{Statement: stmt}
390-
data, err := suite.bpTestClient.ProcessRequest(req)
392+
data, _, err := suite.bpTestClient.ProcessRequest(req)
391393
suite.Require().NoError(err)
392394
origData := make([]byte, len(data))
393395
copy(origData, data)
@@ -495,7 +497,7 @@ func (suite *BadProtocolTestSuite) TestBadQueryRequest() {
495497
2, // VariableValue: INT_TYPE + packed int
496498
}
497499

498-
data, err := suite.bpTestClient.ProcessRequest(req)
500+
data, _, err := suite.bpTestClient.ProcessRequest(req)
499501
suite.Require().NoError(err)
500502
origData := make([]byte, len(data))
501503
copy(origData, data)
@@ -591,7 +593,7 @@ func (suite *BadProtocolTestSuite) TestBadPutRequest() {
591593
ttlLen, // TTL: value(packed long) + unit(byte)
592594
}
593595

594-
data, err := suite.bpTestClient.ProcessRequest(req)
596+
data, _, err := suite.bpTestClient.ProcessRequest(req)
595597
suite.Require().NoError(err)
596598
origData := make([]byte, len(data))
597599
copy(origData, data)
@@ -673,7 +675,7 @@ func (suite *BadProtocolTestSuite) TestBadDeleteRequest() {
673675
0, // MatchVersion: bytes
674676
}
675677

676-
data, err := suite.bpTestClient.ProcessRequest(req)
678+
data, _, err := suite.bpTestClient.ProcessRequest(req)
677679
suite.Require().NoError(err)
678680
origData := make([]byte, len(data))
679681
copy(origData, data)
@@ -730,7 +732,7 @@ func (suite *BadProtocolTestSuite) TestBadWriteMultipleRequest() {
730732
0, // Sub requests: the size does not matter for this test.
731733
}
732734

733-
data, err := suite.bpTestClient.ProcessRequest(req)
735+
data, _, err := suite.bpTestClient.ProcessRequest(req)
734736
suite.Require().NoError(err)
735737
origData := make([]byte, len(data))
736738
copy(origData, data)
@@ -790,7 +792,7 @@ func (suite *BadProtocolTestSuite) TestBadMultiDeleteRequest() {
790792
21, // ContinuationKey: byte array
791793
}
792794

793-
data, err := suite.bpTestClient.ProcessRequest(req)
795+
data, _, err := suite.bpTestClient.ProcessRequest(req)
794796
suite.Require().NoError(err)
795797
origData := make([]byte, len(data))
796798
copy(origData, data)
@@ -866,7 +868,7 @@ func (suite *BadProtocolTestSuite) TestBadTableRequest() {
866868
1, // HasTableName: boolean
867869
}
868870

869-
data, err := suite.bpTestClient.ProcessRequest(req)
871+
data, _, err := suite.bpTestClient.ProcessRequest(req)
870872
suite.Require().NoError(err)
871873
suite.AddToTables(newTable)
872874
origData := make([]byte, len(data))
@@ -926,7 +928,7 @@ func (suite *BadProtocolTestSuite) TestBadSystemRequest() {
926928
stmtLen, // Statement: string
927929
}
928930

929-
data, err := suite.bpTestClient.ProcessRequest(req)
931+
data, _, err := suite.bpTestClient.ProcessRequest(req)
930932
suite.Require().NoError(err)
931933
origData := make([]byte, len(data))
932934
copy(origData, data)
@@ -976,7 +978,7 @@ func (suite *BadProtocolTestSuite) TestBadSystemStatusRequest() {
976978
stmtLen, // Statement: string
977979
}
978980

979-
data, err := suite.bpTestClient.ProcessRequest(req)
981+
data, _, err := suite.bpTestClient.ProcessRequest(req)
980982
suite.Require().NoError(err)
981983
origData := make([]byte, len(data))
982984
copy(origData, data)

nosqldb/client.go

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ type Client struct {
8686

8787
// for managing one-time messaging
8888
oneTimeMessages map[string]struct{}
89+
90+
// sessionStr represents a session cookie to use, if non-nil
91+
sessionStr string
92+
93+
// for generic locking
94+
lockMux sync.Mutex
8995
}
9096

9197
var (
@@ -97,6 +103,8 @@ var (
97103
const (
98104
// LimiterRefreshNanos is used to update table limits once every 10 minutes
99105
LimiterRefreshNanos int64 = 600 * 1000 * 1000 * 1000
106+
// SessionCookieField is used to check for persistent session cookies
107+
SessionCookieField string = "session="
100108
)
101109

102110
// NewClient creates a Client instance with the specified Config.
@@ -794,9 +802,9 @@ func (c *Client) nextRequestID() int32 {
794802
// processRequest processes the specified request before it is sent to server.
795803
// This method applies default configurations such as timeout and consistency
796804
// values for the request if they are not specified for the request.
797-
func (c *Client) processRequest(req Request) (data []byte, err error) {
805+
func (c *Client) processRequest(req Request) (data []byte, serialVerUsed int16, err error) {
798806
if req == nil {
799-
return nil, errNilRequest
807+
return nil, 0, errNilRequest
800808
}
801809

802810
// Set default values for the request with the global request configurations
@@ -806,17 +814,17 @@ func (c *Client) processRequest(req Request) (data []byte, err error) {
806814

807815
// Validates the request, returns immediately if validation fails.
808816
if err = req.validate(); err != nil {
809-
return nil, err
817+
return nil, 0, err
810818
}
811819

812-
data, err = c.serializeRequest(req)
820+
data, serialVerUsed, err = c.serializeRequest(req)
813821
if err != nil || !c.isCloud {
814822
return
815823
}
816824

817825
// check request size for cloud
818826
if err = checkRequestSizeLimit(req, len(data)); err != nil {
819-
return nil, err
827+
return nil, 0, err
820828
}
821829

822830
return
@@ -831,15 +839,15 @@ func (c *Client) execute(req Request) (Result, error) {
831839
}
832840

833841
func (c *Client) executeWithContext(ctx context.Context, req Request) (Result, error) {
834-
data, err := c.processRequest(req)
842+
data, serialVerUsed, err := c.processRequest(req)
835843
if err != nil {
836844
return nil, err
837845
}
838846

839-
return c.doExecute(ctx, req, data)
847+
return c.doExecute(ctx, req, data, serialVerUsed)
840848
}
841849

842-
func (c *Client) doExecute(ctx context.Context, req Request, data []byte) (result Result, err error) {
850+
func (c *Client) doExecute(ctx context.Context, req Request, data []byte, serialVerUsed int16) (result Result, err error) {
843851
if req == nil {
844852
return nil, errNilRequest
845853
}
@@ -962,11 +970,11 @@ func (c *Client) doExecute(ctx context.Context, req Request, data []byte) (resul
962970
}
963971

964972
if nosqlerr.Is(err, nosqlerr.UnsupportedProtocol) {
965-
if c.decrementSerialVersion() == false {
973+
if c.decrementSerialVersion(serialVerUsed) == false {
966974
return nil, err
967975
}
968976
// if serial version mismatch, we must re-serialize the request
969-
data, err = c.serializeRequest(req)
977+
data, serialVerUsed, err = c.serializeRequest(req)
970978
if err != nil {
971979
return nil, err
972980
}
@@ -1048,14 +1056,19 @@ func (c *Client) doExecute(ctx context.Context, req Request, data []byte) (resul
10481056
httpReq.Header.Set("Authorization", authStr)
10491057
}
10501058

1059+
// Allow for session persistence, if available
1060+
if c.sessionStr != "" {
1061+
httpReq.Header.Set("Cookie", c.sessionStr)
1062+
}
1063+
10511064
err = c.signHTTPRequest(httpReq)
10521065
if err != nil {
10531066
return nil, err
10541067
}
10551068

10561069
// warn if using features not implemented at the connected server
10571070
// currently cloud does not support Durability
1058-
if c.serialVersion < 3 || c.isCloud {
1071+
if serialVerUsed < 3 || c.isCloud {
10591072
needMsg := false
10601073
if pReq, ok := req.(*PutRequest); ok && pReq.Durability.IsSet() {
10611074
needMsg = true
@@ -1073,7 +1086,7 @@ func (c *Client) doExecute(ctx context.Context, req Request, data []byte) (resul
10731086
}
10741087

10751088
// OnDemand is not available in V2
1076-
if c.serialVersion < 3 {
1089+
if serialVerUsed < 3 {
10771090
if tReq, ok := req.(*TableRequest); ok && tReq.TableLimits != nil {
10781091
if tReq.TableLimits.CapacityMode == types.OnDemand {
10791092
c.oneTimeMessage("The requested feature is not supported " +
@@ -1334,17 +1347,18 @@ func (c *Client) signHTTPRequest(httpReq *http.Request) error {
13341347
// serializeRequest serializes the specified request into a slice of bytes that
13351348
// will be sent to the server. The serial version is always written followed by
13361349
// the actual request payload.
1337-
func (c *Client) serializeRequest(req Request) (data []byte, err error) {
1350+
func (c *Client) serializeRequest(req Request) (data []byte, serialVerUsed int16, err error) {
13381351
wr := binary.NewWriter()
1339-
if _, err = wr.WriteSerialVersion(c.serialVersion); err != nil {
1340-
return
1352+
serialVerUsed = c.serialVersion
1353+
if _, err = wr.WriteSerialVersion(serialVerUsed); err != nil {
1354+
return nil, 0, err
13411355
}
13421356

1343-
if err = req.serialize(wr, c.serialVersion); err != nil {
1344-
return
1357+
if err = req.serialize(wr, serialVerUsed); err != nil {
1358+
return nil, 0, err
13451359
}
13461360

1347-
return wr.Bytes(), nil
1361+
return wr.Bytes(), serialVerUsed, nil
13481362
}
13491363

13501364
// processResponse processes the http response returned from server.
@@ -1360,6 +1374,7 @@ func (c *Client) processResponse(httpResp *http.Response, req Request) (Result,
13601374
}
13611375

13621376
if httpResp.StatusCode == http.StatusOK {
1377+
c.setSessionCookie(httpResp.Header)
13631378
return c.processOKResponse(data, req)
13641379
}
13651380

@@ -1402,6 +1417,29 @@ func (c *Client) processOKResponse(data []byte, req Request) (Result, error) {
14021417
return nil, wrapResponseErrors(int(code), msg)
14031418
}
14041419

1420+
// setSessionCookie sets a persistent session cookie value to use for
1421+
// following requests, if present in the response header.
1422+
func (c *Client) setSessionCookie(header http.Header) {
1423+
if header == nil {
1424+
return
1425+
}
1426+
// NOTE: this code assumes there will always be at most
1427+
// one Set-Cookie header in the response. If the load balancer
1428+
// settings change, or the proxy changes to add Set-Cookie
1429+
// headers, this code may need to be changed to look for
1430+
// multiple Set-Cookie headers.
1431+
v := header.Get("Set-Cookie")
1432+
if strings.HasPrefix(v, SessionCookieField) == false {
1433+
return
1434+
}
1435+
c.lockMux.Lock()
1436+
defer c.lockMux.Unlock()
1437+
c.sessionStr = strings.Split(v, ";")[0]
1438+
c.logger.LogWithFn(logger.Fine, func() string {
1439+
return fmt.Sprintf("Set session cookie to \"%s\"", c.sessionStr)
1440+
})
1441+
}
1442+
14051443
// processNotOKResponse processes the http response whose status code is not 200.
14061444
func (c *Client) processNotOKResponse(data []byte, statusCode int) error {
14071445
if statusCode == http.StatusBadRequest && len(data) > 0 {
@@ -1502,7 +1540,12 @@ func (c *Client) VerifyConnection() error {
15021540
// decrementSerialVersion attempts to reduce the serial version used for
15031541
// communicating with the server. If the version is already at its lowest
15041542
// value, it will not be decremented and false will be returned.
1505-
func (c *Client) decrementSerialVersion() bool {
1543+
func (c *Client) decrementSerialVersion(serialVerUsed int16) bool {
1544+
c.lockMux.Lock()
1545+
defer c.lockMux.Unlock()
1546+
if c.serialVersion != serialVerUsed {
1547+
return true
1548+
}
15061549
if c.serialVersion > 2 {
15071550
c.serialVersion--
15081551
return true

nosqldb/export_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ func (c *Client) SetResponseHandler(fn HandleResponse) {
2020
c.handleResponse = fn
2121
}
2222

23-
func (c *Client) ProcessRequest(req Request) (data []byte, err error) {
23+
func (c *Client) ProcessRequest(req Request) (data []byte, serialVerUsed int16, err error) {
2424
return c.processRequest(req)
2525
}
2626

27-
func (c *Client) DoExecute(ctx context.Context, req Request, data []byte) (Result, error) {
28-
return c.doExecute(ctx, req, data)
27+
func (c *Client) DoExecute(ctx context.Context, req Request, data []byte, serialVerUsed int16) (Result, error) {
28+
return c.doExecute(ctx, req, data, serialVerUsed)
2929
}
3030

3131
func (p *PreparedStatement) GetStatement() []byte {

0 commit comments

Comments
 (0)