@@ -58,13 +58,11 @@ type Database interface {
5858 Update (table Table ) updateWithSet
5959 // Initiate a DELETE FROM statement
6060 DeleteFrom (table Table ) deleteWithTable
61+ }
6162
62- // Begin Start a new transaction and returning a Transaction object.
63- // the DDL operations using the returned Transaction object will
64- // regard as one time transaction.
65- // User must manually call Commit() or Rollback() to end the transaction,
66- // after that, more DDL operations or TCL will return error.
67- Begin () (Transaction , error )
63+ type txOrDB interface {
64+ QueryContext (ctx context.Context , query string , args ... interface {}) (* sql.Rows , error )
65+ ExecContext (ctx context.Context , query string , args ... interface {}) (sql.Result , error )
6866}
6967
7068var (
7472
7573type database struct {
7674 db * sql.DB
75+ tx * sql.Tx
7776 logger LoggerFunc
7877 dialect dialect
7978 retryPolicy func (error ) bool
@@ -186,6 +185,13 @@ func (d database) GetDB() *sql.DB {
186185 return d .db
187186}
188187
188+ func (d database ) getTxOrDB () txOrDB {
189+ if d .tx != nil {
190+ return d .tx
191+ }
192+ return d .db
193+ }
194+
189195func (d database ) Query (sqlString string ) (Cursor , error ) {
190196 return d .QueryContext (context .Background (), sqlString )
191197}
@@ -196,7 +202,7 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e
196202 sqlStringWithCallerInfo := getCallerInfo (d , isRetry ) + sqlString
197203 rows , err := d .queryContextOnce (ctx , sqlStringWithCallerInfo , isRetry )
198204 if err != nil {
199- isRetry = d .retryPolicy != nil && d .retryPolicy (err )
205+ isRetry = d .tx == nil && d . retryPolicy != nil && d .retryPolicy (err )
200206 if isRetry {
201207 continue
202208 }
@@ -221,7 +227,7 @@ func (d database) queryContextOnce(ctx context.Context, sqlString string, retry
221227 interceptor := d .interceptor
222228 var rows * sql.Rows
223229 invoker := func (ctx context.Context , sql string ) (err error ) {
224- rows , err = d .GetDB ().QueryContext (ctx , sql )
230+ rows , err = d .getTxOrDB ().QueryContext (ctx , sql )
225231 return
226232 }
227233
@@ -258,7 +264,7 @@ func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Res
258264
259265 var result sql.Result
260266 invoker := func (ctx context.Context , sql string ) (err error ) {
261- result , err = d .GetDB ().ExecContext (ctx , sql )
267+ result , err = d .getTxOrDB ().ExecContext (ctx , sql )
262268 return
263269 }
264270 var err error
0 commit comments