Skip to content

Commit 50971c1

Browse files
author
ffffwh
committed
refactor base.ShowCreateTable
1 parent 334d23e commit 50971c1

File tree

3 files changed

+18
-63
lines changed

3 files changed

+18
-63
lines changed

drivers/mysql/mysql/base/utils.go

+8-17
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ func ParseBinlogCoordinatesFromRow(row *sql.Row) (r *common.BinlogCoordinatesX,
7676

7777
// GetTableColumns reads column list from given table
7878
func GetTableColumns(db usql.QueryAble, databaseName, tableName string) (*common.ColumnList, error) {
79-
query := fmt.Sprintf(`show columns from %s.%s`,
80-
umconf.EscapeName(databaseName),
81-
umconf.EscapeName(tableName),
82-
)
79+
databaseNameEscaped := umconf.EscapeName(databaseName)
80+
tableNameEscaped := umconf.EscapeName(tableName)
81+
82+
query := fmt.Sprintf(`show columns from %s.%s`, databaseNameEscaped, tableNameEscaped)
8383
columns := []umconf.Column{}
8484
err := usql.QueryRowsMap(db, query, func(rowMap usql.RowMap) error {
8585
aColumn := umconf.Column{
@@ -97,11 +97,9 @@ func GetTableColumns(db usql.QueryAble, databaseName, tableName string) (*common
9797
return nil, err
9898
}
9999
if len(columns) == 0 {
100-
return nil, fmt.Errorf("Found 0 columns on %s.%s. Bailing out",
101-
umconf.EscapeName(databaseName),
102-
umconf.EscapeName(tableName),
103-
)
100+
return nil, fmt.Errorf("found 0 columns on %s.%s", databaseName, tableName, )
104101
}
102+
105103
return common.NewColumnList(columns), nil
106104
}
107105

@@ -129,18 +127,11 @@ func GetSomeSysVars(db usql.QueryAble, logger g.LoggerType) (r struct {
129127
return r
130128
}
131129

132-
func ShowCreateTable(db *gosql.DB, databaseName, tableName string, dropTableIfExists bool, addUse bool) (statement []string, err error) {
130+
func ShowCreateTable(db usql.QueryAble, databaseName, tableName string) (statement string, err error) {
133131
var dummy, createTableStatement string
134132
query := fmt.Sprintf(`show create table %s.%s`, umconf.EscapeName(databaseName), umconf.EscapeName(tableName))
135133
err = db.QueryRow(query).Scan(&dummy, &createTableStatement)
136-
if addUse {
137-
statement = append(statement, fmt.Sprintf("USE %s", umconf.EscapeName(databaseName)))
138-
}
139-
if dropTableIfExists {
140-
statement = append(statement, fmt.Sprintf("DROP TABLE IF EXISTS %s", umconf.EscapeName(tableName)))
141-
}
142-
statement = append(statement, createTableStatement)
143-
return statement, err
134+
return createTableStatement, err
144135
}
145136

146137
func ShowCreateView(db *gosql.DB, databaseName, tableName string, dropTableIfExists bool) (createTableStatement string, err error) {

drivers/mysql/mysql/base/utils_test.go

-36
Original file line numberDiff line numberDiff line change
@@ -204,42 +204,6 @@ func TestApplyColumnTypes(t *testing.T) {
204204
}
205205
}
206206

207-
func TestShowCreateTable(t *testing.T) {
208-
type args struct {
209-
db *gosql.DB
210-
databaseName string
211-
tableName string
212-
dropTableIfExists bool
213-
addUse bool
214-
}
215-
tests := []struct {
216-
name string
217-
args args
218-
wantCreateTableStatement string
219-
wantErr bool
220-
}{
221-
// TODO: Add test cases.
222-
}
223-
for _, tt := range tests {
224-
t.Run(tt.name, func(t *testing.T) {
225-
gotCreateTableStatement, err := ShowCreateTable(tt.args.db, tt.args.databaseName, tt.args.tableName, tt.args.dropTableIfExists, tt.args.addUse)
226-
if (err != nil) != tt.wantErr {
227-
t.Errorf("ShowCreateTable() error = %v, wantErr %v", err, tt.wantErr)
228-
return
229-
}
230-
exist := false
231-
for _, createTableStatement := range gotCreateTableStatement {
232-
if createTableStatement == tt.wantCreateTableStatement {
233-
exist = true
234-
}
235-
}
236-
if !exist {
237-
t.Errorf("ShowCreateTable() = %v, want %v", gotCreateTableStatement, tt.wantCreateTableStatement)
238-
}
239-
})
240-
}
241-
}
242-
243207
func Test_stringInterval(t *testing.T) {
244208
type args struct {
245209
intervals gomysql.IntervalSlice

drivers/mysql/mysql/extractor.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -764,12 +764,11 @@ func (e *Extractor) getSchemaTablesAndMeta() error {
764764
continue
765765
}
766766

767-
stmts, err := base.ShowCreateTable(e.db, db.TableSchema, tb.TableName, false, false)
767+
stmt, err := base.ShowCreateTable(e.db, db.TableSchema, tb.TableName)
768768
if err != nil {
769769
e.logger.Error("error at ShowCreateTable.", "err", err)
770770
return err
771771
}
772-
stmt := stmts[0]
773772
ast, err := sqle.ParseCreateTableStmt("mysql", stmt)
774773
if err != nil {
775774
e.logger.Error("error at ParseCreateTableStmt.", "err", err)
@@ -1354,18 +1353,19 @@ func (e *Extractor) mysqlDump() error {
13541353
return err
13551354
}*/
13561355
} else if strings.ToLower(tb.TableSchema) != "mysql" {
1357-
tbSQL, err = base.ShowCreateTable(e.singletonDB, tb.TableSchema, tb.TableName, e.mysqlContext.DropTableIfExists, true)
1358-
for num, sql := range tbSQL {
1359-
if db.TableSchemaRename != "" && strings.Contains(sql, fmt.Sprintf("USE %s", mysqlconfig.EscapeName(tb.TableSchema))) {
1360-
tbSQL[num] = strings.Replace(sql, tb.TableSchema, db.TableSchemaRename, 1)
1361-
}
1362-
if tb.TableRename != "" && (strings.Contains(sql, fmt.Sprintf("DROP TABLE IF EXISTS %s", mysqlconfig.EscapeName(tb.TableName))) || strings.Contains(sql, "CREATE TABLE")) {
1363-
tbSQL[num] = strings.Replace(sql, mysqlconfig.EscapeName(tb.TableName), tb.TableRename, 1)
1364-
}
1356+
targetSchemaEscaped := mysqlconfig.EscapeName(g.StringElse(db.TableSchemaRename, tb.TableSchema))
1357+
targetTableEscaped := mysqlconfig.EscapeName(g.StringElse(tb.TableRename, tb.TableName))
1358+
tbSQL = append(tbSQL, fmt.Sprintf("USE %s", targetSchemaEscaped))
1359+
if e.mysqlContext.DropTableIfExists {
1360+
tbSQL = append(tbSQL, fmt.Sprintf("DROP TABLE IF EXISTS %s", targetTableEscaped))
13651361
}
1362+
ctStmt, err := base.ShowCreateTable(e.singletonDB, tb.TableSchema, tb.TableName)
13661363
if err != nil {
13671364
return err
13681365
}
1366+
// TODO do not use string replace
1367+
ctStmt = strings.Replace(ctStmt, mysqlconfig.EscapeName(tb.TableName), targetTableEscaped, 1)
1368+
tbSQL = append(tbSQL, ctStmt)
13691369
}
13701370
entry := &common.DumpEntry{
13711371
TbSQL: tbSQL,

0 commit comments

Comments
 (0)