From 8be285a132f3ca0dee7270caa9e44fb57907d8a4 Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Sat, 14 Sep 2024 16:09:42 +0000 Subject: [PATCH] Clean up socket error handling during connection establishment. --- src/connection.ts | 283 ++++++++-------- test/unit/connection-failure-test.js | 462 +++++++++++++++++++++++++++ 2 files changed, 600 insertions(+), 145 deletions(-) create mode 100644 test/unit/connection-failure-test.js diff --git a/src/connection.ts b/src/connection.ts index 53c79b0ce..eb9f917fc 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -1964,110 +1964,154 @@ class Connection extends EventEmitter { * @private */ initialiseConnection() { - const controller = new AbortController(); - - setTimeout(() => { - const hostPostfix = this.config.options.port ? `:${this.config.options.port}` : `\\${this.config.options.instanceName}`; - // If we have routing data stored, this connection has been redirected - const server = this.routingData ? this.routingData.server : this.config.server; - const port = this.routingData ? `:${this.routingData.port}` : hostPostfix; - // Grab the target host from the connection configuration, and from a redirect message - // otherwise, leave the message empty. - const routingMessage = this.routingData ? ` (redirected from ${this.config.server}${hostPostfix})` : ''; - const message = `Failed to connect to ${server}${port}${routingMessage} in ${this.config.options.connectTimeout}ms`; - this.debug.log(message); - - controller.abort(new ConnectionError(message, 'ETIMEOUT')); - }, this.config.options.connectTimeout).unref(); - - const signal = controller.signal; - (async () => { - let port = this.config.options.port; + const timeoutController = new AbortController(); + + const connectTimer = setTimeout(() => { + const hostPostfix = this.config.options.port ? `:${this.config.options.port}` : `\\${this.config.options.instanceName}`; + // If we have routing data stored, this connection has been redirected + const server = this.routingData ? this.routingData.server : this.config.server; + const port = this.routingData ? `:${this.routingData.port}` : hostPostfix; + // Grab the target host from the connection configuration, and from a redirect message + // otherwise, leave the message empty. + const routingMessage = this.routingData ? ` (redirected from ${this.config.server}${hostPostfix})` : ''; + const message = `Failed to connect to ${server}${port}${routingMessage} in ${this.config.options.connectTimeout}ms`; + this.debug.log(message); + + timeoutController.abort(new ConnectionError(message, 'ETIMEOUT')); + }, this.config.options.connectTimeout); + + try { + let signal = timeoutController.signal; + + let port = this.config.options.port; + + if (!port) { + try { + port = await instanceLookup({ + server: this.config.server, + instanceName: this.config.options.instanceName!, + timeout: this.config.options.connectTimeout, + signal: signal + }); + } catch (err: any) { + if (signal.aborted) { + throw signal.reason; + } + + throw new ConnectionError(err.message, 'EINSTLOOKUP', { cause: err }); + } + } - if (!port) { + let socket; try { - port = await instanceLookup({ - server: this.config.server, - instanceName: this.config.options.instanceName!, - timeout: this.config.options.connectTimeout, - signal: signal - }); + socket = await this.connectOnPort(port, this.config.options.multiSubnetFailover, signal, this.config.options.connector); } catch (err: any) { if (signal.aborted) { throw signal.reason; } - throw new ConnectionError(err.message, 'EINSTLOOKUP', { cause: err }); - } - } - - let socket; - try { - socket = await this.connectOnPort(port, this.config.options.multiSubnetFailover, signal, this.config.options.connector); - } catch (err: any) { - if (signal.aborted) { - throw signal.reason; + throw this.wrapSocketError(err); } - throw this.wrapSocketError(err); - } + const controller = new AbortController(); + const onError = (err: Error) => { + controller.abort(err); + }; + const onClose = () => { + this.debug.log('connection to ' + this.config.server + ':' + this.config.options.port + ' closed'); + }; + const onEnd = () => { + this.debug.log('socket ended'); - this.socketHandlingForSendPreLogin(socket); - this.sendPreLogin(); + const error: ErrorWithCode = new Error('socket hang up'); + error.code = 'ECONNRESET'; + controller.abort(this.wrapSocketError(error)); + }; - this.transitionTo(this.STATE.SENT_PRELOGIN); - await this.performSentPrelogin(signal); + socket.once('error', onError); + socket.once('close', onClose); + socket.once('end', onEnd); - this.sendLogin7Packet(); + signal = AbortSignal.any([signal, controller.signal]); - try { - const { authentication } = this.config; - switch (authentication.type) { - case 'token-credential': - case 'azure-active-directory-password': - case 'azure-active-directory-msi-vm': - case 'azure-active-directory-msi-app-service': - case 'azure-active-directory-service-principal-secret': - case 'azure-active-directory-default': - this.transitionTo(this.STATE.SENT_LOGIN7_WITH_FEDAUTH); - this.routingData = await this.performSentLogin7WithFedAuth(signal); - break; - case 'ntlm': - this.transitionTo(this.STATE.SENT_LOGIN7_WITH_NTLM); - this.routingData = await this.performSentLogin7WithNTLMLogin(signal); - break; - default: - this.transitionTo(this.STATE.SENT_LOGIN7_WITH_STANDARD_LOGIN); - this.routingData = await this.performSentLogin7WithStandardLogin(signal); - break; - } - } catch (err: any) { - if (isTransientError(err)) { - this.debug.log('Initiating retry on transient error'); - this.transitionTo(this.STATE.TRANSIENT_FAILURE_RETRY); - this.performTransientFailureRetry(); - return; - } + try { + socket.setKeepAlive(true, KEEP_ALIVE_INITIAL_DELAY); + + this.messageIo = new MessageIO(socket, this.config.options.packetSize, this.debug); + this.messageIo.on('secure', (cleartext) => { this.emit('secure', cleartext); }); + + this.socket = socket; + + this.closed = false; + this.debug.log('connected to ' + this.config.server + ':' + this.config.options.port); + + this.sendPreLogin(); + + this.transitionTo(this.STATE.SENT_PRELOGIN); + await this.performSentPrelogin(signal); + + this.sendLogin7Packet(); + + try { + const { authentication } = this.config; + switch (authentication.type) { + case 'token-credential': + case 'azure-active-directory-password': + case 'azure-active-directory-msi-vm': + case 'azure-active-directory-msi-app-service': + case 'azure-active-directory-service-principal-secret': + case 'azure-active-directory-default': + this.transitionTo(this.STATE.SENT_LOGIN7_WITH_FEDAUTH); + this.routingData = await this.performSentLogin7WithFedAuth(signal); + break; + case 'ntlm': + this.transitionTo(this.STATE.SENT_LOGIN7_WITH_NTLM); + this.routingData = await this.performSentLogin7WithNTLMLogin(signal); + break; + default: + this.transitionTo(this.STATE.SENT_LOGIN7_WITH_STANDARD_LOGIN); + this.routingData = await this.performSentLogin7WithStandardLogin(signal); + break; + } + } catch (err: any) { + if (isTransientError(err)) { + this.debug.log('Initiating retry on transient error'); + this.transitionTo(this.STATE.TRANSIENT_FAILURE_RETRY); + this.performTransientFailureRetry(); + return; + } + + throw err; + } - throw err; - } + // If routing data is present, we need to re-route the connection + if (this.routingData) { + this.transitionTo(this.STATE.REROUTING); + this.performReRouting(); + return; + } - // If routing data is present, we need to re-route the connection - if (this.routingData) { - this.transitionTo(this.STATE.REROUTING); - this.performReRouting(); - return; - } + this.transitionTo(this.STATE.LOGGED_IN_SENDING_INITIAL_SQL); + await this.performLoggedInSendingInitialSql(signal); + } finally { + socket.removeListener('error', onError); + socket.removeListener('close', onClose); + socket.removeListener('end', onEnd); + } - this.transitionTo(this.STATE.LOGGED_IN_SENDING_INITIAL_SQL); - await this.performLoggedInSendingInitialSql(signal); + socket.on('error', this._onSocketError); + socket.on('close', this._onSocketClose); + socket.on('end', this._onSocketEnd); - this.transitionTo(this.STATE.LOGGED_IN); + this.transitionTo(this.STATE.LOGGED_IN); - process.nextTick(() => { - this.emit('connect'); - }); + process.nextTick(() => { + this.emit('connect'); + }); + } finally { + clearTimeout(connectTimer); + } })().catch((err) => { this.transitionTo(this.STATE.FINAL); @@ -2119,21 +2163,6 @@ class Connection extends EventEmitter { return new TokenStreamParser(message, this.debug, handler, this.config.options); } - socketHandlingForSendPreLogin(socket: net.Socket) { - socket.on('error', this._onSocketError); - socket.on('close', this._onSocketClose); - socket.on('end', this._onSocketEnd); - socket.setKeepAlive(true, KEEP_ALIVE_INITIAL_DELAY); - - this.messageIo = new MessageIO(socket, this.config.options.packetSize, this.debug); - this.messageIo.on('secure', (cleartext) => { this.emit('secure', cleartext); }); - - this.socket = socket; - - this.closed = false; - this.debug.log('connected to ' + this.config.server + ':' + this.config.options.port); - } - wrapWithTls(socket: net.Socket, signal: AbortSignal): Promise { signal.throwIfAborted(); @@ -3609,75 +3638,39 @@ Connection.prototype.STATE = { enter: function() { this.initialiseConnection(); }, - events: { - socketError: function() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, SENT_PRELOGIN: { name: 'SentPrelogin', - events: { - socketError: function() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, REROUTING: { name: 'ReRouting', - events: { - socketError: function() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, TRANSIENT_FAILURE_RETRY: { name: 'TRANSIENT_FAILURE_RETRY', - events: { - socketError: function() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, SENT_TLSSSLNEGOTIATION: { name: 'SentTLSSSLNegotiation', - events: { - socketError: function() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, SENT_LOGIN7_WITH_STANDARD_LOGIN: { name: 'SentLogin7WithStandardLogin', - events: { - socketError: function() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, SENT_LOGIN7_WITH_NTLM: { name: 'SentLogin7WithNTLMLogin', - events: { - socketError: function() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, SENT_LOGIN7_WITH_FEDAUTH: { name: 'SentLogin7WithFedauth', - events: { - socketError: function() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, LOGGED_IN_SENDING_INITIAL_SQL: { name: 'LoggedInSendingInitialSql', - events: { - socketError: function socketError() { - this.transitionTo(this.STATE.FINAL); - } - } + events: {} }, LOGGED_IN: { name: 'LoggedIn', diff --git a/test/unit/connection-failure-test.js b/test/unit/connection-failure-test.js new file mode 100644 index 000000000..59aedacff --- /dev/null +++ b/test/unit/connection-failure-test.js @@ -0,0 +1,462 @@ +const { assert } = require('chai'); +const net = require('net'); + +const { Connection, ConnectionError } = require('../../src/tedious'); +const IncomingMessageStream = require('../../src/incoming-message-stream'); +const OutgoingMessageStream = require('../../src/outgoing-message-stream'); +const Debug = require('../../src/debug'); +const PreloginPayload = require('../../src/prelogin-payload'); +const Message = require('../../src/message'); + +function buildLoginAckToken() { + const progname = 'Tedious SQL Server'; + + const buffer = Buffer.from([ + 0xAD, // Type + 0x00, 0x00, // Length + 0x00, // interface number - SQL + 0x74, 0x00, 0x00, 0x04, // TDS version number + Buffer.byteLength(progname, 'ucs2') / 2, ...Buffer.from(progname, 'ucs2'), // Progname + 0x00, // major + 0x00, // minor + 0x00, 0x00, // buildNum + ]); + + buffer.writeUInt16LE(buffer.length - 3, 1); + + return buffer; +} + +describe('Connection failure handling', function() { + /** + * @type {net.Server} + */ + let server; + + /** + * @type {net.Socket[]} + */ + let _connections; + + beforeEach(function(done) { + _connections = []; + server = net.createServer(); + server.listen(0, '127.0.0.1', done); + }); + + afterEach(function(done) { + _connections.forEach((connection) => { + connection.destroy(); + }); + + server.close(done); + }); + + it('should fail correctly when the connection is aborted after the prelogin message is sent', function(done) { + server.on('connection', async (connection) => { + const debug = new Debug(); + const incomingMessageStream = new IncomingMessageStream(debug); + const outgoingMessageStream = new OutgoingMessageStream(debug, { packetSize: 4 * 1024 }); + + connection.pipe(incomingMessageStream); + outgoingMessageStream.pipe(connection); + + try { + const messageIterator = incomingMessageStream[Symbol.asyncIterator](); + + // PRELOGIN + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x12); + + setImmediate(() => { + connection.destroy(); + }); + } + } catch (err) { + console.log(err); + } + }); + + const connection = new Connection({ + server: server.address().address, + options: { + port: server.address().port, + encrypt: false + } + }); + + connection.connect((err) => { + connection.close(); + + assert.instanceOf(err, ConnectionError); + assert.strictEqual('Connection lost - socket hang up', err.message); + + assert.instanceOf(err.cause, Error); + assert.strictEqual('socket hang up', err.cause.message); + + done(); + }); + }); + + it('should fail correctly when the connection is aborted after the prelogin response is received', function(done) { + server.on('connection', async (connection) => { + const debug = new Debug(); + const incomingMessageStream = new IncomingMessageStream(debug); + const outgoingMessageStream = new OutgoingMessageStream(debug, { packetSize: 4 * 1024 }); + + connection.pipe(incomingMessageStream); + outgoingMessageStream.pipe(connection); + + try { + const messageIterator = incomingMessageStream[Symbol.asyncIterator](); + + // PRELOGIN + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x12); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responsePayload = new PreloginPayload({ encrypt: false, version: { major: 1, minor: 2, build: 3, subbuild: 0 } }); + const responseMessage = new Message({ type: 0x12 }); + responseMessage.end(responsePayload.data); + outgoingMessageStream.write(responseMessage); + } + + setImmediate(() => { + connection.destroy(); + }); + } catch (err) { + console.log(err); + } + }); + + const connection = new Connection({ + server: server.address().address, + options: { + port: server.address().port, + encrypt: false + } + }); + + connection.connect((err) => { + connection.close(); + + assert.instanceOf(err, ConnectionError); + assert.strictEqual('Connection lost - socket hang up', err.message); + + assert.instanceOf(err.cause, Error); + assert.strictEqual('socket hang up', err.cause.message); + + done(); + }); + }); + + it('should fail correctly when the connection is aborted after the Login7 message is sent', function(done) { + server.on('connection', async (connection) => { + const debug = new Debug(); + const incomingMessageStream = new IncomingMessageStream(debug); + const outgoingMessageStream = new OutgoingMessageStream(debug, { packetSize: 4 * 1024 }); + + connection.pipe(incomingMessageStream); + outgoingMessageStream.pipe(connection); + + try { + const messageIterator = incomingMessageStream[Symbol.asyncIterator](); + + // PRELOGIN + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x12); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responsePayload = new PreloginPayload({ encrypt: false, version: { major: 1, minor: 2, build: 3, subbuild: 0 } }); + const responseMessage = new Message({ type: 0x12 }); + responseMessage.end(responsePayload.data); + outgoingMessageStream.write(responseMessage); + } + + // LOGIN7 + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x10); + + setImmediate(() => { + connection.destroy(); + }); + } + } catch (err) { + console.log(err); + } + }); + + const connection = new Connection({ + server: server.address().address, + options: { + port: server.address().port, + encrypt: false + } + }); + + connection.connect((err) => { + connection.close(); + + assert.instanceOf(err, ConnectionError); + assert.strictEqual('Connection lost - socket hang up', err.message); + + assert.instanceOf(err.cause, Error); + assert.strictEqual('socket hang up', err.cause.message); + + done(); + }); + }); + + it('should fail correctly when the connection is aborted after the Login7 response is received', function(done) { + server.on('connection', async (connection) => { + const debug = new Debug(); + const incomingMessageStream = new IncomingMessageStream(debug); + const outgoingMessageStream = new OutgoingMessageStream(debug, { packetSize: 4 * 1024 }); + + connection.pipe(incomingMessageStream); + outgoingMessageStream.pipe(connection); + + try { + const messageIterator = incomingMessageStream[Symbol.asyncIterator](); + + // PRELOGIN + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x12); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responsePayload = new PreloginPayload({ encrypt: false, version: { major: 1, minor: 2, build: 3, subbuild: 0 } }); + const responseMessage = new Message({ type: 0x12 }); + responseMessage.end(responsePayload.data); + outgoingMessageStream.write(responseMessage); + } + + // LOGIN7 + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x10); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responseMessage = new Message({ type: 0x04 }); + responseMessage.end(buildLoginAckToken()); + outgoingMessageStream.write(responseMessage); + } + + setImmediate(() => { + connection.destroy(); + }); + } catch (err) { + console.log(err); + } + }); + + const connection = new Connection({ + server: server.address().address, + options: { + port: server.address().port, + encrypt: false + } + }); + + connection.connect((err) => { + connection.close(); + + assert.instanceOf(err, ConnectionError); + assert.strictEqual('Connection lost - socket hang up', err.message); + + assert.instanceOf(err.cause, Error); + assert.strictEqual('socket hang up', err.cause.message); + + done(); + }); + }); + + it('should fail correctly when the connection is aborted after the initial SQL message is sent', function(done) { + server.on('connection', async (connection) => { + const debug = new Debug(); + const incomingMessageStream = new IncomingMessageStream(debug); + const outgoingMessageStream = new OutgoingMessageStream(debug, { packetSize: 4 * 1024 }); + + connection.pipe(incomingMessageStream); + outgoingMessageStream.pipe(connection); + + try { + const messageIterator = incomingMessageStream[Symbol.asyncIterator](); + + // PRELOGIN + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x12); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responsePayload = new PreloginPayload({ encrypt: false, version: { major: 1, minor: 2, build: 3, subbuild: 0 } }); + const responseMessage = new Message({ type: 0x12 }); + responseMessage.end(responsePayload.data); + outgoingMessageStream.write(responseMessage); + } + + // LOGIN7 + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x10); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responseMessage = new Message({ type: 0x04 }); + responseMessage.end(buildLoginAckToken()); + outgoingMessageStream.write(responseMessage); + } + + // SQL Batch (Initial SQL) + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x01); + + setImmediate(() => { + connection.destroy(); + }); + } + } catch (err) { + console.log(err); + } + }); + + const connection = new Connection({ + server: server.address().address, + options: { + port: server.address().port, + encrypt: false + } + }); + + connection.connect((err) => { + connection.close(); + + assert.instanceOf(err, ConnectionError); + assert.strictEqual('Connection lost - socket hang up', err.message); + + assert.instanceOf(err.cause, Error); + assert.strictEqual('socket hang up', err.cause.message); + + done(); + }); + }); + + it('should fail correctly when the connection is aborted after the initial SQL response is received', function(done) { + server.on('connection', async (connection) => { + const debug = new Debug(); + const incomingMessageStream = new IncomingMessageStream(debug); + const outgoingMessageStream = new OutgoingMessageStream(debug, { packetSize: 4 * 1024 }); + + connection.pipe(incomingMessageStream); + outgoingMessageStream.pipe(connection); + + try { + const messageIterator = incomingMessageStream[Symbol.asyncIterator](); + + // PRELOGIN + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x12); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responsePayload = new PreloginPayload({ encrypt: false, version: { major: 1, minor: 2, build: 3, subbuild: 0 } }); + const responseMessage = new Message({ type: 0x12 }); + responseMessage.end(responsePayload.data); + outgoingMessageStream.write(responseMessage); + } + + // LOGIN7 + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x10); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responseMessage = new Message({ type: 0x04 }); + responseMessage.end(buildLoginAckToken()); + outgoingMessageStream.write(responseMessage); + } + + // SQL Batch (Initial SQL) + { + const { value: message } = await messageIterator.next(); + assert.strictEqual(message.type, 0x01); + + const chunks = []; + for await (const data of message) { + chunks.push(data); + } + + const responseMessage = new Message({ type: 0x04 }); + responseMessage.end(); + outgoingMessageStream.write(responseMessage); + } + + setImmediate(() => { + connection.destroy(); + }); + } catch (err) { + console.log(err); + } + }); + + const connection = new Connection({ + server: server.address().address, + options: { + port: server.address().port, + encrypt: false + } + }); + + connection.connect((err) => { + assert.isUndefined(err); + + connection.on('error', (err) => { + connection.close(); + + assert.instanceOf(err, ConnectionError); + assert.strictEqual('Connection lost - socket hang up', err.message); + + assert.instanceOf(err.cause, Error); + assert.strictEqual('socket hang up', err.cause.message); + + done(); + }); + }); + }); +});