diff --git a/examples/plugin/callgraph/main.go b/examples/plugin/callgraph/main.go new file mode 100644 index 0000000..605dc13 --- /dev/null +++ b/examples/plugin/callgraph/main.go @@ -0,0 +1,88 @@ +package main + +import ( + "context" + "flag" + "fmt" + + "github.com/safedep/code/core" + "github.com/safedep/code/fs" + "github.com/safedep/code/lang" + "github.com/safedep/code/parser" + "github.com/safedep/code/plugin" + "github.com/safedep/code/plugin/callgraph" + "github.com/safedep/dry/log" +) + +var ( + dirToWalk string + language string +) + +func init() { + log.InitZapLogger("walker", "dev") + + flag.StringVar(&dirToWalk, "dir", "", "Directory to walk") + flag.StringVar(&language, "lang", "python", "Language to use for parsing files") + + flag.Parse() +} + +func main() { + if dirToWalk == "" { + flag.Usage() + return + } + + err := run() + if err != nil { + panic(err) + } +} + +func run() error { + fileSystem, err := fs.NewLocalFileSystem(fs.LocalFileSystemConfig{ + AppDirectories: []string{dirToWalk}, + }) + + if err != nil { + return fmt.Errorf("failed to create local filesystem: %w", err) + } + + language, err := lang.GetLanguage(language) + if err != nil { + return fmt.Errorf("failed to get language: %w", err) + } + + walker, err := fs.NewSourceWalker(fs.SourceWalkerConfig{}, language) + if err != nil { + return fmt.Errorf("failed to create source walker: %w", err) + } + + treeWalker, err := parser.NewWalkingParser(walker, language) + if err != nil { + return fmt.Errorf("failed to create tree walker: %w", err) + } + + // consume callgraph + var callgraphCallback callgraph.CallgraphCallback = func(cg *callgraph.CallGraph) error { + cg.PrintCallGraph() + + fmt.Println("DFS Traversal:") + for _, node := range cg.DFS() { + fmt.Println(node) + } + fmt.Println() + return nil + } + + pluginExecutor, err := plugin.NewTreeWalkPluginExecutor(treeWalker, []core.Plugin{ + callgraph.NewCallGraphPlugin(callgraphCallback), + }) + + if err != nil { + return fmt.Errorf("failed to create plugin executor: %w", err) + } + + return pluginExecutor.Execute(context.Background(), fileSystem) +} diff --git a/examples/plugin/callgraph/test.py b/examples/plugin/callgraph/test.py new file mode 100644 index 0000000..371b657 --- /dev/null +++ b/examples/plugin/callgraph/test.py @@ -0,0 +1,69 @@ +import base64 +from utils import printinit, printenc, printdec, printf2 + +# Node must be generated, but shouldn't be part of DFS +class EncodingUnused: + def __init__(self): + printinit("Initialized unused") + pass + + def applyUnused(self, msg, func): + return func(msg) + +class Encoding: + def __init__(self): + printinit("Initialized") + pass + + def apply(self, msg, func): + return func(msg) + + # Unused + def apply2(self, msg, func): + return func(msg) + +encoder = Encoding() +encoded = encoder.apply("Hello, World!".encode('utf-8'), base64.b64encode) +printenc(encoded) +decoded = encoder.apply(encoded, base64.b64decode) +printdec(decoded) + + +def f1(value): + f2(value) + +def f2(value): + printf2(value) + if value == 0: + return + f1(value-1) + pass + +def multiply(a, b): + return a * b + +f1(multiply(2, 3)) + +def foo(): + print("foo") + pass + +def bar(): + print("bar") + pass + +def baz(): + print("baz") + pass +def useless(): + print("useless") + baz() + pass + +xyz = foo + +print("GG") + +xyz = bar + +xyz() \ No newline at end of file diff --git a/examples/plugin/callgraph/testClass.py b/examples/plugin/callgraph/testClass.py new file mode 100644 index 0000000..30aa12a --- /dev/null +++ b/examples/plugin/callgraph/testClass.py @@ -0,0 +1,30 @@ +import base64 +from utils import printinit, printenc, printdec as pdec + +class Encoding: + def __init__(self): + printinit("Initialized") + pass + + def apply(self, msg, func): + return func(msg) + + # Unused + def apply2(self, msg, func): + return func(msg) + +def getenc(): + return "encoded" + +encoder = Encoding() +encoded = encoder.apply("Hello, World!".encode('utf-8'), base64.b64encode) +printenc(encoded) +decoded = encoder.apply(getenc(), base64.b64decode) +pdec(decoded) + +class NoConstructorClass: + def show(): + print("NoConstructorClass") + pass +ncc = NoConstructorClass() +ncc.show() \ No newline at end of file diff --git a/examples/plugin/callgraph/testFnAssignment.py b/examples/plugin/callgraph/testFnAssignment.py new file mode 100644 index 0000000..e88698a --- /dev/null +++ b/examples/plugin/callgraph/testFnAssignment.py @@ -0,0 +1,42 @@ +import pprint +from xyzprintmodule import xyzprint, xyzprint2, xyzprint3 as pxyz3 + +customprintxyz = pxyz3 +customprintxyz = xyzprint2 +customprintxyz("GG") + + +def foo(): + pprint.pprint("foo") + pass + +def bar(): + xyzprint("bar") + pass + +# unused +def baz(): + xyzprint2("baz") + pass + +xyz = foo +print("GG") + +xyz = bar + +xyz() # current analysis will simulate both foo() & bar() calls + +def nestParent(): + def nestChild(): + xyzprint("nestChild") + def nestGrandChild(): + xyzprint2("nestGrandChild") + pass + nestGrandChild() + nestChild() +nestParent() + +def useless(): + print("useless") + baz() + pass diff --git a/examples/plugin/callgraph/testNestedFn.py b/examples/plugin/callgraph/testNestedFn.py new file mode 100644 index 0000000..2947c09 --- /dev/null +++ b/examples/plugin/callgraph/testNestedFn.py @@ -0,0 +1,38 @@ +import pprint +from xyzprintmodule import xyzprint, xyzprint2 +from os import listdir as listdirfn, chmod + +def outerfn1(): + chmod("outerfn1") + pass +def outerfn2(): + listdirfn("outerfn2") + pass + +def nestParent(): + def parentScopedFn(): + xyzprint("parentScopedFn") + + def nestChild(): + xyzprint("nestChild") + outerfn1() + + def childScopedFn(): + xyzprint("childScopedFn") + + def nestGrandChildUseless(): + xyzprint2("nestGrandChildUseless") + + def nestGrandChild(): + pprint.pp("nestGrandChild") + parentScopedFn() + outerfn2() + childScopedFn() + + nestGrandChild() + + outerfn1() + nestChild() + +nestParent() + diff --git a/examples/plugin/callgraph/testObjAssignment.py b/examples/plugin/callgraph/testObjAssignment.py new file mode 100644 index 0000000..472ca2e --- /dev/null +++ b/examples/plugin/callgraph/testObjAssignment.py @@ -0,0 +1,31 @@ +import pprint +from xyz import printxyz as pxyz, printxyz2, printxyz3 +from os import listdir as listdirfn, chmod + +class ClassA: + def __init__(self): + pxyz("init") + + def method1(self): + printxyz2("GG") + + +class ClassB: + def __init__(self): + pxyz("init") + + def method1(self): + printxyz2("GG") + + def methodUnique(self): + printxyz3("GG") + pprint.pp("GG") + + +def main(): + x = ClassA() + x = ClassB() + y = x + y.method1() + y.methodUnique() +main() \ No newline at end of file diff --git a/examples/plugin/callgraph/testScopes.py b/examples/plugin/callgraph/testScopes.py new file mode 100644 index 0000000..4255a8d --- /dev/null +++ b/examples/plugin/callgraph/testScopes.py @@ -0,0 +1,24 @@ +from pprint import pprint +from xyzprintmodule import xyzprint as xprint, xyzprint, xyzprint2, xyzprint3 + +def fn1(): + xprint("very outer fn1") +fn1() + +def fn2(): + def fn1(): + xyzprint("fn1 inside fn2") + fn1() + + def fn3(): + def fn4(): + def fn1(): + xyzprint3("fn1 inside fn4 inside fn3") + xyzprint2("fn4 inside fn3") + fn1() # must call fn1 inside fn4 + fn1() # must call fn1 inside fn2 + fn4() + fn3() + +fn2() + \ No newline at end of file diff --git a/plugin/callgraph/assignment.go b/plugin/callgraph/assignment.go new file mode 100644 index 0000000..cd9884f --- /dev/null +++ b/plugin/callgraph/assignment.go @@ -0,0 +1,19 @@ +package callgraph + +type AssignmentGraph struct { + Assignments map[string][]string // Map of identifier to possible namespaces or other identifiers +} + +func NewAssignmentGraph() *AssignmentGraph { + return &AssignmentGraph{Assignments: make(map[string][]string)} +} + +// Add an assignment +func (ag *AssignmentGraph) AddAssignment(identifier string, target string) { + ag.Assignments[identifier] = append(ag.Assignments[identifier], target) +} + +// Resolve an identifier to its targets +func (ag *AssignmentGraph) Resolve(identifier string) []string { + return ag.Assignments[identifier] +} diff --git a/plugin/callgraph/callgraph.go b/plugin/callgraph/callgraph.go new file mode 100644 index 0000000..b2c6d7d --- /dev/null +++ b/plugin/callgraph/callgraph.go @@ -0,0 +1,310 @@ +package callgraph + +import ( + "context" + "fmt" + "strings" + + "github.com/safedep/code/core" + "github.com/safedep/code/core/ast" + "github.com/safedep/code/pkg/helpers" + "github.com/safedep/dry/log" + sitter "github.com/smacker/go-tree-sitter" +) + +type CallgraphCallback func(*CallGraph) error + +type callgraphPlugin struct { + // Callback function which is called with the callgraph + callgraphCallback CallgraphCallback +} + +// Verify contract +var _ core.TreePlugin = (*callgraphPlugin)(nil) + +func NewCallGraphPlugin(callgraphCallback CallgraphCallback) *callgraphPlugin { + return &callgraphPlugin{ + callgraphCallback: callgraphCallback, + } +} + +func (p *callgraphPlugin) Name() string { + return "CallgraphPlugin" +} + +var supportedLanguages = []core.LanguageCode{core.LanguageCodePython} + +func (p *callgraphPlugin) SupportedLanguages() []core.LanguageCode { + return supportedLanguages +} + +func (p *callgraphPlugin) AnalyzeTree(ctx context.Context, tree core.ParseTree) error { + lang, err := tree.Language() + if err != nil { + return fmt.Errorf("failed to get language: %w", err) + } + + file, err := tree.File() + if err != nil { + return fmt.Errorf("failed to get file: %w", err) + } + + log.Debugf("callgraph - Analyzing tree for language: %s, file: %s\n", + lang.Meta().Code, file.Name()) + + treeData, err := tree.Data() + if err != nil { + return fmt.Errorf("failed to get tree data: %w", err) + } + + cg, err := buildCallGraph(tree, lang, treeData, file.Name()) + + if err != nil { + return fmt.Errorf("failed to build call graph: %w", err) + } + + return p.callgraphCallback(cg) +} + +// buildCallGraph builds a call graph from the syntax tree +func buildCallGraph(tree core.ParseTree, lang core.Language, treeData *[]byte, filePath string) (*CallGraph, error) { + astRootNode := tree.Tree().RootNode() + + imports, err := lang.Resolvers().ResolveImports(tree) + if err != nil { + return nil, fmt.Errorf("failed to resolve imports: %w", err) + } + + // Required to map identifiers to imported modules as assignments + importedIdentifierNamespaces := parseImportedIdentifierNamespaces(imports) + + fmt.Println() + fmt.Println("Imported identifier => namespace:") + for identifier, namespace := range importedIdentifierNamespaces { + fmt.Printf(" %s => %s\n", identifier, namespace) + } + fmt.Println() + + callGraph := NewCallGraph(filePath, importedIdentifierNamespaces) + + // Add root node to the call graph + callGraph.Nodes[filePath] = newGraphNode(filePath) + + traverseTree(astRootNode, treeData, callGraph, filePath, filePath, "", false) + fmt.Println() + + return callGraph, nil +} + +func traverseTree(node *sitter.Node, treeData *[]byte, callGraph *CallGraph, filePath string, currentNamespace string, classNamespace string, insideClass bool) { + if node == nil { + return + } + fmt.Println("Traverse ", node.Type(), "with content:", node.Content(*treeData), "with namespace:", currentNamespace) + + switch node.Type() { + case "function_definition": + nameNode := node.ChildByFieldName("name") + if nameNode != nil { + funcName := nameNode.Content(*treeData) + if insideClass { + currentNamespace = classNamespace + namespaceSeparator + funcName + } else { + currentNamespace = currentNamespace + namespaceSeparator + funcName + } + + if _, exists := callGraph.Nodes[currentNamespace]; !exists { + callGraph.Nodes[currentNamespace] = newGraphNode(currentNamespace) + //@TODO - Class Constructor edge must be language agnostic, (__init__ ) must be obtained from lang. For java, it would be classname itself, etc + if insideClass && funcName == "__init__" { + callGraph.AddEdge(classNamespace, currentNamespace) + } + } + } + case "class_definition": + // Handle class definitions + className := node.ChildByFieldName("name").Content(*treeData) + classNamespace = currentNamespace + namespaceSeparator + className + insideClass = true + case "assignment": + leftNode := node.ChildByFieldName("left") + rightNode := node.ChildByFieldName("right") + if leftNode != nil && rightNode != nil { + leftVar := leftNode.Content(*treeData) + rightTargets := resolveTargets(rightNode, *treeData, currentNamespace, callGraph) + for _, rightTarget := range rightTargets { + callGraph.assignments.AddAssignment(leftVar, rightTarget) + } + } + case "attribute": + // processing a xyz.attr for xyz.attr() call + if node.Parent().Type() == "call" { + baseNode := node.ChildByFieldName("object") + attributeNode := node.ChildByFieldName("attribute") + if baseNode != nil && attributeNode != nil { + // Resolve base object using the assignment graph at different scopes + fmt.Printf("Try resolving target call %s.%s at %s\n", baseNode.Content(*treeData), attributeNode.Content(*treeData), currentNamespace) + baseTargets := resolveTargets(baseNode, *treeData, currentNamespace, callGraph) + fmt.Printf("Processing fn call as %s.%s() on base targets : %v \n", baseNode.Content(*treeData), attributeNode.Content(*treeData), baseTargets) + + for _, baseTarget := range baseTargets { + fmt.Printf("Processing fn call as %s.%s() on base as %s \n", baseNode.Content(*treeData), attributeNode.Content(*treeData), baseTarget) + attributeName := attributeNode.Content(*treeData) + targetNamespace := baseTarget + namespaceSeparator + attributeName + _, existed := callGraph.Nodes[targetNamespace] + fmt.Printf("Attr %s resolved to %s, exists: %t\n", attributeName, targetNamespace, existed) + + // Check if resolved target exists in the call graph + if _, exists := callGraph.Nodes[targetNamespace]; exists && targetNamespace != "" { + fmt.Println("Add attr edge from", currentNamespace, "to", targetNamespace) + callGraph.AddEdge(currentNamespace, targetNamespace) + } else if _, exists := callGraph.importedIdentifierNamespaces[baseTarget]; exists { + fmt.Println("Add attr edge from", currentNamespace, "to module namespace", baseTarget+namespaceSeparator+attributeName) + callGraph.AddEdge(currentNamespace, baseTarget+namespaceSeparator+attributeName) + } + } + } + } + case "call": + fmt.Println("Traverse 'call' with content:", node.Content(*treeData), "with namespace:", currentNamespace, "insideClass:", insideClass, "node type:", node.Type()) + + // Extract call target + targetNode := node.ChildByFieldName("function") + if targetNode != nil { + callTarget := targetNode.Content(*treeData) + + // Search for the call target node at different scopes in the graph + // eg. namespace - nestNestedFn.py//nestParent//nestChild, callTarget - outerfn1 + // try searching for outerfn1 in graph with all scope levels + // eg. search nestNestedFn.py//nestParent//nestChild//outerfn1 + // then nestNestedFn.py//nestParent//outerfn1 then nestNestedFn.py//outerfn1 and so on + // if not found, then use currentNamespace to build it + // like, nestNestedFn.py//nestParent//nestChild//outerfn1 + + targetNamespaces := []string{} + for i := strings.Count(currentNamespace, namespaceSeparator) + 1; i >= 0; i-- { + searchNamespace := strings.Join(strings.Split(currentNamespace, namespaceSeparator)[:i], namespaceSeparator) + namespaceSeparator + callTarget + if i == 0 { + searchNamespace = callTarget + } + fmt.Printf("searching %s in scoped - %s\n", callTarget, searchNamespace) + // check in graph + if _, exists := callGraph.Nodes[searchNamespace]; exists { + targetNamespaces = append(targetNamespaces, searchNamespace) + break + } + } + if len(targetNamespaces) == 0 { + // If namespace not found in available scopes in the graph, try to resolve it from imported namespaces + if assignedNamespaces := callGraph.assignments.Resolve(callTarget); len(assignedNamespaces) > 0 { + fmt.Println("Resolve imported target for", callTarget, ":", assignedNamespaces) + targetNamespaces = assignedNamespaces + } + // else { + // // @TODO - rethink this + // if insideClass { + // targetNamespaces = []string{classNamespace + namespaceSeparator + callTarget} + // } else { + // targetNamespaces = []string{currentNamespace + namespaceSeparator + callTarget} + // } + // } + } + + // Add edge for function call + for _, targetNamespace := range targetNamespaces { + fmt.Println("Adding edge from", currentNamespace, "to", targetNamespace) + callGraph.AddEdge(currentNamespace, targetNamespace) + } + } + } + + // Recursively analyze children + for i := 0; i < int(node.ChildCount()); i++ { + traverseTree(node.Child(i), treeData, callGraph, filePath, currentNamespace, classNamespace, insideClass) + } +} + +func resolveTargets( + node *sitter.Node, + treeData []byte, + currentNamespace string, + callGraph *CallGraph, +) []string { + if node == nil { + return []string{} + } + fmt.Printf("Resolve targets for for %s type on %s with namespace %s\n", node.Type(), node.Content(treeData), currentNamespace) + + // Handle variable names directly + if node.Type() == "identifier" { + identifier := node.Content(treeData) + // Check if the identifier maps to something in the assignment graph + resolvedTargets := callGraph.assignments.Resolve(identifier) + if len(resolvedTargets) > 0 { + return resolvedTargets + } + // Fallback: return the identifier in the current namespace + return []string{currentNamespace + namespaceSeparator + identifier} + } + + // Handle calls, e.g., ClassA() -> resolve to ClassA namespace + if node.Type() == "call" { + functionNode := node.ChildByFieldName("function") + if functionNode != nil { + functionName := functionNode.Content(treeData) + // Check if the function is a class in the current or parrent scopes in callgraph graph + for i := strings.Count(currentNamespace, namespaceSeparator); i >= 0; i-- { + searchNamespace := strings.Join(strings.Split(currentNamespace, namespaceSeparator)[:i], namespaceSeparator) + namespaceSeparator + functionName + if i == 0 { + searchNamespace = functionName + } + if _, exists := callGraph.Nodes[searchNamespace]; exists { + return []string{searchNamespace} + } + } + return []string{currentNamespace + namespaceSeparator + functionName} + } + } + + // Handle member access, e.g., obj.attr + if node.Type() == "attribute" { + baseNode := node.ChildByFieldName("object") + attributeNode := node.ChildByFieldName("attribute") + if baseNode != nil && attributeNode != nil { + var baseTarget []string = resolveTargets(baseNode, treeData, currentNamespace, callGraph) + attributeName := attributeNode.Content(treeData) + var resolvedTargets []string + for _, base := range baseTarget { + resolvedTargets = append(resolvedTargets, base+namespaceSeparator+attributeName) + } + return resolvedTargets + } + } + + // Handle other expressions as fallbacks (e.g., literals, complex expressions) + return []string{node.Content(treeData)} +} + +// Fetches namespaces for imported identifiers +// eg. import pprint is parsed as: +// pprint -> pprint +// eg. from os import listdir as listdirfn, chmod is parsed as: +// listdirfn -> os//listdir +// chmod -> os//chmod +func parseImportedIdentifierNamespaces(imports []*ast.ImportNode) map[string]string { + importedIdentifierNamespaces := make(map[string]string) + for _, imp := range imports { + if imp.IsWildcardImport() { + continue + } + itemNamespace := imp.ModuleItem() + if itemNamespace == "" { + itemNamespace = imp.ModuleName() + } else { + itemNamespace = imp.ModuleName() + namespaceSeparator + itemNamespace + } + identifierKey := helpers.GetFirstNonEmptyString(imp.ModuleAlias(), imp.ModuleItem(), imp.ModuleName()) + importedIdentifierNamespaces[identifierKey] = itemNamespace + } + return importedIdentifierNamespaces +} diff --git a/plugin/callgraph/config.go b/plugin/callgraph/config.go new file mode 100644 index 0000000..c893f33 --- /dev/null +++ b/plugin/callgraph/config.go @@ -0,0 +1,12 @@ +package callgraph + +// TS nodes Ignored when parsing AST +// eg. comment is useless, imports are already resolved +var ignoredTypesList = []string{"comment"} +var ignoredTypes = make(map[string]bool) + +func init() { + for _, ignoredType := range ignoredTypesList { + ignoredTypes[ignoredType] = true + } +} diff --git a/plugin/callgraph/fixtures/.gitkeep b/plugin/callgraph/fixtures/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/plugin/callgraph/fixtures/test.py b/plugin/callgraph/fixtures/test.py new file mode 100644 index 0000000..371b657 --- /dev/null +++ b/plugin/callgraph/fixtures/test.py @@ -0,0 +1,69 @@ +import base64 +from utils import printinit, printenc, printdec, printf2 + +# Node must be generated, but shouldn't be part of DFS +class EncodingUnused: + def __init__(self): + printinit("Initialized unused") + pass + + def applyUnused(self, msg, func): + return func(msg) + +class Encoding: + def __init__(self): + printinit("Initialized") + pass + + def apply(self, msg, func): + return func(msg) + + # Unused + def apply2(self, msg, func): + return func(msg) + +encoder = Encoding() +encoded = encoder.apply("Hello, World!".encode('utf-8'), base64.b64encode) +printenc(encoded) +decoded = encoder.apply(encoded, base64.b64decode) +printdec(decoded) + + +def f1(value): + f2(value) + +def f2(value): + printf2(value) + if value == 0: + return + f1(value-1) + pass + +def multiply(a, b): + return a * b + +f1(multiply(2, 3)) + +def foo(): + print("foo") + pass + +def bar(): + print("bar") + pass + +def baz(): + print("baz") + pass +def useless(): + print("useless") + baz() + pass + +xyz = foo + +print("GG") + +xyz = bar + +xyz() \ No newline at end of file diff --git a/plugin/callgraph/graph.go b/plugin/callgraph/graph.go new file mode 100644 index 0000000..dd1000f --- /dev/null +++ b/plugin/callgraph/graph.go @@ -0,0 +1,80 @@ +package callgraph + +import ( + "fmt" + "strings" +) + +const namespaceSeparator = "//" + +// graphNode represents a single node in the call graph +type graphNode struct { + Namespace string + CallsTo []string +} + +func newGraphNode(namespace string) *graphNode { + return &graphNode{ + Namespace: namespace, + CallsTo: []string{}, + } +} + +type CallGraph struct { + FileName string + Nodes map[string]*graphNode + assignments AssignmentGraph + importedIdentifierNamespaces map[string]string +} + +func NewCallGraph(fileName string, importedIdentifierNamespaces map[string]string) *CallGraph { + cg := &CallGraph{FileName: fileName, Nodes: make(map[string]*graphNode), assignments: *NewAssignmentGraph(), importedIdentifierNamespaces: importedIdentifierNamespaces} + for identifier, namespace := range importedIdentifierNamespaces { + cg.assignments.AddAssignment(identifier, namespace) + } + return cg +} + +// AddEdge adds an edge from one function to another +func (cg *CallGraph) AddEdge(caller, callee string) { + if _, exists := cg.Nodes[caller]; !exists { + cg.Nodes[caller] = newGraphNode(caller) + } + if _, exists := cg.Nodes[callee]; !exists { + cg.Nodes[callee] = newGraphNode(callee) + } + cg.Nodes[caller].CallsTo = append(cg.Nodes[caller].CallsTo, callee) +} + +func (cg *CallGraph) PrintCallGraph() { + fmt.Println("Call Graph:") + for caller, node := range cg.Nodes { + fmt.Printf(" %s (calls)=> %v\n", caller, node.CallsTo) + } + fmt.Println() +} + +func (cg *CallGraph) DFS() []string { + visited := make(map[string]bool) + var dfsResult []string + cg.dfsUtil(cg.FileName, visited, &dfsResult, 0) + return dfsResult +} + +func (cg *CallGraph) dfsUtil(startNode string, visited map[string]bool, result *[]string, depth int) { + fmt.Println("DFS Util:", startNode) + if visited[startNode] { + // append that not going inside this on prev level + *result = append(*result, fmt.Sprintf("%s Stopped at %s", strings.Repeat("|", depth), startNode)) + return + } + + // Mark the current node as visited and add it to the result + visited[startNode] = true + *result = append(*result, fmt.Sprintf("%s %s", strings.Repeat(">", depth), startNode)) + + // Recursively visit all the nodes called by the current node + for _, callee := range cg.Nodes[startNode].CallsTo { + cg.dfsUtil(callee, visited, result, depth+1) + } +} diff --git a/plugin/callgraph/utils.go b/plugin/callgraph/utils.go new file mode 100644 index 0000000..ebe0871 --- /dev/null +++ b/plugin/callgraph/utils.go @@ -0,0 +1,12 @@ +package callgraph + +import "strings" + +// @TODO - Refactor this for a language agnostic approach + +// GetBaseModuleName returns the base module name from the given module name. +// eg. for "os.path", the base module name is "os" +func GetBaseModuleName(moduleName string) string { + parts := strings.Split(moduleName, ".") + return parts[0] +}