Skip to content

[IMP] peek into the JWT to get the channel uuid #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 56 additions & 26 deletions src/services/auth.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ import { AuthenticationError } from "#src/utils/errors.js";
* @property {string} [jti] - JWT ID
*/

/**
* @typedef {Object} JWTData
* @property {JWTHeader} header - The JWT header
* @property {JWTClaims} claims - The JWT claims
* @property {Buffer} signature - The JWT signature
* @property {string} signedData - The signed data (header + claims)
*/

let jwtKey;
const logger = new Logger("AUTH");
const ALGORITHM = {
Expand Down Expand Up @@ -84,7 +92,7 @@ function base64Decode(str) {
* Signs and creates a JsonWebToken
*
* @param {JWTClaims} claims - The claims to include in the token
* @param {WithImplicitCoercion<string>} [key] - Optional key, defaults to the configured jwtKey
* @param {WithImplicitCoercion<string> | Buffer} [key] - Optional key, defaults to the configured jwtKey
* @param {Object} [options]
* @param {string} [options.algorithm] - The algorithm to use, defaults to HS256
* @returns {string} - The signed JsonWebToken
Expand Down Expand Up @@ -144,31 +152,53 @@ function safeEqual(a, b) {
* @throws {AuthenticationError}
*/
export function verify(jsonWebToken, key = jwtKey) {
const keyBuffer = Buffer.isBuffer(key) ? key : Buffer.from(key, "base64");
let parsedJWT;
try {
parsedJWT = parseJwt(jsonWebToken);
} catch {
throw new AuthenticationError("Invalid JWT format");
}
const { header, claims, signature, signedData } = parsedJWT;
const expectedSignature = ALGORITHM_FUNCTIONS[header.alg]?.(signedData, keyBuffer);
if (!expectedSignature) {
throw new AuthenticationError(`Unsupported algorithm: ${header.alg}`);
}
if (!safeEqual(signature, expectedSignature)) {
throw new AuthenticationError("Invalid signature");
}
// `exp`, `iat` and `nbf` are in seconds (`NumericDate` per RFC7519)
const now = Math.floor(Date.now() / 1000);
if (claims.exp && claims.exp < now) {
throw new AuthenticationError("Token expired");
}
if (claims.nbf && claims.nbf > now) {
throw new AuthenticationError("Token not valid yet");
const jwt = new JsonWebToken(jsonWebToken);
return jwt.verify(key);
}

export class JsonWebToken {
/**
* @type {JWTData}
*/
unsafe;
/**
* @param {string} jsonWebToken
*/
constructor(jsonWebToken) {
let payload;
try {
payload = parseJwt(jsonWebToken);
} catch {
throw new AuthenticationError("Malformed JWT");
}
this.unsafe = payload;
}
if (claims.iat && claims.iat > now + 60) {
throw new AuthenticationError("Token issued in the future");

/**
* @param {WithImplicitCoercion<string>} [key] buffer/b64 str
* @return {JWTClaims}
*/
verify(key = jwtKey) {
const { header, claims, signature, signedData } = this.unsafe;
const keyBuffer = Buffer.isBuffer(key) ? key : Buffer.from(key, "base64");
const expectedSignature = ALGORITHM_FUNCTIONS[header.alg]?.(signedData, keyBuffer);
if (!expectedSignature) {
throw new AuthenticationError(`Unsupported algorithm: ${header.alg}`);
}
if (!safeEqual(signature, expectedSignature)) {
throw new AuthenticationError("Invalid signature");
}
// `exp`, `iat` and `nbf` are in seconds (`NumericDate` per RFC7519)
const now = Math.floor(Date.now() / 1000);
if (claims.exp && claims.exp < now) {
throw new AuthenticationError("Token expired");
}
if (claims.nbf && claims.nbf > now) {
throw new AuthenticationError("Token not valid yet");
}
if (claims.iat && claims.iat > now + 60) {
throw new AuthenticationError("Token issued in the future");
}
return claims;
}
return claims;
}
23 changes: 7 additions & 16 deletions src/services/ws.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import { Logger, extractRequestInfo } from "#src/utils/utils.js";
import { AuthenticationError, OvercrowdedError } from "#src/utils/errors.js";
import { SESSION_CLOSE_CODE } from "#src/models/session.js";
import { Channel } from "#src/models/channel.js";
import { verify } from "#src/services/auth.js";
import { JsonWebToken } from "#src/services/auth.js";

/**
* @typedef Credentials
* @property {string} channelUUID
* @property {string} channelUUID deprecated, this is obtained from the jwt
* @property {string} jwt
*/

Expand Down Expand Up @@ -53,7 +53,6 @@ export async function start(options) {
/** @type {Credentials | String} can be a string (the jwt) for backwards compatibility with version 1.1 and earlier */
const credentials = JSON.parse(message);
const session = connect(webSocket, {
channelUUID: credentials?.channelUUID,
jwt: credentials.jwt || credentials,
});
session.remote = remoteAddress;
Expand Down Expand Up @@ -102,22 +101,14 @@ export function close() {
* @param {import("ws").WebSocket} webSocket
* @param {Credentials}
*/
function connect(webSocket, { channelUUID, jwt }) {
let channel = Channel.records.get(channelUUID);
const authResult = verify(jwt, channel?.key);
const { sfu_channel_uuid, session_id, ice_servers } = authResult;
if (!channelUUID && sfu_channel_uuid) {
// Cases where the channelUUID is not provided in the credentials for backwards compatibility with version 1.1 and earlier.
channel = Channel.records.get(sfu_channel_uuid);
if (channel.key) {
throw new AuthenticationError(
"A channel with a key can only be accessed by providing a channelUUID in the credentials"
);
}
}
function connect(webSocket, { jwt }) {
const token = new JsonWebToken(jwt);
const channel = Channel.records.get(token.unsafe.claims.sfu_channel_uuid);
if (!channel) {
throw new AuthenticationError(`Channel does not exist`);
}
const authResult = token.verify(channel.key);
const { session_id, ice_servers } = authResult;
if (!session_id) {
throw new AuthenticationError("Malformed JWT payload");
}
Expand Down
2 changes: 1 addition & 1 deletion tests/network.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ describe("Full network", () => {
expect(closeEvent.code).toBe(SESSION_CLOSE_CODE.P_TIMEOUT);
});
test("A client can broadcast arbitrary messages to other clients on a channel that does not have webRTC", async () => {
const channelUUID = await network.getChannelUUID(false);
const channelUUID = await network.getChannelUUID({ useWebRtc: false });
const user1 = await network.connect(channelUUID, 1);
const user2 = await network.connect(channelUUID, 2);
const sender = await network.connect(channelUUID, 3);
Expand Down
13 changes: 13 additions & 0 deletions tests/security.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,17 @@ describe("Security", () => {
const [event] = await once(websocket, "close");
expect(event).toBe(WS_CLOSE_CODE.TIMEOUT);
});
test("cannot use the default jwt key to access a keyed channel", async () => {
const channelUUID = await network.getChannelUUID({ key: "channel-specific-key" });
const channel = Channel.records.get(channelUUID);
await expect(network.connect(channelUUID, 3)).rejects.toThrow();
expect(channel.sessions.size).toBe(0);
});
test("can join a keyed channel with the appropriate key", async () => {
const key = "channel-specific-key";
const channelUUID = await network.getChannelUUID({ key });
const channel = Channel.records.get(channelUUID);
await network.connect(channelUUID, 4, { key });
expect(channel.sessions.size).toBe(1);
});
});
27 changes: 17 additions & 10 deletions tests/utils/network.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import { Channel } from "#src/models/channel.js";
const HMAC_B64_KEY = "u6bsUQEWrHdKIuYplirRnbBmLbrKV5PxKG7DtA71mng=";
const HMAC_KEY = Buffer.from(HMAC_B64_KEY, "base64");

export function makeJwt(data) {
return auth.sign(data, HMAC_KEY, { algorithm: "HS256" });
export function makeJwt(data, { key = HMAC_KEY } = {}) {
return auth.sign(data, key, { algorithm: "HS256" });
}

/**
Expand Down Expand Up @@ -44,13 +44,15 @@ export class LocalNetwork {
}

/**
* @param {boolean} [useWebRtc]
* @param {Object} [param0]
* @param {boolean} [useWebRtc=true]
* @param {string} [key] the channel-specific key
* @returns {Promise<string>}
*/
async getChannelUUID(useWebRtc = true) {
async getChannelUUID({ useWebRtc = true, key = HMAC_B64_KEY } = {}) {
const jwt = this.makeJwt({
iss: `http://${this.hostname}:${this.port}/`,
key: HMAC_B64_KEY,
key,
});
const response = await fetch(
`http://${this.hostname}:${this.port}/v${http.API_VERSION}/channel?webRTC=${useWebRtc}`,
Expand All @@ -70,10 +72,12 @@ export class LocalNetwork {
*
* @param {string} channelUUID
* @param {number} sessionId
* @param {Object} [options]
* @param {string} [options.key] the key to use to authenticate the session (this should be the key of the channel)
* @returns { Promise<{ session: import("#src/models/session.js").Session, sfuClient: import("#src/client.js").SfuClient }>}
* @throws {Error} if the client is closed before being authenticated
*/
async connect(channelUUID, sessionId) {
async connect(channelUUID, sessionId, { key = HMAC_KEY } = {}) {
const sfuClient = new SfuClient();
this._sfuClients.push(sfuClient);
sfuClient._createDevice = () => {
Expand All @@ -100,10 +104,13 @@ export class LocalNetwork {
});
sfuClient.connect(
`ws://${this.hostname}:${this.port}`,
this.makeJwt({
sfu_channel_uuid: channelUUID,
session_id: sessionId,
}),
this.makeJwt(
{
sfu_channel_uuid: channelUUID,
session_id: sessionId,
},
{ key }
),
{ channelUUID }
);
const channel = Channel.records.get(channelUUID);
Expand Down