Skip to content

Commit 4f4c1a1

Browse files
author
Matt Jones
committed
Sync with internal development.
The biggest changes: - Add boolean constant parsing. - Add context.Context support to rpc code. Minor changes encountered while parsing a large thrift corpus: - Harden code generator against '_test'-suffixed thrift schemas. - Harden code generator against name collisions ('res'). - Rename internal RPC wrapper structs. Better error handling: - Log runtime errors encountered by rpc codec serialization. - Clear buffers more aggressively.
1 parent b029ac8 commit 4f4c1a1

File tree

9 files changed

+736
-579
lines changed

9 files changed

+736
-579
lines changed

generator/go.go

Lines changed: 122 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"fmt"
1414
"go/format"
1515
"io"
16+
"log"
1617
"os"
1718
"path/filepath"
1819
"runtime"
@@ -31,10 +32,11 @@ var (
3132
flagGoImportPrefix = flag.String("go.importprefix", "", "Prefix for thrift-generated go package imports")
3233
flagGoGenerateMethods = flag.Bool("go.generate", false, "Add testing/quick compatible Generate methods to enum types")
3334
flagGoSignedBytes = flag.Bool("go.signedbytes", false, "Interpret Thrift byte as Go signed int8 type")
35+
flagGoRPCContext = flag.Bool("go.rpccontext", false, "Add context.Context objects to rpc wrappers")
3436
)
3537

3638
var (
37-
goNamespaceOrder = []string{"go", "perl", "py", "cpp", "rb", "java"}
39+
goNamespaceOrder = []string{"go", "perl", "py", "cpp", "rb", "java", "cpp2"}
3840
)
3941

4042
type ErrUnknownType string
@@ -63,6 +65,9 @@ type GoGenerator struct {
6365
Format bool
6466
Pointers bool
6567
SignedBytes bool
68+
69+
// package names imported
70+
packageNames map[string]bool
6671
}
6772

6873
var goKeywords = map[string]bool{
@@ -92,8 +97,12 @@ var goKeywords = map[string]bool{
9297
"return": true,
9398
"var": true,
9499

95-
// request arguments are hardcoded to 'req'; blacklist it to prevent accidental name collisions
100+
// request arguments are hardcoded to 'req' and the response to 'res'
96101
"req": true,
102+
"res": true,
103+
// ctx is passed as the first argument, SetContext methods are generated iff flagGoRPCContext is set
104+
"ctx": true,
105+
"SetContext": true,
97106
}
98107

99108
var basicTypes = map[string]bool{
@@ -200,7 +209,7 @@ func (g *GoGenerator) formatType(pkg string, thrift *parser.Thrift, typ *parser.
200209
}
201210

202211
if t := thrift.Typedefs[typ.Name]; t != nil {
203-
name := typ.Name
212+
name := camelCase(typ.Name)
204213
if pkg != g.pkg {
205214
name = pkg + "." + name
206215
}
@@ -356,6 +365,11 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
356365
return strconv.Quote(v2), nil
357366
case int:
358367
return strconv.Itoa(v2), nil
368+
case bool:
369+
if v2 {
370+
return "true", nil
371+
}
372+
return "false", nil
359373
case int64:
360374
if t.Name == "bool" {
361375
if v2 == 0 {
@@ -386,10 +400,11 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
386400
return buf.String(), nil
387401
case []parser.KeyValue:
388402
buf := &bytes.Buffer{}
389-
buf.WriteString(g.formatType(g.pkg, g.thrift, t, 0))
403+
buf.WriteString(g.formatType(g.pkg, g.thrift, t, toNoPointer))
390404
buf.WriteString("{\n")
391405
for _, kv := range v2 {
392406
buf.WriteString("\t\t")
407+
393408
s, err := g.formatValue(kv.Key, t.KeyType)
394409
if err != nil {
395410
return "", err
@@ -400,19 +415,30 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
400415
if err != nil {
401416
return "", err
402417
}
418+
419+
// struct values are pointers
420+
if t.ValueType == nil && *flagGoPointers {
421+
s += ".Ptr()"
422+
}
423+
403424
buf.WriteString(s)
404425
buf.WriteString(",\n")
405426
}
406427
buf.WriteString("\t}")
407428
return buf.String(), nil
408429
case parser.Identifier:
409-
parts := strings.SplitN(string(v2), ".", 2)
410-
if len(parts) == 1 {
411-
return camelCase(parts[0]), nil
430+
ident := string(v2)
431+
idx := strings.LastIndex(ident, ".")
432+
if idx == -1 {
433+
return camelCase(ident), nil
434+
}
435+
436+
scope := ident[:idx]
437+
if g.packageNames[scope] {
438+
scope += "."
412439
}
413440

414-
resolved := parts[0] + camelCase(parts[1])
415-
return resolved, nil
441+
return scope + camelCase(ident[idx+1:]), nil
416442
}
417443
return "", fmt.Errorf("unsupported value type %T", v)
418444
}
@@ -516,10 +542,14 @@ func (e *%s) Generate(rand *rand.Rand, size int) reflect.Value {
516542
return nil
517543
}
518544

519-
func (g *GoGenerator) writeStruct(out io.Writer, st *parser.Struct) error {
545+
func (g *GoGenerator) writeStruct(out io.Writer, st *parser.Struct, includeContext bool) error {
520546
structName := camelCase(st.Name)
521547

522548
g.write(out, "\ntype %s struct {\n", structName)
549+
if includeContext {
550+
g.write(out, "\tctx context.Context\n")
551+
}
552+
523553
for _, field := range st.Fields {
524554
g.write(out, "\t%s\n", g.formatField(field))
525555
}
@@ -532,11 +562,19 @@ func (g *GoGenerator) writeStruct(out io.Writer, st *parser.Struct) error {
532562
g.write(out, "%s\n", g.formatFieldGetter(receiver, structName, field))
533563
}
534564

565+
if includeContext {
566+
g.write(out, `
567+
func (%s *%s) SetContext(ctx context.Context) {
568+
%s.ctx = ctx
569+
}
570+
`, receiver, structName, receiver)
571+
}
572+
535573
return g.write(out, "\n")
536574
}
537575

538576
func (g *GoGenerator) writeException(out io.Writer, ex *parser.Struct) error {
539-
if err := g.writeStruct(out, ex); err != nil {
577+
if err := g.writeStruct(out, ex, false); err != nil {
540578
return err
541579
}
542580

@@ -550,7 +588,7 @@ func (g *GoGenerator) writeException(out io.Writer, ex *parser.Struct) error {
550588
fieldVars := make([]string, len(ex.Fields))
551589
for i, field := range ex.Fields {
552590
fieldNames[i] = camelCase(field.Name) + ": %+v"
553-
fieldVars[i] = "e." + camelCase(field.Name)
591+
fieldVars[i] = "e.Get" + camelCase(field.Name) + "()"
554592
}
555593
g.write(out, "\treturn fmt.Sprintf(\"%s{%s}\", %s)\n",
556594
exName, strings.Join(fieldNames, ", "), strings.Join(fieldVars, ", "))
@@ -570,9 +608,14 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
570608
methodNames := sortedKeys(svc.Methods)
571609
for _, k := range methodNames {
572610
method := svc.Methods[k]
611+
args := g.formatArguments(method.Arguments)
612+
if *flagGoRPCContext {
613+
args = "ctx context.Context, " + args
614+
}
615+
573616
g.write(out,
574617
"\t%s(%s) %s\n",
575-
camelCase(method.Name), g.formatArguments(method.Arguments),
618+
camelCase(method.Name), args,
576619
g.formatReturnType(method.ReturnType, false))
577620
}
578621
g.write(out, "}\n")
@@ -590,12 +633,21 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
590633
for _, k := range methodNames {
591634
method := svc.Methods[k]
592635
mName := camelCase(method.Name)
636+
637+
requestStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Request"
638+
responseStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Response"
639+
593640
resArg := ""
594641
if !method.Oneway {
595-
resArg = fmt.Sprintf(", res *%s%sResponse", svcName, mName)
642+
resArg = fmt.Sprintf(", res *%s", responseStructName)
596643
}
597-
g.write(out, "\nfunc (s *%sServer) %s(req *%s%sRequest%s) error {\n", svcName, mName, svcName, mName, resArg)
644+
g.write(out, "\nfunc (s *%sServer) %s(req *%s%s) error {\n", svcName, mName, requestStructName, resArg)
598645
var args []string
646+
647+
if *flagGoRPCContext {
648+
args = append(args, "req.ctx")
649+
}
650+
599651
for _, arg := range method.Arguments {
600652
aName := camelCase(arg.Name)
601653
args = append(args, "req."+aName)
@@ -626,13 +678,16 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
626678
for _, k := range methodNames {
627679
// Request struct
628680
method := svc.Methods[k]
629-
reqStructName := svcName + camelCase(method.Name) + "Request"
630-
if err := g.writeStruct(out, &parser.Struct{Name: reqStructName, Fields: method.Arguments}); err != nil {
681+
682+
requestStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Request"
683+
responseStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Response"
684+
685+
if err := g.writeStruct(out, &parser.Struct{Name: requestStructName, Fields: method.Arguments}, *flagGoRPCContext); err != nil {
631686
return err
632687
}
633688

634689
if method.Oneway {
635-
g.write(out, "\nfunc (r *%s) Oneway() bool {\n\treturn true\n}\n", reqStructName)
690+
g.write(out, "\nfunc (r *%s) Oneway() bool {\n\treturn true\n}\n", requestStructName)
636691
} else {
637692
// Response struct
638693
args := make([]*parser.Field, 0, len(method.Exceptions))
@@ -642,8 +697,8 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
642697
for _, ex := range method.Exceptions {
643698
args = append(args, ex)
644699
}
645-
res := &parser.Struct{Name: svcName + camelCase(method.Name) + "Response", Fields: args}
646-
if err := g.writeStruct(out, res); err != nil {
700+
res := &parser.Struct{Name: responseStructName, Fields: args}
701+
if err := g.writeStruct(out, res, false); err != nil {
647702
return err
648703
}
649704
}
@@ -657,8 +712,12 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
657712

658713
for _, k := range methodNames {
659714
method := svc.Methods[k]
715+
716+
requestStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Request"
717+
responseStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Response"
718+
660719
methodName := camelCase(method.Name)
661-
returnType := "err error"
720+
returnType := "(err error)"
662721
if !method.Oneway {
663722
returnType = g.formatReturnType(method.ReturnType, true)
664723
}
@@ -668,7 +727,7 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
668727
returnType)
669728

670729
// Request
671-
g.write(out, "\treq := &%s%sRequest{\n", svcName, methodName)
730+
g.write(out, "\treq := &%s{\n", requestStructName)
672731
for _, arg := range method.Arguments {
673732
g.write(out, "\t\t%s: %s,\n", camelCase(arg.Name), validGoIdent(lowerCamelCase(arg.Name)))
674733
}
@@ -679,7 +738,7 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
679738
// g.write(out, "\tvar res *%s%sResponse = nil\n", svcName, methodName)
680739
g.write(out, "\tvar res interface{} = nil\n")
681740
} else {
682-
g.write(out, "\tres := &%s%sResponse{}\n", svcName, methodName)
741+
g.write(out, "\tres := &%s{}\n", responseStructName)
683742
}
684743

685744
// Call
@@ -709,6 +768,26 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
709768
return nil
710769
}
711770

771+
var validMapKeys = map[string]bool{
772+
"string": true,
773+
"i32": true,
774+
"i64": true,
775+
"bool": true,
776+
"double": true,
777+
}
778+
779+
func (g *GoGenerator) isValidGoType(typ *parser.Type) bool {
780+
if typ.KeyType == nil {
781+
return true
782+
}
783+
784+
if _, ok := g.thrift.Enums[g.resolveType(typ.KeyType)]; ok {
785+
return true
786+
}
787+
788+
return validMapKeys[g.resolveType(typ.KeyType)]
789+
}
790+
712791
func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *parser.Thrift) {
713792
packageName := g.Packages[thriftPath].Name
714793
g.thrift = thrift
@@ -726,6 +805,11 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
726805
imports = append(imports, "math/rand", "reflect")
727806
}
728807
}
808+
809+
if len(thrift.Services) > 0 && *flagGoRPCContext {
810+
imports = append(imports, "golang.org/x/net/context")
811+
}
812+
729813
if len(thrift.Includes) > 0 {
730814
for _, path := range thrift.Includes {
731815
pkg := g.Packages[path].Name
@@ -760,6 +844,12 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
760844
if len(thrift.Constants) > 0 {
761845
for _, k := range sortedKeys(thrift.Constants) {
762846
c := thrift.Constants[k]
847+
848+
if !g.isValidGoType(c.Type) {
849+
log.Printf("Skipping generation for constant %s - type is not a valid go type (%s)\n", c.Name, g.resolveType(c.Type.KeyType))
850+
continue
851+
}
852+
763853
v, err := g.formatValue(c.Value, c.Type)
764854
if err != nil {
765855
g.error(err)
@@ -783,7 +873,7 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
783873

784874
for _, k := range sortedKeys(thrift.Structs) {
785875
st := thrift.Structs[k]
786-
if err := g.writeStruct(out, st); err != nil {
876+
if err := g.writeStruct(out, st, false); err != nil {
787877
g.error(err)
788878
}
789879
}
@@ -797,7 +887,7 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
797887

798888
for _, k := range sortedKeys(thrift.Unions) {
799889
un := thrift.Unions[k]
800-
if err := g.writeStruct(out, un); err != nil {
890+
if err := g.writeStruct(out, un, false); err != nil {
801891
g.error(err)
802892
}
803893
}
@@ -822,7 +912,8 @@ func (g *GoGenerator) Generate(outPath string) (err error) {
822912

823913
// Generate package namespace mapping if necessary
824914
if g.Packages == nil {
825-
g.Packages = make(map[string]GoPackage)
915+
g.Packages = map[string]GoPackage{}
916+
g.packageNames = map[string]bool{}
826917
}
827918
for path, th := range g.ThriftFiles {
828919
if pkg, ok := g.Packages[path]; !ok || pkg.Name == "" {
@@ -843,6 +934,7 @@ func (g *GoGenerator) Generate(outPath string) (err error) {
843934
}
844935
pkg.Name = validIdentifier(strings.ToLower(pkg.Name), "_")
845936
g.Packages[path] = pkg
937+
g.packageNames[pkg.Name] = true
846938
}
847939
}
848940

@@ -856,6 +948,10 @@ func (g *GoGenerator) Generate(outPath string) (err error) {
856948
filename = filename[:i]
857949
}
858950
}
951+
if strings.HasSuffix(filename, "_test") {
952+
filename = filename[:len(filename)-len("_test")]
953+
}
954+
859955
filename += ".go"
860956
pkgpath := filepath.Join(outPath, pkg.Path, pkg.Name)
861957
outfile := filepath.Join(pkgpath, filename)

0 commit comments

Comments
 (0)