Skip to content

Commit e98ee4e

Browse files
authored
Add callgraphutil.WriteDOT (#27)
1 parent eebdf3f commit e98ee4e

File tree

2 files changed

+284
-0
lines changed

2 files changed

+284
-0
lines changed

callgraphutil/dot.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package callgraphutil
2+
3+
import (
4+
"bufio"
5+
"fmt"
6+
"io"
7+
"strings"
8+
9+
"golang.org/x/tools/go/callgraph"
10+
)
11+
12+
// WriteDOT writes the given callgraph.Graph to the given io.Writer in the
13+
// DOT format, which can be used to generate a visual representation of the
14+
// call graph using Graphviz.
15+
func WriteDOT(w io.Writer, g *callgraph.Graph) error {
16+
b := bufio.NewWriter(w)
17+
defer b.Flush()
18+
19+
b.WriteString("digraph callgraph {\n")
20+
b.WriteString("\tgraph [fontname=\"Helvetica\", overlap=false normalize=true];\n")
21+
b.WriteString("\tnode [fontname=\"Helvetica\" shape=box];\n")
22+
b.WriteString("\tedge [fontname=\"Helvetica\"];\n")
23+
24+
edges := []*callgraph.Edge{}
25+
26+
nodesByPkg := map[string][]*callgraph.Node{}
27+
28+
addPkgNode := func(n *callgraph.Node) {
29+
// TODO: fix this so there's not so many "shared" functions?
30+
//
31+
// It is a bit of a hack, but it works for now.
32+
var pkgPath string
33+
if n.Func.Pkg != nil {
34+
pkgPath = n.Func.Pkg.Pkg.Path()
35+
} else {
36+
pkgPath = "shared"
37+
}
38+
39+
// Check if the package already exists.
40+
if _, ok := nodesByPkg[pkgPath]; !ok {
41+
// If not, create it.
42+
nodesByPkg[pkgPath] = []*callgraph.Node{}
43+
}
44+
nodesByPkg[pkgPath] = append(nodesByPkg[pkgPath], n)
45+
}
46+
47+
// Check if root node exists, if so, write it.
48+
if g.Root != nil {
49+
b.WriteString(fmt.Sprintf("\troot = %d;\n", g.Root.ID))
50+
}
51+
52+
// Process nodes and edges.
53+
for _, n := range g.Nodes {
54+
// Add node to map of nodes by package.
55+
addPkgNode(n)
56+
57+
// Add edges
58+
edges = append(edges, n.Out...)
59+
}
60+
61+
// Write nodes by package.
62+
for pkg, nodes := range nodesByPkg {
63+
// Make the pkg name sugraph cluster friendly (remove dots, dashes, and slashes).
64+
clusterName := strings.Replace(pkg, ".", "_", -1)
65+
clusterName = strings.Replace(clusterName, "/", "_", -1)
66+
clusterName = strings.Replace(clusterName, "-", "_", -1)
67+
68+
// NOTE: even if we're using a subgraph cluster, it may not be
69+
// respected by all Graphviz layout engines. For example, the
70+
// "dot" engine will respect the cluster, but the "sfdp" engine
71+
// will not.
72+
b.WriteString(fmt.Sprintf("\tsubgraph cluster_%s {\n", clusterName))
73+
b.WriteString(fmt.Sprintf("\t\tlabel=%q;\n", pkg))
74+
for _, n := range nodes {
75+
b.WriteString(fmt.Sprintf("\t\t%d [label=%q];\n", n.ID, n.Func))
76+
}
77+
b.WriteString("\t}\n")
78+
}
79+
80+
// Write edges.
81+
for _, e := range edges {
82+
b.WriteString(fmt.Sprintf("\t%d -> %d;\n", e.Caller.ID, e.Callee.ID))
83+
}
84+
85+
b.WriteString("}\n")
86+
87+
return nil
88+
}

callgraphutil/dot_test.go

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package callgraphutil_test
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"fmt"
7+
"go/ast"
8+
"go/parser"
9+
"go/token"
10+
"os"
11+
"path/filepath"
12+
"testing"
13+
14+
"github.com/go-git/go-git/v5"
15+
"github.com/picatz/taint/callgraphutil"
16+
"golang.org/x/tools/go/callgraph"
17+
"golang.org/x/tools/go/packages"
18+
"golang.org/x/tools/go/ssa"
19+
"golang.org/x/tools/go/ssa/ssautil"
20+
)
21+
22+
func cloneGitHubRepository(ctx context.Context, ownerName, repoName string) (string, string, error) {
23+
// Get the owner and repo part of the URL.
24+
ownerAndRepo := ownerName + "/" + repoName
25+
26+
// Get the directory path.
27+
dir := filepath.Join(os.TempDir(), "taint", "github", ownerAndRepo)
28+
29+
// Check if the directory exists.
30+
_, err := os.Stat(dir)
31+
if err == nil {
32+
// If the directory exists, we'll assume it's a valid repository,
33+
// and return the directory. Open the directory to
34+
repo, err := git.PlainOpen(dir)
35+
if err != nil {
36+
return dir, "", fmt.Errorf("%w", err)
37+
}
38+
39+
// Get the repository's HEAD.
40+
head, err := repo.Head()
41+
if err != nil {
42+
return dir, "", fmt.Errorf("%w", err)
43+
}
44+
45+
return dir, head.Hash().String(), nil
46+
}
47+
48+
// Clone the repository.
49+
repo, err := git.PlainCloneContext(ctx, dir, false, &git.CloneOptions{
50+
URL: fmt.Sprintf("https://github.com/%s", ownerAndRepo),
51+
Depth: 1,
52+
Tags: git.NoTags,
53+
SingleBranch: true,
54+
})
55+
if err != nil {
56+
return dir, "", fmt.Errorf("%w", err)
57+
}
58+
59+
// Get the repository's HEAD.
60+
head, err := repo.Head()
61+
if err != nil {
62+
return dir, "", fmt.Errorf("%w", err)
63+
}
64+
65+
return dir, head.Hash().String(), nil
66+
}
67+
68+
func loadPackages(ctx context.Context, dir, pattern string) ([]*packages.Package, error) {
69+
loadMode :=
70+
packages.NeedName |
71+
packages.NeedDeps |
72+
packages.NeedFiles |
73+
packages.NeedModule |
74+
packages.NeedTypes |
75+
packages.NeedImports |
76+
packages.NeedSyntax |
77+
packages.NeedTypesInfo
78+
// packages.NeedTypesSizes |
79+
// packages.NeedCompiledGoFiles |
80+
// packages.NeedExportFile |
81+
// packages.NeedEmbedPatterns
82+
83+
// parseMode := parser.ParseComments
84+
parseMode := parser.SkipObjectResolution
85+
86+
// patterns := []string{dir}
87+
patterns := []string{pattern}
88+
// patterns := []string{"all"}
89+
90+
pkgs, err := packages.Load(&packages.Config{
91+
Mode: loadMode,
92+
Context: ctx,
93+
Env: os.Environ(),
94+
Dir: dir,
95+
Tests: false,
96+
ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) {
97+
return parser.ParseFile(fset, filename, src, parseMode)
98+
},
99+
}, patterns...)
100+
if err != nil {
101+
return nil, err
102+
}
103+
104+
return pkgs, nil
105+
106+
}
107+
108+
func loadSSA(ctx context.Context, pkgs []*packages.Package) (mainFn *ssa.Function, srcFns []*ssa.Function, err error) {
109+
ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug
110+
111+
// Analyze the package.
112+
ssaProg, ssaPkgs := ssautil.Packages(pkgs, ssaBuildMode)
113+
114+
ssaProg.Build()
115+
116+
for _, pkg := range ssaPkgs {
117+
pkg.Build()
118+
}
119+
120+
mainPkgs := ssautil.MainPackages(ssaPkgs)
121+
122+
mainFn = mainPkgs[0].Members["main"].(*ssa.Function)
123+
124+
for _, pkg := range ssaPkgs {
125+
for _, fn := range pkg.Members {
126+
if fn.Object() == nil {
127+
continue
128+
}
129+
130+
if fn.Object().Name() == "_" {
131+
continue
132+
}
133+
134+
pkgFn := pkg.Func(fn.Object().Name())
135+
if pkgFn == nil {
136+
continue
137+
}
138+
139+
var addAnons func(f *ssa.Function)
140+
addAnons = func(f *ssa.Function) {
141+
srcFns = append(srcFns, f)
142+
for _, anon := range f.AnonFuncs {
143+
addAnons(anon)
144+
}
145+
}
146+
addAnons(pkgFn)
147+
}
148+
}
149+
150+
if mainFn == nil {
151+
err = fmt.Errorf("failed to find main function")
152+
return
153+
}
154+
155+
return
156+
}
157+
158+
func loadCallGraph(ctx context.Context, mainFn *ssa.Function, srcFns []*ssa.Function) (*callgraph.Graph, error) {
159+
cg, err := callgraphutil.NewGraph(mainFn, srcFns...)
160+
if err != nil {
161+
return nil, fmt.Errorf("failed to create new callgraph: %w", err)
162+
}
163+
164+
return cg, nil
165+
}
166+
167+
func TestWriteDOT(t *testing.T) {
168+
repo, _, err := cloneGitHubRepository(context.Background(), "picatz", "taint")
169+
if err != nil {
170+
t.Fatal(err)
171+
}
172+
173+
pkgs, err := loadPackages(context.Background(), repo, "./...")
174+
if err != nil {
175+
t.Fatal(err)
176+
}
177+
178+
mainFn, srcFns, err := loadSSA(context.Background(), pkgs)
179+
if err != nil {
180+
t.Fatal(err)
181+
}
182+
183+
cg, err := loadCallGraph(context.Background(), mainFn, srcFns)
184+
if err != nil {
185+
t.Fatal(err)
186+
}
187+
188+
output := &bytes.Buffer{}
189+
190+
err = callgraphutil.WriteDOT(output, cg)
191+
if err != nil {
192+
t.Fatal(err)
193+
}
194+
195+
fmt.Println(output.String())
196+
}

0 commit comments

Comments
 (0)