Skip to content

Commit 0183433

Browse files
committed
Merge pull request #245 from go-sql-driver/dial
Registration of custom dial functions
2 parents 11fe4e6 + 9d8f29c commit 0183433

File tree

4 files changed

+51
-15
lines changed

4 files changed

+51
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Changes:
1616

1717
New Features:
1818

19+
- `RegisterDial` allows the usage of a custom dial function to establish the network connection
1920
- Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter
2021
- Logging of critical errors is configurable with `SetLogger`
2122
- Google CloudSQL support

appengine.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,5 @@ import (
1616
)
1717

1818
func init() {
19-
if dials == nil {
20-
dials = make(map[string]dialFunc)
21-
}
22-
dials["cloudsql"] = func(cfg *config) (net.Conn, error) {
23-
return cloudsql.Dial(cfg.addr)
24-
}
19+
RegisterDial("cloudsql", cloudsql.Dial)
2520
}

driver.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,21 @@ import (
2626
// In general the driver is used via the database/sql package.
2727
type MySQLDriver struct{}
2828

29-
type dialFunc func(*config) (net.Conn, error)
30-
31-
var dials map[string]dialFunc
29+
// DialFunc is a function which can be used to establish the network connection.
30+
// Custom dial functions must be registered with RegisterDial
31+
type DialFunc func(addr string) (net.Conn, error)
32+
33+
var dials map[string]DialFunc
34+
35+
// RegisterDial registers a custom dial function. It can then be used by the
36+
// network address mynet(addr), where mynet is the registered new network.
37+
// addr is passed as a parameter to the dial function.
38+
func RegisterDial(net string, dial DialFunc) {
39+
if dials == nil {
40+
dials = make(map[string]DialFunc)
41+
}
42+
dials[net] = dial
43+
}
3244

3345
// Open new Connection.
3446
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
@@ -48,7 +60,7 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
4860

4961
// Connect to Server
5062
if dial, ok := dials[mc.cfg.net]; ok {
51-
mc.netConn, err = dial(mc.cfg)
63+
mc.netConn, err = dial(mc.cfg.addr)
5264
} else {
5365
nd := net.Dialer{Timeout: mc.cfg.timeout}
5466
mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)

driver_test.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ import (
2626
)
2727

2828
var (
29+
user string
30+
pass string
31+
prot string
32+
addr string
33+
dbname string
2934
dsn string
3035
netAddr string
3136
available bool
@@ -43,17 +48,18 @@ var (
4348

4449
// See https://github.com/go-sql-driver/mysql/wiki/Testing
4550
func init() {
51+
// get environment variables
4652
env := func(key, defaultValue string) string {
4753
if value := os.Getenv(key); value != "" {
4854
return value
4955
}
5056
return defaultValue
5157
}
52-
user := env("MYSQL_TEST_USER", "root")
53-
pass := env("MYSQL_TEST_PASS", "")
54-
prot := env("MYSQL_TEST_PROT", "tcp")
55-
addr := env("MYSQL_TEST_ADDR", "localhost:3306")
56-
dbname := env("MYSQL_TEST_DBNAME", "gotest")
58+
user = env("MYSQL_TEST_USER", "root")
59+
pass = env("MYSQL_TEST_PASS", "")
60+
prot = env("MYSQL_TEST_PROT", "tcp")
61+
addr = env("MYSQL_TEST_ADDR", "localhost:3306")
62+
dbname = env("MYSQL_TEST_DBNAME", "gotest")
5763
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
5864
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
5965
c, err := net.Dial(prot, addr)
@@ -1340,3 +1346,25 @@ func TestConcurrent(t *testing.T) {
13401346
dbt.Logf("Reached %d concurrent connections\r\n", succeeded)
13411347
})
13421348
}
1349+
1350+
// Tests custom dial functions
1351+
func TestCustomDial(t *testing.T) {
1352+
if !available {
1353+
t.Skipf("MySQL-Server not running on %s", netAddr)
1354+
}
1355+
1356+
// our custom dial function which justs wraps net.Dial here
1357+
RegisterDial("mydial", func(addr string) (net.Conn, error) {
1358+
return net.Dial(prot, addr)
1359+
})
1360+
1361+
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
1362+
if err != nil {
1363+
t.Fatalf("Error connecting: %s", err.Error())
1364+
}
1365+
defer db.Close()
1366+
1367+
if _, err = db.Exec("DO 1"); err != nil {
1368+
t.Fatalf("Connection failed: %s", err.Error())
1369+
}
1370+
}

0 commit comments

Comments
 (0)