Skip to content

feat: remove authenticated socket logic #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions server/src/browser-management/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import { Socket } from "socket.io";
import { uuid } from 'uuidv4';

import { createSocketConnection, createSocketConnectionForRun, registerBrowserUserContext } from "../socket-connection/connection";
import { createSocketConnection, createSocketConnectionForRun } from "../socket-connection/connection";
import { io, browserPool } from "../server";
import { RemoteBrowser } from "./classes/RemoteBrowser";
import { RemoteBrowserOptions } from "../types";
Expand All @@ -24,6 +24,7 @@ export const initializeRemoteBrowserForRecording = (userId: string): string => {
const id = getActiveBrowserIdByState(userId, "recording") || uuid();
createSocketConnection(
io.of(id),
userId,
async (socket: Socket) => {
// browser is already active
const activeId = getActiveBrowserIdByState(userId, "recording");
Expand Down Expand Up @@ -55,11 +56,8 @@ export const initializeRemoteBrowserForRecording = (userId: string): string => {
export const createRemoteBrowserForRun = (userId: string): string => {
const id = uuid();

registerBrowserUserContext(id, userId);
logger.log('debug', `Created new browser for run: ${id} for user: ${userId}`);

createSocketConnectionForRun(
io.of(`/${id}`),
io.of(id),
async (socket: Socket) => {
try {
const browserSession = new RemoteBrowser(socket, userId, id);
Expand Down
115 changes: 45 additions & 70 deletions server/src/browser-management/inputHandlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
* These functions are called by the client through socket communication.
*/
import { Socket } from 'socket.io';
import { IncomingMessage } from 'http';
import { JwtPayload } from 'jsonwebtoken';

import logger from "../logger";
import { Coordinates, ScrollDeltas, KeyboardInput, DatePickerEventData } from '../types';
import { browserPool } from "../server";
Expand All @@ -15,14 +12,6 @@ import { Page } from "playwright";
import { throttle } from "../../../src/helpers/inputHelpers";
import { CustomActions } from "../../../src/shared/types";

interface AuthenticatedIncomingMessage extends IncomingMessage {
user?: JwtPayload | string;
}

interface AuthenticatedSocket extends Socket {
request: AuthenticatedIncomingMessage;
}

/**
* A wrapper function for handling user input.
* This function gets the active browser instance from the browser pool
Expand All @@ -42,20 +31,9 @@ const handleWrapper = async (
page: Page,
args?: any
) => Promise<void>,
args?: any,
socket?: AuthenticatedSocket,
userId: string,
args?: any
) => {
if (!socket || !socket.request || !socket.request.user || typeof socket.request.user === 'string') {
logger.log('warn', `User not authenticated or invalid JWT payload`);
return;
}

const userId = socket.request.user.id;
if (!userId) {
logger.log('warn', `User ID is missing in JWT payload`);
return;
}

const id = browserPool.getActiveBrowserId(userId, "recording");
if (id) {
const activeBrowser = browserPool.getRemoteBrowser(id);
Expand Down Expand Up @@ -93,9 +71,9 @@ interface CustomActionEventData {
* @param customActionEventData The custom action event data
* @category HelperFunctions
*/
const onGenerateAction = async (socket: AuthenticatedSocket, customActionEventData: CustomActionEventData) => {
const onGenerateAction = async (customActionEventData: CustomActionEventData, userId: string) => {
logger.log('debug', `Generating ${customActionEventData.action} action emitted from client`);
await handleWrapper(handleGenerateAction, customActionEventData, socket);
await handleWrapper(handleGenerateAction, userId, customActionEventData);
}

/**
Expand All @@ -117,9 +95,9 @@ const handleGenerateAction =
* @param coordinates - coordinates of the mouse click
* @category HelperFunctions
*/
const onMousedown = async (socket: AuthenticatedSocket, coordinates: Coordinates) => {
const onMousedown = async (coordinates: Coordinates, userId: string) => {
logger.log('debug', 'Handling mousedown event emitted from client');
await handleWrapper(handleMousedown, coordinates, socket);
await handleWrapper(handleMousedown, userId, coordinates);
}

/**
Expand Down Expand Up @@ -168,9 +146,9 @@ const handleMousedown = async (generator: WorkflowGenerator, page: Page, { x, y
* @param scrollDeltas - the scroll deltas of the wheel event
* @category HelperFunctions
*/
const onWheel = async (socket: AuthenticatedSocket, scrollDeltas: ScrollDeltas) => {
const onWheel = async (scrollDeltas: ScrollDeltas, userId: string) => {
logger.log('debug', 'Handling scroll event emitted from client');
await handleWrapper(handleWheel, scrollDeltas, socket);
await handleWrapper(handleWheel, userId, scrollDeltas);
};

/**
Expand Down Expand Up @@ -206,9 +184,9 @@ const handleWheel = async (generator: WorkflowGenerator, page: Page, { deltaX, d
* @param coordinates - the coordinates of the mousemove event
* @category HelperFunctions
*/
const onMousemove = async (socket: AuthenticatedSocket, coordinates: Coordinates) => {
const onMousemove = async (coordinates: Coordinates, userId: string) => {
logger.log('debug', 'Handling mousemove event emitted from client');
await handleWrapper(handleMousemove, coordinates, socket);
await handleWrapper(handleMousemove, userId, coordinates);
}

/**
Expand Down Expand Up @@ -247,9 +225,9 @@ const handleMousemove = async (generator: WorkflowGenerator, page: Page, { x, y
* @param keyboardInput - the keyboard input of the keydown event
* @category HelperFunctions
*/
const onKeydown = async (socket: AuthenticatedSocket, keyboardInput: KeyboardInput) => {
const onKeydown = async (keyboardInput: KeyboardInput, userId: string) => {
logger.log('debug', 'Handling keydown event emitted from client');
await handleWrapper(handleKeydown, keyboardInput, socket);
await handleWrapper(handleKeydown, userId, keyboardInput);
}

/**
Expand Down Expand Up @@ -286,9 +264,9 @@ const handleDateSelection = async (generator: WorkflowGenerator, page: Page, dat
* @param data - the data of the date selection event
* @category HelperFunctions
*/
const onDateSelection = async (socket: AuthenticatedSocket, data: DatePickerEventData) => {
const onDateSelection = async (data: DatePickerEventData, userId: string) => {
logger.log('debug', 'Handling date selection event emitted from client');
await handleWrapper(handleDateSelection, data, socket);
await handleWrapper(handleDateSelection, userId, data);
}

/**
Expand All @@ -309,9 +287,9 @@ const handleDropdownSelection = async (generator: WorkflowGenerator, page: Page,
* @param data - the data of the dropdown selection event
* @category HelperFunctions
*/
const onDropdownSelection = async (socket: AuthenticatedSocket, data: { selector: string, value: string }) => {
const onDropdownSelection = async (data: { selector: string, value: string }, userId: string) => {
logger.log('debug', 'Handling dropdown selection event emitted from client');
await handleWrapper(handleDropdownSelection, data, socket);
await handleWrapper(handleDropdownSelection, userId, data);
}

/**
Expand All @@ -332,9 +310,9 @@ const handleTimeSelection = async (generator: WorkflowGenerator, page: Page, dat
* @param data - the data of the time selection event
* @category HelperFunctions
*/
const onTimeSelection = async (socket: AuthenticatedSocket, data: { selector: string, value: string }) => {
const onTimeSelection = async (data: { selector: string, value: string }, userId: string) => {
logger.log('debug', 'Handling time selection event emitted from client');
await handleWrapper(handleTimeSelection, data, socket);
await handleWrapper(handleTimeSelection, userId, data);
}

/**
Expand All @@ -355,9 +333,9 @@ const handleDateTimeLocalSelection = async (generator: WorkflowGenerator, page:
* @param data - the data of the datetime-local selection event
* @category HelperFunctions
*/
const onDateTimeLocalSelection = async (socket: AuthenticatedSocket, data: { selector: string, value: string }) => {
const onDateTimeLocalSelection = async (data: { selector: string, value: string }, userId: string) => {
logger.log('debug', 'Handling datetime-local selection event emitted from client');
await handleWrapper(handleDateTimeLocalSelection, data, socket);
await handleWrapper(handleDateTimeLocalSelection, userId, data);
}

/**
Expand All @@ -366,9 +344,9 @@ const onDateTimeLocalSelection = async (socket: AuthenticatedSocket, data: { sel
* @param keyboardInput - the keyboard input of the keyup event
* @category HelperFunctions
*/
const onKeyup = async (socket: AuthenticatedSocket, keyboardInput: KeyboardInput) => {
const onKeyup = async (keyboardInput: KeyboardInput, userId: string) => {
logger.log('debug', 'Handling keyup event emitted from client');
await handleWrapper(handleKeyup, keyboardInput, socket);
await handleWrapper(handleKeyup, userId, keyboardInput);
}

/**
Expand All @@ -391,9 +369,9 @@ const handleKeyup = async (generator: WorkflowGenerator, page: Page, key: string
* @param url - the new url of the page
* @category HelperFunctions
*/
const onChangeUrl = async (socket: AuthenticatedSocket, url: string) => {
const onChangeUrl = async (url: string, userId: string) => {
logger.log('debug', 'Handling change url event emitted from client');
await handleWrapper(handleChangeUrl, url, socket);
await handleWrapper(handleChangeUrl, userId, url);
}

/**
Expand Down Expand Up @@ -424,9 +402,9 @@ const handleChangeUrl = async (generator: WorkflowGenerator, page: Page, url: st
* @param socket The socket connection
* @category HelperFunctions
*/
const onRefresh = async (socket: AuthenticatedSocket) => {
const onRefresh = async (userId: string) => {
logger.log('debug', 'Handling refresh event emitted from client');
await handleWrapper(handleRefresh, undefined, socket);
await handleWrapper(handleRefresh, userId, undefined);
}

/**
Expand All @@ -446,9 +424,9 @@ const handleRefresh = async (generator: WorkflowGenerator, page: Page) => {
* @param socket The socket connection
* @category HelperFunctions
*/
const onGoBack = async (socket: AuthenticatedSocket) => {
const onGoBack = async (userId: string) => {
logger.log('debug', 'Handling go back event emitted from client');
await handleWrapper(handleGoBack, undefined, socket);
await handleWrapper(handleGoBack, userId, undefined);
}

/**
Expand All @@ -469,9 +447,9 @@ const handleGoBack = async (generator: WorkflowGenerator, page: Page) => {
* @param socket The socket connection
* @category HelperFunctions
*/
const onGoForward = async (socket: AuthenticatedSocket) => {
const onGoForward = async (userId: string) => {
logger.log('debug', 'Handling go forward event emitted from client');
await handleWrapper(handleGoForward, undefined, socket);
await handleWrapper(handleGoForward, userId, undefined);
}

/**
Expand Down Expand Up @@ -499,25 +477,22 @@ const handleGoForward = async (generator: WorkflowGenerator, page: Page) => {
* @returns void
* @category BrowserManagement
*/
const registerInputHandlers = (socket: Socket) => {
// Cast to our authenticated socket type
const authSocket = socket as AuthenticatedSocket;

const registerInputHandlers = (socket: Socket, userId: string) => {
// Register handlers with the socket
socket.on("input:mousedown", (data) => onMousedown(authSocket, data));
socket.on("input:wheel", (data) => onWheel(authSocket, data));
socket.on("input:mousemove", (data) => onMousemove(authSocket, data));
socket.on("input:keydown", (data) => onKeydown(authSocket, data));
socket.on("input:keyup", (data) => onKeyup(authSocket, data));
socket.on("input:url", (data) => onChangeUrl(authSocket, data));
socket.on("input:refresh", () => onRefresh(authSocket));
socket.on("input:back", () => onGoBack(authSocket));
socket.on("input:forward", () => onGoForward(authSocket));
socket.on("input:date", (data) => onDateSelection(authSocket, data));
socket.on("input:dropdown", (data) => onDropdownSelection(authSocket, data));
socket.on("input:time", (data) => onTimeSelection(authSocket, data));
socket.on("input:datetime-local", (data) => onDateTimeLocalSelection(authSocket, data));
socket.on("action", (data) => onGenerateAction(authSocket, data));
socket.on("input:mousedown", (data) => onMousedown(data, userId));
socket.on("input:wheel", (data) => onWheel(data, userId));
socket.on("input:mousemove", (data) => onMousemove(data, userId));
socket.on("input:keydown", (data) => onKeydown(data, userId));
socket.on("input:keyup", (data) => onKeyup(data, userId));
socket.on("input:url", (data) => onChangeUrl(data, userId));
socket.on("input:refresh", () => onRefresh(userId));
socket.on("input:back", () => onGoBack(userId));
socket.on("input:forward", () => onGoForward(userId));
socket.on("input:date", (data) => onDateSelection(data, userId));
socket.on("input:dropdown", (data) => onDropdownSelection(data, userId));
socket.on("input:time", (data) => onTimeSelection(data, userId));
socket.on("input:datetime-local", (data) => onDateTimeLocalSelection(data, userId));
socket.on("action", (data) => onGenerateAction(data, userId));
};

export default registerInputHandlers;
Loading