Skip to content

Commit 548eea1

Browse files
authored
Handle gRPC requests as tainted sources (#43)
1 parent b6a406b commit 548eea1

File tree

8 files changed

+107
-2
lines changed

8 files changed

+107
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ sources reach the given sinks.
2323
cg, _ := callgraph.New(mainFn, buildSSA.SrcFuncs...)
2424

2525
sources := taint.NewSources(
26-
"*net/http.Request",
26+
"*net/http.Request",
27+
"google.golang.org/protobuf/proto.Message", // gRPC request types
2728
)
2829

2930
sinks := taint.NewSinks(

check.go

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package taint
22

33
import (
4+
"go/types"
5+
46
"golang.org/x/tools/go/callgraph"
57
"golang.org/x/tools/go/ssa"
68

@@ -176,11 +178,18 @@ func checkSSAValue(path callgraphutil.Path, sources Sources, v ssa.Value, visite
176178
// (just one step?) to identify what actual value the caller used.
177179
case *ssa.Parameter:
178180
// Check if the parameter's type is a source.
179-
paramTypeStr := value.Type().String()
181+
paramType := value.Type()
182+
paramTypeStr := paramType.String()
180183
if src, ok := sources.includes(paramTypeStr); ok {
181184
return true, src, value
182185
}
183186

187+
// Check if the parameter type implements proto.Message when the
188+
// caller provided it as a potential source.
189+
if ok, src := protoMessageSource(sources, paramType); ok {
190+
return true, src, value
191+
}
192+
184193
// Check the parameter's referrers.
185194
refs := value.Referrers()
186195
if refs != nil {
@@ -379,6 +388,9 @@ func checkSSAValue(path callgraphutil.Path, sources Sources, v ssa.Value, visite
379388
if src, ok := sources.includes(value.X.Type().String()); ok {
380389
return true, src, value
381390
}
391+
if ok, src := protoMessageSource(sources, value.X.Type()); ok {
392+
return true, src, value
393+
}
382394

383395
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
384396
if tainted {
@@ -597,3 +609,42 @@ func checkSSAInstruction(path callgraphutil.Path, sources Sources, i ssa.Instruc
597609
}
598610
return false, "", nil
599611
}
612+
613+
// protoMessageSource checks if the given type implements proto.Message when that
614+
// type is present in the provided sources list. It returns true with the source
615+
// string if so.
616+
func protoMessageSource(sources Sources, t types.Type) (bool, string) {
617+
if src, ok := sources.includes("google.golang.org/protobuf/proto.Message"); ok {
618+
if hasProtoMessageMethod(t) {
619+
return true, src
620+
}
621+
}
622+
return false, ""
623+
}
624+
625+
// hasProtoMessageMethod reports if the given type implements a ProtoMessage method
626+
// with no parameters and no results, which is used to identify protobuf message
627+
// types commonly used with gRPC services.
628+
func hasProtoMessageMethod(t types.Type) bool {
629+
if ptr, ok := t.(*types.Pointer); ok {
630+
t = ptr.Elem()
631+
}
632+
633+
named, ok := t.(*types.Named)
634+
if !ok {
635+
return false
636+
}
637+
638+
for i := 0; i < named.NumMethods(); i++ {
639+
m := named.Method(i)
640+
if m.Name() != "ProtoMessage" {
641+
continue
642+
}
643+
if sig, ok := m.Type().(*types.Signature); ok {
644+
if sig.Params().Len() == 0 && sig.Results().Len() == 0 {
645+
return true
646+
}
647+
}
648+
}
649+
return false
650+
}

log/injection/injection.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
var userControlledValues = taint.NewSources(
1616
"*net/http.Request",
17+
"google.golang.org/protobuf/proto.Message",
1718
)
1819

1920
var injectableLogFunctions = taint.NewSinks(

log/injection/injection_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,7 @@ func TestK(t *testing.T) {
5555
func TestL(t *testing.T) {
5656
analysistest.Run(t, testdata, Analyzer, "l")
5757
}
58+
59+
func TestGRPC(t *testing.T) {
60+
analysistest.Run(t, testdata, Analyzer, "grpc")
61+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"log"
6+
)
7+
8+
type Request struct{ Msg string }
9+
10+
func (*Request) ProtoMessage() {}
11+
12+
type Server struct{}
13+
14+
func (s *Server) Handle(ctx context.Context, req *Request) {
15+
log.Println(req.Msg) // want "potential log injection"
16+
}
17+
18+
func main() {
19+
srv := &Server{}
20+
srv.Handle(context.Background(), &Request{})
21+
}

sql/injection/injection.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ var userControlledValues = taint.NewSources(
3939
//
4040
// Types (and fields)
4141
"*net/http.Request",
42+
"google.golang.org/protobuf/proto.Message",
4243
//
4344
// "google.golang.org/grpc/metadata.MD", ?
4445
//

sql/injection/injection_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,7 @@ func TestU(t *testing.T) {
101101
func TestV(t *testing.T) {
102102
analysistest.Run(t, testdata, Analyzer, "v")
103103
}
104+
105+
func TestGRPC(t *testing.T) {
106+
analysistest.Run(t, testdata, Analyzer, "grpc")
107+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
)
7+
8+
type Request struct{ Query string }
9+
10+
func (*Request) ProtoMessage() {}
11+
12+
type Server struct{}
13+
14+
func (s *Server) Handle(ctx context.Context, db *sql.DB, req *Request) {
15+
db.Query(req.Query) // want "potential sql injection"
16+
}
17+
18+
func main() {
19+
db, _ := sql.Open("sqlite3", ":memory:")
20+
srv := &Server{}
21+
srv.Handle(context.Background(), db, &Request{})
22+
}

0 commit comments

Comments
 (0)