Skip to content

cleanup Trigger and save passed Value after successful transition #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions state_change_log.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/qor/admin"
"github.com/qor/audited"
"github.com/qor/qor/resource"
Expand Down Expand Up @@ -37,15 +38,18 @@ func GenerateReferenceKey(model interface{}, db *gorm.DB) string {
}

// GetStateChangeLogs get state change logs
func GetStateChangeLogs(model interface{}, db *gorm.DB) []StateChangeLog {
func GetStateChangeLogs(model interface{}, db *gorm.DB) ([]StateChangeLog, error) {
var (
changelogs []StateChangeLog
scope = db.NewScope(model)
)

db.Where("refer_table = ? AND refer_id = ?", scope.TableName(), GenerateReferenceKey(model, db)).Find(&changelogs)
err := db.Where("refer_table = ? AND refer_id = ?", scope.TableName(), GenerateReferenceKey(model, db)).Find(&changelogs).Error
if err != nil {
return nil, errors.Wrap(err, "GetStateChangeLogs: sql query failed")
}

return changelogs
return changelogs, nil
}

// ConfigureQorResource used to configure transition for qor admin
Expand Down
134 changes: 75 additions & 59 deletions transition.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/qor/admin"
"github.com/qor/qor/resource"
"github.com/qor/roles"
Expand Down Expand Up @@ -68,10 +69,15 @@ func (sm *StateMachine) Event(name string) *Event {
}

// Trigger trigger an event
func (sm *StateMachine) Trigger(name string, value Stater, tx *gorm.DB, notes ...string) error {
func (sm *StateMachine) Trigger(name string, value interface{}, tx *gorm.DB, notes ...string) error {
stater, ok := value.(Stater)
if !ok {
return fmt.Errorf("triggerd: passed value does not implement Stater. T:%T", value)
}

var (
newTx *gorm.DB
stateWas = value.GetState()
stateWas = stater.GetState()
)

if tx != nil {
Expand All @@ -80,81 +86,91 @@ func (sm *StateMachine) Trigger(name string, value Stater, tx *gorm.DB, notes ..

if stateWas == "" {
stateWas = sm.initialState
value.SetState(sm.initialState)
stater.SetState(sm.initialState)
}

event, ok := sm.events[name]
if !ok {
return fmt.Errorf("trigger: failed to perform event %s from state %s", name, stateWas)
}

if event := sm.events[name]; event != nil {
var matchedTransitions []*EventTransition
for _, transition := range event.transitions {
var validFrom = len(transition.froms) == 0
if len(transition.froms) > 0 {
for _, from := range transition.froms {
if from == stateWas {
validFrom = true
}
var matchedTransitions []*EventTransition
for _, transition := range event.transitions {
var validFrom = len(transition.froms) == 0
if len(transition.froms) > 0 {
for _, from := range transition.froms {
if from == stateWas {
validFrom = true
}
}
}

if validFrom {
matchedTransitions = append(matchedTransitions, transition)
}
if validFrom {
matchedTransitions = append(matchedTransitions, transition)
}
}

if len(matchedTransitions) == 1 {
transition := matchedTransitions[0]
if n := len(matchedTransitions); n != 1 {
return fmt.Errorf("trigger: did not find one transtion but %d", n)
}

// State: exit
if state, ok := sm.states[stateWas]; ok {
for _, exit := range state.exits {
if err := exit(value, newTx); err != nil {
return err
}
}
}
transition := matchedTransitions[0]

// Transition: before
for _, before := range transition.befores {
if err := before(value, newTx); err != nil {
return err
}
// State: exit
if state, ok := sm.states[stateWas]; ok {
for _, exit := range state.exits {
if err := exit(value, newTx); err != nil {
return err
}
}
}

value.SetState(transition.to)
// Transition: before
for _, before := range transition.befores {
if err := before(value, newTx); err != nil {
return err
}
}

// State: enter
if state, ok := sm.states[transition.to]; ok {
for _, enter := range state.enters {
if err := enter(value, newTx); err != nil {
value.SetState(stateWas)
return err
}
}
}
stater.SetState(transition.to)

// Transition: after
for _, after := range transition.afters {
if err := after(value, newTx); err != nil {
value.SetState(stateWas)
return err
}
// State: enter
if state, ok := sm.states[transition.to]; ok {
for _, enter := range state.enters {
if err := enter(value, newTx); err != nil {
stater.SetState(stateWas)
return err
}
}
}

if newTx != nil {
scope := newTx.NewScope(value)
log := StateChangeLog{
ReferTable: scope.TableName(),
ReferID: GenerateReferenceKey(value, tx),
From: stateWas,
To: transition.to,
Note: strings.Join(notes, ""),
}
return newTx.Save(&log).Error
}
// Transition: after
for _, after := range transition.afters {
if err := after(value, newTx); err != nil {
stater.SetState(stateWas)
return err
}
}

return nil
if newTx != nil {
if err := newTx.Save(value).Error; err != nil {
return errors.Wrap(err, "trigger: failed to save value")
}
scope := newTx.NewScope(stater)
log := StateChangeLog{
ReferTable: scope.TableName(),
ReferID: GenerateReferenceKey(stater, tx),
From: stateWas,
To: transition.to,
Note: strings.Join(notes, ""),
}
if err := newTx.Save(&log).Error; err != nil {
return errors.Wrap(err, "trigger: failed to save stateChangeLog")
}
return nil
}
return fmt.Errorf("failed to perform event %s from state %s", name, stateWas)

return nil
}

// State contains State information, including enter, exit hooks
Expand Down
33 changes: 16 additions & 17 deletions transition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"testing"

"github.com/jinzhu/gorm"
_ "github.com/mattn/go-sqlite3"

"github.com/qor/qor/test/utils"
"github.com/qor/transition"
)
Expand Down Expand Up @@ -49,19 +47,9 @@ func getStateMachine() *transition.StateMachine {
return orderStateMachine
}

func CreateOrderAndExecuteTransition(transition *transition.StateMachine, event string, order *Order) error {
if err := db.Save(order).Error; err != nil {
return err
}

if err := transition.Trigger(event, order, db); err != nil {
return err
}
return nil
}

func TestStateTransition(t *testing.T) {
order := &Order{}
order.Address = t.Name()

if err := getStateMachine().Trigger("checkout", order, db); err != nil {
t.Errorf("should not raise any error when trigger event checkout")
Expand All @@ -71,9 +59,12 @@ func TestStateTransition(t *testing.T) {
t.Errorf("state doesn't changed to checkout")
}

var stateChangeLogs = transition.GetStateChangeLogs(order, db)
if len(stateChangeLogs) != 1 {
t.Errorf("should get one state change log with GetStateChangeLogs")
stateChangeLogs, err := transition.GetStateChangeLogs(order, db)
if err != nil {
t.Fatal(err)
}
if n := len(stateChangeLogs); n != 1 {
t.Errorf("should get one state change log with GetStateChangeLogs got %d", n)
} else {
var stateChangeLog = stateChangeLogs[0]

Expand All @@ -94,7 +85,9 @@ func TestMultipleTransitionWithOneEvent(t *testing.T) {
cancellEvent.To("paid_cancelled").From("paid", "processed")

unpaidOrder1 := &Order{}
unpaidOrder1.Address = t.Name() + ":unpaid1"
if err := orderStateMachine.Trigger("cancel", unpaidOrder1, db); err != nil {
t.Error(err)
t.Errorf("should not raise any error when trigger event cancel")
}

Expand All @@ -103,8 +96,10 @@ func TestMultipleTransitionWithOneEvent(t *testing.T) {
}

unpaidOrder2 := &Order{}
unpaidOrder2.Address = t.Name() + ":unpaid2"
unpaidOrder2.State = "draft"
if err := orderStateMachine.Trigger("cancel", unpaidOrder2, db); err != nil {
t.Error(err)
t.Errorf("should not raise any error when trigger event cancel")
}

Expand All @@ -113,8 +108,10 @@ func TestMultipleTransitionWithOneEvent(t *testing.T) {
}

paidOrder := &Order{}
paidOrder.Address = t.Name() + ":paid"
paidOrder.State = "paid"
if err := orderStateMachine.Trigger("cancel", paidOrder, db); err != nil {
t.Error(err)
t.Errorf("should not raise any error when trigger event cancel")
}

Expand All @@ -125,7 +122,7 @@ func TestMultipleTransitionWithOneEvent(t *testing.T) {

func TestStateCallbacks(t *testing.T) {
orderStateMachine := getStateMachine()
order := &Order{}
order := &Order{Address: t.Name()}

address1 := "I'm an address should be set when enter checkout"
address2 := "I'm an address should be set when exit checkout"
Expand All @@ -138,6 +135,7 @@ func TestStateCallbacks(t *testing.T) {
})

if err := orderStateMachine.Trigger("checkout", order, db); err != nil {
t.Error(err)
t.Errorf("should not raise any error when trigger event checkout")
}

Expand All @@ -146,6 +144,7 @@ func TestStateCallbacks(t *testing.T) {
}

if err := orderStateMachine.Trigger("pay", order, db); err != nil {
t.Error(err)
t.Errorf("should not raise any error when trigger event pay")
}

Expand Down