Skip to content

Commit

Permalink
fixed small bug in CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
Chrisyhjiang committed Jun 27, 2024
1 parent f76e458 commit 1b24817
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/pkg/cli/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ func InitFromSamples(ctx context.Context, dir string, names []string) error {
defer tarball.Close()
tarReader := tar.NewReader(tarball)
term.Info("Writing files to disk...")

found := make(map[string]bool)
for _, name := range names {
found[name] = false
}

for {
h, err := tarReader.Next()
if err != nil {
Expand All @@ -73,13 +79,14 @@ func InitFromSamples(ctx context.Context, dir string, names []string) error {
}

for _, name := range names {
// Create a subdirectory for each sample when there is more than one sample requested
subdir := ""
if len(names) > 1 {
subdir = name
}
prefix := fmt.Sprintf("%s-%s/samples/%s/", repo, branch, name)
if base, ok := strings.CutPrefix(h.Name, prefix); ok && len(base) > 0 {
found[name] = true
// Create a subdirectory for each sample when there is more than one sample requested
subdir := ""
if len(names) > 1 {
subdir = name
}
fmt.Println(" -", base)
path := filepath.Join(dir, subdir, base)
if h.FileInfo().IsDir() {
Expand All @@ -94,9 +101,15 @@ func InitFromSamples(ctx context.Context, dir string, names []string) error {
}
}
}

for _, name := range names {
if !found[name] {
return fmt.Errorf("sample not found")
}
}

return nil
}

func createFile(base string, h *tar.Header, tarReader *tar.Reader) error {
// Like os.Create, but with the same mode as the original file (so scripts are executable, etc.)
file, err := os.OpenFile(base, os.O_RDWR|os.O_CREATE|os.O_EXCL, h.FileInfo().Mode())
Expand Down
16 changes: 16 additions & 0 deletions src/pkg/cli/new_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package cli

import (
"context"
"testing"
)

func TestInitFromSamples(t *testing.T) {
err := InitFromSamples(context.Background(), t.TempDir(), []string{"nonexisting"})
if err == nil {
t.Fatal("Expected test to fail")
}
if err.Error() != "sample not found" {
t.Error("Expected 'sample not found' error")
}
}

0 comments on commit 1b24817

Please sign in to comment.