Skip to content

Commit 9a9d5be

Browse files
committed
Fix belongs_to preload issue
1 parent 4049687 commit 9a9d5be

File tree

2 files changed

+104
-14
lines changed

2 files changed

+104
-14
lines changed

preload_associations.go

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88

99
"github.com/gobuffalo/flect"
10+
"github.com/gobuffalo/nulls"
1011
"github.com/gobuffalo/pop/v5/internal/defaults"
1112
"github.com/gobuffalo/pop/v5/logging"
1213
"github.com/jmoiron/sqlx"
@@ -153,6 +154,20 @@ func (ami *AssociationMetaInfo) getDBFieldTaggedWith(value string) *reflectx.Fie
153154
return nil
154155
}
155156

157+
func (ami *AssociationMetaInfo) targetPrimaryID() string {
158+
pid := ami.Field.Tag.Get("primary_id")
159+
switch {
160+
case pid == "":
161+
return "id"
162+
case ami.getDBFieldTaggedWith(pid) != nil:
163+
return pid
164+
case ami.GetByPath(pid) != nil:
165+
return ami.GetByPath(pid).Field.Tag.Get("db")
166+
default:
167+
return ""
168+
}
169+
}
170+
156171
func (ami *AssociationMetaInfo) fkName() string {
157172
t := ami.Field.Type
158173
if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
@@ -224,9 +239,20 @@ func isFieldAssociation(field reflect.StructField) bool {
224239
func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
225240
// 1) get all associations ids.
226241
// 1.1) In here I pick ids from model meta info directly.
242+
idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID())
227243
ids := []interface{}{}
228244
mmi.Model.iterate(func(m *Model) error {
229-
ids = append(ids, m.ID())
245+
if idField.Path == "ID" {
246+
ids = append(ids, m.ID())
247+
return nil
248+
}
249+
250+
v, err := m.fieldByName(idField.Path)
251+
if err != nil {
252+
return err
253+
}
254+
255+
ids = append(ids, normalizeValue(v.Interface()))
230256
return nil
231257
})
232258

@@ -268,8 +294,8 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf
268294
modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
269295
for i := 0; i < slice.Elem().Len(); i++ {
270296
asocValue := slice.Elem().Index(i)
271-
if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
272-
reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
297+
if mmi.mapper.FieldByName(mvalue, idField.Path).Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
298+
reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, idField.Path), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
273299

274300
switch {
275301
case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
@@ -288,9 +314,20 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf
288314

289315
func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
290316
// 1) get all associations ids.
317+
idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID())
291318
ids := []interface{}{}
292319
mmi.Model.iterate(func(m *Model) error {
293-
ids = append(ids, m.ID())
320+
if idField.Path == "ID" {
321+
ids = append(ids, m.ID())
322+
return nil
323+
}
324+
325+
v, err := m.fieldByName(idField.Path)
326+
if err != nil {
327+
return err
328+
}
329+
330+
ids = append(ids, normalizeValue(v.Interface()))
294331
return nil
295332
})
296333

@@ -327,13 +364,16 @@ func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo
327364
modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
328365
for i := 0; i < slice.Elem().Len(); i++ {
329366
asocValue := slice.Elem().Index(i)
330-
if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
331-
reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
332-
if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array {
367+
if mmi.mapper.FieldByName(mvalue, idField.Path).Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
368+
reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, idField.Path), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
369+
switch {
370+
case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
333371
modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
334-
continue
372+
case modelAssociationField.Kind() == reflect.Ptr:
373+
modelAssociationField.Elem().Set(asocValue)
374+
default:
375+
modelAssociationField.Set(asocValue)
335376
}
336-
modelAssociationField.Set(asocValue)
337377
}
338378
}
339379
})
@@ -358,7 +398,7 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI
358398
}
359399

360400
// 2) load all associations constraint by association fields ids.
361-
fk := "id"
401+
fk := asoc.targetPrimaryID()
362402

363403
q := tx.Q()
364404
q.eager = false
@@ -403,9 +443,20 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI
403443
func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
404444
// 1) get all associations ids.
405445
// 1.1) In here I pick ids from model meta info directly.
446+
idField := asoc.getDBFieldTaggedWith(asoc.targetPrimaryID())
406447
ids := []interface{}{}
407448
mmi.Model.iterate(func(m *Model) error {
408-
ids = append(ids, m.ID())
449+
if idField.Path == "ID" {
450+
ids = append(ids, m.ID())
451+
return nil
452+
}
453+
454+
v, err := m.fieldByName(idField.Path)
455+
if err != nil {
456+
return err
457+
}
458+
459+
ids = append(ids, normalizeValue(v.Interface()))
409460
return nil
410461
})
411462

@@ -419,9 +470,8 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta
419470
modelAssociationName := mmi.Model.associationName()
420471
assocFkName := asoc.fkName()
421472

422-
if strings.Contains(manyToManyTableName, ":") {
423-
modelAssociationName = strings.TrimSpace(manyToManyTableName[strings.Index(manyToManyTableName, ":")+1:])
424-
manyToManyTableName = strings.TrimSpace(manyToManyTableName[:strings.Index(manyToManyTableName, ":")])
473+
if asoc.Field.Tag.Get("primary_id") != "" {
474+
modelAssociationName = asoc.Field.Tag.Get("primary_id")
425475
}
426476

427477
if tx.TX != nil {
@@ -490,3 +540,18 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta
490540
}
491541
return nil
492542
}
543+
544+
func normalizeValue(val interface{}) interface{} {
545+
switch t := val.(type) {
546+
case nulls.String:
547+
return t.String
548+
case nulls.Float64:
549+
return t.Float64
550+
case nulls.Int:
551+
return t.Int
552+
case nulls.Time:
553+
return t.Time
554+
default:
555+
return t
556+
}
557+
}

preload_associations_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,28 @@ func Test_New_Implementation_For_Nplus1_BelongsTo_Not_Underscore(t *testing.T) {
203203
SetEagerMode(EagerDefault)
204204
})
205205
}
206+
207+
func Test_New_Implementation_For_Nplus1_BelongsTo_Primary_ID(t *testing.T) {
208+
if PDB == nil {
209+
t.Skip("skipping integration tests")
210+
}
211+
transaction(func(tx *Connection) {
212+
a := require.New(t)
213+
user := User{Name: nulls.NewString("Mark"), UserName: "Mark"}
214+
a.NoError(tx.Create(&user))
215+
216+
a.NoError(tx.Create(&UserAttribute{
217+
UserName: "Mark",
218+
}))
219+
220+
a.NoError(tx.Create(&UserAttribute{
221+
UserName: "Mark",
222+
}))
223+
224+
userAttrs := []UserAttribute{}
225+
a.NoError(tx.EagerPreload("User").All(&userAttrs))
226+
a.Len(userAttrs, 2)
227+
a.Equal("Mark", userAttrs[0].UserName)
228+
a.Equal("Mark", userAttrs[1].UserName)
229+
})
230+
}

0 commit comments

Comments
 (0)