|
1 | 1 | package taint |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "go/types" |
| 5 | + |
4 | 6 | "golang.org/x/tools/go/callgraph" |
5 | 7 | "golang.org/x/tools/go/ssa" |
6 | 8 |
|
@@ -176,11 +178,18 @@ func checkSSAValue(path callgraphutil.Path, sources Sources, v ssa.Value, visite |
176 | 178 | // (just one step?) to identify what actual value the caller used. |
177 | 179 | case *ssa.Parameter: |
178 | 180 | // Check if the parameter's type is a source. |
179 | | - paramTypeStr := value.Type().String() |
| 181 | + paramType := value.Type() |
| 182 | + paramTypeStr := paramType.String() |
180 | 183 | if src, ok := sources.includes(paramTypeStr); ok { |
181 | 184 | return true, src, value |
182 | 185 | } |
183 | 186 |
|
| 187 | + // Check if the parameter type implements proto.Message when the |
| 188 | + // caller provided it as a potential source. |
| 189 | + if ok, src := protoMessageSource(sources, paramType); ok { |
| 190 | + return true, src, value |
| 191 | + } |
| 192 | + |
184 | 193 | // Check the parameter's referrers. |
185 | 194 | refs := value.Referrers() |
186 | 195 | if refs != nil { |
@@ -379,6 +388,9 @@ func checkSSAValue(path callgraphutil.Path, sources Sources, v ssa.Value, visite |
379 | 388 | if src, ok := sources.includes(value.X.Type().String()); ok { |
380 | 389 | return true, src, value |
381 | 390 | } |
| 391 | + if ok, src := protoMessageSource(sources, value.X.Type()); ok { |
| 392 | + return true, src, value |
| 393 | + } |
382 | 394 |
|
383 | 395 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) |
384 | 396 | if tainted { |
@@ -597,3 +609,42 @@ func checkSSAInstruction(path callgraphutil.Path, sources Sources, i ssa.Instruc |
597 | 609 | } |
598 | 610 | return false, "", nil |
599 | 611 | } |
| 612 | + |
| 613 | +// protoMessageSource checks if the given type implements proto.Message when that |
| 614 | +// type is present in the provided sources list. It returns true with the source |
| 615 | +// string if so. |
| 616 | +func protoMessageSource(sources Sources, t types.Type) (bool, string) { |
| 617 | + if src, ok := sources.includes("google.golang.org/protobuf/proto.Message"); ok { |
| 618 | + if hasProtoMessageMethod(t) { |
| 619 | + return true, src |
| 620 | + } |
| 621 | + } |
| 622 | + return false, "" |
| 623 | +} |
| 624 | + |
| 625 | +// hasProtoMessageMethod reports if the given type implements a ProtoMessage method |
| 626 | +// with no parameters and no results, which is used to identify protobuf message |
| 627 | +// types commonly used with gRPC services. |
| 628 | +func hasProtoMessageMethod(t types.Type) bool { |
| 629 | + if ptr, ok := t.(*types.Pointer); ok { |
| 630 | + t = ptr.Elem() |
| 631 | + } |
| 632 | + |
| 633 | + named, ok := t.(*types.Named) |
| 634 | + if !ok { |
| 635 | + return false |
| 636 | + } |
| 637 | + |
| 638 | + for i := 0; i < named.NumMethods(); i++ { |
| 639 | + m := named.Method(i) |
| 640 | + if m.Name() != "ProtoMessage" { |
| 641 | + continue |
| 642 | + } |
| 643 | + if sig, ok := m.Type().(*types.Signature); ok { |
| 644 | + if sig.Params().Len() == 0 && sig.Results().Len() == 0 { |
| 645 | + return true |
| 646 | + } |
| 647 | + } |
| 648 | + } |
| 649 | + return false |
| 650 | +} |
0 commit comments