From 1cef0b813aa481fe2180dcd7622cff91635f6de0 Mon Sep 17 00:00:00 2001 From: Eric Date: Tue, 12 Mar 2024 17:46:17 +0800 Subject: [PATCH 1/2] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E5=88=9B?= =?UTF-8?q?=E5=BB=BA=20connecttoken=20=E9=80=BB=E8=BE=91=EF=BC=8C=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E6=89=8B=E5=8A=A8=E8=B4=A6=E5=8F=B7=E7=9A=84=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/handler/asset.go | 28 +++++++++++++++++++++++++ pkg/handler/direct_handler.go | 9 ++++++++ pkg/jms-sdk-go/model/account.go | 4 ++++ pkg/proxy/server.go | 37 +-------------------------------- 4 files changed, 42 insertions(+), 36 deletions(-) diff --git a/pkg/handler/asset.go b/pkg/handler/asset.go index 70f1d154..c400ca48 100644 --- a/pkg/handler/asset.go +++ b/pkg/handler/asset.go @@ -3,6 +3,7 @@ package handler import ( "errors" "fmt" + "io" "strconv" "strings" @@ -13,6 +14,7 @@ import ( "github.com/jumpserver/koko/pkg/proxy" "github.com/jumpserver/koko/pkg/srvconn" "github.com/jumpserver/koko/pkg/utils" + "golang.org/x/term" ) func (u *UserSelectHandler) retrieveRemoteAsset(reqParam model.PaginationParam) []model.PermAsset { @@ -117,6 +119,23 @@ func (u *UserSelectHandler) displayAssets(searchHeader string) { u.displayResult(searchHeader, labels, fields, fieldsSize, data) } +func GetInputUsername(sess io.ReadWriteCloser) (username string, err error) { + vt := term.NewTerminal(sess, "username: ") + count := 0 + for count < 3 { + username, err = vt.ReadLine() + if err != nil { + return "", err + } + username = strings.TrimSpace(username) + if username != "" { + return username, nil + } + count++ + } + return "", errors.New("input username exceed max retry") +} + func (u *UserSelectHandler) proxyAsset(asset model.PermAsset) { u.selectedAsset = &asset permAssetDetail, err := u.h.jmsService.GetUserPermAssetDetailById(u.user.ID, asset.ID) @@ -176,6 +195,15 @@ func (u *UserSelectHandler) proxyAsset(asset model.PermAsset) { ConnectMethod: "ssh", RemoteAddr: u.h.sess.RemoteAddr(), } + if selectedAccount.IsInputUser() { + inputUsername, err1 := GetInputUsername(u.h.sess) + if err1 != nil { + logger.Errorf("Get input username err: %s", err1) + return + } + req.InputUsername = inputUsername + } + tokenInfo, err := u.h.jmsService.CreateSuperConnectToken(&req) if err != nil { if tokenInfo.Code == "" { diff --git a/pkg/handler/direct_handler.go b/pkg/handler/direct_handler.go index 8f508b5b..3d66b64f 100644 --- a/pkg/handler/direct_handler.go +++ b/pkg/handler/direct_handler.go @@ -392,6 +392,15 @@ func (d *DirectHandler) Proxy(asset model.PermAsset) { ConnectMethod: model.ProtocolSSH, RemoteAddr: d.wrapperSess.RemoteAddr(), } + if selectAccount.IsInputUser() { + inputUsername, err1 := GetInputUsername(d.wrapperSess) + if err1 != nil { + logger.Errorf("Get input username err: %s", err1) + return + } + req.InputUsername = inputUsername + } + tokenInfo, err := d.jmsService.CreateSuperConnectToken(&req) if err != nil { if tokenInfo.Code == "" { diff --git a/pkg/jms-sdk-go/model/account.go b/pkg/jms-sdk-go/model/account.go index 513afcb3..0e9e893e 100644 --- a/pkg/jms-sdk-go/model/account.go +++ b/pkg/jms-sdk-go/model/account.go @@ -78,6 +78,10 @@ func (a *PermAccount) IsAnonymous() bool { return a.Username == ANONUser } +func (a *PermAccount) IsInputUser() bool { + return a.Username == InputUser +} + const ( InputUser = "@INPUT" DynamicUser = "@USER" diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 8c2a2ece..7423f387 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -307,36 +307,11 @@ func (s *Server) GenerateCommandItem(user, input, output string, item *ExecutedC } } -func (s *Server) getUsernameIfNeed() (err error) { - if s.account.Username == "" { - logger.Infof("Conn[%s] need manuel input system user username", s.UserConn.ID()) - var username string - vt := term.NewTerminal(s.UserConn, "username: ") - for { - username, err = vt.ReadLine() - if err != nil { - return err - } - username = strings.TrimSpace(username) - if username != "" { - break - } - } - s.account.Username = username - logger.Infof("Conn[%s] get username from user input: %s", s.UserConn.ID(), username) - } - return -} - func (s *Server) getAuthPasswordIfNeed() (err error) { var line string if s.account.Secret == "" { vt := term.NewTerminal(s.UserConn, "password: ") - if s.account.Username != "" { - line, err = vt.ReadPassword(fmt.Sprintf("%s's password: ", s.account.Username)) - } else { - line, err = vt.ReadPassword("password: ") - } + line, err = vt.ReadPassword(fmt.Sprintf("%s's password: ", s.account.String())) if err != nil { logger.Errorf("Conn[%s] get password from user err: %s", s.UserConn.ID(), err.Error()) @@ -365,11 +340,6 @@ func (s *Server) checkRequiredAuth() error { srvconn.ProtocolMySQL, srvconn.ProtocolMariadb, srvconn.ProtocolSQLServer, srvconn.ProtocolPostgresql: - if err := s.getUsernameIfNeed(); err != nil { - msg := utils.WrapperWarn(lang.T("Get auth username failed")) - utils.IgnoreErrWriteString(s.UserConn, msg) - return fmt.Errorf("get auth username failed: %s", err) - } if err := s.getAuthPasswordIfNeed(); err != nil { msg := utils.WrapperWarn(lang.T("Get auth password failed")) utils.IgnoreErrWriteString(s.UserConn, msg) @@ -382,11 +352,6 @@ func (s *Server) checkRequiredAuth() error { return fmt.Errorf("get auth password failed: %s", err) } case srvconn.ProtocolSSH: - if err := s.getUsernameIfNeed(); err != nil { - msg := utils.WrapperWarn(lang.T("Get auth username failed")) - utils.IgnoreErrWriteString(s.UserConn, msg) - return err - } if s.checkReuseSSHClient() { if cacheConn, ok := s.getCacheSSHConn(); ok { s.cacheSSHConnection = cacheConn From 7c0d7252efc35371df6a805d2a8557ece4071fd1 Mon Sep 17 00:00:00 2001 From: Eric Date: Tue, 12 Mar 2024 17:53:52 +0800 Subject: [PATCH 2/2] =?UTF-8?q?perf:=20=E7=AE=80=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/proxy/server.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 7423f387..25078d14 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -339,13 +339,8 @@ func (s *Server) checkRequiredAuth() error { srvconn.ProtocolMongoDB, srvconn.ProtocolMySQL, srvconn.ProtocolMariadb, - srvconn.ProtocolSQLServer, srvconn.ProtocolPostgresql: - if err := s.getAuthPasswordIfNeed(); err != nil { - msg := utils.WrapperWarn(lang.T("Get auth password failed")) - utils.IgnoreErrWriteString(s.UserConn, msg) - return fmt.Errorf("get auth password failed: %s", err) - } - case srvconn.ProtocolRedis: + srvconn.ProtocolSQLServer, srvconn.ProtocolPostgresql, + srvconn.ProtocolRedis: if err := s.getAuthPasswordIfNeed(); err != nil { msg := utils.WrapperWarn(lang.T("Get auth password failed")) utils.IgnoreErrWriteString(s.UserConn, msg)