@@ -13,6 +13,7 @@ import (
1313 "fmt"
1414 "go/format"
1515 "io"
16+ "log"
1617 "os"
1718 "path/filepath"
1819 "runtime"
@@ -31,10 +32,11 @@ var (
3132 flagGoImportPrefix = flag .String ("go.importprefix" , "" , "Prefix for thrift-generated go package imports" )
3233 flagGoGenerateMethods = flag .Bool ("go.generate" , false , "Add testing/quick compatible Generate methods to enum types" )
3334 flagGoSignedBytes = flag .Bool ("go.signedbytes" , false , "Interpret Thrift byte as Go signed int8 type" )
35+ flagGoRPCContext = flag .Bool ("go.rpccontext" , false , "Add context.Context objects to rpc wrappers" )
3436)
3537
3638var (
37- goNamespaceOrder = []string {"go" , "perl" , "py" , "cpp" , "rb" , "java" }
39+ goNamespaceOrder = []string {"go" , "perl" , "py" , "cpp" , "rb" , "java" , "cpp2" }
3840)
3941
4042type ErrUnknownType string
@@ -63,6 +65,9 @@ type GoGenerator struct {
6365 Format bool
6466 Pointers bool
6567 SignedBytes bool
68+
69+ // package names imported
70+ packageNames map [string ]bool
6671}
6772
6873var goKeywords = map [string ]bool {
@@ -92,8 +97,12 @@ var goKeywords = map[string]bool{
9297 "return" : true ,
9398 "var" : true ,
9499
95- // request arguments are hardcoded to 'req'; blacklist it to prevent accidental name collisions
100+ // request arguments are hardcoded to 'req' and the response to 'res'
96101 "req" : true ,
102+ "res" : true ,
103+ // ctx is passed as the first argument, SetContext methods are generated iff flagGoRPCContext is set
104+ "ctx" : true ,
105+ "SetContext" : true ,
97106}
98107
99108var basicTypes = map [string ]bool {
@@ -200,7 +209,7 @@ func (g *GoGenerator) formatType(pkg string, thrift *parser.Thrift, typ *parser.
200209 }
201210
202211 if t := thrift .Typedefs [typ .Name ]; t != nil {
203- name := typ .Name
212+ name := camelCase ( typ .Name )
204213 if pkg != g .pkg {
205214 name = pkg + "." + name
206215 }
@@ -356,6 +365,11 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
356365 return strconv .Quote (v2 ), nil
357366 case int :
358367 return strconv .Itoa (v2 ), nil
368+ case bool :
369+ if v2 {
370+ return "true" , nil
371+ }
372+ return "false" , nil
359373 case int64 :
360374 if t .Name == "bool" {
361375 if v2 == 0 {
@@ -386,10 +400,11 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
386400 return buf .String (), nil
387401 case []parser.KeyValue :
388402 buf := & bytes.Buffer {}
389- buf .WriteString (g .formatType (g .pkg , g .thrift , t , 0 ))
403+ buf .WriteString (g .formatType (g .pkg , g .thrift , t , toNoPointer ))
390404 buf .WriteString ("{\n " )
391405 for _ , kv := range v2 {
392406 buf .WriteString ("\t \t " )
407+
393408 s , err := g .formatValue (kv .Key , t .KeyType )
394409 if err != nil {
395410 return "" , err
@@ -400,19 +415,30 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
400415 if err != nil {
401416 return "" , err
402417 }
418+
419+ // struct values are pointers
420+ if t .ValueType == nil && * flagGoPointers {
421+ s += ".Ptr()"
422+ }
423+
403424 buf .WriteString (s )
404425 buf .WriteString (",\n " )
405426 }
406427 buf .WriteString ("\t }" )
407428 return buf .String (), nil
408429 case parser.Identifier :
409- parts := strings .SplitN (string (v2 ), "." , 2 )
410- if len (parts ) == 1 {
411- return camelCase (parts [0 ]), nil
430+ ident := string (v2 )
431+ idx := strings .LastIndex (ident , "." )
432+ if idx == - 1 {
433+ return camelCase (ident ), nil
434+ }
435+
436+ scope := ident [:idx ]
437+ if g .packageNames [scope ] {
438+ scope += "."
412439 }
413440
414- resolved := parts [0 ] + camelCase (parts [1 ])
415- return resolved , nil
441+ return scope + camelCase (ident [idx + 1 :]), nil
416442 }
417443 return "" , fmt .Errorf ("unsupported value type %T" , v )
418444}
@@ -516,10 +542,14 @@ func (e *%s) Generate(rand *rand.Rand, size int) reflect.Value {
516542 return nil
517543}
518544
519- func (g * GoGenerator ) writeStruct (out io.Writer , st * parser.Struct ) error {
545+ func (g * GoGenerator ) writeStruct (out io.Writer , st * parser.Struct , includeContext bool ) error {
520546 structName := camelCase (st .Name )
521547
522548 g .write (out , "\n type %s struct {\n " , structName )
549+ if includeContext {
550+ g .write (out , "\t ctx context.Context\n " )
551+ }
552+
523553 for _ , field := range st .Fields {
524554 g .write (out , "\t %s\n " , g .formatField (field ))
525555 }
@@ -532,11 +562,19 @@ func (g *GoGenerator) writeStruct(out io.Writer, st *parser.Struct) error {
532562 g .write (out , "%s\n " , g .formatFieldGetter (receiver , structName , field ))
533563 }
534564
565+ if includeContext {
566+ g .write (out , `
567+ func (%s *%s) SetContext(ctx context.Context) {
568+ %s.ctx = ctx
569+ }
570+ ` , receiver , structName , receiver )
571+ }
572+
535573 return g .write (out , "\n " )
536574}
537575
538576func (g * GoGenerator ) writeException (out io.Writer , ex * parser.Struct ) error {
539- if err := g .writeStruct (out , ex ); err != nil {
577+ if err := g .writeStruct (out , ex , false ); err != nil {
540578 return err
541579 }
542580
@@ -550,7 +588,7 @@ func (g *GoGenerator) writeException(out io.Writer, ex *parser.Struct) error {
550588 fieldVars := make ([]string , len (ex .Fields ))
551589 for i , field := range ex .Fields {
552590 fieldNames [i ] = camelCase (field .Name ) + ": %+v"
553- fieldVars [i ] = "e." + camelCase (field .Name )
591+ fieldVars [i ] = "e.Get " + camelCase (field .Name ) + "()"
554592 }
555593 g .write (out , "\t return fmt.Sprintf(\" %s{%s}\" , %s)\n " ,
556594 exName , strings .Join (fieldNames , ", " ), strings .Join (fieldVars , ", " ))
@@ -570,9 +608,14 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
570608 methodNames := sortedKeys (svc .Methods )
571609 for _ , k := range methodNames {
572610 method := svc .Methods [k ]
611+ args := g .formatArguments (method .Arguments )
612+ if * flagGoRPCContext {
613+ args = "ctx context.Context, " + args
614+ }
615+
573616 g .write (out ,
574617 "\t %s(%s) %s\n " ,
575- camelCase (method .Name ), g . formatArguments ( method . Arguments ) ,
618+ camelCase (method .Name ), args ,
576619 g .formatReturnType (method .ReturnType , false ))
577620 }
578621 g .write (out , "}\n " )
@@ -590,12 +633,21 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
590633 for _ , k := range methodNames {
591634 method := svc .Methods [k ]
592635 mName := camelCase (method .Name )
636+
637+ requestStructName := "InternalRPC" + svcName + camelCase (method .Name ) + "Request"
638+ responseStructName := "InternalRPC" + svcName + camelCase (method .Name ) + "Response"
639+
593640 resArg := ""
594641 if ! method .Oneway {
595- resArg = fmt .Sprintf (", res *%s%sResponse " , svcName , mName )
642+ resArg = fmt .Sprintf (", res *%s" , responseStructName )
596643 }
597- g .write (out , "\n func (s *%sServer) %s(req *%s%sRequest% s) error {\n " , svcName , mName , svcName , mName , resArg )
644+ g .write (out , "\n func (s *%sServer) %s(req *%s%s) error {\n " , svcName , mName , requestStructName , resArg )
598645 var args []string
646+
647+ if * flagGoRPCContext {
648+ args = append (args , "req.ctx" )
649+ }
650+
599651 for _ , arg := range method .Arguments {
600652 aName := camelCase (arg .Name )
601653 args = append (args , "req." + aName )
@@ -626,13 +678,16 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
626678 for _ , k := range methodNames {
627679 // Request struct
628680 method := svc .Methods [k ]
629- reqStructName := svcName + camelCase (method .Name ) + "Request"
630- if err := g .writeStruct (out , & parser.Struct {Name : reqStructName , Fields : method .Arguments }); err != nil {
681+
682+ requestStructName := "InternalRPC" + svcName + camelCase (method .Name ) + "Request"
683+ responseStructName := "InternalRPC" + svcName + camelCase (method .Name ) + "Response"
684+
685+ if err := g .writeStruct (out , & parser.Struct {Name : requestStructName , Fields : method .Arguments }, * flagGoRPCContext ); err != nil {
631686 return err
632687 }
633688
634689 if method .Oneway {
635- g .write (out , "\n func (r *%s) Oneway() bool {\n \t return true\n }\n " , reqStructName )
690+ g .write (out , "\n func (r *%s) Oneway() bool {\n \t return true\n }\n " , requestStructName )
636691 } else {
637692 // Response struct
638693 args := make ([]* parser.Field , 0 , len (method .Exceptions ))
@@ -642,8 +697,8 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
642697 for _ , ex := range method .Exceptions {
643698 args = append (args , ex )
644699 }
645- res := & parser.Struct {Name : svcName + camelCase ( method . Name ) + "Response" , Fields : args }
646- if err := g .writeStruct (out , res ); err != nil {
700+ res := & parser.Struct {Name : responseStructName , Fields : args }
701+ if err := g .writeStruct (out , res , false ); err != nil {
647702 return err
648703 }
649704 }
@@ -657,8 +712,12 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
657712
658713 for _ , k := range methodNames {
659714 method := svc .Methods [k ]
715+
716+ requestStructName := "InternalRPC" + svcName + camelCase (method .Name ) + "Request"
717+ responseStructName := "InternalRPC" + svcName + camelCase (method .Name ) + "Response"
718+
660719 methodName := camelCase (method .Name )
661- returnType := "err error"
720+ returnType := "( err error) "
662721 if ! method .Oneway {
663722 returnType = g .formatReturnType (method .ReturnType , true )
664723 }
@@ -668,7 +727,7 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
668727 returnType )
669728
670729 // Request
671- g .write (out , "\t req := &%s%sRequest {\n " , svcName , methodName )
730+ g .write (out , "\t req := &%s{\n " , requestStructName )
672731 for _ , arg := range method .Arguments {
673732 g .write (out , "\t \t %s: %s,\n " , camelCase (arg .Name ), validGoIdent (lowerCamelCase (arg .Name )))
674733 }
@@ -679,7 +738,7 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
679738 // g.write(out, "\tvar res *%s%sResponse = nil\n", svcName, methodName)
680739 g .write (out , "\t var res interface{} = nil\n " )
681740 } else {
682- g .write (out , "\t res := &%s%sResponse {}\n " , svcName , methodName )
741+ g .write (out , "\t res := &%s{}\n " , responseStructName )
683742 }
684743
685744 // Call
@@ -709,6 +768,26 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
709768 return nil
710769}
711770
771+ var validMapKeys = map [string ]bool {
772+ "string" : true ,
773+ "i32" : true ,
774+ "i64" : true ,
775+ "bool" : true ,
776+ "double" : true ,
777+ }
778+
779+ func (g * GoGenerator ) isValidGoType (typ * parser.Type ) bool {
780+ if typ .KeyType == nil {
781+ return true
782+ }
783+
784+ if _ , ok := g .thrift .Enums [g .resolveType (typ .KeyType )]; ok {
785+ return true
786+ }
787+
788+ return validMapKeys [g .resolveType (typ .KeyType )]
789+ }
790+
712791func (g * GoGenerator ) generateSingle (out io.Writer , thriftPath string , thrift * parser.Thrift ) {
713792 packageName := g .Packages [thriftPath ].Name
714793 g .thrift = thrift
@@ -726,6 +805,11 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
726805 imports = append (imports , "math/rand" , "reflect" )
727806 }
728807 }
808+
809+ if len (thrift .Services ) > 0 && * flagGoRPCContext {
810+ imports = append (imports , "golang.org/x/net/context" )
811+ }
812+
729813 if len (thrift .Includes ) > 0 {
730814 for _ , path := range thrift .Includes {
731815 pkg := g .Packages [path ].Name
@@ -760,6 +844,12 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
760844 if len (thrift .Constants ) > 0 {
761845 for _ , k := range sortedKeys (thrift .Constants ) {
762846 c := thrift .Constants [k ]
847+
848+ if ! g .isValidGoType (c .Type ) {
849+ log .Printf ("Skipping generation for constant %s - type is not a valid go type (%s)\n " , c .Name , g .resolveType (c .Type .KeyType ))
850+ continue
851+ }
852+
763853 v , err := g .formatValue (c .Value , c .Type )
764854 if err != nil {
765855 g .error (err )
@@ -783,7 +873,7 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
783873
784874 for _ , k := range sortedKeys (thrift .Structs ) {
785875 st := thrift .Structs [k ]
786- if err := g .writeStruct (out , st ); err != nil {
876+ if err := g .writeStruct (out , st , false ); err != nil {
787877 g .error (err )
788878 }
789879 }
@@ -797,7 +887,7 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
797887
798888 for _ , k := range sortedKeys (thrift .Unions ) {
799889 un := thrift .Unions [k ]
800- if err := g .writeStruct (out , un ); err != nil {
890+ if err := g .writeStruct (out , un , false ); err != nil {
801891 g .error (err )
802892 }
803893 }
@@ -822,7 +912,8 @@ func (g *GoGenerator) Generate(outPath string) (err error) {
822912
823913 // Generate package namespace mapping if necessary
824914 if g .Packages == nil {
825- g .Packages = make (map [string ]GoPackage )
915+ g .Packages = map [string ]GoPackage {}
916+ g .packageNames = map [string ]bool {}
826917 }
827918 for path , th := range g .ThriftFiles {
828919 if pkg , ok := g .Packages [path ]; ! ok || pkg .Name == "" {
@@ -843,6 +934,7 @@ func (g *GoGenerator) Generate(outPath string) (err error) {
843934 }
844935 pkg .Name = validIdentifier (strings .ToLower (pkg .Name ), "_" )
845936 g .Packages [path ] = pkg
937+ g .packageNames [pkg .Name ] = true
846938 }
847939 }
848940
@@ -856,6 +948,10 @@ func (g *GoGenerator) Generate(outPath string) (err error) {
856948 filename = filename [:i ]
857949 }
858950 }
951+ if strings .HasSuffix (filename , "_test" ) {
952+ filename = filename [:len (filename )- len ("_test" )]
953+ }
954+
859955 filename += ".go"
860956 pkgpath := filepath .Join (outPath , pkg .Path , pkg .Name )
861957 outfile := filepath .Join (pkgpath , filename )
0 commit comments