Skip to content

Commit b89f42d

Browse files
GT-292 Fix reusing same connection with different Authentication parameters (#452)
* Add test-case where same connection re-used with different Authentication params * Add more comments of reusing same connection with different Auhtentication parameters * Fix reusing same connection with different Authentication parameters passed via driver.NewClient
1 parent 875dd62 commit b89f42d

File tree

6 files changed

+118
-12
lines changed

6 files changed

+118
-12
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## [master](https://github.com/arangodb/go-driver/tree/master) (N/A)
44
- Add support for `checksum` in Collections
5+
- Fix reusing same connection with different Authentication parameters passed via driver.NewClient
56

67
## [1.4.0](https://github.com/arangodb/go-driver/tree/v1.4.0) (2022-10-04)
78
- Add `hex` property to analyzer's properties

cluster/cluster.go

+13-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ type ServerConnectionBuilder func(endpoint string) (driver.Connection, error)
5252
// The given connections are existing connections to each of the servers.
5353
func NewConnection(config ConnectionConfig, connectionBuilder ServerConnectionBuilder, endpoints []string) (driver.Connection, error) {
5454
if connectionBuilder == nil {
55-
return nil, driver.WithStack(driver.InvalidArgumentError{Message: "Must a connection builder"})
55+
return nil, driver.WithStack(driver.InvalidArgumentError{Message: "Must provide a connection builder"})
5656
}
5757
if len(endpoints) == 0 {
5858
return nil, driver.WithStack(driver.InvalidArgumentError{Message: "Must provide at least 1 endpoint"})
@@ -285,7 +285,7 @@ func (c *clusterConnection) UpdateEndpoints(endpoints []string) error {
285285
return nil
286286
}
287287

288-
// Configure the authentication used for this connection.
288+
// SetAuthentication creates a copy of connection wrapper for given auth parameters.
289289
func (c *clusterConnection) SetAuthentication(auth driver.Authentication) (driver.Connection, error) {
290290
c.mutex.Lock()
291291
defer c.mutex.Unlock()
@@ -300,11 +300,20 @@ func (c *clusterConnection) SetAuthentication(auth driver.Authentication) (drive
300300
newServerConnections[i] = authConn
301301
}
302302

303-
// Save authentication
303+
// These two lines are not required for normal work but left for backward compatibility
304+
// of SetAuthentication method - it was returning self object
304305
c.auth = auth
305306
c.servers = newServerConnections
306307

307-
return c, nil
308+
return &clusterConnection{
309+
connectionBuilder: c.connectionBuilder,
310+
servers: c.servers,
311+
endpoints: c.endpoints,
312+
current: c.current,
313+
mutex: sync.RWMutex{},
314+
defaultTimeout: c.defaultTimeout,
315+
auth: c.auth,
316+
}, nil
308317
}
309318

310319
// Protocols returns all protocols used by this connection.

connection.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import (
3030
velocypack "github.com/arangodb/go-velocypack"
3131
)
3232

33-
// Connection is a connenction to a database server using a specific protocol.
33+
// Connection is a connection to a database server using a specific protocol.
3434
type Connection interface {
3535
// NewRequest creates a new request with given method and path.
3636
NewRequest(method, path string) (Request, error)
@@ -47,7 +47,7 @@ type Connection interface {
4747
// UpdateEndpoints reconfigures the connection to use the given endpoints.
4848
UpdateEndpoints(endpoints []string) error
4949

50-
// Configure the authentication used for this connection.
50+
// SetAuthentication creates a copy of connection wrapper for given auth parameters.
5151
SetAuthentication(Authentication) (Connection, error)
5252

5353
// Protocols returns all protocols used by this connection.

http/authentication.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ func (c *authenticatedConnection) UpdateEndpoints(endpoints []string) error {
246246
return nil
247247
}
248248

249-
// Configure the authentication used for this connection.
249+
// SetAuthentication creates a copy of connection wrapper for given auth parameters.
250250
func (c *authenticatedConnection) SetAuthentication(auth driver.Authentication) (driver.Connection, error) {
251251
result, err := c.conn.SetAuthentication(auth)
252252
if err != nil {

http/connection.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ func (c *httpConnection) UpdateEndpoints(endpoints []string) error {
389389
return nil
390390
}
391391

392-
// Configure the authentication used for this connection.
392+
// SetAuthentication creates a copy of connection wrapper for given auth parameters.
393393
func (c *httpConnection) SetAuthentication(auth driver.Authentication) (driver.Connection, error) {
394394
var httpAuth httpAuthentication
395395
switch auth.Type() {
@@ -471,8 +471,8 @@ func (h *RepeatConnection) UpdateEndpoints(endpoints []string) error {
471471
return h.conn.UpdateEndpoints(endpoints)
472472
}
473473

474-
// Configure the authentication used for this connection.
475-
// Returns ErrAuthenticationNotChanged in when the authentication is not changed.
474+
// SetAuthentication configure the authentication used for this connection.
475+
// Returns ErrAuthenticationNotChanged when the authentication is not changed.
476476
func (h *RepeatConnection) SetAuthentication(authentication driver.Authentication) (driver.Connection, error) {
477477
h.mutex.Lock()
478478
defer h.mutex.Unlock()
@@ -481,16 +481,17 @@ func (h *RepeatConnection) SetAuthentication(authentication driver.Authenticatio
481481
return h, ErrAuthenticationNotChanged
482482
}
483483

484-
_, err := h.conn.SetAuthentication(authentication)
484+
newConn, err := h.conn.SetAuthentication(authentication)
485485
if err != nil {
486486
return nil, driver.WithStack(err)
487487
}
488+
h.conn = newConn
488489
h.auth = authentication
489490

490491
return h, nil
491492
}
492493

493494
// Protocols returns all protocols used by this connection.
494-
func (h RepeatConnection) Protocols() driver.ProtocolSet {
495+
func (h *RepeatConnection) Protocols() driver.ProtocolSet {
495496
return h.conn.Protocols()
496497
}

test/client_test.go

+95
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,19 @@ package test
2525
import (
2626
"context"
2727
"crypto/tls"
28+
"fmt"
2829
"log"
2930
httplib "net/http"
3031
_ "net/http/pprof"
3132
"os"
33+
"runtime"
3234
"strconv"
3335
"strings"
3436
"sync"
3537
"testing"
3638
"time"
3739

40+
"github.com/pkg/errors"
3841
"github.com/stretchr/testify/assert"
3942
"github.com/stretchr/testify/require"
4043

@@ -500,3 +503,95 @@ func TestCreateClientHttpRepeatConnection(t *testing.T) {
500503
require.NoError(t, err)
501504
assert.Equal(t, 2, requestRepeat.counter)
502505
}
506+
507+
// TestClientConnectionReuse checks that reusing same connection with different auth parameters is possible using
508+
func TestClientConnectionReuse(t *testing.T) {
509+
if os.Getenv("TEST_CONNECTION") == "vst" {
510+
t.Skip("not possible with VST connections by design")
511+
return
512+
}
513+
514+
c := createClientFromEnv(t, true)
515+
ctx := context.Background()
516+
517+
prefix := t.Name()
518+
dbUsers := map[string]driver.CreateDatabaseUserOptions{
519+
prefix + "-db1": {UserName: prefix + "-user1", Password: "password1"},
520+
prefix + "-db2": {UserName: prefix + "-user2", Password: "password2"},
521+
}
522+
for dbName, userOptions := range dbUsers {
523+
ensureDatabase(ctx, c, dbName, &driver.CreateDatabaseOptions{
524+
Users: []driver.CreateDatabaseUserOptions{userOptions},
525+
Options: driver.CreateDatabaseDefaultOptions{},
526+
}, t)
527+
}
528+
529+
var wg sync.WaitGroup
530+
const clientsPerDB = 20
531+
startTime := time.Now()
532+
533+
const testDuration = time.Second * 10
534+
if testing.Verbose() {
535+
wg.Add(1)
536+
go func() {
537+
defer wg.Done()
538+
539+
for {
540+
stats, _ := c.Statistics(ctx)
541+
t.Logf("goroutine count: %d, server connections: %d", runtime.NumGoroutine(), stats.Client.HTTPConnections)
542+
if time.Now().Sub(startTime) > testDuration {
543+
break
544+
}
545+
time.Sleep(1 * time.Second)
546+
}
547+
}()
548+
}
549+
550+
conn := createConnection(t, false)
551+
for dbName, userOptions := range dbUsers {
552+
t.Logf("Starting %d goroutines for DB %s ...", clientsPerDB, dbName)
553+
for i := 0; i < clientsPerDB; i++ {
554+
wg.Add(1)
555+
go func(dbName string, userOptions driver.CreateDatabaseUserOptions, conn driver.Connection) {
556+
defer wg.Done()
557+
for {
558+
if time.Now().Sub(startTime) > testDuration {
559+
break
560+
}
561+
562+
// the test will pass only if checkDBAccess is using mutex
563+
err := checkDBAccess(ctx, conn, dbName, userOptions.UserName, userOptions.Password)
564+
require.NoError(t, err)
565+
566+
time.Sleep(10 * time.Millisecond)
567+
}
568+
}(dbName, userOptions, conn)
569+
}
570+
}
571+
wg.Wait()
572+
}
573+
574+
func checkDBAccess(ctx context.Context, conn driver.Connection, dbName, username, password string) error {
575+
client, err := driver.NewClient(driver.ClientConfig{
576+
Connection: conn,
577+
Authentication: driver.BasicAuthentication(username, password),
578+
})
579+
if err != nil {
580+
return err
581+
}
582+
583+
dbExists, err := client.DatabaseExists(ctx, dbName)
584+
if err != nil {
585+
return errors.Wrapf(err, "DatabaseExists failed")
586+
}
587+
if !dbExists {
588+
return fmt.Errorf("db %s must exist for any user", dbName)
589+
}
590+
591+
_, err = client.Database(ctx, dbName)
592+
if err != nil {
593+
return errors.Wrapf(err, "db %s must be accessible for user %s", dbName, username)
594+
}
595+
596+
return nil
597+
}

0 commit comments

Comments
 (0)