diff --git a/sockets/unix_socket.go b/sockets/unix_socket.go index e7591e6e..d8f28ee5 100644 --- a/sockets/unix_socket.go +++ b/sockets/unix_socket.go @@ -86,8 +86,18 @@ func WithChmod(mask os.FileMode) SockOption { // this should only be for a short duration, it may affect other processes that // create files/directories during that period. func NewUnixSocketWithOpts(path string, opts ...SockOption) (net.Listener, error) { + // Using syscall.Unlink(), not os.Remove() to prevent deleting the socket if it's in use if err := syscall.Unlink(path); err != nil && !os.IsNotExist(err) { - return nil, err + if err != syscall.EISDIR { + // On Linux, attempting to remove a directory returns syscall.EISDIR, + // in which case we try to remove the directory. MacOS does not return + // this error, so we'll return immediately, see: + // https://github.com/golang/go/blob/6b420169d798c7ebe733487b56ea5c3fa4aab5ce/src/os/file_unix.go#L300-L311 + return nil, err + } + if err := syscall.Rmdir(path); err != nil { + return nil, err + } } // net.Listen does not allow for permissions to be set. As a result, when diff --git a/sockets/unix_socket_test.go b/sockets/unix_socket_test.go index 8957efd3..8d589559 100644 --- a/sockets/unix_socket_test.go +++ b/sockets/unix_socket_test.go @@ -4,8 +4,11 @@ package sockets import ( "fmt" + "io/ioutil" "net" "os" + "path" + "runtime" "syscall" "testing" ) @@ -75,3 +78,48 @@ func TestUnixSocketWithOpts(t *testing.T) { } runTest(t, path, l, echoStr) } + +func TestUnixSocketConflictDirectory(t *testing.T) { + tmpDir, err := ioutil.TempDir("", t.Name()) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + t.Run("conflicting directory", func(t *testing.T) { + if runtime.GOOS == "darwin" { + t.Skip("not supported on macOS") + } + path := path.Join(tmpDir, "test.sock") + + // Create a conflicting directory at the socket location + err = os.MkdirAll(path, 0700) + if err != nil { + t.Fatal(err) + } + + l, err := NewUnixSocketWithOpts(path) + if err != nil { + t.Fatal(err) + } + defer l.Close() + runTest(t, path, l, "hello") + }) + + t.Run("conflicting file", func(t *testing.T) { + // Create a conflicting file at the socket location + path := path.Join(tmpDir, "test2.sock") + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + f.Close() + + l, err := NewUnixSocketWithOpts(path) + if err != nil { + t.Fatal(err) + } + defer l.Close() + runTest(t, path, l, "hello") + }) +}