From 30de254a543b80fdbc64deea5ef1a1065aeeb32f Mon Sep 17 00:00:00 2001 From: Larry M Jordan Date: Mon, 25 May 2020 13:23:09 -0500 Subject: [PATCH 1/3] Fix belongs_to preload issue --- preload_associations.go | 84 +++++++++++++++++++++++++++++++----- preload_associations_test.go | 25 +++++++++++ 2 files changed, 98 insertions(+), 11 deletions(-) diff --git a/preload_associations.go b/preload_associations.go index 6e8fec108..158e616ad 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/gobuffalo/flect" + "github.com/gobuffalo/nulls" "github.com/gobuffalo/pop/v6/internal/defaults" "github.com/gobuffalo/pop/v6/logging" "github.com/jmoiron/sqlx" @@ -154,6 +155,20 @@ func (ami *AssociationMetaInfo) getDBFieldTaggedWith(value string) *reflectx.Fie return nil } +func (ami *AssociationMetaInfo) targetPrimaryID() string { + pid := ami.Field.Tag.Get("primary_id") + switch { + case pid == "": + return "id" + case ami.getDBFieldTaggedWith(pid) != nil: + return pid + case ami.GetByPath(pid) != nil: + return ami.GetByPath(pid).Field.Tag.Get("db") + default: + return "" + } +} + func (ami *AssociationMetaInfo) fkName() string { t := ami.Field.Type if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { @@ -225,9 +240,20 @@ func isFieldAssociation(field reflect.StructField) bool { func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { // 1) get all associations ids. // 1.1) In here I pick ids from model meta info directly. + idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID()) ids := []interface{}{} mmi.Model.iterate(func(m *Model) error { - ids = append(ids, m.ID()) + if idField.Path == "ID" { + ids = append(ids, m.ID()) + return nil + } + + v, err := m.fieldByName(idField.Path) + if err != nil { + return err + } + + ids = append(ids, normalizeValue(v.Interface())) return nil }) @@ -271,8 +297,8 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf for i := 0; i < slice.Elem().Len(); i++ { asocValue := slice.Elem().Index(i) valueField := reflect.Indirect(mmi.mapper.FieldByName(asocValue, foreignField.Path)) - if mmi.mapper.FieldByName(mvalue, "ID").Interface() == valueField.Interface() || - reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), valueField) { + if mmi.mapper.FieldByName(mvalue, idField.Path).Interface() == valueField.Interface() || + reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, idField.Path), valueField) { // IMPORTANT // // FieldByName will initialize the value. It is important that this happens AFTER @@ -297,9 +323,20 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { // 1) get all associations ids. + idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID()) ids := []interface{}{} mmi.Model.iterate(func(m *Model) error { - ids = append(ids, m.ID()) + if idField.Path == "ID" { + ids = append(ids, m.ID()) + return nil + } + + v, err := m.fieldByName(idField.Path) + if err != nil { + return err + } + + ids = append(ids, normalizeValue(v.Interface())) return nil }) @@ -337,8 +374,8 @@ func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo mmi.iterate(func(mvalue reflect.Value) { for i := 0; i < slice.Elem().Len(); i++ { asocValue := slice.Elem().Index(i) - if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() || - reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) { + if mmi.mapper.FieldByName(mvalue, idField.Path).Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() || + reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, idField.Path), mmi.mapper.FieldByName(asocValue, foreignField.Path)) { // IMPORTANT // // FieldByName will initialize the value. It is important that this happens AFTER @@ -380,7 +417,7 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI } // 2) load all associations constraint by association fields ids. - fk := "id" + fk := asoc.targetPrimaryID() q := tx.Q() q.eager = false @@ -436,9 +473,20 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { // 1) get all associations ids. // 1.1) In here I pick ids from model meta info directly. + idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID()) ids := []interface{}{} mmi.Model.iterate(func(m *Model) error { - ids = append(ids, m.ID()) + if idField.Path == "ID" { + ids = append(ids, m.ID()) + return nil + } + + v, err := m.fieldByName(idField.Path) + if err != nil { + return err + } + + ids = append(ids, normalizeValue(v.Interface())) return nil }) @@ -452,9 +500,8 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta modelAssociationName := mmi.Model.associationName() assocFkName := asoc.fkName() - if strings.Contains(manyToManyTableName, ":") { - modelAssociationName = strings.TrimSpace(manyToManyTableName[strings.Index(manyToManyTableName, ":")+1:]) - manyToManyTableName = strings.TrimSpace(manyToManyTableName[:strings.Index(manyToManyTableName, ":")]) + if asoc.Field.Tag.Get("primary_id") != "" { + modelAssociationName = asoc.Field.Tag.Get("primary_id") } sql := fmt.Sprintf("SELECT %s, %s FROM %s WHERE %s in (?)", modelAssociationName, assocFkName, manyToManyTableName, modelAssociationName) @@ -541,3 +588,18 @@ func isFieldNilPtr(val reflect.Value, fi *reflectx.FieldInfo) bool { fieldValue := reflectx.FieldByIndexesReadOnly(val, fi.Index) return fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() } + +func normalizeValue(val interface{}) interface{} { + switch t := val.(type) { + case nulls.String: + return t.String + case nulls.Float64: + return t.Float64 + case nulls.Int: + return t.Int + case nulls.Time: + return t.Time + default: + return t + } +} diff --git a/preload_associations_test.go b/preload_associations_test.go index 3d70bad96..3f195c593 100644 --- a/preload_associations_test.go +++ b/preload_associations_test.go @@ -373,3 +373,28 @@ func Test_New_Implementation_For_HasMany_Ptr_Field(t *testing.T) { SetEagerMode(EagerDefault) }) } + +func Test_New_Implementation_For_Nplus1_BelongsTo_Primary_ID(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + a := require.New(t) + user := User{Name: nulls.NewString("Mark"), UserName: "Mark"} + a.NoError(tx.Create(&user)) + + a.NoError(tx.Create(&UserAttribute{ + UserName: "Mark", + })) + + a.NoError(tx.Create(&UserAttribute{ + UserName: "Mark", + })) + + userAttrs := []UserAttribute{} + a.NoError(tx.EagerPreload("User").All(&userAttrs)) + a.Len(userAttrs, 2) + a.Equal("Mark", userAttrs[0].UserName) + a.Equal("Mark", userAttrs[1].UserName) + }) +} From 5d9251a105809688efc20fedd5598114b333a7d8 Mon Sep 17 00:00:00 2001 From: Larry M Jordan Date: Mon, 25 May 2020 13:38:40 -0500 Subject: [PATCH 2/3] fix test --- preload_associations.go | 3 ++- preload_associations_test.go | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/preload_associations.go b/preload_associations.go index 158e616ad..ec8c89544 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -439,6 +439,7 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI } // 3) iterate over every model and fill it with the assoc. + idField := mmi.getDBFieldTaggedWith(asoc.targetPrimaryID()) mmi.iterate(func(mvalue reflect.Value) { if isFieldNilPtr(mvalue, fi) { return @@ -446,7 +447,7 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI for i := 0; i < slice.Elem().Len(); i++ { asocValue := slice.Elem().Index(i) fkField := reflect.Indirect(mmi.mapper.FieldByName(mvalue, fi.Path)) - field := mmi.mapper.FieldByName(asocValue, "ID") + field := mmi.mapper.FieldByName(asocValue, idField.Path) if fkField.Interface() == field.Interface() || reflect.DeepEqual(fkField, field) { // IMPORTANT // diff --git a/preload_associations_test.go b/preload_associations_test.go index 3f195c593..35c91c9bc 100644 --- a/preload_associations_test.go +++ b/preload_associations_test.go @@ -394,7 +394,7 @@ func Test_New_Implementation_For_Nplus1_BelongsTo_Primary_ID(t *testing.T) { userAttrs := []UserAttribute{} a.NoError(tx.EagerPreload("User").All(&userAttrs)) a.Len(userAttrs, 2) - a.Equal("Mark", userAttrs[0].UserName) - a.Equal("Mark", userAttrs[1].UserName) + a.Equal("Mark", userAttrs[0].User.UserName) + a.Equal("Mark", userAttrs[1].User.UserName) }) } From a41ccf91077114bafef75cc7422a03427f6fba1e Mon Sep 17 00:00:00 2001 From: Yonghwan SO Date: Mon, 19 Sep 2022 23:16:29 +0900 Subject: [PATCH 3/3] deduplicate some code and clean up --- preload_associations.go | 93 ++++++++++++++++------------------------- 1 file changed, 35 insertions(+), 58 deletions(-) diff --git a/preload_associations.go b/preload_associations.go index ec8c89544..16a9c959d 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -169,6 +169,27 @@ func (ami *AssociationMetaInfo) targetPrimaryID() string { } } +func getAllAssociationIds(idField *reflectx.FieldInfo, mmi *ModelMetaInfo) []interface{} { + ids := []interface{}{} + + mmi.Model.iterate(func(m *Model) error { + if idField.Path == "ID" { + ids = append(ids, m.ID()) + return nil + } + + v, err := m.fieldByName(idField.Path) + if err != nil { + return err + } + + ids = append(ids, normalizeValue(v.Interface())) + return nil + }) + + return ids +} + func (ami *AssociationMetaInfo) fkName() string { t := ami.Field.Type if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { @@ -241,22 +262,7 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf // 1) get all associations ids. // 1.1) In here I pick ids from model meta info directly. idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID()) - ids := []interface{}{} - mmi.Model.iterate(func(m *Model) error { - if idField.Path == "ID" { - ids = append(ids, m.ID()) - return nil - } - - v, err := m.fieldByName(idField.Path) - if err != nil { - return err - } - - ids = append(ids, normalizeValue(v.Interface())) - return nil - }) - + ids := getAllAssociationIds(idField, mmi) if len(ids) == 0 { return nil } @@ -273,6 +279,7 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf slice := asoc.toSlice() + // if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" { q.Order(asoc.Field.Tag.Get("order_by")) } @@ -306,10 +313,10 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf // // This is most likely the reason for https://github.com/gobuffalo/pop/issues/139 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) - switch { - case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: + switch modelAssociationField.Kind() { + case reflect.Slice, reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) - case modelAssociationField.Kind() == reflect.Ptr: + case reflect.Ptr: modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue)) default: modelAssociationField.Set(asocValue) @@ -324,22 +331,7 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { // 1) get all associations ids. idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID()) - ids := []interface{}{} - mmi.Model.iterate(func(m *Model) error { - if idField.Path == "ID" { - ids = append(ids, m.ID()) - return nil - } - - v, err := m.fieldByName(idField.Path) - if err != nil { - return err - } - - ids = append(ids, normalizeValue(v.Interface())) - return nil - }) - + ids := getAllAssociationIds(idField, mmi) if len(ids) == 0 { return nil } @@ -383,10 +375,10 @@ func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo // // This is most likely the reason for https://github.com/gobuffalo/pop/issues/139 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) - switch { - case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: + switch modelAssociationField.Kind() { + case reflect.Slice, reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) - case modelAssociationField.Kind() == reflect.Ptr: + case reflect.Ptr: modelAssociationField.Elem().Set(asocValue) default: modelAssociationField.Set(asocValue) @@ -456,10 +448,10 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI // // This is most likely the reason for https://github.com/gobuffalo/pop/issues/139 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) - switch { - case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: + switch modelAssociationField.Kind() { + case reflect.Slice, reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) - case modelAssociationField.Kind() == reflect.Ptr: + case reflect.Ptr: modelAssociationField.Elem().Set(asocValue) default: modelAssociationField.Set(asocValue) @@ -475,22 +467,7 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta // 1) get all associations ids. // 1.1) In here I pick ids from model meta info directly. idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID()) - ids := []interface{}{} - mmi.Model.iterate(func(m *Model) error { - if idField.Path == "ID" { - ids = append(ids, m.ID()) - return nil - } - - v, err := m.fieldByName(idField.Path) - if err != nil { - return err - } - - ids = append(ids, normalizeValue(v.Interface())) - return nil - }) - + ids := getAllAssociationIds(idField, mmi) if len(ids) == 0 { return nil } @@ -508,13 +485,13 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta sql := fmt.Sprintf("SELECT %s, %s FROM %s WHERE %s in (?)", modelAssociationName, assocFkName, manyToManyTableName, modelAssociationName) sql, args, _ := sqlx.In(sql, ids) sql = tx.Dialect.TranslateSQL(sql) - log(logging.SQL, sql, args...) cn, err := tx.Store.Transaction() if err != nil { return err } + log(logging.SQL, sql, args...) rows, err := cn.Queryx(sql, args...) if err != nil { return err