@@ -3,6 +3,7 @@ package callgraphutil
33import (
44 "bytes"
55 "fmt"
6+ "go/token"
67 "go/types"
78
89 "golang.org/x/tools/go/callgraph"
@@ -48,19 +49,34 @@ func NewGraph(root *ssa.Function, srcFns ...*ssa.Function) (*callgraph.Graph, er
4849
4950 allFns := ssautil .AllFunctions (root .Prog )
5051
51- for _ , srcFn := range srcFns {
52- // debug("adding src function %d/%d: %v\n", i+1, len(srcFns), srcFn)
52+ visited := make (map [* ssa.Function ]bool )
5353
54- err := AddFunction (g , srcFn , allFns )
55- if err != nil {
56- return g , fmt .Errorf ("failed to add src function %v: %w" , srcFn , err )
54+ var walkFn func (fn * ssa.Function ) error
55+ walkFn = func (fn * ssa.Function ) error {
56+ if visited [fn ] {
57+ return nil
58+ }
59+ visited [fn ] = true
60+
61+ if err := AddFunction (g , fn , allFns ); err != nil {
62+ return fmt .Errorf ("failed to add function %v: %w" , fn , err )
5763 }
5864
59- for _ , block := range srcFn .DomPreorder () {
65+ for _ , block := range fn .DomPreorder () {
6066 for _ , instr := range block .Instrs {
61- checkBlockInstruction (root , allFns , g , srcFn , instr )
67+ if err := checkBlockInstruction (root , allFns , g , fn , instr , walkFn ); err != nil {
68+ return err
69+ }
6270 }
6371 }
72+
73+ return nil
74+ }
75+
76+ for _ , srcFn := range srcFns {
77+ if err := walkFn (srcFn ); err != nil {
78+ return g , err
79+ }
6480 }
6581
6682 return g , nil
@@ -69,7 +85,7 @@ func NewGraph(root *ssa.Function, srcFns ...*ssa.Function) (*callgraph.Graph, er
6985// checkBlockInstruction checks the given instruction for any function calls, adding
7086// edges to the call graph as needed and recursively adding any new functions to the graph
7187// that are discovered during the process (typically via interface methods).
72- func checkBlockInstruction (root * ssa.Function , allFns map [* ssa.Function ]bool , g * callgraph.Graph , fn * ssa.Function , instr ssa.Instruction ) error {
88+ func checkBlockInstruction (root * ssa.Function , allFns map [* ssa.Function ]bool , g * callgraph.Graph , fn * ssa.Function , instr ssa.Instruction , walkFn func ( * ssa. Function ) error ) error {
7389 // debug("\tcheckBlockInstruction: %v\n", instr)
7490 switch instrt := instr .(type ) {
7591 case * ssa.Call :
@@ -159,6 +175,15 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g
159175 case * ssa.Function :
160176 instrCall = calltFn
161177 }
178+ case * ssa.UnOp :
179+ if callt .Op == token .MUL {
180+ switch fa := callt .X .(type ) {
181+ case * ssa.FieldAddr :
182+ instrCall = findFunctionInField (fa , allFns )
183+ case * ssa.Field :
184+ instrCall = findFunctionInFieldValue (fa , allFns )
185+ }
186+ }
162187 case * ssa.Parameter :
163188 // This is likely a method call, so we need to
164189 // get the function from the method receiver which
@@ -204,6 +229,10 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g
204229 return fmt .Errorf ("failed to add function %v from block instr: %w" , instrCall , err )
205230 }
206231
232+ if err := walkFn (instrCall ); err != nil {
233+ return err
234+ }
235+
207236 // attempt to link function arguments that are functions
208237 for a := 0 ; a < len (instrt .Call .Args ); a ++ {
209238 arg := instrt .Call .Args [a ]
@@ -306,3 +335,61 @@ func AddFunction(cg *callgraph.Graph, target *ssa.Function, allFns map[*ssa.Func
306335
307336 return nil
308337}
338+
339+ // findFunctionInField scans all functions for assignments to the provided
340+ // struct field address and returns the first discovered function value.
341+ func findFunctionInField (fieldAddr * ssa.FieldAddr , allFns map [* ssa.Function ]bool ) * ssa.Function {
342+ idx := fieldAddr .Field
343+ structType := fieldAddr .X .Type ()
344+
345+ for fn := range allFns {
346+ for _ , blk := range fn .Blocks {
347+ for _ , ins := range blk .Instrs {
348+ if store , ok := ins .(* ssa.Store ); ok {
349+ if fa , ok := store .Addr .(* ssa.FieldAddr ); ok {
350+ if fa .Field == idx && types .Identical (fa .X .Type (), structType ) {
351+ switch v := store .Val .(type ) {
352+ case * ssa.Function :
353+ return v
354+ case * ssa.MakeClosure :
355+ if f , ok := v .Fn .(* ssa.Function ); ok {
356+ return f
357+ }
358+ }
359+ }
360+ }
361+ }
362+ }
363+ }
364+ }
365+ return nil
366+ }
367+
368+ // findFunctionInFieldValue searches for function assignments made to the struct
369+ // field represented by the given Field value.
370+ func findFunctionInFieldValue (field * ssa.Field , allFns map [* ssa.Function ]bool ) * ssa.Function {
371+ idx := field .Field
372+ structType := field .X .Type ()
373+
374+ for fn := range allFns {
375+ for _ , blk := range fn .Blocks {
376+ for _ , ins := range blk .Instrs {
377+ if store , ok := ins .(* ssa.Store ); ok {
378+ if fa , ok := store .Addr .(* ssa.FieldAddr ); ok {
379+ if fa .Field == idx && types .Identical (fa .X .Type (), structType ) {
380+ switch v := store .Val .(type ) {
381+ case * ssa.Function :
382+ return v
383+ case * ssa.MakeClosure :
384+ if f , ok := v .Fn .(* ssa.Function ); ok {
385+ return f
386+ }
387+ }
388+ }
389+ }
390+ }
391+ }
392+ }
393+ }
394+ return nil
395+ }
0 commit comments