Skip to content

Commit

Permalink
bpf2go: simplify types sorting
Browse files Browse the repository at this point in the history
bpf2go orders generated types by name to ensure the output is stable.
Adjust data structures such that we can rely on implicit sorting in
text/template, removing sortTypes alltogether.

The added reservedNames hash is handy for rejecting "reserved" type names
such as "<STEM>Specs", and for the upcoming "ergonomic" enum feature
(generate short names for enum members if not yet taken).

Signed-off-by: Nick Zavaritsky <[email protected]>
  • Loading branch information
mejedi committed Dec 20, 2024
1 parent e439d37 commit ad8c53b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 95 deletions.
69 changes: 29 additions & 40 deletions cmd/bpf2go/gen/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"go/build/constraint"
"go/token"
"io"
"sort"
"strings"
"text/template"
"unicode"
Expand Down Expand Up @@ -141,21 +140,38 @@ func Generate(args GenerateArgs) error {
programs[name] = args.Identifier(name)
}

typeNames := make(map[btf.Type]string)
for _, typ := range args.Types {
// NB: This also deduplicates types.
typeNames[typ] = args.Stem + args.Identifier(typ.TypeName())
tn := templateName(args.Stem)
reservedNames := map[string]struct{}{
tn.Specs(): {},
tn.MapSpecs(): {},
tn.ProgramSpecs(): {},
tn.VariableSpecs(): {},
tn.Objects(): {},
tn.Maps(): {},
tn.Programs(): {},
tn.Variables(): {},
}

// Ensure we don't have conflicting names and generate a sorted list of
// named types so that the output is stable.
types, err := sortTypes(typeNames)
if err != nil {
return err
typeByName := map[string]btf.Type{}
nameByType := map[btf.Type]string{}
for _, typ := range args.Types {
// NB: This also deduplicates types.
name := args.Stem + args.Identifier(typ.TypeName())
if _, reserved := reservedNames[name]; reserved {
return fmt.Errorf("type name %q is reserved", name)
}
if otherType, ok := typeByName[name]; ok {
if otherType == typ {
continue
}
return fmt.Errorf("type name %q is used multiple times", name)
}
typeByName[name] = typ
nameByType[typ] = name
}

gf := &btf.GoFormatter{
Names: typeNames,
Names: nameByType,
Identifier: args.Identifier,
}

Expand All @@ -168,8 +184,7 @@ func Generate(args GenerateArgs) error {
Maps map[string]string
Variables map[string]string
Programs map[string]string
Types []btf.Type
TypeNames map[btf.Type]string
Types map[string]btf.Type
File string
}{
gf,
Expand All @@ -180,8 +195,7 @@ func Generate(args GenerateArgs) error {
maps,
variables,
programs,
types,
typeNames,
typeByName,
args.ObjectFile,
}

Expand All @@ -193,31 +207,6 @@ func Generate(args GenerateArgs) error {
return internal.WriteFormatted(buf.Bytes(), args.Output)
}

// sortTypes returns a list of types sorted by their (generated) Go type name.
//
// Duplicate Go type names are rejected.
func sortTypes(typeNames map[btf.Type]string) ([]btf.Type, error) {
var types []btf.Type
var names []string
for typ, name := range typeNames {
i := sort.SearchStrings(names, name)
if i >= len(names) {
types = append(types, typ)
names = append(names, name)
continue
}

if names[i] == name {
return nil, fmt.Errorf("type name %q is used multiple times", name)
}

types = append(types[:i], append([]btf.Type{typ}, types[i:]...)...)
names = append(names[:i], append([]string{name}, names[i:]...)...)
}

return types, nil
}

func toUpperFirst(str string) string {
first, n := utf8.DecodeRuneInString(str)
return string(unicode.ToUpper(first)) + str[n:]
Expand Down
4 changes: 2 additions & 2 deletions cmd/bpf2go/gen/output.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
)

{{- if .Types }}
{{- range $type := .Types }}
{{ $.TypeDeclaration (index $.TypeNames $type) $type }}
{{- range $name, $type := .Types }}
{{ $.TypeDeclaration $name $type }}

{{ end }}
{{- end }}
Expand Down
53 changes: 0 additions & 53 deletions cmd/bpf2go/gen/output_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,62 +8,9 @@ import (

"github.com/go-quicktest/qt"

"github.com/cilium/ebpf/btf"
"github.com/cilium/ebpf/cmd/bpf2go/internal"
)

func TestOrderTypes(t *testing.T) {
a := &btf.Int{}
b := &btf.Int{}
c := &btf.Int{}

for _, test := range []struct {
name string
in map[btf.Type]string
out []btf.Type
}{
{
"order",
map[btf.Type]string{
a: "foo",
b: "bar",
c: "baz",
},
[]btf.Type{b, c, a},
},
} {
t.Run(test.name, func(t *testing.T) {
result, err := sortTypes(test.in)
qt.Assert(t, qt.IsNil(err))
qt.Assert(t, qt.Equals(len(result), len(test.out)))
for i, o := range test.out {
if result[i] != o {
t.Fatalf("Index %d: expected %p got %p", i, o, result[i])
}
}
})
}

for _, test := range []struct {
name string
in map[btf.Type]string
}{
{
"duplicate names",
map[btf.Type]string{
a: "foo",
b: "foo",
},
},
} {
t.Run(test.name, func(t *testing.T) {
result, err := sortTypes(test.in)
qt.Assert(t, qt.IsNotNil(err))
qt.Assert(t, qt.IsNil(result))
})
}
}

func TestPackageImport(t *testing.T) {
var buf bytes.Buffer
err := Generate(GenerateArgs{
Expand Down

0 comments on commit ad8c53b

Please sign in to comment.