diff --git a/server/src/browser-management/controller.ts b/server/src/browser-management/controller.ts index ff5b57347..7c5271bf1 100644 --- a/server/src/browser-management/controller.ts +++ b/server/src/browser-management/controller.ts @@ -5,7 +5,7 @@ import { Socket } from "socket.io"; import { uuid } from 'uuidv4'; -import { createSocketConnection, createSocketConnectionForRun, registerBrowserUserContext } from "../socket-connection/connection"; +import { createSocketConnection, createSocketConnectionForRun } from "../socket-connection/connection"; import { io, browserPool } from "../server"; import { RemoteBrowser } from "./classes/RemoteBrowser"; import { RemoteBrowserOptions } from "../types"; @@ -24,6 +24,7 @@ export const initializeRemoteBrowserForRecording = (userId: string): string => { const id = getActiveBrowserIdByState(userId, "recording") || uuid(); createSocketConnection( io.of(id), + userId, async (socket: Socket) => { // browser is already active const activeId = getActiveBrowserIdByState(userId, "recording"); @@ -55,11 +56,8 @@ export const initializeRemoteBrowserForRecording = (userId: string): string => { export const createRemoteBrowserForRun = (userId: string): string => { const id = uuid(); - registerBrowserUserContext(id, userId); - logger.log('debug', `Created new browser for run: ${id} for user: ${userId}`); - createSocketConnectionForRun( - io.of(`/${id}`), + io.of(id), async (socket: Socket) => { try { const browserSession = new RemoteBrowser(socket, userId, id); diff --git a/server/src/browser-management/inputHandlers.ts b/server/src/browser-management/inputHandlers.ts index 4e6d5d190..9b6551a5f 100644 --- a/server/src/browser-management/inputHandlers.ts +++ b/server/src/browser-management/inputHandlers.ts @@ -4,9 +4,6 @@ * These functions are called by the client through socket communication. */ import { Socket } from 'socket.io'; -import { IncomingMessage } from 'http'; -import { JwtPayload } from 'jsonwebtoken'; - import logger from "../logger"; import { Coordinates, ScrollDeltas, KeyboardInput, DatePickerEventData } from '../types'; import { browserPool } from "../server"; @@ -15,14 +12,6 @@ import { Page } from "playwright"; import { throttle } from "../../../src/helpers/inputHelpers"; import { CustomActions } from "../../../src/shared/types"; -interface AuthenticatedIncomingMessage extends IncomingMessage { - user?: JwtPayload | string; -} - -interface AuthenticatedSocket extends Socket { - request: AuthenticatedIncomingMessage; -} - /** * A wrapper function for handling user input. * This function gets the active browser instance from the browser pool @@ -42,20 +31,9 @@ const handleWrapper = async ( page: Page, args?: any ) => Promise, - args?: any, - socket?: AuthenticatedSocket, + userId: string, + args?: any ) => { - if (!socket || !socket.request || !socket.request.user || typeof socket.request.user === 'string') { - logger.log('warn', `User not authenticated or invalid JWT payload`); - return; - } - - const userId = socket.request.user.id; - if (!userId) { - logger.log('warn', `User ID is missing in JWT payload`); - return; - } - const id = browserPool.getActiveBrowserId(userId, "recording"); if (id) { const activeBrowser = browserPool.getRemoteBrowser(id); @@ -93,9 +71,9 @@ interface CustomActionEventData { * @param customActionEventData The custom action event data * @category HelperFunctions */ -const onGenerateAction = async (socket: AuthenticatedSocket, customActionEventData: CustomActionEventData) => { +const onGenerateAction = async (customActionEventData: CustomActionEventData, userId: string) => { logger.log('debug', `Generating ${customActionEventData.action} action emitted from client`); - await handleWrapper(handleGenerateAction, customActionEventData, socket); + await handleWrapper(handleGenerateAction, userId, customActionEventData); } /** @@ -117,9 +95,9 @@ const handleGenerateAction = * @param coordinates - coordinates of the mouse click * @category HelperFunctions */ -const onMousedown = async (socket: AuthenticatedSocket, coordinates: Coordinates) => { +const onMousedown = async (coordinates: Coordinates, userId: string) => { logger.log('debug', 'Handling mousedown event emitted from client'); - await handleWrapper(handleMousedown, coordinates, socket); + await handleWrapper(handleMousedown, userId, coordinates); } /** @@ -168,9 +146,9 @@ const handleMousedown = async (generator: WorkflowGenerator, page: Page, { x, y * @param scrollDeltas - the scroll deltas of the wheel event * @category HelperFunctions */ -const onWheel = async (socket: AuthenticatedSocket, scrollDeltas: ScrollDeltas) => { +const onWheel = async (scrollDeltas: ScrollDeltas, userId: string) => { logger.log('debug', 'Handling scroll event emitted from client'); - await handleWrapper(handleWheel, scrollDeltas, socket); + await handleWrapper(handleWheel, userId, scrollDeltas); }; /** @@ -206,9 +184,9 @@ const handleWheel = async (generator: WorkflowGenerator, page: Page, { deltaX, d * @param coordinates - the coordinates of the mousemove event * @category HelperFunctions */ -const onMousemove = async (socket: AuthenticatedSocket, coordinates: Coordinates) => { +const onMousemove = async (coordinates: Coordinates, userId: string) => { logger.log('debug', 'Handling mousemove event emitted from client'); - await handleWrapper(handleMousemove, coordinates, socket); + await handleWrapper(handleMousemove, userId, coordinates); } /** @@ -247,9 +225,9 @@ const handleMousemove = async (generator: WorkflowGenerator, page: Page, { x, y * @param keyboardInput - the keyboard input of the keydown event * @category HelperFunctions */ -const onKeydown = async (socket: AuthenticatedSocket, keyboardInput: KeyboardInput) => { +const onKeydown = async (keyboardInput: KeyboardInput, userId: string) => { logger.log('debug', 'Handling keydown event emitted from client'); - await handleWrapper(handleKeydown, keyboardInput, socket); + await handleWrapper(handleKeydown, userId, keyboardInput); } /** @@ -286,9 +264,9 @@ const handleDateSelection = async (generator: WorkflowGenerator, page: Page, dat * @param data - the data of the date selection event * @category HelperFunctions */ -const onDateSelection = async (socket: AuthenticatedSocket, data: DatePickerEventData) => { +const onDateSelection = async (data: DatePickerEventData, userId: string) => { logger.log('debug', 'Handling date selection event emitted from client'); - await handleWrapper(handleDateSelection, data, socket); + await handleWrapper(handleDateSelection, userId, data); } /** @@ -309,9 +287,9 @@ const handleDropdownSelection = async (generator: WorkflowGenerator, page: Page, * @param data - the data of the dropdown selection event * @category HelperFunctions */ -const onDropdownSelection = async (socket: AuthenticatedSocket, data: { selector: string, value: string }) => { +const onDropdownSelection = async (data: { selector: string, value: string }, userId: string) => { logger.log('debug', 'Handling dropdown selection event emitted from client'); - await handleWrapper(handleDropdownSelection, data, socket); + await handleWrapper(handleDropdownSelection, userId, data); } /** @@ -332,9 +310,9 @@ const handleTimeSelection = async (generator: WorkflowGenerator, page: Page, dat * @param data - the data of the time selection event * @category HelperFunctions */ -const onTimeSelection = async (socket: AuthenticatedSocket, data: { selector: string, value: string }) => { +const onTimeSelection = async (data: { selector: string, value: string }, userId: string) => { logger.log('debug', 'Handling time selection event emitted from client'); - await handleWrapper(handleTimeSelection, data, socket); + await handleWrapper(handleTimeSelection, userId, data); } /** @@ -355,9 +333,9 @@ const handleDateTimeLocalSelection = async (generator: WorkflowGenerator, page: * @param data - the data of the datetime-local selection event * @category HelperFunctions */ -const onDateTimeLocalSelection = async (socket: AuthenticatedSocket, data: { selector: string, value: string }) => { +const onDateTimeLocalSelection = async (data: { selector: string, value: string }, userId: string) => { logger.log('debug', 'Handling datetime-local selection event emitted from client'); - await handleWrapper(handleDateTimeLocalSelection, data, socket); + await handleWrapper(handleDateTimeLocalSelection, userId, data); } /** @@ -366,9 +344,9 @@ const onDateTimeLocalSelection = async (socket: AuthenticatedSocket, data: { sel * @param keyboardInput - the keyboard input of the keyup event * @category HelperFunctions */ -const onKeyup = async (socket: AuthenticatedSocket, keyboardInput: KeyboardInput) => { +const onKeyup = async (keyboardInput: KeyboardInput, userId: string) => { logger.log('debug', 'Handling keyup event emitted from client'); - await handleWrapper(handleKeyup, keyboardInput, socket); + await handleWrapper(handleKeyup, userId, keyboardInput); } /** @@ -391,9 +369,9 @@ const handleKeyup = async (generator: WorkflowGenerator, page: Page, key: string * @param url - the new url of the page * @category HelperFunctions */ -const onChangeUrl = async (socket: AuthenticatedSocket, url: string) => { +const onChangeUrl = async (url: string, userId: string) => { logger.log('debug', 'Handling change url event emitted from client'); - await handleWrapper(handleChangeUrl, url, socket); + await handleWrapper(handleChangeUrl, userId, url); } /** @@ -424,9 +402,9 @@ const handleChangeUrl = async (generator: WorkflowGenerator, page: Page, url: st * @param socket The socket connection * @category HelperFunctions */ -const onRefresh = async (socket: AuthenticatedSocket) => { +const onRefresh = async (userId: string) => { logger.log('debug', 'Handling refresh event emitted from client'); - await handleWrapper(handleRefresh, undefined, socket); + await handleWrapper(handleRefresh, userId, undefined); } /** @@ -446,9 +424,9 @@ const handleRefresh = async (generator: WorkflowGenerator, page: Page) => { * @param socket The socket connection * @category HelperFunctions */ -const onGoBack = async (socket: AuthenticatedSocket) => { +const onGoBack = async (userId: string) => { logger.log('debug', 'Handling go back event emitted from client'); - await handleWrapper(handleGoBack, undefined, socket); + await handleWrapper(handleGoBack, userId, undefined); } /** @@ -469,9 +447,9 @@ const handleGoBack = async (generator: WorkflowGenerator, page: Page) => { * @param socket The socket connection * @category HelperFunctions */ -const onGoForward = async (socket: AuthenticatedSocket) => { +const onGoForward = async (userId: string) => { logger.log('debug', 'Handling go forward event emitted from client'); - await handleWrapper(handleGoForward, undefined, socket); + await handleWrapper(handleGoForward, userId, undefined); } /** @@ -499,25 +477,22 @@ const handleGoForward = async (generator: WorkflowGenerator, page: Page) => { * @returns void * @category BrowserManagement */ -const registerInputHandlers = (socket: Socket) => { - // Cast to our authenticated socket type - const authSocket = socket as AuthenticatedSocket; - +const registerInputHandlers = (socket: Socket, userId: string) => { // Register handlers with the socket - socket.on("input:mousedown", (data) => onMousedown(authSocket, data)); - socket.on("input:wheel", (data) => onWheel(authSocket, data)); - socket.on("input:mousemove", (data) => onMousemove(authSocket, data)); - socket.on("input:keydown", (data) => onKeydown(authSocket, data)); - socket.on("input:keyup", (data) => onKeyup(authSocket, data)); - socket.on("input:url", (data) => onChangeUrl(authSocket, data)); - socket.on("input:refresh", () => onRefresh(authSocket)); - socket.on("input:back", () => onGoBack(authSocket)); - socket.on("input:forward", () => onGoForward(authSocket)); - socket.on("input:date", (data) => onDateSelection(authSocket, data)); - socket.on("input:dropdown", (data) => onDropdownSelection(authSocket, data)); - socket.on("input:time", (data) => onTimeSelection(authSocket, data)); - socket.on("input:datetime-local", (data) => onDateTimeLocalSelection(authSocket, data)); - socket.on("action", (data) => onGenerateAction(authSocket, data)); + socket.on("input:mousedown", (data) => onMousedown(data, userId)); + socket.on("input:wheel", (data) => onWheel(data, userId)); + socket.on("input:mousemove", (data) => onMousemove(data, userId)); + socket.on("input:keydown", (data) => onKeydown(data, userId)); + socket.on("input:keyup", (data) => onKeyup(data, userId)); + socket.on("input:url", (data) => onChangeUrl(data, userId)); + socket.on("input:refresh", () => onRefresh(userId)); + socket.on("input:back", () => onGoBack(userId)); + socket.on("input:forward", () => onGoForward(userId)); + socket.on("input:date", (data) => onDateSelection(data, userId)); + socket.on("input:dropdown", (data) => onDropdownSelection(data, userId)); + socket.on("input:time", (data) => onTimeSelection(data, userId)); + socket.on("input:datetime-local", (data) => onDateTimeLocalSelection(data, userId)); + socket.on("action", (data) => onGenerateAction(data, userId)); }; export default registerInputHandlers; diff --git a/server/src/socket-connection/connection.ts b/server/src/socket-connection/connection.ts index a7f7565d2..109e50cb6 100644 --- a/server/src/socket-connection/connection.ts +++ b/server/src/socket-connection/connection.ts @@ -1,98 +1,7 @@ import { Namespace, Socket } from 'socket.io'; -import { IncomingMessage } from 'http'; -import { verify, JwtPayload } from 'jsonwebtoken'; import logger from "../logger"; import registerInputHandlers from '../browser-management/inputHandlers'; -interface AuthenticatedIncomingMessage extends IncomingMessage { - user?: JwtPayload | string; -} - -interface AuthenticatedSocket extends Socket { - request: AuthenticatedIncomingMessage; -} - -declare global { - var userContextMap: Map; -} - -if (!global.userContextMap) { - global.userContextMap = new Map(); -} - -/** - * Register browser-user association in the global context map - */ -export function registerBrowserUserContext(browserId: string, userId: string) { - if (!global.userContextMap) { - global.userContextMap = new Map(); - } - global.userContextMap.set(browserId, userId); - logger.log('debug', `Registered browser-user association: ${browserId} -> ${userId}`); -} - -/** - * Socket.io middleware for authentication - * This is a socket.io specific auth handler that doesn't rely on Express middleware - */ -const socketAuthMiddleware = (socket: Socket, next: (err?: Error) => void) => { - // Extract browserId from namespace - const namespace = socket.nsp.name; - const browserId = namespace.slice(1); - - // Check if this browser is in our context map - if (global.userContextMap && global.userContextMap.has(browserId)) { - const userId = global.userContextMap.get(browserId); - logger.log('debug', `Found browser in context map: ${browserId} -> ${userId}`); - - const authSocket = socket as AuthenticatedSocket; - authSocket.request.user = { id: userId }; - return next(); - } - - const cookies = socket.handshake.headers.cookie; - if (!cookies) { - logger.log('debug', `No cookies found in socket handshake for ${browserId}`); - return next(new Error('Authentication required')); - } - - const tokenMatch = cookies.split(';').find(c => c.trim().startsWith('token=')); - if (!tokenMatch) { - logger.log('debug', `No token cookie found in socket handshake for ${browserId}`); - return next(new Error('Authentication required')); - } - - const token = tokenMatch.split('=')[1]; - if (!token) { - logger.log('debug', `Empty token value in cookie for ${browserId}`); - return next(new Error('Authentication required')); - } - - const secret = process.env.JWT_SECRET; - if (!secret) { - logger.error('JWT_SECRET environment variable is not defined'); - return next(new Error('Server configuration error')); - } - - verify(token, secret, (err: any, user: any) => { - if (err) { - logger.log('warn', `JWT verification error: ${err.message}`); - return next(new Error('Authentication failed')); - } - - // Normalize payload key - if (user.userId && !user.id) { - user.id = user.userId; - delete user.userId; - } - - // Attach user to socket request - const authSocket = socket as AuthenticatedSocket; - authSocket.request.user = user; - next(); - }); -}; - /** * Opens a websocket canal for duplex data transfer and registers all handlers for this data for the recording session. * Uses socket.io dynamic namespaces for multiplexing the traffic from different running remote browser instances. @@ -102,13 +11,12 @@ const socketAuthMiddleware = (socket: Socket, next: (err?: Error) => void) => { */ export const createSocketConnection = ( io: Namespace, + userId: string, callback: (socket: Socket) => void, ) => { - io.use(socketAuthMiddleware); - const onConnection = async (socket: Socket) => { logger.log('info', "Client connected " + socket.id); - registerInputHandlers(socket); + registerInputHandlers(socket, userId); socket.on('disconnect', () => logger.log('info', "Client disconnected " + socket.id)); callback(socket); } @@ -127,8 +35,6 @@ export const createSocketConnectionForRun = ( io: Namespace, callback: (socket: Socket) => void, ) => { - io.use(socketAuthMiddleware); - const onConnection = async (socket: Socket) => { logger.log('info', "Client connected " + socket.id); socket.on('disconnect', () => logger.log('info', "Client disconnected " + socket.id));