77 "strings"
88
99 "github.com/gobuffalo/flect"
10+ "github.com/gobuffalo/nulls"
1011 "github.com/gobuffalo/pop/v5/internal/defaults"
1112 "github.com/gobuffalo/pop/v5/logging"
1213 "github.com/jmoiron/sqlx"
@@ -153,6 +154,20 @@ func (ami *AssociationMetaInfo) getDBFieldTaggedWith(value string) *reflectx.Fie
153154 return nil
154155}
155156
157+ func (ami * AssociationMetaInfo ) targetPrimaryID () string {
158+ pid := ami .Field .Tag .Get ("primary_id" )
159+ switch {
160+ case pid == "" :
161+ return "id"
162+ case ami .getDBFieldTaggedWith (pid ) != nil :
163+ return pid
164+ case ami .GetByPath (pid ) != nil :
165+ return ami .GetByPath (pid ).Field .Tag .Get ("db" )
166+ default :
167+ return ""
168+ }
169+ }
170+
156171func (ami * AssociationMetaInfo ) fkName () string {
157172 t := ami .Field .Type
158173 if t .Kind () == reflect .Slice || t .Kind () == reflect .Array {
@@ -224,9 +239,20 @@ func isFieldAssociation(field reflect.StructField) bool {
224239func preloadHasMany (tx * Connection , asoc * AssociationMetaInfo , mmi * ModelMetaInfo ) error {
225240 // 1) get all associations ids.
226241 // 1.1) In here I pick ids from model meta info directly.
242+ idField := asoc .getDBFieldTaggedWith (asoc .targetPrimaryID ())
227243 ids := []interface {}{}
228244 mmi .Model .iterate (func (m * Model ) error {
229- ids = append (ids , m .ID ())
245+ if idField .Path == "ID" {
246+ ids = append (ids , m .ID ())
247+ return nil
248+ }
249+
250+ v , err := m .fieldByName (idField .Path )
251+ if err != nil {
252+ return err
253+ }
254+
255+ ids = append (ids , normalizeValue (v .Interface ()))
230256 return nil
231257 })
232258
@@ -268,8 +294,8 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf
268294 modelAssociationField := mmi .mapper .FieldByName (mvalue , asoc .Name )
269295 for i := 0 ; i < slice .Elem ().Len (); i ++ {
270296 asocValue := slice .Elem ().Index (i )
271- if mmi .mapper .FieldByName (mvalue , "ID" ).Interface () == mmi .mapper .FieldByName (asocValue , foreignField .Path ).Interface () ||
272- reflect .DeepEqual (mmi .mapper .FieldByName (mvalue , "ID" ), mmi .mapper .FieldByName (asocValue , foreignField .Path )) {
297+ if mmi .mapper .FieldByName (mvalue , idField . Path ).Interface () == mmi .mapper .FieldByName (asocValue , foreignField .Path ).Interface () ||
298+ reflect .DeepEqual (mmi .mapper .FieldByName (mvalue , idField . Path ), mmi .mapper .FieldByName (asocValue , foreignField .Path )) {
273299
274300 switch {
275301 case modelAssociationField .Kind () == reflect .Slice || modelAssociationField .Kind () == reflect .Array :
@@ -288,9 +314,20 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf
288314
289315func preloadHasOne (tx * Connection , asoc * AssociationMetaInfo , mmi * ModelMetaInfo ) error {
290316 // 1) get all associations ids.
317+ idField := asoc .getDBFieldTaggedWith (asoc .targetPrimaryID ())
291318 ids := []interface {}{}
292319 mmi .Model .iterate (func (m * Model ) error {
293- ids = append (ids , m .ID ())
320+ if idField .Path == "ID" {
321+ ids = append (ids , m .ID ())
322+ return nil
323+ }
324+
325+ v , err := m .fieldByName (idField .Path )
326+ if err != nil {
327+ return err
328+ }
329+
330+ ids = append (ids , normalizeValue (v .Interface ()))
294331 return nil
295332 })
296333
@@ -327,13 +364,16 @@ func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo
327364 modelAssociationField := mmi .mapper .FieldByName (mvalue , asoc .Name )
328365 for i := 0 ; i < slice .Elem ().Len (); i ++ {
329366 asocValue := slice .Elem ().Index (i )
330- if mmi .mapper .FieldByName (mvalue , "ID" ).Interface () == mmi .mapper .FieldByName (asocValue , foreignField .Path ).Interface () ||
331- reflect .DeepEqual (mmi .mapper .FieldByName (mvalue , "ID" ), mmi .mapper .FieldByName (asocValue , foreignField .Path )) {
332- if modelAssociationField .Kind () == reflect .Slice || modelAssociationField .Kind () == reflect .Array {
367+ if mmi .mapper .FieldByName (mvalue , idField .Path ).Interface () == mmi .mapper .FieldByName (asocValue , foreignField .Path ).Interface () ||
368+ reflect .DeepEqual (mmi .mapper .FieldByName (mvalue , idField .Path ), mmi .mapper .FieldByName (asocValue , foreignField .Path )) {
369+ switch {
370+ case modelAssociationField .Kind () == reflect .Slice || modelAssociationField .Kind () == reflect .Array :
333371 modelAssociationField .Set (reflect .Append (modelAssociationField , asocValue ))
334- continue
372+ case modelAssociationField .Kind () == reflect .Ptr :
373+ modelAssociationField .Elem ().Set (asocValue )
374+ default :
375+ modelAssociationField .Set (asocValue )
335376 }
336- modelAssociationField .Set (asocValue )
337377 }
338378 }
339379 })
@@ -358,7 +398,7 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI
358398 }
359399
360400 // 2) load all associations constraint by association fields ids.
361- fk := "id"
401+ fk := asoc . targetPrimaryID ()
362402
363403 q := tx .Q ()
364404 q .eager = false
@@ -403,9 +443,20 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI
403443func preloadManyToMany (tx * Connection , asoc * AssociationMetaInfo , mmi * ModelMetaInfo ) error {
404444 // 1) get all associations ids.
405445 // 1.1) In here I pick ids from model meta info directly.
446+ idField := asoc .getDBFieldTaggedWith (asoc .targetPrimaryID ())
406447 ids := []interface {}{}
407448 mmi .Model .iterate (func (m * Model ) error {
408- ids = append (ids , m .ID ())
449+ if idField .Path == "ID" {
450+ ids = append (ids , m .ID ())
451+ return nil
452+ }
453+
454+ v , err := m .fieldByName (idField .Path )
455+ if err != nil {
456+ return err
457+ }
458+
459+ ids = append (ids , normalizeValue (v .Interface ()))
409460 return nil
410461 })
411462
@@ -419,9 +470,8 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta
419470 modelAssociationName := mmi .Model .associationName ()
420471 assocFkName := asoc .fkName ()
421472
422- if strings .Contains (manyToManyTableName , ":" ) {
423- modelAssociationName = strings .TrimSpace (manyToManyTableName [strings .Index (manyToManyTableName , ":" )+ 1 :])
424- manyToManyTableName = strings .TrimSpace (manyToManyTableName [:strings .Index (manyToManyTableName , ":" )])
473+ if asoc .Field .Tag .Get ("primary_id" ) != "" {
474+ modelAssociationName = asoc .Field .Tag .Get ("primary_id" )
425475 }
426476
427477 if tx .TX != nil {
@@ -490,3 +540,18 @@ func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMeta
490540 }
491541 return nil
492542}
543+
544+ func normalizeValue (val interface {}) interface {} {
545+ switch t := val .(type ) {
546+ case nulls.String :
547+ return t .String
548+ case nulls.Float64 :
549+ return t .Float64
550+ case nulls.Int :
551+ return t .Int
552+ case nulls.Time :
553+ return t .Time
554+ default :
555+ return t
556+ }
557+ }
0 commit comments