Skip to content

Commit

Permalink
perf: 优化创建 connecttoken 逻辑,获取手动账号的用户名 (#1311)
Browse files Browse the repository at this point in the history
* perf: 优化创建 connecttoken 逻辑,获取手动账号的用户名

* perf: 简化代码

---------

Co-authored-by: Eric <[email protected]>
  • Loading branch information
fit2bot and LeeEirc authored Mar 12, 2024
1 parent fe89922 commit eb5bbd1
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 43 deletions.
28 changes: 28 additions & 0 deletions pkg/handler/asset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package handler
import (
"errors"
"fmt"
"io"
"strconv"
"strings"

Expand All @@ -13,6 +14,7 @@ import (
"github.com/jumpserver/koko/pkg/proxy"
"github.com/jumpserver/koko/pkg/srvconn"
"github.com/jumpserver/koko/pkg/utils"
"golang.org/x/term"
)

func (u *UserSelectHandler) retrieveRemoteAsset(reqParam model.PaginationParam) []model.PermAsset {
Expand Down Expand Up @@ -117,6 +119,23 @@ func (u *UserSelectHandler) displayAssets(searchHeader string) {
u.displayResult(searchHeader, labels, fields, fieldsSize, data)
}

func GetInputUsername(sess io.ReadWriteCloser) (username string, err error) {
vt := term.NewTerminal(sess, "username: ")
count := 0
for count < 3 {
username, err = vt.ReadLine()
if err != nil {
return "", err
}
username = strings.TrimSpace(username)
if username != "" {
return username, nil
}
count++
}
return "", errors.New("input username exceed max retry")
}

func (u *UserSelectHandler) proxyAsset(asset model.PermAsset) {
u.selectedAsset = &asset
permAssetDetail, err := u.h.jmsService.GetUserPermAssetDetailById(u.user.ID, asset.ID)
Expand Down Expand Up @@ -176,6 +195,15 @@ func (u *UserSelectHandler) proxyAsset(asset model.PermAsset) {
ConnectMethod: "ssh",
RemoteAddr: u.h.sess.RemoteAddr(),
}
if selectedAccount.IsInputUser() {
inputUsername, err1 := GetInputUsername(u.h.sess)
if err1 != nil {
logger.Errorf("Get input username err: %s", err1)
return
}
req.InputUsername = inputUsername
}

tokenInfo, err := u.h.jmsService.CreateSuperConnectToken(&req)
if err != nil {
if tokenInfo.Code == "" {
Expand Down
9 changes: 9 additions & 0 deletions pkg/handler/direct_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@ func (d *DirectHandler) Proxy(asset model.PermAsset) {
ConnectMethod: model.ProtocolSSH,
RemoteAddr: d.wrapperSess.RemoteAddr(),
}
if selectAccount.IsInputUser() {
inputUsername, err1 := GetInputUsername(d.wrapperSess)
if err1 != nil {
logger.Errorf("Get input username err: %s", err1)
return
}
req.InputUsername = inputUsername
}

tokenInfo, err := d.jmsService.CreateSuperConnectToken(&req)
if err != nil {
if tokenInfo.Code == "" {
Expand Down
4 changes: 4 additions & 0 deletions pkg/jms-sdk-go/model/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ func (a *PermAccount) IsAnonymous() bool {
return a.Username == ANONUser
}

func (a *PermAccount) IsInputUser() bool {
return a.Username == InputUser
}

const (
InputUser = "@INPUT"
DynamicUser = "@USER"
Expand Down
46 changes: 3 additions & 43 deletions pkg/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,36 +307,11 @@ func (s *Server) GenerateCommandItem(user, input, output string, item *ExecutedC
}
}

func (s *Server) getUsernameIfNeed() (err error) {
if s.account.Username == "" {
logger.Infof("Conn[%s] need manuel input system user username", s.UserConn.ID())
var username string
vt := term.NewTerminal(s.UserConn, "username: ")
for {
username, err = vt.ReadLine()
if err != nil {
return err
}
username = strings.TrimSpace(username)
if username != "" {
break
}
}
s.account.Username = username
logger.Infof("Conn[%s] get username from user input: %s", s.UserConn.ID(), username)
}
return
}

func (s *Server) getAuthPasswordIfNeed() (err error) {
var line string
if s.account.Secret == "" {
vt := term.NewTerminal(s.UserConn, "password: ")
if s.account.Username != "" {
line, err = vt.ReadPassword(fmt.Sprintf("%s's password: ", s.account.Username))
} else {
line, err = vt.ReadPassword("password: ")
}
line, err = vt.ReadPassword(fmt.Sprintf("%s's password: ", s.account.String()))

if err != nil {
logger.Errorf("Conn[%s] get password from user err: %s", s.UserConn.ID(), err.Error())
Expand Down Expand Up @@ -364,29 +339,14 @@ func (s *Server) checkRequiredAuth() error {
srvconn.ProtocolMongoDB,

srvconn.ProtocolMySQL, srvconn.ProtocolMariadb,
srvconn.ProtocolSQLServer, srvconn.ProtocolPostgresql:
if err := s.getUsernameIfNeed(); err != nil {
msg := utils.WrapperWarn(lang.T("Get auth username failed"))
utils.IgnoreErrWriteString(s.UserConn, msg)
return fmt.Errorf("get auth username failed: %s", err)
}
if err := s.getAuthPasswordIfNeed(); err != nil {
msg := utils.WrapperWarn(lang.T("Get auth password failed"))
utils.IgnoreErrWriteString(s.UserConn, msg)
return fmt.Errorf("get auth password failed: %s", err)
}
case srvconn.ProtocolRedis:
srvconn.ProtocolSQLServer, srvconn.ProtocolPostgresql,
srvconn.ProtocolRedis:
if err := s.getAuthPasswordIfNeed(); err != nil {
msg := utils.WrapperWarn(lang.T("Get auth password failed"))
utils.IgnoreErrWriteString(s.UserConn, msg)
return fmt.Errorf("get auth password failed: %s", err)
}
case srvconn.ProtocolSSH:
if err := s.getUsernameIfNeed(); err != nil {
msg := utils.WrapperWarn(lang.T("Get auth username failed"))
utils.IgnoreErrWriteString(s.UserConn, msg)
return err
}
if s.checkReuseSSHClient() {
if cacheConn, ok := s.getCacheSSHConn(); ok {
s.cacheSSHConnection = cacheConn
Expand Down

0 comments on commit eb5bbd1

Please sign in to comment.