diff --git a/README.md b/README.md index d97c3e4..9b6c877 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ $ ln -s /usr/local/bin/shellsrv /usr/local/bin/flatpak $ flatpak --version Flatpak 1.12.7 ``` +**Note:** you will want to store the symlink in a location visible only to the container, to avoid an infinite loop. If you are using toolbox/distrobox, this means anywhere outside your home directory. I recommend `/usr/local/bin`. Example of file transfer to server: @@ -101,7 +102,7 @@ Example of file transfer from server: shellsrv cat /server/path/some_file.tar.zst > /client/path/some_file.tar.zst # directory with zstd compression: -shellsrv sh -c "tar -I 'zstd -T0 -1' -c /server/path/some_dir 2>/dev/null"|tar --zstd -xf - -C /client/path/some_dir +shellsrv tar -I 'zstd -T0 -1' -c /server/path/some_dir|tar --zstd -xf - -C /client/path/some_dir +# or dir to archive: +shellsrv tar -I 'zstd -T0 -1' -c /server/path/some_dir > /client/path/some_dir.tar.zst ``` - -**Note:** you will want to store the symlink in a location visible only to the container, to avoid an infinite loop. If you are using toolbox/distrobox, this means anywhere outside your home directory. I recommend `/usr/local/bin`. diff --git a/shellsrv.go b/shellsrv.go index 4e5ec9d..e479aed 100644 --- a/shellsrv.go +++ b/shellsrv.go @@ -14,6 +14,7 @@ import ( "path" "strconv" "strings" + "sync" "syscall" "time" "unsafe" @@ -237,6 +238,7 @@ func ssrv_env_vars_parse() { } func srv_handle(conn net.Conn, self_cpids_dir string) { + var wg sync.WaitGroup disconnect := func(session *yamux.Session, remote string) { session.Close() log.Printf("[%s] [ DISCONNECTED ]", remote) @@ -251,8 +253,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { defer disconnect(session, remote) log.Printf("[%s] [ NEW CONNECTION ]", remote) - done := make(chan struct{}) - envs_channel, err := session.Accept() if err != nil { log.Printf("[%s] environment channel accept error: %v", remote, err) @@ -269,6 +269,7 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { last_env_num := len(envs) - 1 var stdin_channel net.Conn + var stderr_channel net.Conn if envs[last_env_num] == "is_alloc_pty := false" { is_alloc_pty = false envs = envs[:last_env_num] @@ -278,10 +279,21 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { log.Printf("[%s] stdin channel accept error: %v", remote, err) return } + stderr_channel, err = session.Accept() + if err != nil { + log.Printf("[%s] stderr channel accept error: %v", remote, err) + return + } } else { is_alloc_pty = true } + data_channel, err := session.Accept() + if err != nil { + log.Printf("[%s] data channel accept error: %v", remote, err) + return + } + command_channel, err := session.Accept() if err != nil { log.Printf("[%s] command channel accept error: %v", remote, err) @@ -322,7 +334,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { } return } - defer cmd_ptmx.Close() cmd_pid := strconv.Itoa(exec_cmd.Process.Pid) log.Printf("[%s] pid: %s", remote, cmd_pid) @@ -340,6 +351,11 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { } } + cp := func(dst io.Writer, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + } + if is_alloc_pty { control_channel, err := session.Accept() if err != nil { @@ -364,30 +380,16 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { break } } - done <- struct{}{} }() - } - - data_channel, err := session.Accept() - if err != nil { - log.Printf("[%s] data channel accept error: %v", remote, err) - return - } - cp := func(dst io.Writer, src io.Reader) { - io.Copy(dst, src) - done <- struct{}{} - } - - if is_alloc_pty { + wg.Add(2) go cp(data_channel, cmd_ptmx) go cp(cmd_ptmx, data_channel) } else { + wg.Add(2) go cp(data_channel, cmd_stdout) - go cp(data_channel, cmd_stderr) + go cp(stderr_channel, cmd_stderr) } - <-done - state, err := exec_cmd.Process.Wait() if err != nil { log.Printf("[%s] error getting exit code: %v", remote, err) @@ -399,8 +401,14 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { _, err = command_channel.Write([]byte(fmt.Sprint(exit_code + "\r\n"))) if err != nil { log.Printf("[%s] error sending exit code: %v", remote, err) + return } + if is_alloc_pty { + session.Close() + } + + wg.Wait() os.Remove(cpid) } @@ -467,6 +475,7 @@ func server(proto, socket string) { } func client(proto, socket string, exec_args []string) int { + var wg sync.WaitGroup if len(exec_args) != 0 { is_alloc_pty = !pty_blocklist[exec_args[0]] } @@ -487,40 +496,37 @@ func client(proto, socket string, exec_args []string) int { } defer session.Close() - is_stdin_piped := false - stdin_stat, err := os.Stdin.Stat() - if err != nil { - log.Fatalf("unable to stat stdin: %v", err) - } - if (stdin_stat.Mode() & os.ModeCharDevice) == 0 { - is_stdin_piped = true + stdin := int(os.Stdin.Fd()) + is_stdin_term := false + if term.IsTerminal(stdin) { + is_stdin_term = true } - stdout_stat, err := os.Stdout.Stat() - if err != nil { - log.Fatalf("unable to stat stdout: %v", err) + stdout := int(os.Stdout.Fd()) + is_stdout_term := false + if term.IsTerminal(stdout) { + is_stdout_term = true } - is_stdout_piped := false - if (stdout_stat.Mode() & os.ModeCharDevice) == 0 { - is_stdout_piped = true + + stderr := int(os.Stderr.Fd()) + is_stderr_term := false + if term.IsTerminal(stderr) { + is_stderr_term = true } - var old_state *term.State - stdin := int(os.Stdin.Fd()) - if term.IsTerminal(stdin) && !is_stdout_piped && !is_stdin_piped { - if is_alloc_pty { - old_state, err = term.MakeRaw(stdin) + var term_old_state *term.State + if (is_stdin_term && is_stderr_term && is_stdout_term) || (*is_pty && is_stdin_term) { + if is_alloc_pty && is_stdin_term { + term_old_state, err = term.MakeRaw(stdin) if err != nil { log.Fatalf("unable to make terminal raw: %v", err) } - defer term.Restore(stdin, old_state) + defer term.Restore(stdin, term_old_state) } } else { is_alloc_pty = false } - done := make(chan struct{}) - env_vars_pass := strings.Split(*env_vars, ",") var envs string for _, env := range env_vars_pass { @@ -542,11 +548,21 @@ func client(proto, socket string, exec_args []string) int { } var stdin_channel net.Conn + var stderr_channel net.Conn if !is_alloc_pty { stdin_channel, err = session.Open() if err != nil { log.Fatalf("stdin channel open error: %v", err) } + stderr_channel, err = session.Open() + if err != nil { + log.Fatalf("stderr channel open error: %v", err) + } + } + + data_channel, err := session.Open() + if err != nil { + log.Fatalf("data channel open error: %v", err) } command_channel, err := session.Open() @@ -559,6 +575,16 @@ func client(proto, socket string, exec_args []string) int { log.Fatalf("failed to send command: %v", err) } + pipe_stdin := func(dst io.Writer, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + stdin_channel.Close() + } + cp := func(dst io.Writer, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + } + if is_alloc_pty { control_channel, err := session.Open() if err != nil { @@ -582,40 +608,21 @@ func client(proto, socket string, exec_args []string) int { } <-sig } - done <- struct{}{} }() } - data_channel, err := session.Open() - if err != nil { - log.Fatalf("data channel open error: %v", err) - } - cp := func(dst io.Writer, src io.Reader) { - io.Copy(dst, src) - done <- struct{}{} - } - if is_stdin_piped { - go func() { - reader := bufio.NewReader(os.Stdin) - buffer := make([]byte, 1024) - for { - stdin_seek, err := reader.Read(buffer) - if err != nil && err != io.EOF { - log.Fatalf("unable to read stdin: %v", err) - } - if err == io.EOF { - break - } - _, err = stdin_channel.Write(buffer[:stdin_seek]) - if err != nil { - log.Fatalf("failed to send stdin data: %v", err) - } - } - stdin_channel.Close() - }() + if !is_stdin_term { + wg.Add(1) + go pipe_stdin(stdin_channel, os.Stdin) } else { + wg.Add(1) go cp(data_channel, os.Stdin) } + if !is_alloc_pty { + wg.Add(1) + go cp(os.Stderr, stderr_channel) + } + wg.Add(1) go cp(os.Stdout, data_channel) var exit_code = 1 @@ -632,7 +639,18 @@ func client(proto, socket string, exec_args []string) int { log.Printf("error reading from command channel: %v", err) } - <-done + if term_old_state != nil { + term.Restore(stdin, term_old_state) + wg.Done() + } + if is_stdin_term && ((!*is_pty && !*is_no_pty) || + (*is_no_pty && (!is_stdout_term || !is_stderr_term)) || *is_no_pty) { + if !is_stderr_term || !is_alloc_pty { + wg.Done() + } + } + + wg.Wait() return exit_code } diff --git a/tls/README.md b/tls/README.md index 2a5f6ea..70a5ec3 100644 --- a/tls/README.md +++ b/tls/README.md @@ -95,6 +95,7 @@ $ ln -s /usr/local/bin/shellsrv /usr/local/bin/flatpak $ flatpak --version Flatpak 1.12.7 ``` +**Note:** you will want to store the symlink in a location visible only to the container, to avoid an infinite loop. If you are using toolbox/distrobox, this means anywhere outside your home directory. I recommend `/usr/local/bin`. Example of file transfer to server: @@ -113,7 +114,7 @@ Example of file transfer from server: shellsrv cat /server/path/some_file.tar.zst > /client/path/some_file.tar.zst # directory with zstd compression: -shellsrv sh -c "tar -I 'zstd -T0 -1' -c /server/path/some_dir 2>/dev/null"|tar --zstd -xf - -C /client/path/some_dir +shellsrv tar -I 'zstd -T0 -1' -c /server/path/some_dir|tar --zstd -xf - -C /client/path/some_dir +# or dir to archive: +shellsrv tar -I 'zstd -T0 -1' -c /server/path/some_dir > /client/path/some_dir.tar.zst ``` - -**Note:** you will want to store the symlink in a location visible only to the container, to avoid an infinite loop. If you are using toolbox/distrobox, this means anywhere outside your home directory. I recommend `/usr/local/bin`. diff --git a/tls/shellsrv.go b/tls/shellsrv.go index 6e6d6e8..58f8f75 100644 --- a/tls/shellsrv.go +++ b/tls/shellsrv.go @@ -15,6 +15,7 @@ import ( "path" "strconv" "strings" + "sync" "syscall" "time" "unsafe" @@ -264,6 +265,7 @@ func ssrv_env_vars_parse() { } func srv_handle(conn net.Conn, self_cpids_dir string) { + var wg sync.WaitGroup disconnect := func(session *yamux.Session, remote string) { session.Close() log.Printf("[%s] [ DISCONNECTED ]", remote) @@ -278,8 +280,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { defer disconnect(session, remote) log.Printf("[%s] [ NEW CONNECTION ]", remote) - done := make(chan struct{}) - envs_channel, err := session.Accept() if err != nil { log.Printf("[%s] environment channel accept error: %v", remote, err) @@ -296,6 +296,7 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { last_env_num := len(envs) - 1 var stdin_channel net.Conn + var stderr_channel net.Conn if envs[last_env_num] == "is_alloc_pty := false" { is_alloc_pty = false envs = envs[:last_env_num] @@ -305,10 +306,21 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { log.Printf("[%s] stdin channel accept error: %v", remote, err) return } + stderr_channel, err = session.Accept() + if err != nil { + log.Printf("[%s] stderr channel accept error: %v", remote, err) + return + } } else { is_alloc_pty = true } + data_channel, err := session.Accept() + if err != nil { + log.Printf("[%s] data channel accept error: %v", remote, err) + return + } + command_channel, err := session.Accept() if err != nil { log.Printf("[%s] command channel accept error: %v", remote, err) @@ -349,7 +361,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { } return } - defer cmd_ptmx.Close() cmd_pid := strconv.Itoa(exec_cmd.Process.Pid) log.Printf("[%s] pid: %s", remote, cmd_pid) @@ -367,6 +378,11 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { } } + cp := func(dst io.Writer, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + } + if is_alloc_pty { control_channel, err := session.Accept() if err != nil { @@ -391,30 +407,16 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { break } } - done <- struct{}{} }() - } - - data_channel, err := session.Accept() - if err != nil { - log.Printf("[%s] data channel accept error: %v", remote, err) - return - } - cp := func(dst io.Writer, src io.Reader) { - io.Copy(dst, src) - done <- struct{}{} - } - - if is_alloc_pty { + wg.Add(2) go cp(data_channel, cmd_ptmx) go cp(cmd_ptmx, data_channel) } else { + wg.Add(2) go cp(data_channel, cmd_stdout) - go cp(data_channel, cmd_stderr) + go cp(stderr_channel, cmd_stderr) } - <-done - state, err := exec_cmd.Process.Wait() if err != nil { log.Printf("[%s] error getting exit code: %v", remote, err) @@ -426,8 +428,14 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { _, err = command_channel.Write([]byte(fmt.Sprint(exit_code + "\r\n"))) if err != nil { log.Printf("[%s] error sending exit code: %v", remote, err) + return } + if is_alloc_pty { + session.Close() + } + + wg.Wait() os.Remove(cpid) } @@ -513,6 +521,7 @@ func server(proto, socket string) { } func client(proto, socket string, exec_args []string) int { + var wg sync.WaitGroup var err error if len(exec_args) != 0 { @@ -548,40 +557,37 @@ func client(proto, socket string, exec_args []string) int { } defer session.Close() - is_stdin_piped := false - stdin_stat, err := os.Stdin.Stat() - if err != nil { - log.Fatalf("unable to stat stdin: %v", err) - } - if (stdin_stat.Mode() & os.ModeCharDevice) == 0 { - is_stdin_piped = true + stdin := int(os.Stdin.Fd()) + is_stdin_term := false + if term.IsTerminal(stdin) { + is_stdin_term = true } - stdout_stat, err := os.Stdout.Stat() - if err != nil { - log.Fatalf("unable to stat stdout: %v", err) + stdout := int(os.Stdout.Fd()) + is_stdout_term := false + if term.IsTerminal(stdout) { + is_stdout_term = true } - is_stdout_piped := false - if (stdout_stat.Mode() & os.ModeCharDevice) == 0 { - is_stdout_piped = true + + stderr := int(os.Stderr.Fd()) + is_stderr_term := false + if term.IsTerminal(stderr) { + is_stderr_term = true } - var old_state *term.State - stdin := int(os.Stdin.Fd()) - if term.IsTerminal(stdin) && !is_stdout_piped && !is_stdin_piped { - if is_alloc_pty { - old_state, err = term.MakeRaw(stdin) + var term_old_state *term.State + if (is_stdin_term && is_stderr_term && is_stdout_term) || (*is_pty && is_stdin_term) { + if is_alloc_pty && is_stdin_term { + term_old_state, err = term.MakeRaw(stdin) if err != nil { log.Fatalf("unable to make terminal raw: %v", err) } - defer term.Restore(stdin, old_state) + defer term.Restore(stdin, term_old_state) } } else { is_alloc_pty = false } - done := make(chan struct{}) - env_vars_pass := strings.Split(*env_vars, ",") var envs string for _, env := range env_vars_pass { @@ -603,11 +609,21 @@ func client(proto, socket string, exec_args []string) int { } var stdin_channel net.Conn + var stderr_channel net.Conn if !is_alloc_pty { stdin_channel, err = session.Open() if err != nil { log.Fatalf("stdin channel open error: %v", err) } + stderr_channel, err = session.Open() + if err != nil { + log.Fatalf("stderr channel open error: %v", err) + } + } + + data_channel, err := session.Open() + if err != nil { + log.Fatalf("data channel open error: %v", err) } command_channel, err := session.Open() @@ -620,6 +636,16 @@ func client(proto, socket string, exec_args []string) int { log.Fatalf("failed to send command: %v", err) } + pipe_stdin := func(dst io.Writer, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + stdin_channel.Close() + } + cp := func(dst io.Writer, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + } + if is_alloc_pty { control_channel, err := session.Open() if err != nil { @@ -643,40 +669,21 @@ func client(proto, socket string, exec_args []string) int { } <-sig } - done <- struct{}{} }() } - data_channel, err := session.Open() - if err != nil { - log.Fatalf("data channel open error: %v", err) - } - cp := func(dst io.Writer, src io.Reader) { - io.Copy(dst, src) - done <- struct{}{} - } - if is_stdin_piped { - go func() { - reader := bufio.NewReader(os.Stdin) - buffer := make([]byte, 1024) - for { - stdin_seek, err := reader.Read(buffer) - if err != nil && err != io.EOF { - log.Fatalf("unable to read stdin: %v", err) - } - if err == io.EOF { - break - } - _, err = stdin_channel.Write(buffer[:stdin_seek]) - if err != nil { - log.Fatalf("failed to send stdin data: %v", err) - } - } - stdin_channel.Close() - }() + if !is_stdin_term { + wg.Add(1) + go pipe_stdin(stdin_channel, os.Stdin) } else { + wg.Add(1) go cp(data_channel, os.Stdin) } + if !is_alloc_pty { + wg.Add(1) + go cp(os.Stderr, stderr_channel) + } + wg.Add(1) go cp(os.Stdout, data_channel) var exit_code = 1 @@ -693,7 +700,18 @@ func client(proto, socket string, exec_args []string) int { log.Printf("error reading from command channel: %v", err) } - <-done + if term_old_state != nil { + term.Restore(stdin, term_old_state) + wg.Done() + } + if is_stdin_term && ((!*is_pty && !*is_no_pty) || + (*is_no_pty && (!is_stdout_term || !is_stderr_term)) || *is_no_pty) { + if !is_stderr_term || !is_alloc_pty { + wg.Done() + } + } + + wg.Wait() return exit_code }