Skip to content

Commit

Permalink
Fix has_many associations with embedded structs (jinzhu#2)
Browse files Browse the repository at this point in the history
* Fix has_many association in embedded struct

* Remove need for association_foreignkey with embedded has_many.

* Fix more calls to getForeignField

* Rename test models for consistency
  • Loading branch information
kmcclive authored Jul 1, 2020
1 parent 7cc9c13 commit 1851ee5
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 9 deletions.
24 changes: 15 additions & 9 deletions model_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ func getForeignField(column string, fields []*StructField) *StructField {

// GetModelStruct get value's model struct, relationships based on struct and tag definition
func (scope *Scope) GetModelStruct() *ModelStruct {
return scope.getModelStruct(scope, make([]*StructField, 0))
}

func (scope *Scope) getModelStruct(rootScope *Scope, allFields []*StructField) *ModelStruct {
var modelStruct ModelStruct
// Scope value can't be nil
if scope.Value == nil {
Expand Down Expand Up @@ -237,7 +241,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
field.IsNormal = true
} else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
// is embedded struct
for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
for _, subField := range scope.New(fieldValue).getModelStruct(rootScope, allFields).StructFields {
subField = subField.clone()
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
Expand All @@ -261,6 +265,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}

modelStruct.StructFields = append(modelStruct.StructFields, subField)
allFields = append(allFields, subField)
}
continue
} else {
Expand Down Expand Up @@ -394,7 +399,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} else {
// generate foreign keys from defined association foreign keys
for _, scopeFieldName := range associationForeignKeys {
if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
if foreignField := getForeignField(scopeFieldName, allFields); foreignField != nil {
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
}
Expand All @@ -406,13 +411,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, associationType) {
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
associationForeignKeys = []string{scope.PrimaryKey()}
associationForeignKeys = []string{rootScope.PrimaryKey()}
}
} else if len(foreignKeys) != len(associationForeignKeys) {
scope.Err(errors.New("invalid foreign keys, should have same length"))
Expand All @@ -422,7 +427,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {

for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
if associationField := getForeignField(associationForeignKeys[idx], allFields); associationField != nil {
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true
Expand Down Expand Up @@ -502,7 +507,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} else {
// generate foreign keys form association foreign keys
for _, associationForeignKey := range tagAssociationForeignKeys {
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
}
Expand All @@ -514,13 +519,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, associationType) {
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
associationForeignKeys = []string{scope.PrimaryKey()}
associationForeignKeys = []string{rootScope.PrimaryKey()}
}
} else if len(foreignKeys) != len(associationForeignKeys) {
scope.Err(errors.New("invalid foreign keys, should have same length"))
Expand All @@ -530,7 +535,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {

for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
if scopeField := getForeignField(associationForeignKeys[idx], allFields); scopeField != nil {
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true
Expand Down Expand Up @@ -630,6 +635,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}

modelStruct.StructFields = append(modelStruct.StructFields, field)
allFields = append(allFields, field)
}
}

Expand Down
47 changes: 47 additions & 0 deletions model_struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ type ModelC struct {
OtherB *ModelB `gorm:"foreignkey:OtherBID"`
}

type RequestModel struct {
Name string
Children []ChildModel `gorm:"foreignkey:ParentID"`
}

type ChildModel struct {
ID string
ParentID string
Name string
}

type ResponseModel struct {
gorm.Model
RequestModel
}

// This test will try to cause a race condition on the model's foreignkey metadata
func TestModelStructRaceSameModel(t *testing.T) {
// use a WaitGroup to execute as much in-sync as possible
Expand Down Expand Up @@ -91,3 +107,34 @@ func TestModelStructRaceDifferentModel(t *testing.T) {

done.Wait()
}

func TestModelStructEmbeddedHasMany(t *testing.T) {
fields := DB.NewScope(&ResponseModel{}).GetStructFields()

var childrenField *gorm.StructField

for i := 0; i < len(fields); i++ {
field := fields[i]

if field != nil && field.Name == "Children" {
childrenField = field
}
}

if childrenField == nil {
t.Error("childrenField should not be nil")
return
}

if childrenField.Relationship == nil {
t.Error("childrenField.Relation should not be nil")
return
}

expected := "has_many"
actual := childrenField.Relationship.Kind

if actual != expected {
t.Errorf("childrenField.Relationship.Kind should be %v, but was %v", expected, actual)
}
}

0 comments on commit 1851ee5

Please sign in to comment.