diff --git a/src/Message.ts b/src/Message.ts index 6e0b1bb..fbb5e7f 100644 --- a/src/Message.ts +++ b/src/Message.ts @@ -3,8 +3,10 @@ export type TransportRequest = { type: 'request', slotName: string, id: string, export type TransportResponse = { type: 'response', slotName: string, id: string, data: any } export type TransportError = { type: 'error', slotName: string, id: string, message: string, stack?: string } export type TransportRegistrationMessage = { type: 'handler_registered', slotName: string } +export type TransportUnregistrationMessage = { type: 'handler_unregistered', slotName: string } export type TransportMessage = TransportRegistrationMessage + | TransportUnregistrationMessage | TransportRequest | TransportResponse | TransportError diff --git a/src/Slot.ts b/src/Slot.ts index 90769c7..137c1e5 100644 --- a/src/Slot.ts +++ b/src/Slot.ts @@ -102,6 +102,10 @@ export function connectSlot(slotName: string, transports: Trans // Return the unsubscription function return () => { + + // Unregister remote handler with all of our remote transports + transports.forEach(t => t.unregisterHandler(slotName, handler)) + const ix = handlers.indexOf(handler) if (ix !== -1) { handlers.splice(ix, 1) diff --git a/src/Transport.ts b/src/Transport.ts index edc276f..5a5793e 100644 --- a/src/Transport.ts +++ b/src/Transport.ts @@ -2,6 +2,7 @@ import { Handler, callHandlers } from './Handler' import { Channel } from './Channel' import { TransportRegistrationMessage, + TransportUnregistrationMessage, TransportError, TransportRequest, TransportResponse, @@ -74,6 +75,8 @@ export class Transport { return this._responseReceived(message) case 'handler_registered': return this._registerRemoteHandler(message) + case 'handler_unregistered': + return this._unregisterRemoteHandler(message.slotName) case 'error': return this._errorReceived(message) default: @@ -92,7 +95,7 @@ export class Transport { this._channelReady = false // When the far end disconnects, remove all the handlers it had set - this._unregisterHandlers() + this._unregisterAllRemoteHandlers() this._rejectAllPendingRequests(new Error(`${ERRORS.REMOTE_CONNECTION_CLOSED}`)) }) @@ -212,15 +215,19 @@ export class Transport { addHandler(remoteHandler) } - private _unregisterHandlers(): void { + private _unregisterRemoteHandler(slotName: string): void { + const unregisterRemoteHandler = this._remoteHandlerDeletionCallbacks[slotName] + const remoteHandler = this._remoteHandlers[slotName] + if (remoteHandler && unregisterRemoteHandler) { + unregisterRemoteHandler(remoteHandler) + delete this._remoteHandlers[slotName] + } + } + + private _unregisterAllRemoteHandlers(): void { Object.keys(this._remoteHandlerDeletionCallbacks) .forEach(slotName => { - const unregisterRemoteHandler = this._remoteHandlerDeletionCallbacks[slotName] - const remoteHandler = this._remoteHandlers[slotName] - if (remoteHandler && unregisterRemoteHandler) { - unregisterRemoteHandler(remoteHandler) - delete this._remoteHandlers[slotName] - } + this._unregisterRemoteHandler(slotName) }) } @@ -264,14 +271,45 @@ export class Transport { this._localHandlers[slotName] = [] } this._localHandlers[slotName].push(handler) - const registrationMessage: TransportRegistrationMessage = { - type: 'handler_registered', - slotName - } - this._localHandlerRegistrations.push(registrationMessage) - if (this._channelReady) { - this._channel.send(registrationMessage) + /** + * We notify the far end when adding the first handler only, as they + * only need to know if at least one handler is connected. + */ + if (this._localHandlers[slotName].length === 1) { + const registrationMessage: TransportRegistrationMessage = { + type: 'handler_registered', + slotName + } + this._localHandlerRegistrations.push(registrationMessage) + if (this._channelReady) { + this._channel.send(registrationMessage) + } } } + /** + * Called when a local handler is unregistered, to send a `handler_unregistered` + * message to the far end. + */ + public unregisterHandler(slotName: string, handler: Handler): void { + if (this._localHandlers[slotName]) { + const ix = this._localHandlers[slotName].indexOf(handler) + if (ix > -1) { + this._localHandlers[slotName].splice(ix, 1) + /** + * We notify the far end when removing the last handler only, as they + * only need to know if at least one handler is connected. + */ + if (this._localHandlers[slotName].length === 0) { + const unregistrationMessage: TransportUnregistrationMessage = { + type: 'handler_unregistered', + slotName + } + if (this._channelReady) { + this._channel.send(unregistrationMessage) + } + } + } + } + } } diff --git a/test/Transport.test.ts b/test/Transport.test.ts index 2f52350..ee884ac 100644 --- a/test/Transport.test.ts +++ b/test/Transport.test.ts @@ -4,6 +4,7 @@ import { Transport } from './../src/Transport' import { TransportMessage } from './../src/Message' import { TestChannel } from './TestChannel' import * as sinon from 'sinon' +import { SinonSpy } from 'sinon' import { createEventBus } from '../src/Events' import { slot } from '../src/Slot' @@ -24,19 +25,25 @@ describe('Transport', () => { }) - context('handler registration and requests', () => { + context('handler registration, requests and unregistration', () => { - const channel = new TestChannel() - const transport = new Transport(channel) - const handlers = { - buildCelery: sinon.spy(() => ({ color: 'blue' })), - getCarrotStock: sinon.spy() - } + let channel: TestChannel + let transport: Transport + let slots: { [slotName: string]: SinonSpy[] } - it('should send a handler_registered message when a local handler is registered', () => { + beforeEach(() => { + slots = { + buildCelery: [sinon.spy(() => ({ color: 'blue' }))], + getCarrotStock: [sinon.spy(), sinon.spy()] + } + channel = new TestChannel() + transport = new Transport(channel) channel.callConnected() - Object.keys(handlers).forEach(slotName => { - transport.registerHandler(slotName, handlers[slotName]) + }) + + it('should send a handler_registered message for each slot when a local handler is registered', () => { + Object.keys(slots).forEach(slotName => { + transport.registerHandler(slotName, slots[slotName][0]) channel.sendSpy.calledWith({ type: 'handler_registered', slotName @@ -44,10 +51,25 @@ describe('Transport', () => { }) }) + it('should not send a handler_registered message when an additional local handler is registered', () => { + const slotName = 'getCarrotStock' + transport.registerHandler(slotName, slots[slotName][0]) + transport.registerHandler(slotName, slots[slotName][1]) + channel.sendSpy + .withArgs({ type: 'handler_registered', slotName }) + .calledOnce // should have been called exactly once + .should.be.True() + }) + + it('should call the appropriate handler when a request is received', async () => { + const slotName = 'buildCelery' - const handler = handlers[slotName] - handler.called.should.be.False() + const handler = slots[slotName][0] + + // Register handler on slot + transport.registerHandler(slotName, handler) + const request: TransportMessage = { type: 'request', slotName, @@ -57,6 +79,7 @@ describe('Transport', () => { constitution: 'strong' } } + channel.fakeReceive(request) await Promise.resolve() // yield to ts-event-bus internals @@ -71,22 +94,84 @@ describe('Transport', () => { }) }) - context('adding and using a remote handler', () => { + it('should send a handler_unregistered message when the last local handler is unregistered', () => { + + const slotName = 'buildCelery' + + // Register one handler on slot + transport.registerHandler(slotName, slots[slotName][0]) + + // Unregister it + transport.unregisterHandler(slotName, slots[slotName][0]) + + channel.sendSpy.calledWith({ + type: 'handler_unregistered', + slotName + }).should.be.True() + }) + + it('should not call the unregistered handler when a request is received', async () => { + + const slotName = 'buildCelery' + const handler = slots[slotName][0] + + // Register one handler on slot + transport.registerHandler(slotName, handler) + + // Unregister it + transport.unregisterHandler(slotName, slots[slotName][0]) + + const request: TransportMessage = { + type: 'request', + slotName, + id: '5', + data: { + height: 5, + constitution: 'strong' + } + } + channel.fakeReceive(request) + await Promise.resolve() // yield to ts-event-bus internals + handler.called.should.be.False() + }) + + it('should not send a handler_unregistered message when an additional local handler is unregistered', () => { const slotName = 'getCarrotStock' - const addLocalHandler = sinon.spy() + + // Register two handlers on slot + transport.registerHandler(slotName, slots[slotName][0]) + transport.registerHandler(slotName, slots[slotName][1]) + + // Unregister one handler only + transport.unregisterHandler(slotName, slots[slotName][0]) + + channel.sendSpy.calledWith({ + type: 'handler_unregistered', + slotName + }).should.be.False() + }) + + context('adding, using and removing a remote handler', () => { + + const slotName = 'getCarrotStock' + + let addLocalHandler: SinonSpy + let removeLocalHandler: SinonSpy let localHandler: (...args: any[]) => Promise - it('should add a local handler when a remote handler registration is received', () => { + beforeEach(() => { + addLocalHandler = sinon.spy() + removeLocalHandler = sinon.spy() transport.onRemoteHandlerRegistered(slotName, addLocalHandler) - channel.fakeReceive({ - type: 'handler_registered', - slotName - }) - addLocalHandler.called.should.be.True() + channel.fakeReceive({ type: 'handler_registered', slotName }) localHandler = addLocalHandler.lastCall.args[0] }) + it('should add a local handler when a remote handler registration is received', () => { + addLocalHandler.called.should.be.True() + }) + it('should resolve a local pending request when a response is received', () => { const requestData = { carrotType: 'red' } const pendingPromise = localHandler(requestData) @@ -122,6 +207,15 @@ describe('Transport', () => { }) .catch(err => `${err}`.should.eql('Error: all out of blue on getCarrotStock')) }) + + it('should remove a local handler when a remote handler unregistration is received', () => { + transport.onRemoteHandlerUnregistered(slotName, removeLocalHandler) + channel.fakeReceive({ + type: 'handler_unregistered', + slotName + }) + removeLocalHandler.called.should.be.True() + }) }) })