Skip to content

Commit

Permalink
v0.1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
VHSgunzo committed Apr 13, 2024
1 parent 06189fa commit bc23b40
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 146 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ $ ln -s /usr/local/bin/shellsrv /usr/local/bin/flatpak
$ flatpak --version
Flatpak 1.12.7
```
**Note:** you will want to store the symlink in a location visible only to the container, to avoid an infinite loop. If you are using toolbox/distrobox, this means anywhere outside your home directory. I recommend `/usr/local/bin`.

Example of file transfer to server:

Expand All @@ -101,7 +102,7 @@ Example of file transfer from server:
shellsrv cat /server/path/some_file.tar.zst > /client/path/some_file.tar.zst
# directory with zstd compression:
shellsrv sh -c "tar -I 'zstd -T0 -1' -c /server/path/some_dir 2>/dev/null"|tar --zstd -xf - -C /client/path/some_dir
shellsrv tar -I 'zstd -T0 -1' -c /server/path/some_dir|tar --zstd -xf - -C /client/path/some_dir
# or dir to archive:
shellsrv tar -I 'zstd -T0 -1' -c /server/path/some_dir > /client/path/some_dir.tar.zst
```

**Note:** you will want to store the symlink in a location visible only to the container, to avoid an infinite loop. If you are using toolbox/distrobox, this means anywhere outside your home directory. I recommend `/usr/local/bin`.
158 changes: 88 additions & 70 deletions shellsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"path"
"strconv"
"strings"
"sync"
"syscall"
"time"
"unsafe"
Expand Down Expand Up @@ -237,6 +238,7 @@ func ssrv_env_vars_parse() {
}

func srv_handle(conn net.Conn, self_cpids_dir string) {
var wg sync.WaitGroup
disconnect := func(session *yamux.Session, remote string) {
session.Close()
log.Printf("[%s] [ DISCONNECTED ]", remote)
Expand All @@ -251,8 +253,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) {
defer disconnect(session, remote)
log.Printf("[%s] [ NEW CONNECTION ]", remote)

done := make(chan struct{})

envs_channel, err := session.Accept()
if err != nil {
log.Printf("[%s] environment channel accept error: %v", remote, err)
Expand All @@ -269,6 +269,7 @@ func srv_handle(conn net.Conn, self_cpids_dir string) {
last_env_num := len(envs) - 1

var stdin_channel net.Conn
var stderr_channel net.Conn
if envs[last_env_num] == "is_alloc_pty := false" {
is_alloc_pty = false
envs = envs[:last_env_num]
Expand All @@ -278,10 +279,21 @@ func srv_handle(conn net.Conn, self_cpids_dir string) {
log.Printf("[%s] stdin channel accept error: %v", remote, err)
return
}
stderr_channel, err = session.Accept()
if err != nil {
log.Printf("[%s] stderr channel accept error: %v", remote, err)
return
}
} else {
is_alloc_pty = true
}

data_channel, err := session.Accept()
if err != nil {
log.Printf("[%s] data channel accept error: %v", remote, err)
return
}

command_channel, err := session.Accept()
if err != nil {
log.Printf("[%s] command channel accept error: %v", remote, err)
Expand Down Expand Up @@ -322,7 +334,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) {
}
return
}
defer cmd_ptmx.Close()

cmd_pid := strconv.Itoa(exec_cmd.Process.Pid)
log.Printf("[%s] pid: %s", remote, cmd_pid)
Expand All @@ -340,6 +351,11 @@ func srv_handle(conn net.Conn, self_cpids_dir string) {
}
}

cp := func(dst io.Writer, src io.Reader) {
defer wg.Done()
io.Copy(dst, src)
}

if is_alloc_pty {
control_channel, err := session.Accept()
if err != nil {
Expand All @@ -364,30 +380,16 @@ func srv_handle(conn net.Conn, self_cpids_dir string) {
break
}
}
done <- struct{}{}
}()
}

data_channel, err := session.Accept()
if err != nil {
log.Printf("[%s] data channel accept error: %v", remote, err)
return
}
cp := func(dst io.Writer, src io.Reader) {
io.Copy(dst, src)
done <- struct{}{}
}

if is_alloc_pty {
wg.Add(2)
go cp(data_channel, cmd_ptmx)
go cp(cmd_ptmx, data_channel)
} else {
wg.Add(2)
go cp(data_channel, cmd_stdout)
go cp(data_channel, cmd_stderr)
go cp(stderr_channel, cmd_stderr)
}

<-done

state, err := exec_cmd.Process.Wait()
if err != nil {
log.Printf("[%s] error getting exit code: %v", remote, err)
Expand All @@ -399,8 +401,14 @@ func srv_handle(conn net.Conn, self_cpids_dir string) {
_, err = command_channel.Write([]byte(fmt.Sprint(exit_code + "\r\n")))
if err != nil {
log.Printf("[%s] error sending exit code: %v", remote, err)
return
}

if is_alloc_pty {
session.Close()
}

wg.Wait()
os.Remove(cpid)
}

Expand Down Expand Up @@ -467,6 +475,7 @@ func server(proto, socket string) {
}

func client(proto, socket string, exec_args []string) int {
var wg sync.WaitGroup
if len(exec_args) != 0 {
is_alloc_pty = !pty_blocklist[exec_args[0]]
}
Expand All @@ -487,40 +496,37 @@ func client(proto, socket string, exec_args []string) int {
}
defer session.Close()

is_stdin_piped := false
stdin_stat, err := os.Stdin.Stat()
if err != nil {
log.Fatalf("unable to stat stdin: %v", err)
}
if (stdin_stat.Mode() & os.ModeCharDevice) == 0 {
is_stdin_piped = true
stdin := int(os.Stdin.Fd())
is_stdin_term := false
if term.IsTerminal(stdin) {
is_stdin_term = true
}

stdout_stat, err := os.Stdout.Stat()
if err != nil {
log.Fatalf("unable to stat stdout: %v", err)
stdout := int(os.Stdout.Fd())
is_stdout_term := false
if term.IsTerminal(stdout) {
is_stdout_term = true
}
is_stdout_piped := false
if (stdout_stat.Mode() & os.ModeCharDevice) == 0 {
is_stdout_piped = true

stderr := int(os.Stderr.Fd())
is_stderr_term := false
if term.IsTerminal(stderr) {
is_stderr_term = true
}

var old_state *term.State
stdin := int(os.Stdin.Fd())
if term.IsTerminal(stdin) && !is_stdout_piped && !is_stdin_piped {
if is_alloc_pty {
old_state, err = term.MakeRaw(stdin)
var term_old_state *term.State
if (is_stdin_term && is_stderr_term && is_stdout_term) || (*is_pty && is_stdin_term) {
if is_alloc_pty && is_stdin_term {
term_old_state, err = term.MakeRaw(stdin)
if err != nil {
log.Fatalf("unable to make terminal raw: %v", err)
}
defer term.Restore(stdin, old_state)
defer term.Restore(stdin, term_old_state)
}
} else {
is_alloc_pty = false
}

done := make(chan struct{})

env_vars_pass := strings.Split(*env_vars, ",")
var envs string
for _, env := range env_vars_pass {
Expand All @@ -542,11 +548,21 @@ func client(proto, socket string, exec_args []string) int {
}

var stdin_channel net.Conn
var stderr_channel net.Conn
if !is_alloc_pty {
stdin_channel, err = session.Open()
if err != nil {
log.Fatalf("stdin channel open error: %v", err)
}
stderr_channel, err = session.Open()
if err != nil {
log.Fatalf("stderr channel open error: %v", err)
}
}

data_channel, err := session.Open()
if err != nil {
log.Fatalf("data channel open error: %v", err)
}

command_channel, err := session.Open()
Expand All @@ -559,6 +575,16 @@ func client(proto, socket string, exec_args []string) int {
log.Fatalf("failed to send command: %v", err)
}

pipe_stdin := func(dst io.Writer, src io.Reader) {
defer wg.Done()
io.Copy(dst, src)
stdin_channel.Close()
}
cp := func(dst io.Writer, src io.Reader) {
defer wg.Done()
io.Copy(dst, src)
}

if is_alloc_pty {
control_channel, err := session.Open()
if err != nil {
Expand All @@ -582,40 +608,21 @@ func client(proto, socket string, exec_args []string) int {
}
<-sig
}
done <- struct{}{}
}()
}

data_channel, err := session.Open()
if err != nil {
log.Fatalf("data channel open error: %v", err)
}
cp := func(dst io.Writer, src io.Reader) {
io.Copy(dst, src)
done <- struct{}{}
}
if is_stdin_piped {
go func() {
reader := bufio.NewReader(os.Stdin)
buffer := make([]byte, 1024)
for {
stdin_seek, err := reader.Read(buffer)
if err != nil && err != io.EOF {
log.Fatalf("unable to read stdin: %v", err)
}
if err == io.EOF {
break
}
_, err = stdin_channel.Write(buffer[:stdin_seek])
if err != nil {
log.Fatalf("failed to send stdin data: %v", err)
}
}
stdin_channel.Close()
}()
if !is_stdin_term {
wg.Add(1)
go pipe_stdin(stdin_channel, os.Stdin)
} else {
wg.Add(1)
go cp(data_channel, os.Stdin)
}
if !is_alloc_pty {
wg.Add(1)
go cp(os.Stderr, stderr_channel)
}
wg.Add(1)
go cp(os.Stdout, data_channel)

var exit_code = 1
Expand All @@ -632,7 +639,18 @@ func client(proto, socket string, exec_args []string) int {
log.Printf("error reading from command channel: %v", err)
}

<-done
if term_old_state != nil {
term.Restore(stdin, term_old_state)
wg.Done()
}
if is_stdin_term && ((!*is_pty && !*is_no_pty) ||
(*is_no_pty && (!is_stdout_term || !is_stderr_term)) || *is_no_pty) {
if !is_stderr_term || !is_alloc_pty {
wg.Done()
}
}

wg.Wait()
return exit_code
}

Expand Down
7 changes: 4 additions & 3 deletions tls/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ $ ln -s /usr/local/bin/shellsrv /usr/local/bin/flatpak
$ flatpak --version
Flatpak 1.12.7
```
**Note:** you will want to store the symlink in a location visible only to the container, to avoid an infinite loop. If you are using toolbox/distrobox, this means anywhere outside your home directory. I recommend `/usr/local/bin`.

Example of file transfer to server:

Expand All @@ -113,7 +114,7 @@ Example of file transfer from server:
shellsrv cat /server/path/some_file.tar.zst > /client/path/some_file.tar.zst
# directory with zstd compression:
shellsrv sh -c "tar -I 'zstd -T0 -1' -c /server/path/some_dir 2>/dev/null"|tar --zstd -xf - -C /client/path/some_dir
shellsrv tar -I 'zstd -T0 -1' -c /server/path/some_dir|tar --zstd -xf - -C /client/path/some_dir
# or dir to archive:
shellsrv tar -I 'zstd -T0 -1' -c /server/path/some_dir > /client/path/some_dir.tar.zst
```

**Note:** you will want to store the symlink in a location visible only to the container, to avoid an infinite loop. If you are using toolbox/distrobox, this means anywhere outside your home directory. I recommend `/usr/local/bin`.
Loading

0 comments on commit bc23b40

Please sign in to comment.