forked from jinzhu/gorm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,461 additions
and
1,557 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
package gorm_test | ||
|
||
import ( | ||
"errors" | ||
|
||
"github.com/jinzhu/gorm" | ||
|
||
"reflect" | ||
"testing" | ||
) | ||
|
||
func (s *Product) BeforeCreate() (err error) { | ||
if s.Code == "Invalid" { | ||
err = errors.New("invalid product") | ||
} | ||
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 | ||
return | ||
} | ||
|
||
func (s *Product) BeforeUpdate() (err error) { | ||
if s.Code == "dont_update" { | ||
err = errors.New("can't update") | ||
} | ||
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 | ||
return | ||
} | ||
|
||
func (s *Product) BeforeSave() (err error) { | ||
if s.Code == "dont_save" { | ||
err = errors.New("can't save") | ||
} | ||
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 | ||
return | ||
} | ||
|
||
func (s *Product) AfterFind() { | ||
s.AfterFindCallTimes = s.AfterFindCallTimes + 1 | ||
} | ||
|
||
func (s *Product) AfterCreate(db *gorm.DB) { | ||
db.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) | ||
} | ||
|
||
func (s *Product) AfterUpdate() { | ||
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 | ||
} | ||
|
||
func (s *Product) AfterSave() (err error) { | ||
if s.Code == "after_save_error" { | ||
err = errors.New("can't save") | ||
} | ||
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 | ||
return | ||
} | ||
|
||
func (s *Product) BeforeDelete() (err error) { | ||
if s.Code == "dont_delete" { | ||
err = errors.New("can't delete") | ||
} | ||
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 | ||
return | ||
} | ||
|
||
func (s *Product) AfterDelete() (err error) { | ||
if s.Code == "after_delete_error" { | ||
err = errors.New("can't delete") | ||
} | ||
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 | ||
return | ||
} | ||
|
||
func (s *Product) GetCallTimes() []int64 { | ||
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} | ||
} | ||
|
||
func TestRunCallbacks(t *testing.T) { | ||
p := Product{Code: "unique_code", Price: 100} | ||
db.Save(&p) | ||
|
||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { | ||
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||
} | ||
|
||
db.Where("Code = ?", "unique_code").First(&p) | ||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { | ||
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) | ||
} | ||
|
||
p.Price = 200 | ||
db.Save(&p) | ||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { | ||
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||
} | ||
|
||
var products []Product | ||
db.Find(&products, "code = ?", "unique_code") | ||
if products[0].AfterFindCallTimes != 2 { | ||
t.Errorf("AfterFind callbacks should work with slice") | ||
} | ||
|
||
db.Where("Code = ?", "unique_code").First(&p) | ||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { | ||
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) | ||
} | ||
|
||
db.Delete(&p) | ||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { | ||
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||
} | ||
|
||
if db.Where("Code = ?", "unique_code").First(&p).Error == nil { | ||
t.Errorf("Can't find a deleted record") | ||
} | ||
} | ||
|
||
func TestCallbacksWithErrors(t *testing.T) { | ||
p := Product{Code: "Invalid", Price: 100} | ||
if db.Save(&p).Error == nil { | ||
t.Errorf("An error from before create callbacks happened when create with invalid value") | ||
} | ||
|
||
if db.Where("code = ?", "Invalid").First(&Product{}).Error == nil { | ||
t.Errorf("Should not save record that have errors") | ||
} | ||
|
||
if db.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { | ||
t.Errorf("An error from after create callbacks happened when create with invalid value") | ||
} | ||
|
||
p2 := Product{Code: "update_callback", Price: 100} | ||
db.Save(&p2) | ||
|
||
p2.Code = "dont_update" | ||
if db.Save(&p2).Error == nil { | ||
t.Errorf("An error from before update callbacks happened when update with invalid value") | ||
} | ||
|
||
if db.Where("code = ?", "update_callback").First(&Product{}).Error != nil { | ||
t.Errorf("Record Should not be updated due to errors happened in before update callback") | ||
} | ||
|
||
if db.Where("code = ?", "dont_update").First(&Product{}).Error == nil { | ||
t.Errorf("Record Should not be updated due to errors happened in before update callback") | ||
} | ||
|
||
p2.Code = "dont_save" | ||
if db.Save(&p2).Error == nil { | ||
t.Errorf("An error from before save callbacks happened when update with invalid value") | ||
} | ||
|
||
p3 := Product{Code: "dont_delete", Price: 100} | ||
db.Save(&p3) | ||
if db.Delete(&p3).Error == nil { | ||
t.Errorf("An error from before delete callbacks happened when delete") | ||
} | ||
|
||
if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil { | ||
t.Errorf("An error from before delete callbacks happened") | ||
} | ||
|
||
p4 := Product{Code: "after_save_error", Price: 100} | ||
db.Save(&p4) | ||
if err := db.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { | ||
t.Errorf("Record should be reverted if get an error in after save callback") | ||
} | ||
|
||
p5 := Product{Code: "after_delete_error", Price: 100} | ||
db.Save(&p5) | ||
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { | ||
t.Errorf("Record should be found") | ||
} | ||
|
||
db.Delete(&p5) | ||
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { | ||
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.