From 98e2bb0d0739538b85344d027f5b2253d8d8d80c Mon Sep 17 00:00:00 2001 From: Omkar Phansopkar Date: Fri, 10 Jan 2025 16:48:29 +0530 Subject: [PATCH 1/5] Basic PoC for callgraph generation (nested fn calls, class constructors etc) Signed-off-by: Omkar Phansopkar --- examples/plugin/callgraph/main.go | 88 ++++++ examples/plugin/callgraph/test.py | 69 +++++ examples/plugin/callgraph/testClass.py | 30 ++ examples/plugin/callgraph/testFnAssignment.py | 35 +++ examples/plugin/callgraph/testNestedFn.py | 38 +++ examples/plugin/callgraph/testScopes.py | 24 ++ examples/plugin/callgraph/testTmp.py | 14 + plugin/callgraph/assignment.go | 19 ++ plugin/callgraph/callgraph.go | 292 ++++++++++++++++++ plugin/callgraph/config.go | 12 + plugin/callgraph/fixtures/.gitkeep | 0 plugin/callgraph/fixtures/test.py | 69 +++++ plugin/callgraph/graph.go | 75 +++++ plugin/callgraph/utils.go | 12 + 14 files changed, 777 insertions(+) create mode 100644 examples/plugin/callgraph/main.go create mode 100644 examples/plugin/callgraph/test.py create mode 100644 examples/plugin/callgraph/testClass.py create mode 100644 examples/plugin/callgraph/testFnAssignment.py create mode 100644 examples/plugin/callgraph/testNestedFn.py create mode 100644 examples/plugin/callgraph/testScopes.py create mode 100644 examples/plugin/callgraph/testTmp.py create mode 100644 plugin/callgraph/assignment.go create mode 100644 plugin/callgraph/callgraph.go create mode 100644 plugin/callgraph/config.go create mode 100644 plugin/callgraph/fixtures/.gitkeep create mode 100644 plugin/callgraph/fixtures/test.py create mode 100644 plugin/callgraph/graph.go create mode 100644 plugin/callgraph/utils.go diff --git a/examples/plugin/callgraph/main.go b/examples/plugin/callgraph/main.go new file mode 100644 index 0000000..605dc13 --- /dev/null +++ b/examples/plugin/callgraph/main.go @@ -0,0 +1,88 @@ +package main + +import ( + "context" + "flag" + "fmt" + + "github.com/safedep/code/core" + "github.com/safedep/code/fs" + "github.com/safedep/code/lang" + "github.com/safedep/code/parser" + "github.com/safedep/code/plugin" + "github.com/safedep/code/plugin/callgraph" + "github.com/safedep/dry/log" +) + +var ( + dirToWalk string + language string +) + +func init() { + log.InitZapLogger("walker", "dev") + + flag.StringVar(&dirToWalk, "dir", "", "Directory to walk") + flag.StringVar(&language, "lang", "python", "Language to use for parsing files") + + flag.Parse() +} + +func main() { + if dirToWalk == "" { + flag.Usage() + return + } + + err := run() + if err != nil { + panic(err) + } +} + +func run() error { + fileSystem, err := fs.NewLocalFileSystem(fs.LocalFileSystemConfig{ + AppDirectories: []string{dirToWalk}, + }) + + if err != nil { + return fmt.Errorf("failed to create local filesystem: %w", err) + } + + language, err := lang.GetLanguage(language) + if err != nil { + return fmt.Errorf("failed to get language: %w", err) + } + + walker, err := fs.NewSourceWalker(fs.SourceWalkerConfig{}, language) + if err != nil { + return fmt.Errorf("failed to create source walker: %w", err) + } + + treeWalker, err := parser.NewWalkingParser(walker, language) + if err != nil { + return fmt.Errorf("failed to create tree walker: %w", err) + } + + // consume callgraph + var callgraphCallback callgraph.CallgraphCallback = func(cg *callgraph.CallGraph) error { + cg.PrintCallGraph() + + fmt.Println("DFS Traversal:") + for _, node := range cg.DFS() { + fmt.Println(node) + } + fmt.Println() + return nil + } + + pluginExecutor, err := plugin.NewTreeWalkPluginExecutor(treeWalker, []core.Plugin{ + callgraph.NewCallGraphPlugin(callgraphCallback), + }) + + if err != nil { + return fmt.Errorf("failed to create plugin executor: %w", err) + } + + return pluginExecutor.Execute(context.Background(), fileSystem) +} diff --git a/examples/plugin/callgraph/test.py b/examples/plugin/callgraph/test.py new file mode 100644 index 0000000..371b657 --- /dev/null +++ b/examples/plugin/callgraph/test.py @@ -0,0 +1,69 @@ +import base64 +from utils import printinit, printenc, printdec, printf2 + +# Node must be generated, but shouldn't be part of DFS +class EncodingUnused: + def __init__(self): + printinit("Initialized unused") + pass + + def applyUnused(self, msg, func): + return func(msg) + +class Encoding: + def __init__(self): + printinit("Initialized") + pass + + def apply(self, msg, func): + return func(msg) + + # Unused + def apply2(self, msg, func): + return func(msg) + +encoder = Encoding() +encoded = encoder.apply("Hello, World!".encode('utf-8'), base64.b64encode) +printenc(encoded) +decoded = encoder.apply(encoded, base64.b64decode) +printdec(decoded) + + +def f1(value): + f2(value) + +def f2(value): + printf2(value) + if value == 0: + return + f1(value-1) + pass + +def multiply(a, b): + return a * b + +f1(multiply(2, 3)) + +def foo(): + print("foo") + pass + +def bar(): + print("bar") + pass + +def baz(): + print("baz") + pass +def useless(): + print("useless") + baz() + pass + +xyz = foo + +print("GG") + +xyz = bar + +xyz() \ No newline at end of file diff --git a/examples/plugin/callgraph/testClass.py b/examples/plugin/callgraph/testClass.py new file mode 100644 index 0000000..30aa12a --- /dev/null +++ b/examples/plugin/callgraph/testClass.py @@ -0,0 +1,30 @@ +import base64 +from utils import printinit, printenc, printdec as pdec + +class Encoding: + def __init__(self): + printinit("Initialized") + pass + + def apply(self, msg, func): + return func(msg) + + # Unused + def apply2(self, msg, func): + return func(msg) + +def getenc(): + return "encoded" + +encoder = Encoding() +encoded = encoder.apply("Hello, World!".encode('utf-8'), base64.b64encode) +printenc(encoded) +decoded = encoder.apply(getenc(), base64.b64decode) +pdec(decoded) + +class NoConstructorClass: + def show(): + print("NoConstructorClass") + pass +ncc = NoConstructorClass() +ncc.show() \ No newline at end of file diff --git a/examples/plugin/callgraph/testFnAssignment.py b/examples/plugin/callgraph/testFnAssignment.py new file mode 100644 index 0000000..15bb8c4 --- /dev/null +++ b/examples/plugin/callgraph/testFnAssignment.py @@ -0,0 +1,35 @@ +import pprint +from xyzprintmodule import xyzprint, xyzprint2 + +def foo(): + pprint.pprint("foo") + pass + +def bar(): + xyzprint("bar") + pass + +def baz(): + xyzprint2("baz") + pass + +xyz = foo +print("GG") + +xyz = bar +xyz() + +def nestParent(): + def nestChild(): + xyzprint("nestChild") + def nestGrandChild(): + xyzprint2("nestGrandChild") + pass + nestGrandChild() + nestChild() +nestParent() + +def useless(): + print("useless") + baz() + pass diff --git a/examples/plugin/callgraph/testNestedFn.py b/examples/plugin/callgraph/testNestedFn.py new file mode 100644 index 0000000..cb0d043 --- /dev/null +++ b/examples/plugin/callgraph/testNestedFn.py @@ -0,0 +1,38 @@ +from pprint import pprint as pprintfn +from xyzprintmodule import xyzprint, xyzprint2 +from os import listdir as listdirfn, chmod + +def outerfn1(): + chmod("outerfn1") + pass +def outerfn2(): + listdirfn("outerfn2") + pass + +def nestParent(): + def parentScopedFn(): + xyzprint("parentScopedFn") + + def nestChild(): + xyzprint("nestChild") + outerfn1() + + def childScopedFn(): + xyzprint("childScopedFn") + + def nestGrandChildUseless(): + xyzprint2("nestGrandChildUseless") + + def nestGrandChild(): + xyzprint2("nestGrandChild") + parentScopedFn() + outerfn2() + childScopedFn() + + nestGrandChild() + + outerfn1() + nestChild() + +nestParent() + diff --git a/examples/plugin/callgraph/testScopes.py b/examples/plugin/callgraph/testScopes.py new file mode 100644 index 0000000..f6a5a7d --- /dev/null +++ b/examples/plugin/callgraph/testScopes.py @@ -0,0 +1,24 @@ +from pprint import pprint +from xyzprintmodule import xprint, xyzprint, xyzprint2, xyzprint3 + +def fn1(): + xprint("very outer fn1") +fn1() + +def fn2(): + def fn1(): + xyzprint("fn1 inside fn2") + fn1() + + def fn3(): + def fn4(): + def fn1(): + xyzprint3("fn1 inside fn4 inside fn3") + xyzprint2("fn4 inside fn3") + fn1() # must call fn1 inside fn4 + fn1() # must call fn1 inside fn2 + fn4() + fn3() + +fn2() + \ No newline at end of file diff --git a/examples/plugin/callgraph/testTmp.py b/examples/plugin/callgraph/testTmp.py new file mode 100644 index 0000000..f33da21 --- /dev/null +++ b/examples/plugin/callgraph/testTmp.py @@ -0,0 +1,14 @@ +from xyz import printxyz as pxyz, printxyz2 + +class ClassA: + def __init__(self): + pxyz("init") + + def method1(self): + printxyz2("GG") + +def main(): + x = ClassA() + y = x + y.method1() +main() \ No newline at end of file diff --git a/plugin/callgraph/assignment.go b/plugin/callgraph/assignment.go new file mode 100644 index 0000000..cd9884f --- /dev/null +++ b/plugin/callgraph/assignment.go @@ -0,0 +1,19 @@ +package callgraph + +type AssignmentGraph struct { + Assignments map[string][]string // Map of identifier to possible namespaces or other identifiers +} + +func NewAssignmentGraph() *AssignmentGraph { + return &AssignmentGraph{Assignments: make(map[string][]string)} +} + +// Add an assignment +func (ag *AssignmentGraph) AddAssignment(identifier string, target string) { + ag.Assignments[identifier] = append(ag.Assignments[identifier], target) +} + +// Resolve an identifier to its targets +func (ag *AssignmentGraph) Resolve(identifier string) []string { + return ag.Assignments[identifier] +} diff --git a/plugin/callgraph/callgraph.go b/plugin/callgraph/callgraph.go new file mode 100644 index 0000000..2b2ef4b --- /dev/null +++ b/plugin/callgraph/callgraph.go @@ -0,0 +1,292 @@ +package callgraph + +import ( + "context" + "fmt" + "strings" + + "github.com/safedep/code/core" + "github.com/safedep/code/pkg/helpers" + "github.com/safedep/dry/log" + sitter "github.com/smacker/go-tree-sitter" +) + +type CallgraphCallback func(*CallGraph) error + +type callgraphPlugin struct { + // Callback function which is called with the callgraph + callgraphCallback CallgraphCallback +} + +// Verify contract +var _ core.TreePlugin = (*callgraphPlugin)(nil) + +func NewCallGraphPlugin(callgraphCallback CallgraphCallback) *callgraphPlugin { + return &callgraphPlugin{ + callgraphCallback: callgraphCallback, + } +} + +func (p *callgraphPlugin) Name() string { + return "CallgraphPlugin" +} + +var supportedLanguages = []core.LanguageCode{core.LanguageCodePython} + +func (p *callgraphPlugin) SupportedLanguages() []core.LanguageCode { + return supportedLanguages +} + +func (p *callgraphPlugin) AnalyzeTree(ctx context.Context, tree core.ParseTree) error { + lang, err := tree.Language() + if err != nil { + return fmt.Errorf("failed to get language: %w", err) + } + + file, err := tree.File() + if err != nil { + return fmt.Errorf("failed to get file: %w", err) + } + + log.Debugf("callgraph - Analyzing tree for language: %s, file: %s\n", + lang.Meta().Code, file.Name()) + + treeData, err := tree.Data() + if err != nil { + return fmt.Errorf("failed to get tree data: %w", err) + } + + fmt.Println() + + // Build the call graph + cg, err := buildCallGraph(tree, lang, treeData, file.Name()) + + if err != nil { + return fmt.Errorf("failed to build call graph: %w", err) + } + + p.callgraphCallback(cg) + + return nil +} + +// buildCallGraph builds a call graph from the syntax tree +func buildCallGraph(tree core.ParseTree, lang core.Language, treeData *[]byte, filePath string) (*CallGraph, error) { + callGraph := NewCallGraph(filePath) + rootNode := tree.Tree().RootNode() + imports, err := lang.Resolvers().ResolveImports(tree) + if err != nil { + return nil, fmt.Errorf("failed to resolve imports: %w", err) + } + + identifierToModuleNamespace := make(map[string]string) + for _, imp := range imports { + if imp.IsWildcardImport() { + continue + } + itemNamespace := imp.ModuleItem() + if itemNamespace == "" { + itemNamespace = imp.ModuleName() + } else { + itemNamespace = imp.ModuleName() + namespaceSeparator + itemNamespace + } + + identifierKey := helpers.GetFirstNonEmptyString(imp.ModuleAlias(), imp.ModuleItem(), imp.ModuleName()) + identifierToModuleNamespace[identifierKey] = itemNamespace + } + fmt.Println("Imports (identifier => namespace):") + for identifier, moduleNamespace := range identifierToModuleNamespace { + fmt.Printf("%s => %s\n", identifier, moduleNamespace) + } + fmt.Println() + + // Add root node to the call graph + callGraph.Nodes[filePath] = newGraphNode(filePath) + + traverseTree(rootNode, treeData, callGraph, identifierToModuleNamespace, filePath, filePath, "", false) + fmt.Println() + + return callGraph, nil +} + +func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, identifierToModuleNamespace map[string]string, filePath string, currentNamespace string, classNamespace string, insideClass bool) { + if node == nil { + return + } + fmt.Println("Traverse ", node.Type(), "with content:", node.Content(*treeData), "with namespace:", currentNamespace) + switch node.Type() { + case "function_definition": + nameNode := node.ChildByFieldName("name") + if nameNode != nil { + funcName := nameNode.Content(*treeData) + if insideClass { + currentNamespace = classNamespace + namespaceSeparator + funcName + } else { + currentNamespace = currentNamespace + namespaceSeparator + funcName + } + fmt.Printf("Current funcdef namespace for %s - %s\n", funcName, currentNamespace) + + if _, exists := callGraph.Nodes[currentNamespace]; !exists { + callGraph.Nodes[currentNamespace] = newGraphNode(currentNamespace) + //@TODO - Class -> __init__ edge needed ? + //@TODO - If needed, should make it language agnostic, (__init__ ) must be obtained from lang. For java, it would be classname itself, etc + if insideClass && funcName == "__init__" { + callGraph.AddEdge(classNamespace, currentNamespace) + } + } + } + case "class_definition": + // Handle class definitions + className := node.ChildByFieldName("name").Content(*treeData) + classNamespace = currentNamespace + namespaceSeparator + className + insideClass = true + case "assignment": + leftNode := node.ChildByFieldName("left") + rightNode := node.ChildByFieldName("right") + if leftNode != nil && rightNode != nil { + leftVar := leftNode.Content(*treeData) + rightTargets := resolveTargets(rightNode, *treeData, currentNamespace, identifierToModuleNamespace, callGraph) + for _, rightTarget := range rightTargets { + callGraph.Assignments.AddAssignment(leftVar, rightTarget) + fmt.Printf("Assignment: %s -> %s\n", leftVar, rightTarget) + } + } + case "attribute": + // processing a xyz.attr for xyz.attr() call + if node.Parent().Type() == "call" { + baseNode := node.ChildByFieldName("object") + attributeNode := node.ChildByFieldName("attribute") + if baseNode != nil && attributeNode != nil { + // Resolve base object using the assignment graph at different scopes + fmt.Printf("Try resolving target call %s.%s at %s\n", baseNode.Content(*treeData), attributeNode.Content(*treeData), currentNamespace) + baseTargets := resolveTargets(baseNode, *treeData, currentNamespace, identifierToModuleNamespace, callGraph) + for _, baseTarget := range baseTargets { + fmt.Printf("Processing fn call as %s.%s() on base as %s \n", baseNode.Content(*treeData), attributeNode.Content(*treeData), baseTarget) + attributeName := attributeNode.Content(*treeData) + targetNamespace := baseTarget + namespaceSeparator + attributeName + _, existed := callGraph.Nodes[targetNamespace] + fmt.Printf("Attr %s resolved to %s, exists: %t\n", node.Content(*treeData), targetNamespace, existed) + + // Check if resolved target exists in the call graph + if _, exists := callGraph.Nodes[targetNamespace]; exists && targetNamespace != "" { + callGraph.AddEdge(currentNamespace, targetNamespace) + } + } + } + } + case "call": + fmt.Println("Traverse 'call' with content:", node.Content(*treeData), "with namespace:", currentNamespace, "insideClass:", insideClass, "node type:", node.Type()) + + // Extract call target + targetNode := node.ChildByFieldName("function") + if targetNode != nil { + callTarget := targetNode.Content(*treeData) + + // Search for the call target node at different scopes in the graph + // eg. namespace - nestNestedFn.py//nestParent//nestChild, callTarget - outerfn1 + // try searching for outerfn1 in identifierToModuleNamespace with all scope levels + // eg. search nestNestedFn.py//nestParent//nestChild//outerfn1 + // then nestNestedFn.py//nestParent//outerfn1 then nestNestedFn.py//outerfn1 and so on + // if not found, then use currentNamespace to build it + // like, nestNestedFn.py//nestParent//nestChild//outerfn1 + + var found bool + var targetNamespace string + for i := strings.Count(currentNamespace, namespaceSeparator) + 1; i >= 0; i-- { + searchNamespace := strings.Join(strings.Split(currentNamespace, namespaceSeparator)[:i], namespaceSeparator) + namespaceSeparator + callTarget + if i == 0 { + searchNamespace = callTarget + } + fmt.Printf("searching %s in scoped - %s\n", callTarget, searchNamespace) + // check in graph + if _, exists := callGraph.Nodes[searchNamespace]; exists { + targetNamespace = searchNamespace + found = true + break + } + } + if !found { + // check in identifierToModuleNamespace + if moduleNamespace, exists := identifierToModuleNamespace[callTarget]; exists { + targetNamespace = moduleNamespace + } else if insideClass { + targetNamespace = classNamespace + namespaceSeparator + callTarget + } else { + targetNamespace = currentNamespace + namespaceSeparator + callTarget + } + } + + // Add edge for function call + fmt.Println("Adding edge from", currentNamespace, "to", targetNamespace) + callGraph.AddEdge(currentNamespace, targetNamespace) + } + } + + // Recursively analyze children + for i := 0; i < int(node.ChildCount()); i++ { + traverseTree(node.Child(i), treeData, callGraph, identifierToModuleNamespace, filePath, currentNamespace, classNamespace, insideClass) + } +} + +func resolveTargets( + node *sitter.Node, + treeData []byte, + currentNamespace string, + identifierToModuleNamespace map[string]string, + callGraph *CallGraph, +) []string { + if node == nil { + return []string{} + } + fmt.Printf("Resolve targets for for %s type on %s with namespace %s\n", node.Type(), node.Content(treeData), currentNamespace) + + // Handle variable names directly + if node.Type() == "identifier" { + identifier := node.Content(treeData) + // Check if the identifier maps to something in the assignment graph + resolvedTargets := callGraph.Assignments.Resolve(identifier) + fmt.Println("Resolved targets for", identifier, ":", resolvedTargets) + if len(resolvedTargets) > 0 { + return resolvedTargets + } + // Fallback: return the identifier in the current namespace + return []string{currentNamespace + namespaceSeparator + identifier} + } + + // Handle calls, e.g., ClassA() -> resolve to ClassA namespace + if node.Type() == "call" { + functionNode := node.ChildByFieldName("function") + if functionNode != nil { + functionName := functionNode.Content(treeData) + // Check if the function is a class in the current or parrent scopes in callgraph graph + for i := strings.Count(currentNamespace, namespaceSeparator); i >= 0; i-- { + searchNamespace := strings.Join(strings.Split(currentNamespace, namespaceSeparator)[:i], namespaceSeparator) + namespaceSeparator + functionName + if i == 0 { + searchNamespace = functionName + } + if _, exists := callGraph.Nodes[searchNamespace]; exists { + return []string{searchNamespace} + } + } + return []string{currentNamespace + namespaceSeparator + functionName} + } + } + + // Handle member access, e.g., obj.attr + if node.Type() == "attribute" { + baseNode := node.ChildByFieldName("object") + attributeNode := node.ChildByFieldName("attribute") + if baseNode != nil && attributeNode != nil { + var baseTarget []string = resolveTargets(baseNode, treeData, currentNamespace, identifierToModuleNamespace, callGraph) + attributeName := attributeNode.Content(treeData) + var resolvedTargets []string + for _, base := range baseTarget { + resolvedTargets = append(resolvedTargets, base+namespaceSeparator+attributeName) + } + return resolvedTargets + } + } + + // Handle other expressions as fallbacks (e.g., literals, complex expressions) + return []string{node.Content(treeData)} +} diff --git a/plugin/callgraph/config.go b/plugin/callgraph/config.go new file mode 100644 index 0000000..c893f33 --- /dev/null +++ b/plugin/callgraph/config.go @@ -0,0 +1,12 @@ +package callgraph + +// TS nodes Ignored when parsing AST +// eg. comment is useless, imports are already resolved +var ignoredTypesList = []string{"comment"} +var ignoredTypes = make(map[string]bool) + +func init() { + for _, ignoredType := range ignoredTypesList { + ignoredTypes[ignoredType] = true + } +} diff --git a/plugin/callgraph/fixtures/.gitkeep b/plugin/callgraph/fixtures/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/plugin/callgraph/fixtures/test.py b/plugin/callgraph/fixtures/test.py new file mode 100644 index 0000000..371b657 --- /dev/null +++ b/plugin/callgraph/fixtures/test.py @@ -0,0 +1,69 @@ +import base64 +from utils import printinit, printenc, printdec, printf2 + +# Node must be generated, but shouldn't be part of DFS +class EncodingUnused: + def __init__(self): + printinit("Initialized unused") + pass + + def applyUnused(self, msg, func): + return func(msg) + +class Encoding: + def __init__(self): + printinit("Initialized") + pass + + def apply(self, msg, func): + return func(msg) + + # Unused + def apply2(self, msg, func): + return func(msg) + +encoder = Encoding() +encoded = encoder.apply("Hello, World!".encode('utf-8'), base64.b64encode) +printenc(encoded) +decoded = encoder.apply(encoded, base64.b64decode) +printdec(decoded) + + +def f1(value): + f2(value) + +def f2(value): + printf2(value) + if value == 0: + return + f1(value-1) + pass + +def multiply(a, b): + return a * b + +f1(multiply(2, 3)) + +def foo(): + print("foo") + pass + +def bar(): + print("bar") + pass + +def baz(): + print("baz") + pass +def useless(): + print("useless") + baz() + pass + +xyz = foo + +print("GG") + +xyz = bar + +xyz() \ No newline at end of file diff --git a/plugin/callgraph/graph.go b/plugin/callgraph/graph.go new file mode 100644 index 0000000..04014d7 --- /dev/null +++ b/plugin/callgraph/graph.go @@ -0,0 +1,75 @@ +package callgraph + +import ( + "fmt" + "strings" +) + +const namespaceSeparator = "//" + +// graphNode represents a single node in the call graph +type graphNode struct { + Namespace string + CallsTo []string +} + +func newGraphNode(namespace string) *graphNode { + return &graphNode{ + Namespace: namespace, + CallsTo: []string{}, + } +} + +type CallGraph struct { + FileName string + Nodes map[string]*graphNode + Assignments AssignmentGraph +} + +func NewCallGraph(fileName string) *CallGraph { + return &CallGraph{FileName: fileName, Nodes: make(map[string]*graphNode), Assignments: *NewAssignmentGraph()} +} + +// AddEdge adds an edge from one function to another +func (cg *CallGraph) AddEdge(caller, callee string) { + if _, exists := cg.Nodes[caller]; !exists { + cg.Nodes[caller] = newGraphNode(caller) + } + if _, exists := cg.Nodes[callee]; !exists { + cg.Nodes[callee] = newGraphNode(callee) + } + cg.Nodes[caller].CallsTo = append(cg.Nodes[caller].CallsTo, callee) +} + +func (cg *CallGraph) PrintCallGraph() { + fmt.Println("Call Graph:") + for caller, node := range cg.Nodes { + fmt.Printf(" %s (calls)=> %v\n", caller, node.CallsTo) + } + fmt.Println() +} + +func (cg *CallGraph) DFS() []string { + visited := make(map[string]bool) + var dfsResult []string + cg.dfsUtil(cg.FileName, visited, &dfsResult, 0) + return dfsResult +} + +func (cg *CallGraph) dfsUtil(startNode string, visited map[string]bool, result *[]string, depth int) { + fmt.Println("DFS Util:", startNode) + if visited[startNode] { + // append that not going inside this on prev level + *result = append(*result, fmt.Sprintf("%s Stopped at %s", strings.Repeat("|", depth), startNode)) + return + } + + // Mark the current node as visited and add it to the result + visited[startNode] = true + *result = append(*result, fmt.Sprintf("%s %s", strings.Repeat(">", depth), startNode)) + + // Recursively visit all the nodes called by the current node + for _, callee := range cg.Nodes[startNode].CallsTo { + cg.dfsUtil(callee, visited, result, depth+1) + } +} diff --git a/plugin/callgraph/utils.go b/plugin/callgraph/utils.go new file mode 100644 index 0000000..ebe0871 --- /dev/null +++ b/plugin/callgraph/utils.go @@ -0,0 +1,12 @@ +package callgraph + +import "strings" + +// @TODO - Refactor this for a language agnostic approach + +// GetBaseModuleName returns the base module name from the given module name. +// eg. for "os.path", the base module name is "os" +func GetBaseModuleName(moduleName string) string { + parts := strings.Split(moduleName, ".") + return parts[0] +} From df477e0b459e40ae2d22bdc3a804c1d7c59d4e14 Mon Sep 17 00:00:00 2001 From: Omkar Phansopkar Date: Fri, 10 Jan 2025 16:55:35 +0530 Subject: [PATCH 2/5] Fixed lint err Signed-off-by: Omkar Phansopkar --- plugin/callgraph/callgraph.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/plugin/callgraph/callgraph.go b/plugin/callgraph/callgraph.go index 2b2ef4b..f08bc42 100644 --- a/plugin/callgraph/callgraph.go +++ b/plugin/callgraph/callgraph.go @@ -65,9 +65,7 @@ func (p *callgraphPlugin) AnalyzeTree(ctx context.Context, tree core.ParseTree) return fmt.Errorf("failed to build call graph: %w", err) } - p.callgraphCallback(cg) - - return nil + return p.callgraphCallback(cg) } // buildCallGraph builds a call graph from the syntax tree From 2ecb052651de16ec6fe222d2a64b56eba6ce78f1 Mon Sep 17 00:00:00 2001 From: Omkar Phansopkar Date: Fri, 10 Jan 2025 17:01:07 +0530 Subject: [PATCH 3/5] Updated object assignment example Signed-off-by: Omkar Phansopkar --- .../callgraph/{testTmp.py => testObjAssignment.py} | 9 +++++++++ 1 file changed, 9 insertions(+) rename examples/plugin/callgraph/{testTmp.py => testObjAssignment.py} (63%) diff --git a/examples/plugin/callgraph/testTmp.py b/examples/plugin/callgraph/testObjAssignment.py similarity index 63% rename from examples/plugin/callgraph/testTmp.py rename to examples/plugin/callgraph/testObjAssignment.py index f33da21..33a8fd8 100644 --- a/examples/plugin/callgraph/testTmp.py +++ b/examples/plugin/callgraph/testObjAssignment.py @@ -7,8 +7,17 @@ def __init__(self): def method1(self): printxyz2("GG") + +class ClassB: + def __init__(self): + pxyz("init") + + def method1(self): + printxyz2("GG") + def main(): x = ClassA() + x = ClassB() y = x y.method1() main() \ No newline at end of file From d15449efc4ddfa23696ebcaf6744b7a0241f4d3c Mon Sep 17 00:00:00 2001 From: Omkar Phansopkar Date: Fri, 10 Jan 2025 23:50:40 +0530 Subject: [PATCH 4/5] Modularised import identifier resolution Signed-off-by: Omkar Phansopkar --- examples/plugin/callgraph/testFnAssignment.py | 7 +- .../plugin/callgraph/testObjAssignment.py | 10 +- examples/plugin/callgraph/testScopes.py | 2 +- plugin/callgraph/callgraph.go | 113 +++++++++++------- plugin/callgraph/graph.go | 15 ++- 5 files changed, 93 insertions(+), 54 deletions(-) diff --git a/examples/plugin/callgraph/testFnAssignment.py b/examples/plugin/callgraph/testFnAssignment.py index 15bb8c4..087519d 100644 --- a/examples/plugin/callgraph/testFnAssignment.py +++ b/examples/plugin/callgraph/testFnAssignment.py @@ -1,5 +1,10 @@ import pprint -from xyzprintmodule import xyzprint, xyzprint2 +from xyzprintmodule import xyzprint, xyzprint2, xyzprint3 as pxyz3 + +customprintxyz = pxyz3 +customprintxyz = xyzprint2 +customprintxyz("GG") + def foo(): pprint.pprint("foo") diff --git a/examples/plugin/callgraph/testObjAssignment.py b/examples/plugin/callgraph/testObjAssignment.py index 33a8fd8..472ca2e 100644 --- a/examples/plugin/callgraph/testObjAssignment.py +++ b/examples/plugin/callgraph/testObjAssignment.py @@ -1,4 +1,6 @@ -from xyz import printxyz as pxyz, printxyz2 +import pprint +from xyz import printxyz as pxyz, printxyz2, printxyz3 +from os import listdir as listdirfn, chmod class ClassA: def __init__(self): @@ -15,9 +17,15 @@ def __init__(self): def method1(self): printxyz2("GG") + def methodUnique(self): + printxyz3("GG") + pprint.pp("GG") + + def main(): x = ClassA() x = ClassB() y = x y.method1() + y.methodUnique() main() \ No newline at end of file diff --git a/examples/plugin/callgraph/testScopes.py b/examples/plugin/callgraph/testScopes.py index f6a5a7d..4255a8d 100644 --- a/examples/plugin/callgraph/testScopes.py +++ b/examples/plugin/callgraph/testScopes.py @@ -1,5 +1,5 @@ from pprint import pprint -from xyzprintmodule import xprint, xyzprint, xyzprint2, xyzprint3 +from xyzprintmodule import xyzprint as xprint, xyzprint, xyzprint2, xyzprint3 def fn1(): xprint("very outer fn1") diff --git a/plugin/callgraph/callgraph.go b/plugin/callgraph/callgraph.go index f08bc42..256e7fa 100644 --- a/plugin/callgraph/callgraph.go +++ b/plugin/callgraph/callgraph.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/safedep/code/core" + "github.com/safedep/code/core/ast" "github.com/safedep/code/pkg/helpers" "github.com/safedep/dry/log" sitter "github.com/smacker/go-tree-sitter" @@ -56,9 +57,6 @@ func (p *callgraphPlugin) AnalyzeTree(ctx context.Context, tree core.ParseTree) return fmt.Errorf("failed to get tree data: %w", err) } - fmt.Println() - - // Build the call graph cg, err := buildCallGraph(tree, lang, treeData, file.Name()) if err != nil { @@ -70,48 +68,40 @@ func (p *callgraphPlugin) AnalyzeTree(ctx context.Context, tree core.ParseTree) // buildCallGraph builds a call graph from the syntax tree func buildCallGraph(tree core.ParseTree, lang core.Language, treeData *[]byte, filePath string) (*CallGraph, error) { - callGraph := NewCallGraph(filePath) - rootNode := tree.Tree().RootNode() + astRootNode := tree.Tree().RootNode() + imports, err := lang.Resolvers().ResolveImports(tree) if err != nil { return nil, fmt.Errorf("failed to resolve imports: %w", err) } - identifierToModuleNamespace := make(map[string]string) - for _, imp := range imports { - if imp.IsWildcardImport() { - continue - } - itemNamespace := imp.ModuleItem() - if itemNamespace == "" { - itemNamespace = imp.ModuleName() - } else { - itemNamespace = imp.ModuleName() + namespaceSeparator + itemNamespace - } + // Required to map identifiers to imported modules as assignments + importedIdentifierNamespaces := parseImportedIdentifierNamespaces(imports) - identifierKey := helpers.GetFirstNonEmptyString(imp.ModuleAlias(), imp.ModuleItem(), imp.ModuleName()) - identifierToModuleNamespace[identifierKey] = itemNamespace - } - fmt.Println("Imports (identifier => namespace):") - for identifier, moduleNamespace := range identifierToModuleNamespace { - fmt.Printf("%s => %s\n", identifier, moduleNamespace) + fmt.Println() + fmt.Println("Imported identifier => namespace:") + for identifier, namespace := range importedIdentifierNamespaces { + fmt.Printf(" %s => %s\n", identifier, namespace) } fmt.Println() + callGraph := NewCallGraph(filePath, importedIdentifierNamespaces) + // Add root node to the call graph callGraph.Nodes[filePath] = newGraphNode(filePath) - traverseTree(rootNode, treeData, callGraph, identifierToModuleNamespace, filePath, filePath, "", false) + traverseTree(astRootNode, treeData, callGraph, filePath, filePath, "", false) fmt.Println() return callGraph, nil } -func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, identifierToModuleNamespace map[string]string, filePath string, currentNamespace string, classNamespace string, insideClass bool) { +func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, filePath string, currentNamespace string, classNamespace string, insideClass bool) { if node == nil { return } fmt.Println("Traverse ", node.Type(), "with content:", node.Content(*treeData), "with namespace:", currentNamespace) + switch node.Type() { case "function_definition": nameNode := node.ChildByFieldName("name") @@ -143,9 +133,9 @@ func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, ide rightNode := node.ChildByFieldName("right") if leftNode != nil && rightNode != nil { leftVar := leftNode.Content(*treeData) - rightTargets := resolveTargets(rightNode, *treeData, currentNamespace, identifierToModuleNamespace, callGraph) + rightTargets := resolveTargets(rightNode, *treeData, currentNamespace, callGraph) for _, rightTarget := range rightTargets { - callGraph.Assignments.AddAssignment(leftVar, rightTarget) + callGraph.assignments.AddAssignment(leftVar, rightTarget) fmt.Printf("Assignment: %s -> %s\n", leftVar, rightTarget) } } @@ -157,17 +147,23 @@ func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, ide if baseNode != nil && attributeNode != nil { // Resolve base object using the assignment graph at different scopes fmt.Printf("Try resolving target call %s.%s at %s\n", baseNode.Content(*treeData), attributeNode.Content(*treeData), currentNamespace) - baseTargets := resolveTargets(baseNode, *treeData, currentNamespace, identifierToModuleNamespace, callGraph) + baseTargets := resolveTargets(baseNode, *treeData, currentNamespace, callGraph) + fmt.Printf("Processing fn call as %s.%s() on base targets : %v \n", baseNode.Content(*treeData), attributeNode.Content(*treeData), baseTargets) + for _, baseTarget := range baseTargets { fmt.Printf("Processing fn call as %s.%s() on base as %s \n", baseNode.Content(*treeData), attributeNode.Content(*treeData), baseTarget) attributeName := attributeNode.Content(*treeData) targetNamespace := baseTarget + namespaceSeparator + attributeName _, existed := callGraph.Nodes[targetNamespace] - fmt.Printf("Attr %s resolved to %s, exists: %t\n", node.Content(*treeData), targetNamespace, existed) + fmt.Printf("Attr %s resolved to %s, exists: %t\n", attributeName, targetNamespace, existed) // Check if resolved target exists in the call graph if _, exists := callGraph.Nodes[targetNamespace]; exists && targetNamespace != "" { + fmt.Println("Add attr edge from", currentNamespace, "to", targetNamespace) callGraph.AddEdge(currentNamespace, targetNamespace) + } else if _, exists := callGraph.importedIdentifierNamespaces[baseTarget]; exists { + fmt.Println("Add attr edge from", currentNamespace, "to module namespace", baseTarget+namespaceSeparator+attributeName) + callGraph.AddEdge(currentNamespace, baseTarget+namespaceSeparator+attributeName) } } } @@ -182,14 +178,13 @@ func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, ide // Search for the call target node at different scopes in the graph // eg. namespace - nestNestedFn.py//nestParent//nestChild, callTarget - outerfn1 - // try searching for outerfn1 in identifierToModuleNamespace with all scope levels + // try searching for outerfn1 in graph with all scope levels // eg. search nestNestedFn.py//nestParent//nestChild//outerfn1 // then nestNestedFn.py//nestParent//outerfn1 then nestNestedFn.py//outerfn1 and so on // if not found, then use currentNamespace to build it // like, nestNestedFn.py//nestParent//nestChild//outerfn1 - var found bool - var targetNamespace string + targetNamespaces := []string{} for i := strings.Count(currentNamespace, namespaceSeparator) + 1; i >= 0; i-- { searchNamespace := strings.Join(strings.Split(currentNamespace, namespaceSeparator)[:i], namespaceSeparator) + namespaceSeparator + callTarget if i == 0 { @@ -198,31 +193,34 @@ func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, ide fmt.Printf("searching %s in scoped - %s\n", callTarget, searchNamespace) // check in graph if _, exists := callGraph.Nodes[searchNamespace]; exists { - targetNamespace = searchNamespace - found = true + targetNamespaces = append(targetNamespaces, searchNamespace) break } } - if !found { - // check in identifierToModuleNamespace - if moduleNamespace, exists := identifierToModuleNamespace[callTarget]; exists { - targetNamespace = moduleNamespace - } else if insideClass { - targetNamespace = classNamespace + namespaceSeparator + callTarget + if len(targetNamespaces) == 0 { + if assignedNamespaces := callGraph.assignments.Resolve(callTarget); len(assignedNamespaces) > 0 { + fmt.Println("Resolve imported target for", callTarget, ":", assignedNamespaces) + targetNamespaces = assignedNamespaces } else { - targetNamespace = currentNamespace + namespaceSeparator + callTarget + if insideClass { + targetNamespaces = []string{classNamespace + namespaceSeparator + callTarget} + } else { + targetNamespaces = []string{currentNamespace + namespaceSeparator + callTarget} + } } } // Add edge for function call - fmt.Println("Adding edge from", currentNamespace, "to", targetNamespace) - callGraph.AddEdge(currentNamespace, targetNamespace) + for _, targetNamespace := range targetNamespaces { + fmt.Println("Adding edge from", currentNamespace, "to", targetNamespace) + callGraph.AddEdge(currentNamespace, targetNamespace) + } } } // Recursively analyze children for i := 0; i < int(node.ChildCount()); i++ { - traverseTree(node.Child(i), treeData, callGraph, identifierToModuleNamespace, filePath, currentNamespace, classNamespace, insideClass) + traverseTree(node.Child(i), treeData, callGraph, filePath, currentNamespace, classNamespace, insideClass) } } @@ -230,7 +228,6 @@ func resolveTargets( node *sitter.Node, treeData []byte, currentNamespace string, - identifierToModuleNamespace map[string]string, callGraph *CallGraph, ) []string { if node == nil { @@ -242,7 +239,7 @@ func resolveTargets( if node.Type() == "identifier" { identifier := node.Content(treeData) // Check if the identifier maps to something in the assignment graph - resolvedTargets := callGraph.Assignments.Resolve(identifier) + resolvedTargets := callGraph.assignments.Resolve(identifier) fmt.Println("Resolved targets for", identifier, ":", resolvedTargets) if len(resolvedTargets) > 0 { return resolvedTargets @@ -275,7 +272,7 @@ func resolveTargets( baseNode := node.ChildByFieldName("object") attributeNode := node.ChildByFieldName("attribute") if baseNode != nil && attributeNode != nil { - var baseTarget []string = resolveTargets(baseNode, treeData, currentNamespace, identifierToModuleNamespace, callGraph) + var baseTarget []string = resolveTargets(baseNode, treeData, currentNamespace, callGraph) attributeName := attributeNode.Content(treeData) var resolvedTargets []string for _, base := range baseTarget { @@ -288,3 +285,27 @@ func resolveTargets( // Handle other expressions as fallbacks (e.g., literals, complex expressions) return []string{node.Content(treeData)} } + +// Fetches namespaces for imported identifiers +// eg. import pprint is parsed as: +// pprint -> pprint +// eg. from os import listdir as listdirfn, chmod is parsed as: +// listdirfn -> os//listdir +// chmod -> os//chmod +func parseImportedIdentifierNamespaces(imports []*ast.ImportNode) map[string]string { + importedIdentifierNamespaces := make(map[string]string) + for _, imp := range imports { + if imp.IsWildcardImport() { + continue + } + itemNamespace := imp.ModuleItem() + if itemNamespace == "" { + itemNamespace = imp.ModuleName() + } else { + itemNamespace = imp.ModuleName() + namespaceSeparator + itemNamespace + } + identifierKey := helpers.GetFirstNonEmptyString(imp.ModuleAlias(), imp.ModuleItem(), imp.ModuleName()) + importedIdentifierNamespaces[identifierKey] = itemNamespace + } + return importedIdentifierNamespaces +} diff --git a/plugin/callgraph/graph.go b/plugin/callgraph/graph.go index 04014d7..dd1000f 100644 --- a/plugin/callgraph/graph.go +++ b/plugin/callgraph/graph.go @@ -21,13 +21,18 @@ func newGraphNode(namespace string) *graphNode { } type CallGraph struct { - FileName string - Nodes map[string]*graphNode - Assignments AssignmentGraph + FileName string + Nodes map[string]*graphNode + assignments AssignmentGraph + importedIdentifierNamespaces map[string]string } -func NewCallGraph(fileName string) *CallGraph { - return &CallGraph{FileName: fileName, Nodes: make(map[string]*graphNode), Assignments: *NewAssignmentGraph()} +func NewCallGraph(fileName string, importedIdentifierNamespaces map[string]string) *CallGraph { + cg := &CallGraph{FileName: fileName, Nodes: make(map[string]*graphNode), assignments: *NewAssignmentGraph(), importedIdentifierNamespaces: importedIdentifierNamespaces} + for identifier, namespace := range importedIdentifierNamespaces { + cg.assignments.AddAssignment(identifier, namespace) + } + return cg } // AddEdge adds an edge from one function to another From c24d46d3b3934fabe9af60fb907b463a6fc1b914 Mon Sep 17 00:00:00 2001 From: Omkar Phansopkar Date: Sat, 11 Jan 2025 00:15:23 +0530 Subject: [PATCH 5/5] Fixed redundant namespaces for assigned calls Signed-off-by: Omkar Phansopkar --- examples/plugin/callgraph/testFnAssignment.py | 4 +++- examples/plugin/callgraph/testNestedFn.py | 6 +++--- plugin/callgraph/callgraph.go | 21 +++++++++---------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/plugin/callgraph/testFnAssignment.py b/examples/plugin/callgraph/testFnAssignment.py index 087519d..e88698a 100644 --- a/examples/plugin/callgraph/testFnAssignment.py +++ b/examples/plugin/callgraph/testFnAssignment.py @@ -14,6 +14,7 @@ def bar(): xyzprint("bar") pass +# unused def baz(): xyzprint2("baz") pass @@ -22,7 +23,8 @@ def baz(): print("GG") xyz = bar -xyz() + +xyz() # current analysis will simulate both foo() & bar() calls def nestParent(): def nestChild(): diff --git a/examples/plugin/callgraph/testNestedFn.py b/examples/plugin/callgraph/testNestedFn.py index cb0d043..2947c09 100644 --- a/examples/plugin/callgraph/testNestedFn.py +++ b/examples/plugin/callgraph/testNestedFn.py @@ -1,4 +1,4 @@ -from pprint import pprint as pprintfn +import pprint from xyzprintmodule import xyzprint, xyzprint2 from os import listdir as listdirfn, chmod @@ -24,11 +24,11 @@ def nestGrandChildUseless(): xyzprint2("nestGrandChildUseless") def nestGrandChild(): - xyzprint2("nestGrandChild") + pprint.pp("nestGrandChild") parentScopedFn() outerfn2() childScopedFn() - + nestGrandChild() outerfn1() diff --git a/plugin/callgraph/callgraph.go b/plugin/callgraph/callgraph.go index 256e7fa..b2c6d7d 100644 --- a/plugin/callgraph/callgraph.go +++ b/plugin/callgraph/callgraph.go @@ -112,12 +112,10 @@ func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, fil } else { currentNamespace = currentNamespace + namespaceSeparator + funcName } - fmt.Printf("Current funcdef namespace for %s - %s\n", funcName, currentNamespace) if _, exists := callGraph.Nodes[currentNamespace]; !exists { callGraph.Nodes[currentNamespace] = newGraphNode(currentNamespace) - //@TODO - Class -> __init__ edge needed ? - //@TODO - If needed, should make it language agnostic, (__init__ ) must be obtained from lang. For java, it would be classname itself, etc + //@TODO - Class Constructor edge must be language agnostic, (__init__ ) must be obtained from lang. For java, it would be classname itself, etc if insideClass && funcName == "__init__" { callGraph.AddEdge(classNamespace, currentNamespace) } @@ -136,7 +134,6 @@ func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, fil rightTargets := resolveTargets(rightNode, *treeData, currentNamespace, callGraph) for _, rightTarget := range rightTargets { callGraph.assignments.AddAssignment(leftVar, rightTarget) - fmt.Printf("Assignment: %s -> %s\n", leftVar, rightTarget) } } case "attribute": @@ -198,16 +195,19 @@ func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, fil } } if len(targetNamespaces) == 0 { + // If namespace not found in available scopes in the graph, try to resolve it from imported namespaces if assignedNamespaces := callGraph.assignments.Resolve(callTarget); len(assignedNamespaces) > 0 { fmt.Println("Resolve imported target for", callTarget, ":", assignedNamespaces) targetNamespaces = assignedNamespaces - } else { - if insideClass { - targetNamespaces = []string{classNamespace + namespaceSeparator + callTarget} - } else { - targetNamespaces = []string{currentNamespace + namespaceSeparator + callTarget} - } } + // else { + // // @TODO - rethink this + // if insideClass { + // targetNamespaces = []string{classNamespace + namespaceSeparator + callTarget} + // } else { + // targetNamespaces = []string{currentNamespace + namespaceSeparator + callTarget} + // } + // } } // Add edge for function call @@ -240,7 +240,6 @@ func resolveTargets( identifier := node.Content(treeData) // Check if the identifier maps to something in the assignment graph resolvedTargets := callGraph.assignments.Resolve(identifier) - fmt.Println("Resolved targets for", identifier, ":", resolvedTargets) if len(resolvedTargets) > 0 { return resolvedTargets }