Skip to content

Commit

Permalink
allow to expect prepared statement to be closed, closes #89
Browse files Browse the repository at this point in the history
  • Loading branch information
l3pp4rd committed Sep 1, 2017
1 parent c91a7f4 commit 9d03611
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ It only asserts that argument is of `time.Time` type.

## Change Log

- **2017-09-01** - it is now possible to expect that prepared statement will be closed,
using **ExpectedPrepare.WillBeClosed**.
- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct
but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now
accept multiple row sets.
Expand Down
19 changes: 14 additions & 5 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,13 @@ func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
// Returned by *Sqlmock.ExpectPrepare.
type ExpectedPrepare struct {
commonExpectation
mock *sqlmock
sqlRegex *regexp.Regexp
statement driver.Stmt
closeErr error
delay time.Duration
mock *sqlmock
sqlRegex *regexp.Regexp
statement driver.Stmt
closeErr error
mustBeClosed bool
wasClosed bool
delay time.Duration
}

// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
Expand All @@ -278,6 +280,13 @@ func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare
return e
}

// WillBeClosed expects this prepared statement to
// be closed.
func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
e.mustBeClosed = true
return e
}

// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
// this method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
Expand Down
9 changes: 8 additions & 1 deletion sqlmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ func (c *sqlmock) ExpectationsWereMet() error {
if !e.fulfilled() {
return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
}

// for expected prepared statement check whether it was closed if expected
if prep, ok := e.(*ExpectedPrepare); ok {
if prep.mustBeClosed && !prep.wasClosed {
return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep)
}
}
}
return nil
}
Expand Down Expand Up @@ -302,7 +309,7 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
}

time.Sleep(ex.delay)
return &statement{c, query, ex.closeErr}, nil
return &statement{c, ex, query}, nil
}

func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
Expand Down
30 changes: 30 additions & 0 deletions sqlmock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1033,3 +1033,33 @@ func TestExpectedBeginOrder(t *testing.T) {
t.Error("an error was expected when calling close, but got none")
}
}

func TestPreparedStatementCloseExpectation(t *testing.T) {
// Open new mock database
db, mock, err := New()
if err != nil {
fmt.Println("error creating mock database")
return
}
defer db.Close()

ep := mock.ExpectPrepare("INSERT INTO ORDERS").WillBeClosed()
ep.ExpectExec().WillReturnResult(NewResult(1, 1))

stmt, err := db.Prepare("INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)")
if err != nil {
t.Fatal(err)
}

if _, err := stmt.Exec(1, "Hello"); err != nil {
t.Fatal(err)
}

if err := stmt.Close(); err != nil {
t.Fatal(err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
5 changes: 3 additions & 2 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (

type statement struct {
conn *sqlmock
ex *ExpectedPrepare
query string
err error
}

func (stmt *statement) Close() error {
return stmt.err
stmt.ex.wasClosed = true
return stmt.ex.closeErr
}

func (stmt *statement) NumInput() int {
Expand Down

0 comments on commit 9d03611

Please sign in to comment.