Skip to content

Commit d1e1289

Browse files
committed
Add recursive callgraph traversal for struct field functions
1 parent d6d87c7 commit d1e1289

File tree

3 files changed

+184
-8
lines changed

3 files changed

+184
-8
lines changed

callgraphutil/graph.go

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package callgraphutil
33
import (
44
"bytes"
55
"fmt"
6+
"go/token"
67
"go/types"
78

89
"golang.org/x/tools/go/callgraph"
@@ -48,19 +49,34 @@ func NewGraph(root *ssa.Function, srcFns ...*ssa.Function) (*callgraph.Graph, er
4849

4950
allFns := ssautil.AllFunctions(root.Prog)
5051

51-
for _, srcFn := range srcFns {
52-
// debug("adding src function %d/%d: %v\n", i+1, len(srcFns), srcFn)
52+
visited := make(map[*ssa.Function]bool)
5353

54-
err := AddFunction(g, srcFn, allFns)
55-
if err != nil {
56-
return g, fmt.Errorf("failed to add src function %v: %w", srcFn, err)
54+
var walkFn func(fn *ssa.Function) error
55+
walkFn = func(fn *ssa.Function) error {
56+
if visited[fn] {
57+
return nil
58+
}
59+
visited[fn] = true
60+
61+
if err := AddFunction(g, fn, allFns); err != nil {
62+
return fmt.Errorf("failed to add function %v: %w", fn, err)
5763
}
5864

59-
for _, block := range srcFn.DomPreorder() {
65+
for _, block := range fn.DomPreorder() {
6066
for _, instr := range block.Instrs {
61-
checkBlockInstruction(root, allFns, g, srcFn, instr)
67+
if err := checkBlockInstruction(root, allFns, g, fn, instr, walkFn); err != nil {
68+
return err
69+
}
6270
}
6371
}
72+
73+
return nil
74+
}
75+
76+
for _, srcFn := range srcFns {
77+
if err := walkFn(srcFn); err != nil {
78+
return g, err
79+
}
6480
}
6581

6682
return g, nil
@@ -69,7 +85,7 @@ func NewGraph(root *ssa.Function, srcFns ...*ssa.Function) (*callgraph.Graph, er
6985
// checkBlockInstruction checks the given instruction for any function calls, adding
7086
// edges to the call graph as needed and recursively adding any new functions to the graph
7187
// that are discovered during the process (typically via interface methods).
72-
func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g *callgraph.Graph, fn *ssa.Function, instr ssa.Instruction) error {
88+
func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g *callgraph.Graph, fn *ssa.Function, instr ssa.Instruction, walkFn func(*ssa.Function) error) error {
7389
// debug("\tcheckBlockInstruction: %v\n", instr)
7490
switch instrt := instr.(type) {
7591
case *ssa.Call:
@@ -159,6 +175,15 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g
159175
case *ssa.Function:
160176
instrCall = calltFn
161177
}
178+
case *ssa.UnOp:
179+
if callt.Op == token.MUL {
180+
switch fa := callt.X.(type) {
181+
case *ssa.FieldAddr:
182+
instrCall = findFunctionInField(fa, allFns)
183+
case *ssa.Field:
184+
instrCall = findFunctionInFieldValue(fa, allFns)
185+
}
186+
}
162187
case *ssa.Parameter:
163188
// This is likely a method call, so we need to
164189
// get the function from the method receiver which
@@ -204,6 +229,10 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g
204229
return fmt.Errorf("failed to add function %v from block instr: %w", instrCall, err)
205230
}
206231

232+
if err := walkFn(instrCall); err != nil {
233+
return err
234+
}
235+
207236
// attempt to link function arguments that are functions
208237
for a := 0; a < len(instrt.Call.Args); a++ {
209238
arg := instrt.Call.Args[a]
@@ -306,3 +335,61 @@ func AddFunction(cg *callgraph.Graph, target *ssa.Function, allFns map[*ssa.Func
306335

307336
return nil
308337
}
338+
339+
// findFunctionInField scans all functions for assignments to the provided
340+
// struct field address and returns the first discovered function value.
341+
func findFunctionInField(fieldAddr *ssa.FieldAddr, allFns map[*ssa.Function]bool) *ssa.Function {
342+
idx := fieldAddr.Field
343+
structType := fieldAddr.X.Type()
344+
345+
for fn := range allFns {
346+
for _, blk := range fn.Blocks {
347+
for _, ins := range blk.Instrs {
348+
if store, ok := ins.(*ssa.Store); ok {
349+
if fa, ok := store.Addr.(*ssa.FieldAddr); ok {
350+
if fa.Field == idx && types.Identical(fa.X.Type(), structType) {
351+
switch v := store.Val.(type) {
352+
case *ssa.Function:
353+
return v
354+
case *ssa.MakeClosure:
355+
if f, ok := v.Fn.(*ssa.Function); ok {
356+
return f
357+
}
358+
}
359+
}
360+
}
361+
}
362+
}
363+
}
364+
}
365+
return nil
366+
}
367+
368+
// findFunctionInFieldValue searches for function assignments made to the struct
369+
// field represented by the given Field value.
370+
func findFunctionInFieldValue(field *ssa.Field, allFns map[*ssa.Function]bool) *ssa.Function {
371+
idx := field.Field
372+
structType := field.X.Type()
373+
374+
for fn := range allFns {
375+
for _, blk := range fn.Blocks {
376+
for _, ins := range blk.Instrs {
377+
if store, ok := ins.(*ssa.Store); ok {
378+
if fa, ok := store.Addr.(*ssa.FieldAddr); ok {
379+
if fa.Field == idx && types.Identical(fa.X.Type(), structType) {
380+
switch v := store.Val.(type) {
381+
case *ssa.Function:
382+
return v
383+
case *ssa.MakeClosure:
384+
if f, ok := v.Fn.(*ssa.Function); ok {
385+
return f
386+
}
387+
}
388+
}
389+
}
390+
}
391+
}
392+
}
393+
}
394+
return nil
395+
}

callgraphutil/struct_field_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package callgraphutil_test
2+
3+
import (
4+
"context"
5+
"path/filepath"
6+
"testing"
7+
8+
"github.com/picatz/taint/callgraphutil"
9+
)
10+
11+
func TestStructFieldCallGraph(t *testing.T) {
12+
dir, err := filepath.Abs(filepath.Join("testdata"))
13+
if err != nil {
14+
t.Fatal(err)
15+
}
16+
17+
pkgs, err := loadPackages(context.Background(), dir, "./...")
18+
if err != nil {
19+
t.Fatal(err)
20+
}
21+
22+
mainFn, srcFns, err := loadSSA(context.Background(), pkgs)
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
27+
cg, err := loadCallGraph(context.Background(), mainFn, srcFns)
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
32+
target := "github.com/picatz/taint/callgraphutil/testdata.doSomething"
33+
if paths := callgraphutil.PathsSearchCallTo(cg.Root, target); len(paths) == 0 {
34+
t.Fatalf("expected path to %s", target)
35+
}
36+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"os"
6+
)
7+
8+
type command struct {
9+
name string
10+
run func(args []string) error
11+
}
12+
13+
type commands []*command
14+
15+
func (c commands) run(args []string) error {
16+
for i := 0; i < len(c); i++ {
17+
cmd := c[i]
18+
if cmd.name == args[0] {
19+
return cmd.run(args[1:])
20+
}
21+
}
22+
return fmt.Errorf("unknown command: %s", args[0])
23+
}
24+
25+
type cli struct {
26+
commands commands
27+
}
28+
29+
func (c *cli) run(args []string) error {
30+
return c.commands.run(args)
31+
}
32+
33+
func doSomething() error {
34+
fmt.Println("doing something")
35+
return nil
36+
}
37+
38+
func main() {
39+
c := &cli{
40+
commands{
41+
{
42+
name: "do-something",
43+
run: func(args []string) error {
44+
return doSomething()
45+
},
46+
},
47+
},
48+
}
49+
50+
if err := c.run(os.Args[1:]); err != nil {
51+
panic(err)
52+
}
53+
}

0 commit comments

Comments
 (0)