Skip to content

Commit

Permalink
Fix incorrect services load (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
dizzyfool committed Mar 23, 2021
1 parent ad40214 commit ee9c359
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 30 deletions.
47 changes: 29 additions & 18 deletions parser/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ import (
"strings"
)

type PackageFiles struct {
PackagePath string
PackageName string

AstFiles []*ast.File
}

func filterFile(filepath string) bool {
if !strings.HasSuffix(filepath, goFileSuffix) ||
strings.HasSuffix(filepath, GenerateFileSuffix) || strings.HasSuffix(filepath, testFileSuffix) {
Expand All @@ -17,15 +24,6 @@ func filterFile(filepath string) bool {
return true
}

func GetDependencies(entryPoint string) ([]string, error) {
dir := path.Dir(entryPoint)
allGoFiles, err := getDependenciesFilenames(dir)
if err != nil {
return nil, err
}
return allGoFiles, nil
}

func getDependenciesFilenames(dir string) ([]string, error) {
goFiles := []string{}
pkgs, err := loadPackage(dir)
Expand All @@ -41,28 +39,41 @@ func getDependenciesFilenames(dir string) ([]string, error) {
return funk.UniqString(goFiles), nil
}

func GetDependenciesAstFiles(filename string) ([]*ast.File, error) {
func GetDependenciesAstFiles(filename string) ([]PackageFiles, error) {
pkgs, err := loadPackageWithSyntax(path.Dir(filename))
if err != nil {
return nil, err
}
astFiles := []*ast.File{}
pfs := []PackageFiles{}
done := map[string]bool{}
for _, pkg := range pkgs {
if _, ok := done[pkg.PkgPath]; ok {
continue
}
astFiles = append(astFiles, pkg.Syntax...)

pfs = append(pfs, PackageFiles{
PackagePath: pkg.PkgPath,
PackageName: pkg.Name,
AstFiles: pkg.Syntax,
})

done[pkg.PkgPath] = true

for _, childPack := range pkg.Imports {
if _, ok := done[childPack.PkgPath]; ok {
continue
}
astFiles = append(astFiles, childPack.Syntax...)

pfs = append(pfs, PackageFiles{
PackagePath: childPack.PkgPath,
PackageName: childPack.Name,
AstFiles: childPack.Syntax,
})

done[childPack.PkgPath] = true
}
}
return astFiles, nil
return pfs, nil
}

func goFilesFromPackage(pkg *packages.Package) []string {
Expand All @@ -71,15 +82,15 @@ func goFilesFromPackage(pkg *packages.Package) []string {
return funk.FilterString(files, filterFile)
}

func EntryPointPackageName(filename string) (string, error) {
func EntryPointPackageName(filename string) (string, string, error) {
pkgs, err := loadPackage(path.Dir(filename))
if err != nil {
return "", err
return "", "", err
}
for _, pack := range pkgs {
return pack.Name, nil
return pack.Name, pack.PkgPath, nil
}
return "", fmt.Errorf("package not found for entry point")
return "", "", fmt.Errorf("package not found for entry point")
}

func loadPackage(path string) ([]*packages.Package, error) {
Expand Down
32 changes: 20 additions & 12 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type PackageInfo struct {
EntryPoint string
Dir string
PackageName string
PackagePath string

Services []*Service

Expand Down Expand Up @@ -123,7 +124,7 @@ func NewPackageInfo(filename string) (*PackageInfo, error) {
return nil, err
}

packageName, err := EntryPointPackageName(filename)
packageName, packagePath, err := EntryPointPackageName(filename)
if err != nil {
return nil, err
}
Expand All @@ -132,6 +133,7 @@ func NewPackageInfo(filename string) (*PackageInfo, error) {
EntryPoint: filename,
Dir: dir,
PackageName: packageName,
PackagePath: packagePath,
Services: []*Service{},

Scopes: make(map[string][]*ast.Scope),
Expand All @@ -145,24 +147,30 @@ func NewPackageInfo(filename string) (*PackageInfo, error) {

// ParseFiles parse all files associated with package from original file
func (pi *PackageInfo) Parse(filename string) error {
astFiles, err := GetDependenciesAstFiles(filename)
pfs, err := GetDependenciesAstFiles(filename)
if err != nil {
return err
}

for _, astFile := range astFiles {
// collect scopes
pi.collectScopes(astFile)
// get structs for zenrpc
pi.collectServices(astFile)
// get imports
pi.collectImports(astFile)
for _, pkg := range pfs {
for _, astFile := range pkg.AstFiles {
if pkg.PackagePath == pi.PackagePath {
// get structs for zenrpc only for root package
pi.collectServices(astFile)
}
// collect scopes
pi.collectScopes(astFile)
// get imports
pi.collectImports(astFile)
}
}

// second loop: parse methods. It runs in separate loop because we need all services to be collected for this parsing
for _, f := range astFiles {
if err := pi.parseMethods(f); err != nil {
return err
for _, pkg := range pfs {
for _, f := range pkg.AstFiles {
if err := pi.parseMethods(f); err != nil {
return err
}
}
}

Expand Down

0 comments on commit ee9c359

Please sign in to comment.