Skip to content

Commit

Permalink
Merge pull request go-gorm#2721 from rubensayshi/isforeignkeyrace
Browse files Browse the repository at this point in the history
fix a race condition on IsForeignKey that is being detected by -race
  • Loading branch information
emirb authored Oct 28, 2019
2 parents 4bd5638 + 8420e32 commit 2586a05
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 2 deletions.
19 changes: 17 additions & 2 deletions model_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}

// lock for mutating global cached model metadata
var structsLock sync.Mutex

// global cache of model metadata
var modelStructsMap sync.Map

// ModelStruct model definition
Expand Down Expand Up @@ -419,8 +423,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
// source foreign keys
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true
structsLock.Unlock()

// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)

Expand Down Expand Up @@ -523,8 +531,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true
// source foreign keys
structsLock.Unlock()

// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)

Expand Down Expand Up @@ -582,7 +594,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true
structsLock.Unlock()

// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
Expand Down
93 changes: 93 additions & 0 deletions model_struct_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package gorm_test

import (
"sync"
"testing"

"github.com/jinzhu/gorm"
)

type ModelA struct {
gorm.Model
Name string

ModelCs []ModelC `gorm:"foreignkey:OtherAID"`
}

type ModelB struct {
gorm.Model
Name string

ModelCs []ModelC `gorm:"foreignkey:OtherBID"`
}

type ModelC struct {
gorm.Model
Name string

OtherAID uint64
OtherA *ModelA `gorm:"foreignkey:OtherAID"`
OtherBID uint64
OtherB *ModelB `gorm:"foreignkey:OtherBID"`
}

// This test will try to cause a race condition on the model's foreignkey metadata
func TestModelStructRaceSameModel(t *testing.T) {
// use a WaitGroup to execute as much in-sync as possible
// it's more likely to hit a race condition than without
n := 32
start := sync.WaitGroup{}
start.Add(n)

// use another WaitGroup to know when the test is done
done := sync.WaitGroup{}
done.Add(n)

for i := 0; i < n; i++ {
go func() {
start.Wait()

// call GetStructFields, this had a race condition before we fixed it
DB.NewScope(&ModelA{}).GetStructFields()

done.Done()
}()

start.Done()
}

done.Wait()
}

// This test will try to cause a race condition on the model's foreignkey metadata
func TestModelStructRaceDifferentModel(t *testing.T) {
// use a WaitGroup to execute as much in-sync as possible
// it's more likely to hit a race condition than without
n := 32
start := sync.WaitGroup{}
start.Add(n)

// use another WaitGroup to know when the test is done
done := sync.WaitGroup{}
done.Add(n)

for i := 0; i < n; i++ {
i := i
go func() {
start.Wait()

// call GetStructFields, this had a race condition before we fixed it
if i%2 == 0 {
DB.NewScope(&ModelA{}).GetStructFields()
} else {
DB.NewScope(&ModelB{}).GetStructFields()
}

done.Done()
}()

start.Done()
}

done.Wait()
}

0 comments on commit 2586a05

Please sign in to comment.