diff --git a/parser/helpers.go b/parser/helpers.go index 248c836..aa5000b 100644 --- a/parser/helpers.go +++ b/parser/helpers.go @@ -9,6 +9,13 @@ import ( "strings" ) +type PackageFiles struct { + PackagePath string + PackageName string + + AstFiles []*ast.File +} + func filterFile(filepath string) bool { if !strings.HasSuffix(filepath, goFileSuffix) || strings.HasSuffix(filepath, GenerateFileSuffix) || strings.HasSuffix(filepath, testFileSuffix) { @@ -17,15 +24,6 @@ func filterFile(filepath string) bool { return true } -func GetDependencies(entryPoint string) ([]string, error) { - dir := path.Dir(entryPoint) - allGoFiles, err := getDependenciesFilenames(dir) - if err != nil { - return nil, err - } - return allGoFiles, nil -} - func getDependenciesFilenames(dir string) ([]string, error) { goFiles := []string{} pkgs, err := loadPackage(dir) @@ -41,28 +39,41 @@ func getDependenciesFilenames(dir string) ([]string, error) { return funk.UniqString(goFiles), nil } -func GetDependenciesAstFiles(filename string) ([]*ast.File, error) { +func GetDependenciesAstFiles(filename string) ([]PackageFiles, error) { pkgs, err := loadPackageWithSyntax(path.Dir(filename)) if err != nil { return nil, err } - astFiles := []*ast.File{} + pfs := []PackageFiles{} done := map[string]bool{} for _, pkg := range pkgs { if _, ok := done[pkg.PkgPath]; ok { continue } - astFiles = append(astFiles, pkg.Syntax...) + + pfs = append(pfs, PackageFiles{ + PackagePath: pkg.PkgPath, + PackageName: pkg.Name, + AstFiles: pkg.Syntax, + }) + done[pkg.PkgPath] = true + for _, childPack := range pkg.Imports { if _, ok := done[childPack.PkgPath]; ok { continue } - astFiles = append(astFiles, childPack.Syntax...) + + pfs = append(pfs, PackageFiles{ + PackagePath: childPack.PkgPath, + PackageName: childPack.Name, + AstFiles: childPack.Syntax, + }) + done[childPack.PkgPath] = true } } - return astFiles, nil + return pfs, nil } func goFilesFromPackage(pkg *packages.Package) []string { @@ -71,15 +82,15 @@ func goFilesFromPackage(pkg *packages.Package) []string { return funk.FilterString(files, filterFile) } -func EntryPointPackageName(filename string) (string, error) { +func EntryPointPackageName(filename string) (string, string, error) { pkgs, err := loadPackage(path.Dir(filename)) if err != nil { - return "", err + return "", "", err } for _, pack := range pkgs { - return pack.Name, nil + return pack.Name, pack.PkgPath, nil } - return "", fmt.Errorf("package not found for entry point") + return "", "", fmt.Errorf("package not found for entry point") } func loadPackage(path string) ([]*packages.Package, error) { diff --git a/parser/parser.go b/parser/parser.go index 782a8d9..0163b20 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -28,6 +28,7 @@ type PackageInfo struct { EntryPoint string Dir string PackageName string + PackagePath string Services []*Service @@ -123,7 +124,7 @@ func NewPackageInfo(filename string) (*PackageInfo, error) { return nil, err } - packageName, err := EntryPointPackageName(filename) + packageName, packagePath, err := EntryPointPackageName(filename) if err != nil { return nil, err } @@ -132,6 +133,7 @@ func NewPackageInfo(filename string) (*PackageInfo, error) { EntryPoint: filename, Dir: dir, PackageName: packageName, + PackagePath: packagePath, Services: []*Service{}, Scopes: make(map[string][]*ast.Scope), @@ -145,24 +147,30 @@ func NewPackageInfo(filename string) (*PackageInfo, error) { // ParseFiles parse all files associated with package from original file func (pi *PackageInfo) Parse(filename string) error { - astFiles, err := GetDependenciesAstFiles(filename) + pfs, err := GetDependenciesAstFiles(filename) if err != nil { return err } - for _, astFile := range astFiles { - // collect scopes - pi.collectScopes(astFile) - // get structs for zenrpc - pi.collectServices(astFile) - // get imports - pi.collectImports(astFile) + for _, pkg := range pfs { + for _, astFile := range pkg.AstFiles { + if pkg.PackagePath == pi.PackagePath { + // get structs for zenrpc only for root package + pi.collectServices(astFile) + } + // collect scopes + pi.collectScopes(astFile) + // get imports + pi.collectImports(astFile) + } } // second loop: parse methods. It runs in separate loop because we need all services to be collected for this parsing - for _, f := range astFiles { - if err := pi.parseMethods(f); err != nil { - return err + for _, pkg := range pfs { + for _, f := range pkg.AstFiles { + if err := pi.parseMethods(f); err != nil { + return err + } } }