Skip to content

Commit

Permalink
Refact tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jul 29, 2014
1 parent b929a08 commit 468e54f
Show file tree
Hide file tree
Showing 10 changed files with 1,461 additions and 1,557 deletions.
177 changes: 177 additions & 0 deletions callbacks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package gorm_test

import (
"errors"

"github.com/jinzhu/gorm"

"reflect"
"testing"
)

func (s *Product) BeforeCreate() (err error) {
if s.Code == "Invalid" {
err = errors.New("invalid product")
}
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return
}

func (s *Product) BeforeUpdate() (err error) {
if s.Code == "dont_update" {
err = errors.New("can't update")
}
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return
}

func (s *Product) BeforeSave() (err error) {
if s.Code == "dont_save" {
err = errors.New("can't save")
}
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return
}

func (s *Product) AfterFind() {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
}

func (s *Product) AfterCreate(db *gorm.DB) {
db.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
}

func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}

func (s *Product) AfterSave() (err error) {
if s.Code == "after_save_error" {
err = errors.New("can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
}

func (s *Product) BeforeDelete() (err error) {
if s.Code == "dont_delete" {
err = errors.New("can't delete")
}
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return
}

func (s *Product) AfterDelete() (err error) {
if s.Code == "after_delete_error" {
err = errors.New("can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
}

func (s *Product) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
}

func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100}
db.Save(&p)

if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
}

db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
}

p.Price = 200
db.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}

var products []Product
db.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should work with slice")
}

db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
}

db.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}

if db.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Errorf("Can't find a deleted record")
}
}

func TestCallbacksWithErrors(t *testing.T) {
p := Product{Code: "Invalid", Price: 100}
if db.Save(&p).Error == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value")
}

if db.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
t.Errorf("Should not save record that have errors")
}

if db.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value")
}

p2 := Product{Code: "update_callback", Price: 100}
db.Save(&p2)

p2.Code = "dont_update"
if db.Save(&p2).Error == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value")
}

if db.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}

if db.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}

p2.Code = "dont_save"
if db.Save(&p2).Error == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value")
}

p3 := Product{Code: "dont_delete", Price: 100}
db.Save(&p3)
if db.Delete(&p3).Error == nil {
t.Errorf("An error from before delete callbacks happened when delete")
}

if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("An error from before delete callbacks happened")
}

p4 := Product{Code: "after_save_error", Price: 100}
db.Save(&p4)
if err := db.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error in after save callback")
}

p5 := Product{Code: "after_delete_error", Price: 100}
db.Save(&p5)
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found")
}

db.Delete(&p5)
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}
48 changes: 47 additions & 1 deletion create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
)

func TestCreate(t *testing.T) {
user := User{Name: "1", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}}
float := 35.03554004971999
user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}

if !db.NewRecord(user) || !db.NewRecord(&user) {
t.Error("User should be new record before create")
Expand Down Expand Up @@ -35,4 +36,49 @@ func TestCreate(t *testing.T) {
if newUser.UserNum != Num(111) {
t.Errorf("User's UserNum should be saved (custom type)")
}

if newUser.Latitude != float {
t.Errorf("Float64 should not be changed after save")
}

if user.CreatedAt.IsZero() {
t.Errorf("Should have created_at after create")
}

if newUser.CreatedAt.IsZero() {
t.Errorf("Should have created_at after create")
}

db.Model(user).Update("name", "create_user_new_name")
db.First(&user, user.Id)
if user.CreatedAt != newUser.CreatedAt {
t.Errorf("CreatedAt should not be changed after update")
}
}

func TestAnonymousScanner(t *testing.T) {
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
db.Save(&user)

var user2 User
db.First(&user2, "name = ?", "anonymous_scanner")
if user2.Role.Name != "admin" {
t.Errorf("Should be able to get anonymous scanner")
}

if !user2.IsAdmin() {
t.Errorf("Should be able to get anonymous scanner")
}
}

func TestAnonymousField(t *testing.T) {
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
db.Save(&user)

var user2 User
db.First(&user2, "name = ?", "anonymous_field")
db.Model(&user2).Related(&user2.Company)
if user2.Company.Name != "company" {
t.Errorf("Should be able to get anonymous field")
}
}
Loading

0 comments on commit 468e54f

Please sign in to comment.