Skip to content

Commit 254d735

Browse files
committed
Add an option handling UTF-8 decode errors in certain SSH packets
The SSH RFCs define that strings in some SSH packets must be encoded as UTF-8, and AsyncSSH enforces this by raising a ProtocolError when UTF-8 decoding fails. This commit adds a new config option to specify an error handling strategy for the following cases: * Disconnect reason * Debug message * Userauth banner * Channel open failure reason * Channel exit signal reason * SFTP error reason The default is 'strict', whch preserves the existing behavior of raising a ProtocolError, but selecting other options allows the invalid bytes to be removed or replaced, avoiding the exception. Thanks go to GitHub user Le-Syl21 for suggesting this!
1 parent c7bc3aa commit 254d735

File tree

6 files changed

+133
-33
lines changed

6 files changed

+133
-33
lines changed

asyncssh/channel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,10 +1125,12 @@ class SSHClientChannel(SSHChannel, Generic[AnyStr]):
11251125
_read_datatypes = {EXTENDED_DATA_STDERR}
11261126

11271127
def __init__(self, conn: 'SSHClientConnection',
1128-
loop: asyncio.AbstractEventLoop, encoding: Optional[str],
1129-
errors: str, window: int, max_pktsize: int):
1128+
loop: asyncio.AbstractEventLoop, utf8_decode_errors: str,
1129+
encoding: Optional[str], errors: str, window: int,
1130+
max_pktsize: int):
11301131
super().__init__(conn, loop, encoding, errors, window, max_pktsize)
11311132

1133+
self._utf8_decode_errors = utf8_decode_errors
11321134
self._exit_status: Optional[int] = None
11331135
self._exit_signal: Optional[_ExitSignal] = None
11341136

@@ -1299,7 +1301,7 @@ def _process_exit_signal_request(self, packet: SSHPacket) -> bool:
12991301

13001302
try:
13011303
signal = signal_bytes.decode('ascii')
1302-
msg = msg_bytes.decode('utf-8')
1304+
msg = msg_bytes.decode('utf-8', self._utf8_decode_errors)
13031305
lang = lang_bytes.decode('ascii')
13041306
except UnicodeDecodeError:
13051307
raise ProtocolError('Invalid exit signal request') from None

asyncssh/connection.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
877877
self._peer_addr = ''
878878
self._peer_port = 0
879879
self._tcp_keepalive = options.tcp_keepalive
880+
self._utf8_decode_errors = options.utf8_decode_errors
880881
self._owner: Optional[Union[SSHClient, SSHServer]] = None
881882
self._extra: Dict[str, object] = {}
882883

@@ -1042,6 +1043,11 @@ def logger(self) -> SSHLogger:
10421043

10431044
return self._logger
10441045

1046+
def _decode_utf8(self, msg_bytes) -> str:
1047+
"""Decode UTF-8 bytes, honoring utf8_decode_errors setting"""
1048+
1049+
return msg_bytes.decode('utf-8', self._utf8_decode_errors)
1050+
10451051
def _cleanup(self, exc: Optional[Exception]) -> None:
10461052
"""Clean up this connection"""
10471053

@@ -2193,7 +2199,7 @@ def _process_disconnect(self, _pkttype: int, _pktid: int,
21932199
packet.check_end()
21942200

21952201
try:
2196-
reason = reason_bytes.decode('utf-8')
2202+
reason = self._decode_utf8(reason_bytes)
21972203
lang = lang_bytes.decode('ascii')
21982204
except UnicodeDecodeError:
21992205
raise ProtocolError('Invalid disconnect message') from None
@@ -2236,7 +2242,7 @@ def _process_debug(self, _pkttype: int, _pktid: int,
22362242
packet.check_end()
22372243

22382244
try:
2239-
msg = msg_bytes.decode('utf-8')
2245+
msg = self._decode_utf8(msg_bytes)
22402246
lang = lang_bytes.decode('ascii')
22412247
except UnicodeDecodeError:
22422248
raise ProtocolError('Invalid debug message') from None
@@ -2638,7 +2644,7 @@ def _process_userauth_banner(self, _pkttype: int, _pktid: int,
26382644
packet.check_end()
26392645

26402646
try:
2641-
msg = msg_bytes.decode('utf-8')
2647+
msg = self._decode_utf8(msg_bytes)
26422648
lang = lang_bytes.decode('ascii')
26432649
except UnicodeDecodeError:
26442650
raise ProtocolError('Invalid userauth banner') from None
@@ -2755,7 +2761,7 @@ def _process_channel_open_failure(self, _pkttype: int, _pktid: int,
27552761
packet.check_end()
27562762

27572763
try:
2758-
reason = reason_bytes.decode('utf-8')
2764+
reason = self._decode_utf8(reason_bytes)
27592765
lang = lang_bytes.decode('ascii')
27602766
except UnicodeDecodeError:
27612767
raise ProtocolError('Invalid channel open failure') from None
@@ -4373,8 +4379,8 @@ async def create_session(self, session_factory: SSHClientSessionFactory,
43734379
window: int
43744380
max_pktsize: int
43754381

4376-
chan = SSHClientChannel(self, self._loop, encoding, errors,
4377-
window, max_pktsize)
4382+
chan = SSHClientChannel(self, self._loop, self._utf8_decode_errors,
4383+
encoding, errors, window, max_pktsize)
43784384

43794385
session = await chan.create(session_factory, command, subsystem,
43804386
new_env, request_pty, term_type, term_size,
@@ -5745,9 +5751,9 @@ async def start_sftp_client(self, env: DefTuple[Optional[Env]] = (),
57455751
env=env, send_env=send_env,
57465752
encoding=None)
57475753

5748-
return await start_sftp_client(self, self._loop, reader, writer,
5749-
path_encoding, path_errors,
5750-
sftp_version)
5754+
return await start_sftp_client(self, self._loop,
5755+
self._utf8_decode_errors, reader, writer,
5756+
path_encoding, path_errors, sftp_version)
57515757

57525758

57535759
class SSHServerConnection(SSHConnection):
@@ -7278,6 +7284,7 @@ class SSHConnectionOptions(Options, Generic[_Options]):
72787284
family: int
72797285
local_addr: HostPort
72807286
tcp_keepalive: bool
7287+
utf8_decode_errors: str
72817288
canonicalize_hostname: Union[bool, str]
72827289
canonical_domains: Sequence[str]
72837290
canonicalize_fallback_local: bool
@@ -7323,6 +7330,7 @@ def prepare(self, config: SSHConfig, # type: ignore
73237330
passphrase: Optional[BytesOrStr],
73247331
proxy_command: DefTuple[_ProxyCommand], family: DefTuple[int],
73257332
local_addr: DefTuple[HostPort], tcp_keepalive: DefTuple[bool],
7333+
utf8_decode_errors: str,
73267334
canonicalize_hostname: DefTuple[Union[bool, str]],
73277335
canonical_domains: DefTuple[Sequence[str]],
73287336
canonicalize_fallback_local: DefTuple[bool],
@@ -7387,6 +7395,8 @@ def _split_cname_patterns(
73877395
self.tcp_keepalive = cast(bool, tcp_keepalive if tcp_keepalive != ()
73887396
else config.get('TCPKeepAlive', True))
73897397

7398+
self.utf8_decode_errors = utf8_decode_errors
7399+
73907400
self.canonicalize_hostname = \
73917401
cast(Union[bool, str], canonicalize_hostname
73927402
if canonicalize_hostname != ()
@@ -7812,6 +7822,13 @@ class SSHClientConnectionOptions(SSHConnectionOptions):
78127822
:param tcp_keepalive: (optional)
78137823
Whether or not to enable keepalive probes at the TCP level to
78147824
detect broken connections, defaulting to `True`.
7825+
:param utf8_decode_errors: (optional)
7826+
Error handling strategy to apply when UTF-8 decode errors
7827+
occur in SSH protocol messages, defaulting to 'strict'
7828+
which shuts down the connection with a ProtocolError.
7829+
Choosing other strategies can allow the message parsing
7830+
to proceed with invalid bytes in the message being removed
7831+
or replaced.
78157832
:param canonicalize_hostname: (optional)
78167833
Whether or not to enable hostname canonicalization, defaulting
78177834
to `False`, in which case hostnames are passed as-is to the
@@ -7984,6 +8001,7 @@ class SSHClientConnectionOptions(SSHConnectionOptions):
79848001
:type keepalive_interval: *see* :ref:`SpecifyingTimeIntervals`
79858002
:type keepalive_count_max: `int`
79868003
:type tcp_keepalive: `bool`
8004+
:type utf8_decode_errors: `str`
79878005
:type canonicalize_hostname: `bool` or `'always'`
79888006
:type canonical_domains: `list` of `str`
79898007
:type canonicalize_fallback_local: `bool`
@@ -8069,6 +8087,7 @@ def prepare(self, # type: ignore
80698087
family: DefTuple[int] = (),
80708088
local_addr: DefTuple[HostPort] = (),
80718089
tcp_keepalive: DefTuple[bool] = (),
8090+
utf8_decode_errors: str = 'strict',
80728091
canonicalize_hostname: DefTuple[Union[bool, str]] = (),
80738092
canonical_domains: DefTuple[Sequence[str]] = (),
80748093
canonicalize_fallback_local: DefTuple[bool] = (),
@@ -8180,10 +8199,11 @@ def prepare(self, # type: ignore
81808199

81818200
super().prepare(config, client_factory or SSHClient, client_version,
81828201
host, port, tunnel, passphrase, proxy_command, family,
8183-
local_addr, tcp_keepalive, canonicalize_hostname,
8184-
canonical_domains, canonicalize_fallback_local,
8185-
canonicalize_max_dots, canonicalize_permitted_cnames,
8186-
kex_algs, encryption_algs, mac_algs, compression_algs,
8202+
local_addr, tcp_keepalive, utf8_decode_errors,
8203+
canonicalize_hostname, canonical_domains,
8204+
canonicalize_fallback_local, canonicalize_max_dots,
8205+
canonicalize_permitted_cnames, kex_algs,
8206+
encryption_algs, mac_algs, compression_algs,
81878207
signature_algs, host_based_auth, public_key_auth,
81888208
kbdint_auth, password_auth, x509_trusted_certs,
81898209
x509_trusted_cert_paths, x509_purposes, rekey_bytes,
@@ -8636,6 +8656,13 @@ class SSHServerConnectionOptions(SSHConnectionOptions):
86368656
:param tcp_keepalive: (optional)
86378657
Whether or not to enable keepalive probes at the TCP level to
86388658
detect broken connections, defaulting to `True`.
8659+
:param utf8_decode_errors: (optional)
8660+
Error handling strategy to apply when UTF-8 decode errors
8661+
occur in SSH protocol messages, defaulting to 'strict'
8662+
which shuts down the connection with a ProtocolError.
8663+
Choosing other strategies can allow the message parsing
8664+
to proceed with invalid bytes in the message being removed
8665+
or replaced.
86398666
:param canonicalize_hostname: (optional)
86408667
Whether or not to enable hostname canonicalization, defaulting
86418668
to `False`, in which case hostnames are passed as-is to the
@@ -8732,6 +8759,7 @@ class SSHServerConnectionOptions(SSHConnectionOptions):
87328759
:type keepalive_interval: *see* :ref:`SpecifyingTimeIntervals`
87338760
:type keepalive_count_max: `int`
87348761
:type tcp_keepalive: `bool`
8762+
:type utf8_decode_errors: `str`
87358763
:type canonicalize_hostname: `bool` or `'always'`
87368764
:type canonical_domains: `list` of `str`
87378765
:type canonicalize_fallback_local: `bool`
@@ -8790,6 +8818,7 @@ def prepare(self, # type: ignore
87908818
family: DefTuple[int] = (),
87918819
local_addr: DefTuple[HostPort] = (),
87928820
tcp_keepalive: DefTuple[bool] = (),
8821+
utf8_decode_errors: str = 'strict',
87938822
canonicalize_hostname: DefTuple[Union[bool, str]] = (),
87948823
canonical_domains: DefTuple[Sequence[str]] = (),
87958824
canonicalize_fallback_local: DefTuple[bool] = (),
@@ -8865,10 +8894,11 @@ def prepare(self, # type: ignore
88658894

88668895
super().prepare(config, server_factory or SSHServer, server_version,
88678896
host, port, tunnel, passphrase, proxy_command, family,
8868-
local_addr, tcp_keepalive, canonicalize_hostname,
8869-
canonical_domains, canonicalize_fallback_local,
8870-
canonicalize_max_dots, canonicalize_permitted_cnames,
8871-
kex_algs, encryption_algs, mac_algs, compression_algs,
8897+
local_addr, tcp_keepalive, utf8_decode_errors,
8898+
canonicalize_hostname, canonical_domains,
8899+
canonicalize_fallback_local, canonicalize_max_dots,
8900+
canonicalize_permitted_cnames, kex_algs,
8901+
encryption_algs, mac_algs, compression_algs,
88728902
signature_algs, host_based_auth, public_key_auth,
88738903
kbdint_auth, password_auth, x509_trusted_certs,
88748904
x509_trusted_cert_paths, x509_purposes,

asyncssh/sftp.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -959,14 +959,16 @@ class SFTPError(Error):
959959
"""
960960

961961
@staticmethod
962-
def construct(packet: SSHPacket) -> Optional['SFTPError']:
962+
def construct(packet: SSHPacket, utf8_decode_errors: str) -> \
963+
Optional['SFTPError']:
963964
"""Construct an SFTPError from an FXP_STATUS response"""
964965

965966
code = packet.get_uint32()
966967

967968
if packet:
968969
try:
969-
reason = packet.get_string().decode('utf-8')
970+
reason = packet.get_string().decode('utf-8',
971+
utf8_decode_errors)
970972
lang = packet.get_string().decode('ascii')
971973
except UnicodeDecodeError:
972974
raise SFTPBadMessage('Invalid status message') from None
@@ -2596,11 +2598,12 @@ class SFTPClientHandler(SFTPHandler):
25962598
"""An SFTP client session handler"""
25972599

25982600
def __init__(self, loop: asyncio.AbstractEventLoop,
2599-
reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]',
2600-
sftp_version: int):
2601+
utf8_decode_errors: str, reader: 'SSHReader[bytes]',
2602+
writer: 'SSHWriter[bytes]', sftp_version: int):
26012603
super().__init__(reader, writer)
26022604

26032605
self._loop = loop
2606+
self._utf8_decode_errors = utf8_decode_errors
26042607
self._version = sftp_version
26052608
self._next_pktid = 0
26062609
self._requests: Dict[int, _RequestWaiter] = {}
@@ -2694,7 +2697,7 @@ async def _make_request(self, pkttype: Union[int, bytes],
26942697
def _process_status(self, packet: SSHPacket) -> None:
26952698
"""Process an incoming SFTP status response"""
26962699

2697-
exc = SFTPError.construct(packet)
2700+
exc = SFTPError.construct(packet, self._utf8_decode_errors)
26982701

26992702
if self._version < 6:
27002703
packet.check_end()
@@ -8209,13 +8212,15 @@ async def open(self, path: bytes, mode: str) -> SFTPServerFile:
82098212

82108213
async def start_sftp_client(conn: 'SSHClientConnection',
82118214
loop: asyncio.AbstractEventLoop,
8215+
utf8_decode_errors: str,
82128216
reader: 'SSHReader[bytes]',
82138217
writer: 'SSHWriter[bytes]',
82148218
path_encoding: Optional[str],
82158219
path_errors: str, sftp_version: int) -> SFTPClient:
82168220
"""Start an SFTP client"""
82178221

8218-
handler = SFTPClientHandler(loop, reader, writer, sftp_version)
8222+
handler = SFTPClientHandler(loop, utf8_decode_errors,
8223+
reader, writer, sftp_version)
82198224

82208225
handler.logger.info('Starting SFTP client')
82218226

tests/test_channel.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def _send_request(self, request, *args, want_reply=False):
145145
if args[0] == String('invalid'):
146146
args = (String(b'\xff'),) + args[1:]
147147

148+
if args[2] == String('invalid'):
149+
args = args[:2] + (String(b'\xff'),) + args[3:]
150+
148151
if args[3] == String('invalid'):
149152
args = args[:3] + (String(b'\xff'),)
150153

@@ -429,6 +432,8 @@ async def _begin_session(self, stdin, stdout, stderr):
429432
stdin.channel.exit_with_signal('INT', False, 'closed_signal')
430433
elif action == 'invalid_exit_signal':
431434
stdin.channel.exit_with_signal('invalid')
435+
elif action == 'invalid_exit_msg':
436+
stdin.channel.exit_with_signal('INT', False, 'invalid', '')
432437
elif action == 'invalid_exit_lang':
433438
stdin.channel.exit_with_signal('INT', False, '', 'invalid')
434439
elif action == 'window_after_close':
@@ -1593,6 +1598,27 @@ async def test_invalid_exit_signal(self):
15931598
chan, _ = await _create_session(conn, 'invalid_exit_signal')
15941599

15951600
await chan.wait_closed()
1601+
self.assertIsNone(chan.get_exit_signal())
1602+
1603+
@asynctest
1604+
async def test_invalid_exit_msg(self):
1605+
"""Test delivery of invalid exit signal message"""
1606+
1607+
async with self.connect() as conn:
1608+
chan, _ = await _create_session(conn, 'invalid_exit_msg')
1609+
1610+
await chan.wait_closed()
1611+
self.assertIsNone(chan.get_exit_signal())
1612+
1613+
@asynctest
1614+
async def test_invalid_exit_msg_error_handler(self):
1615+
"""Test delivery of invalid exit signal message"""
1616+
1617+
async with self.connect(utf8_decode_errors='ignore') as conn:
1618+
chan, _ = await _create_session(conn, 'invalid_exit_msg')
1619+
1620+
await chan.wait_closed()
1621+
self.assertIsNotNone(chan.get_exit_signal())
15961622

15971623
@asynctest
15981624
async def test_invalid_exit_lang(self):

tests/test_connection.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,16 +350,16 @@ def begin_auth(self, username):
350350
return False
351351

352352

353-
class _VersionRecordingClient(asyncssh.SSHClient):
354-
"""Client for testing custom client version"""
353+
class _AuthBannerRecordingClient(asyncssh.SSHClient):
354+
"""Client which records auth banner"""
355355

356356
def __init__(self):
357-
self.reported_version = None
357+
self.auth_banner = None
358358

359359
def auth_banner_received(self, msg, lang):
360360
"""Record the client version reported in the auth banner"""
361361

362-
self.reported_version = msg
362+
self.auth_banner = msg
363363

364364

365365
class _VersionReportingServer(Server):
@@ -2542,11 +2542,26 @@ async def start_server(cls):
25422542

25432543
@asynctest
25442544
async def test_invalid_auth_banner(self):
2545-
"""Test server sending invalid auth banner"""
2545+
"""Test server sending invalid UTF-8 in auth banner"""
25462546

25472547
with self.assertRaises(asyncssh.ProtocolError):
25482548
await self.connect()
25492549

2550+
@asynctest
2551+
async def test_invalid_auth_banner_error_handler(self):
2552+
"""Test error handler for invalid UTF-8 in auth banner"""
2553+
2554+
for errors in ('ignore', 'replace', 'backslashreplace',
2555+
'surrogateescape'):
2556+
conn, client = \
2557+
await self.create_connection(_AuthBannerRecordingClient,
2558+
utf8_decode_errors=errors)
2559+
2560+
async with conn:
2561+
self.assertEqual(client.auth_banner,
2562+
b'\xff'.decode('utf-8', errors))
2563+
2564+
25502565

25512566
class _TestExpiredServerHostCertificate(ServerTestCase):
25522567
"""Unit tests for expired server host certificate"""
@@ -2585,11 +2600,11 @@ async def _check_client_version(self, version):
25852600
"""Check custom client version"""
25862601

25872602
conn, client = \
2588-
await self.create_connection(_VersionRecordingClient,
2603+
await self.create_connection(_AuthBannerRecordingClient,
25892604
client_version=version)
25902605

25912606
async with conn:
2592-
self.assertEqual(client.reported_version, 'SSH-2.0-custom')
2607+
self.assertEqual(client.auth_banner, 'SSH-2.0-custom')
25932608

25942609
@asynctest
25952610
async def test_custom_client_version(self):

0 commit comments

Comments
 (0)