From 9e6c56d406ef49dc8ed650980446566cc59df0a6 Mon Sep 17 00:00:00 2001 From: Henry Date: Fri, 14 Dec 2018 17:06:45 +0100 Subject: [PATCH] cleanup Trigger and save passed Value after successful transition --- state_change_log.go | 10 +++- transition.go | 134 +++++++++++++++++++++++++------------------- transition_test.go | 33 ++++++----- 3 files changed, 98 insertions(+), 79 deletions(-) diff --git a/state_change_log.go b/state_change_log.go index e58f461..ae16e16 100644 --- a/state_change_log.go +++ b/state_change_log.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/jinzhu/gorm" + "github.com/pkg/errors" "github.com/qor/admin" "github.com/qor/audited" "github.com/qor/qor/resource" @@ -37,15 +38,18 @@ func GenerateReferenceKey(model interface{}, db *gorm.DB) string { } // GetStateChangeLogs get state change logs -func GetStateChangeLogs(model interface{}, db *gorm.DB) []StateChangeLog { +func GetStateChangeLogs(model interface{}, db *gorm.DB) ([]StateChangeLog, error) { var ( changelogs []StateChangeLog scope = db.NewScope(model) ) - db.Where("refer_table = ? AND refer_id = ?", scope.TableName(), GenerateReferenceKey(model, db)).Find(&changelogs) + err := db.Where("refer_table = ? AND refer_id = ?", scope.TableName(), GenerateReferenceKey(model, db)).Find(&changelogs).Error + if err != nil { + return nil, errors.Wrap(err, "GetStateChangeLogs: sql query failed") + } - return changelogs + return changelogs, nil } // ConfigureQorResource used to configure transition for qor admin diff --git a/transition.go b/transition.go index 8fd0d81..cdb829d 100644 --- a/transition.go +++ b/transition.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/jinzhu/gorm" + "github.com/pkg/errors" "github.com/qor/admin" "github.com/qor/qor/resource" "github.com/qor/roles" @@ -68,10 +69,15 @@ func (sm *StateMachine) Event(name string) *Event { } // Trigger trigger an event -func (sm *StateMachine) Trigger(name string, value Stater, tx *gorm.DB, notes ...string) error { +func (sm *StateMachine) Trigger(name string, value interface{}, tx *gorm.DB, notes ...string) error { + stater, ok := value.(Stater) + if !ok { + return fmt.Errorf("triggerd: passed value does not implement Stater. T:%T", value) + } + var ( newTx *gorm.DB - stateWas = value.GetState() + stateWas = stater.GetState() ) if tx != nil { @@ -80,81 +86,91 @@ func (sm *StateMachine) Trigger(name string, value Stater, tx *gorm.DB, notes .. if stateWas == "" { stateWas = sm.initialState - value.SetState(sm.initialState) + stater.SetState(sm.initialState) + } + + event, ok := sm.events[name] + if !ok { + return fmt.Errorf("trigger: failed to perform event %s from state %s", name, stateWas) } - if event := sm.events[name]; event != nil { - var matchedTransitions []*EventTransition - for _, transition := range event.transitions { - var validFrom = len(transition.froms) == 0 - if len(transition.froms) > 0 { - for _, from := range transition.froms { - if from == stateWas { - validFrom = true - } + var matchedTransitions []*EventTransition + for _, transition := range event.transitions { + var validFrom = len(transition.froms) == 0 + if len(transition.froms) > 0 { + for _, from := range transition.froms { + if from == stateWas { + validFrom = true } } + } - if validFrom { - matchedTransitions = append(matchedTransitions, transition) - } + if validFrom { + matchedTransitions = append(matchedTransitions, transition) } + } - if len(matchedTransitions) == 1 { - transition := matchedTransitions[0] + if n := len(matchedTransitions); n != 1 { + return fmt.Errorf("trigger: did not find one transtion but %d", n) + } - // State: exit - if state, ok := sm.states[stateWas]; ok { - for _, exit := range state.exits { - if err := exit(value, newTx); err != nil { - return err - } - } - } + transition := matchedTransitions[0] - // Transition: before - for _, before := range transition.befores { - if err := before(value, newTx); err != nil { - return err - } + // State: exit + if state, ok := sm.states[stateWas]; ok { + for _, exit := range state.exits { + if err := exit(value, newTx); err != nil { + return err } + } + } - value.SetState(transition.to) + // Transition: before + for _, before := range transition.befores { + if err := before(value, newTx); err != nil { + return err + } + } - // State: enter - if state, ok := sm.states[transition.to]; ok { - for _, enter := range state.enters { - if err := enter(value, newTx); err != nil { - value.SetState(stateWas) - return err - } - } - } + stater.SetState(transition.to) - // Transition: after - for _, after := range transition.afters { - if err := after(value, newTx); err != nil { - value.SetState(stateWas) - return err - } + // State: enter + if state, ok := sm.states[transition.to]; ok { + for _, enter := range state.enters { + if err := enter(value, newTx); err != nil { + stater.SetState(stateWas) + return err } + } + } - if newTx != nil { - scope := newTx.NewScope(value) - log := StateChangeLog{ - ReferTable: scope.TableName(), - ReferID: GenerateReferenceKey(value, tx), - From: stateWas, - To: transition.to, - Note: strings.Join(notes, ""), - } - return newTx.Save(&log).Error - } + // Transition: after + for _, after := range transition.afters { + if err := after(value, newTx); err != nil { + stater.SetState(stateWas) + return err + } + } - return nil + if newTx != nil { + if err := newTx.Save(value).Error; err != nil { + return errors.Wrap(err, "trigger: failed to save value") } + scope := newTx.NewScope(stater) + log := StateChangeLog{ + ReferTable: scope.TableName(), + ReferID: GenerateReferenceKey(stater, tx), + From: stateWas, + To: transition.to, + Note: strings.Join(notes, ""), + } + if err := newTx.Save(&log).Error; err != nil { + return errors.Wrap(err, "trigger: failed to save stateChangeLog") + } + return nil } - return fmt.Errorf("failed to perform event %s from state %s", name, stateWas) + + return nil } // State contains State information, including enter, exit hooks diff --git a/transition_test.go b/transition_test.go index c749287..47786c8 100644 --- a/transition_test.go +++ b/transition_test.go @@ -5,8 +5,6 @@ import ( "testing" "github.com/jinzhu/gorm" - _ "github.com/mattn/go-sqlite3" - "github.com/qor/qor/test/utils" "github.com/qor/transition" ) @@ -49,19 +47,9 @@ func getStateMachine() *transition.StateMachine { return orderStateMachine } -func CreateOrderAndExecuteTransition(transition *transition.StateMachine, event string, order *Order) error { - if err := db.Save(order).Error; err != nil { - return err - } - - if err := transition.Trigger(event, order, db); err != nil { - return err - } - return nil -} - func TestStateTransition(t *testing.T) { order := &Order{} + order.Address = t.Name() if err := getStateMachine().Trigger("checkout", order, db); err != nil { t.Errorf("should not raise any error when trigger event checkout") @@ -71,9 +59,12 @@ func TestStateTransition(t *testing.T) { t.Errorf("state doesn't changed to checkout") } - var stateChangeLogs = transition.GetStateChangeLogs(order, db) - if len(stateChangeLogs) != 1 { - t.Errorf("should get one state change log with GetStateChangeLogs") + stateChangeLogs, err := transition.GetStateChangeLogs(order, db) + if err != nil { + t.Fatal(err) + } + if n := len(stateChangeLogs); n != 1 { + t.Errorf("should get one state change log with GetStateChangeLogs got %d", n) } else { var stateChangeLog = stateChangeLogs[0] @@ -94,7 +85,9 @@ func TestMultipleTransitionWithOneEvent(t *testing.T) { cancellEvent.To("paid_cancelled").From("paid", "processed") unpaidOrder1 := &Order{} + unpaidOrder1.Address = t.Name() + ":unpaid1" if err := orderStateMachine.Trigger("cancel", unpaidOrder1, db); err != nil { + t.Error(err) t.Errorf("should not raise any error when trigger event cancel") } @@ -103,8 +96,10 @@ func TestMultipleTransitionWithOneEvent(t *testing.T) { } unpaidOrder2 := &Order{} + unpaidOrder2.Address = t.Name() + ":unpaid2" unpaidOrder2.State = "draft" if err := orderStateMachine.Trigger("cancel", unpaidOrder2, db); err != nil { + t.Error(err) t.Errorf("should not raise any error when trigger event cancel") } @@ -113,8 +108,10 @@ func TestMultipleTransitionWithOneEvent(t *testing.T) { } paidOrder := &Order{} + paidOrder.Address = t.Name() + ":paid" paidOrder.State = "paid" if err := orderStateMachine.Trigger("cancel", paidOrder, db); err != nil { + t.Error(err) t.Errorf("should not raise any error when trigger event cancel") } @@ -125,7 +122,7 @@ func TestMultipleTransitionWithOneEvent(t *testing.T) { func TestStateCallbacks(t *testing.T) { orderStateMachine := getStateMachine() - order := &Order{} + order := &Order{Address: t.Name()} address1 := "I'm an address should be set when enter checkout" address2 := "I'm an address should be set when exit checkout" @@ -138,6 +135,7 @@ func TestStateCallbacks(t *testing.T) { }) if err := orderStateMachine.Trigger("checkout", order, db); err != nil { + t.Error(err) t.Errorf("should not raise any error when trigger event checkout") } @@ -146,6 +144,7 @@ func TestStateCallbacks(t *testing.T) { } if err := orderStateMachine.Trigger("pay", order, db); err != nil { + t.Error(err) t.Errorf("should not raise any error when trigger event pay") }