diff --git a/internal/localruntime/socket.go b/internal/localruntime/socket.go index 7bf99cf3..bb04a391 100644 --- a/internal/localruntime/socket.go +++ b/internal/localruntime/socket.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "syscall" ) func DefaultSocketPath() string { @@ -14,5 +15,29 @@ func DefaultSocketPath() string { } func EnsureSocketDir(socketPath string) error { - return os.MkdirAll(filepath.Dir(socketPath), 0o700) + dir := filepath.Dir(socketPath) + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + + info, err := os.Lstat(dir) + if err != nil { + return err + } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("socket directory %q must not be a symlink", dir) + } + if !info.IsDir() { + return fmt.Errorf("socket directory %q is not a directory", dir) + } + + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return fmt.Errorf("socket directory %q owner could not be verified", dir) + } + if int(stat.Uid) != os.Getuid() { + return fmt.Errorf("socket directory %q must be owned by uid %d", dir, os.Getuid()) + } + + return os.Chmod(dir, 0o700) } diff --git a/internal/localruntime/socket_test.go b/internal/localruntime/socket_test.go new file mode 100644 index 00000000..ce306139 --- /dev/null +++ b/internal/localruntime/socket_test.go @@ -0,0 +1,84 @@ +package localruntime + +import ( + "os" + "path/filepath" + "strings" + "syscall" + "testing" +) + +func TestEnsureSocketDirCreatesPrivateDirectory(t *testing.T) { + t.Parallel() + + root := t.TempDir() + socketPath := filepath.Join(root, "guard", "kontext.sock") + if err := EnsureSocketDir(socketPath); err != nil { + t.Fatalf("EnsureSocketDir() error = %v", err) + } + + info, err := os.Stat(filepath.Dir(socketPath)) + if err != nil { + t.Fatalf("Stat() error = %v", err) + } + if got := info.Mode().Perm(); got != 0o700 { + t.Fatalf("socket dir mode = %o, want 700", got) + } + + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + t.Fatal("socket dir stat missing uid") + } + if got := int(stat.Uid); got != os.Getuid() { + t.Fatalf("socket dir owner uid = %d, want %d", got, os.Getuid()) + } +} + +func TestEnsureSocketDirTightensExistingDirectoryPermissions(t *testing.T) { + t.Parallel() + + root := t.TempDir() + dir := filepath.Join(root, "guard") + if err := os.MkdirAll(dir, 0o777); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.Chmod(dir, 0o777); err != nil { + t.Fatalf("Chmod() setup error = %v", err) + } + + socketPath := filepath.Join(dir, "kontext.sock") + if err := EnsureSocketDir(socketPath); err != nil { + t.Fatalf("EnsureSocketDir() error = %v", err) + } + + info, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat() error = %v", err) + } + if got := info.Mode().Perm(); got != 0o700 { + t.Fatalf("socket dir mode = %o, want 700", got) + } +} + +func TestEnsureSocketDirRejectsSymlinkDirectory(t *testing.T) { + t.Parallel() + + root := t.TempDir() + target := filepath.Join(root, "target") + if err := os.MkdirAll(target, 0o700); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + + link := filepath.Join(root, "guard") + if err := os.Symlink(target, link); err != nil { + t.Fatalf("Symlink() error = %v", err) + } + + err := EnsureSocketDir(filepath.Join(link, "kontext.sock")) + if err == nil { + t.Fatal("EnsureSocketDir() error = nil, want symlink rejection") + } + if !strings.Contains(err.Error(), "must not be a symlink") { + t.Fatalf("EnsureSocketDir() error = %v, want symlink rejection", err) + } +}