diff --git a/apps/mesh/migrations/016-downstream-token-client-info.ts b/apps/mesh/migrations/016-downstream-token-client-info.ts new file mode 100644 index 000000000..8b3124ff4 --- /dev/null +++ b/apps/mesh/migrations/016-downstream-token-client-info.ts @@ -0,0 +1,53 @@ +/** + * Migration 015: Add client registration info to downstream_tokens + * + * Adds clientId and clientSecret columns to support Dynamic Client Registration + * and token refresh for downstream MCP OAuth flows. + */ + +import type { Kysely } from "kysely"; + +export async function up(db: Kysely): Promise { + // Add clientId and clientSecret for Dynamic Client Registration + await db.schema + .alterTable("downstream_tokens") + .addColumn("clientId", "text") + .execute(); + + await db.schema + .alterTable("downstream_tokens") + .addColumn("clientSecret", "text") + .execute(); + + // Add tokenEndpoint to know where to refresh + await db.schema + .alterTable("downstream_tokens") + .addColumn("tokenEndpoint", "text") + .execute(); + + // Create index for faster lookups by connectionId + userId + await db.schema + .createIndex("idx_downstream_tokens_connection_user") + .on("downstream_tokens") + .columns(["connectionId", "userId"]) + .execute(); +} + +export async function down(db: Kysely): Promise { + await db.schema.dropIndex("idx_downstream_tokens_connection_user").execute(); + + await db.schema + .alterTable("downstream_tokens") + .dropColumn("tokenEndpoint") + .execute(); + + await db.schema + .alterTable("downstream_tokens") + .dropColumn("clientSecret") + .execute(); + + await db.schema + .alterTable("downstream_tokens") + .dropColumn("clientId") + .execute(); +} diff --git a/apps/mesh/migrations/index.ts b/apps/mesh/migrations/index.ts index 47ff15a59..475088e02 100644 --- a/apps/mesh/migrations/index.ts +++ b/apps/mesh/migrations/index.ts @@ -14,6 +14,7 @@ import * as migration012gatewaytoolselectionmode from "./012-gateway-tool-select import * as migration013monitoringuseragentgateway from "./013-monitoring-user-agent-gateway.ts"; import * as migration014gatewayresourcesprompts from "./014-gateway-resources-prompts.ts"; import * as migration015monitoringproperties from "./015-monitoring-properties.ts"; +import * as migration016downstreamtokenclientinfo from "./016-downstream-token-client-info.ts"; const migrations = { "001-initial-schema": migration001initialschema, @@ -31,6 +32,7 @@ const migrations = { "013-monitoring-user-agent-gateway": migration013monitoringuseragentgateway, "014-gateway-resources-prompts": migration014gatewayresourcesprompts, "015-monitoring-properties": migration015monitoringproperties, + "016-downstream-token-client-info": migration016downstreamtokenclientinfo, } satisfies Record; export default migrations; diff --git a/apps/mesh/src/api/app.ts b/apps/mesh/src/api/app.ts index ba900b39d..a5405f312 100644 --- a/apps/mesh/src/api/app.ts +++ b/apps/mesh/src/api/app.ts @@ -22,6 +22,7 @@ import { shouldSkipMeshContext, SYSTEM_PATHS } from "./utils/paths"; import { createEventBus, type EventBus } from "../event-bus"; import { meter, prometheusExporter, tracer } from "../observability"; import authRoutes from "./routes/auth"; +import downstreamTokenRoutes from "./routes/downstream-token"; import gatewayRoutes from "./routes/gateway"; import managementRoutes from "./routes/management"; import modelsRoutes from "./routes/models"; @@ -550,6 +551,9 @@ export function createApp(options: CreateAppOptions = {}) { return c.json({ success: true }); }); + // Downstream token management routes + app.route("/api", downstreamTokenRoutes); + // ============================================================================ // 404 Handler // ============================================================================ diff --git a/apps/mesh/src/api/routes/downstream-token.test.ts b/apps/mesh/src/api/routes/downstream-token.test.ts new file mode 100644 index 000000000..30606b555 --- /dev/null +++ b/apps/mesh/src/api/routes/downstream-token.test.ts @@ -0,0 +1,95 @@ +import { describe, it, expect, beforeEach, afterEach, mock } from "bun:test"; +import { Hono } from "hono"; +import type { MeshContext } from "../../core/mesh-context"; +import { CredentialVault } from "../../encryption/credential-vault"; +import { + createDatabase, + closeDatabase, + type MeshDatabase, +} from "../../database"; +import { createTestSchema } from "../../storage/test-helpers"; +import downstreamTokenRoutes from "./downstream-token"; + +describe("Downstream Token Routes", () => { + let database: MeshDatabase; + let app: Hono<{ Variables: { meshContext: MeshContext } }>; + + beforeEach(async () => { + database = createDatabase(":memory:"); + await createTestSchema(database.db); + + const vault = new CredentialVault(CredentialVault.generateKey()); + + const ctx = { + db: database.db, + vault, + organization: { id: "org_1" }, + auth: { user: { id: "user_1" } }, + storage: { + connections: { + findById: mock(async () => ({ id: "conn_1" })), + }, + }, + } as unknown as MeshContext; + + app = new Hono(); + app.use("*", async (c, next) => { + c.set("meshContext", ctx); + await next(); + }); + app.route("/", downstreamTokenRoutes); + }); + + afterEach(async () => { + await closeDatabase(database); + mock.restore(); + }); + + it("rejects invalid tokenEndpoint", async () => { + const res = await app.request("/connections/conn_1/oauth-token", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + accessToken: "at", + tokenEndpoint: "not-a-url", + }), + }); + + expect(res.status).toBe(400); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe("tokenEndpoint must be a valid URL"); + }); + + it("rejects non-http(s) tokenEndpoint", async () => { + const res = await app.request("/connections/conn_1/oauth-token", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + accessToken: "at", + tokenEndpoint: "javascript:alert(1)", + }), + }); + + expect(res.status).toBe(400); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe("tokenEndpoint must be an http(s) URL"); + }); + + it("accepts http(s) tokenEndpoint", async () => { + const res = await app.request("/connections/conn_1/oauth-token", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + accessToken: "at", + refreshToken: "rt", + expiresIn: 3600, + tokenEndpoint: "https://example.com/token", + }), + }); + + expect(res.status).toBe(200); + const body = (await res.json()) as { success: boolean; expiresAt: string }; + expect(body.success).toBe(true); + expect(body.expiresAt).toBeTruthy(); + }); +}); diff --git a/apps/mesh/src/api/routes/downstream-token.ts b/apps/mesh/src/api/routes/downstream-token.ts new file mode 100644 index 000000000..07fedd841 --- /dev/null +++ b/apps/mesh/src/api/routes/downstream-token.ts @@ -0,0 +1,163 @@ +/** + * Downstream Token API Routes + * + * Handles OAuth token management for downstream MCP connections. + * Called from frontend after OAuth authentication to persist tokens. + */ + +import { Hono } from "hono"; +import type { MeshContext } from "../../core/mesh-context"; +import { + DownstreamTokenStorage, + type DownstreamTokenData, +} from "../../storage/downstream-token"; + +// Define Hono variables type +type Variables = { + meshContext: MeshContext; +}; + +const app = new Hono<{ Variables: Variables }>(); + +/** + * POST /api/connections/:connectionId/oauth-token + * + * Save OAuth tokens after authentication. + * Called from frontend after OAuth flow completes. + */ +app.post("/connections/:connectionId/oauth-token", async (c) => { + const ctx = c.get("meshContext"); + const connectionId = c.req.param("connectionId"); + + // Require authentication + const userId = ctx.auth.user?.id ?? ctx.auth.apiKey?.userId ?? null; + if (!userId) { + return c.json({ error: "Unauthorized" }, 401); + } + + // Verify connection exists and user has access + // Pass organizationId to ensure the user has access to this connection + // Connections are scoped to organizations, and ctx.storage.connections.findById + // enforces this check if organizationId is provided. + const connection = await ctx.storage.connections.findById( + connectionId, + ctx.organization?.id, + ); + if (!connection) { + return c.json({ error: "Connection not found" }, 404); + } + + // Parse request body + const body = await c.req.json<{ + accessToken: string; + refreshToken?: string | null; + expiresIn?: number | null; + scope?: string | null; + clientId?: string | null; + clientSecret?: string | null; + tokenEndpoint?: string | null; + }>(); + + if (!body.accessToken) { + return c.json({ error: "accessToken is required" }, 400); + } + + if (body.tokenEndpoint) { + let url: URL; + try { + url = new URL(body.tokenEndpoint); + } catch { + return c.json({ error: "tokenEndpoint must be a valid URL" }, 400); + } + + if (url.protocol !== "http:" && url.protocol !== "https:") { + return c.json({ error: "tokenEndpoint must be an http(s) URL" }, 400); + } + } + + // Calculate expiry time + const expiresAt = body.expiresIn + ? new Date(Date.now() + body.expiresIn * 1000) + : null; + + // Create storage instance + const tokenStorage = new DownstreamTokenStorage(ctx.db, ctx.vault); + + // Save token + const tokenData: DownstreamTokenData = { + connectionId, + userId, + accessToken: body.accessToken, + refreshToken: body.refreshToken ?? null, + scope: body.scope ?? null, + expiresAt, + clientId: body.clientId ?? null, + clientSecret: body.clientSecret ?? null, + tokenEndpoint: body.tokenEndpoint ?? null, + }; + + const token = await tokenStorage.upsert(tokenData); + + return c.json({ + success: true, + expiresAt: token.expiresAt, + }); +}); + +/** + * DELETE /api/connections/:connectionId/oauth-token + * + * Delete OAuth token for a connection. + */ +app.delete("/connections/:connectionId/oauth-token", async (c) => { + const ctx = c.get("meshContext"); + const connectionId = c.req.param("connectionId"); + + const userId = ctx.auth.user?.id ?? ctx.auth.apiKey?.userId ?? null; + if (!userId) { + return c.json({ error: "Unauthorized" }, 401); + } + + const tokenStorage = new DownstreamTokenStorage(ctx.db, ctx.vault); + await tokenStorage.delete(connectionId, userId); + + return c.json({ success: true }); +}); + +/** + * GET /api/connections/:connectionId/oauth-token/status + * + * Check if user has a valid cached token for a connection. + */ +app.get("/connections/:connectionId/oauth-token/status", async (c) => { + const ctx = c.get("meshContext"); + const connectionId = c.req.param("connectionId"); + + const userId = ctx.auth.user?.id ?? ctx.auth.apiKey?.userId ?? null; + if (!userId) { + return c.json({ error: "Unauthorized" }, 401); + } + + const tokenStorage = new DownstreamTokenStorage(ctx.db, ctx.vault); + const token = await tokenStorage.get(connectionId, userId); + + if (!token) { + return c.json({ + hasToken: false, + isExpired: true, + canRefresh: false, + }); + } + + const isExpired = tokenStorage.isExpired(token); + const canRefresh = !!token.refreshToken && !!token.tokenEndpoint; + + return c.json({ + hasToken: true, + isExpired, + canRefresh, + expiresAt: token.expiresAt, + }); +}); + +export default app; diff --git a/apps/mesh/src/api/routes/proxy.ts b/apps/mesh/src/api/routes/proxy.ts index eb1f4c252..e7cfe49dc 100644 --- a/apps/mesh/src/api/routes/proxy.ts +++ b/apps/mesh/src/api/routes/proxy.ts @@ -14,7 +14,9 @@ import { extractConnectionPermissions } from "@/auth/configuration-scopes"; import { once } from "@/common"; import { getMonitoringConfig } from "@/core/config"; +import { refreshAccessToken } from "@/oauth/token-refresh"; import { getStableStdioClient } from "@/stdio/stable-transport"; +import { DownstreamTokenStorage } from "@/storage/downstream-token"; import { ConnectionEntity, type HttpConnectionParameters, @@ -249,6 +251,7 @@ async function createMCPProxyDoNotUseDirectly( // Build request headers - reusable for both client and direct fetch // Now issues token lazily on first call + // Also handles token refresh for downstream OAuth tokens const buildRequestHeaders = async (): Promise> => { // Ensure configuration token is issued (lazy) await ensureConfigurationToken(); @@ -257,9 +260,74 @@ async function createMCPProxyDoNotUseDirectly( ...(callerConnectionId ? { "x-caller-id": callerConnectionId } : {}), }; - // Add connection token (already decrypted by storage layer) - if (connection.connection_token) { - headers["Authorization"] = `Bearer ${connection.connection_token}`; + // Try to get cached token from downstream_tokens first + // This supports OAuth token refresh for connections that use OAuth + const userId = ctx.auth.user?.id ?? ctx.auth.apiKey?.userId ?? null; + let accessToken: string | null = null; + + if (userId) { + const tokenStorage = new DownstreamTokenStorage(ctx.db, ctx.vault); + const cachedToken = await tokenStorage.get(connectionId, userId); + + if (cachedToken) { + // Check if token is expired or about to expire + if (tokenStorage.isExpired(cachedToken)) { + // Try to refresh if we have refresh capability + if (cachedToken.refreshToken && cachedToken.tokenEndpoint) { + console.log( + `[Proxy] Token expired for ${connectionId}, attempting refresh`, + ); + const refreshResult = await refreshAccessToken(cachedToken); + + if (refreshResult.success && refreshResult.accessToken) { + // Save refreshed token + await tokenStorage.upsert({ + connectionId, + userId, + accessToken: refreshResult.accessToken, + refreshToken: + refreshResult.refreshToken ?? cachedToken.refreshToken, + scope: refreshResult.scope ?? cachedToken.scope, + expiresAt: refreshResult.expiresIn + ? new Date(Date.now() + refreshResult.expiresIn * 1000) + : null, + clientId: cachedToken.clientId, + clientSecret: cachedToken.clientSecret, + tokenEndpoint: cachedToken.tokenEndpoint, + }); + + accessToken = refreshResult.accessToken; + console.log(`[Proxy] Token refreshed for ${connectionId}`); + } else { + // Refresh failed - token is invalid + // Delete the cached token so user gets prompted to re-auth + await tokenStorage.delete(connectionId, userId); + console.error( + `[Proxy] Token refresh failed for ${connectionId}: ${refreshResult.error}`, + ); + } + } else { + // Token expired but no refresh capability - delete it + await tokenStorage.delete(connectionId, userId); + console.log( + `[Proxy] Token expired without refresh capability for ${connectionId}`, + ); + } + } else { + // Token is still valid + accessToken = cachedToken.accessToken; + } + } + } + + // Fall back to connection token if no cached token + if (!accessToken && connection.connection_token) { + accessToken = connection.connection_token; + } + + // Add authorization header if we have a token + if (accessToken) { + headers["Authorization"] = `Bearer ${accessToken}`; } // Add configuration token if issued diff --git a/apps/mesh/src/oauth/token-refresh.ts b/apps/mesh/src/oauth/token-refresh.ts new file mode 100644 index 000000000..9ec8d6ea4 --- /dev/null +++ b/apps/mesh/src/oauth/token-refresh.ts @@ -0,0 +1,129 @@ +/** + * OAuth Token Refresh Utility + * + * Handles automatic token refresh for downstream MCP connections. + * Uses the refresh_token grant to obtain new access tokens. + */ + +import type { DownstreamToken } from "../storage/types"; + +/** + * Result of a token refresh attempt + */ +export interface TokenRefreshResult { + success: boolean; + accessToken?: string; + refreshToken?: string; + expiresIn?: number; + scope?: string; + error?: string; +} + +/** + * Refresh an OAuth access token using the refresh_token grant + * + * @param token - The downstream token containing refresh info + * @returns TokenRefreshResult with new tokens or error + */ +export async function refreshAccessToken( + token: DownstreamToken, +): Promise { + // Check if we have the required info for refresh + if (!token.refreshToken) { + return { + success: false, + error: "No refresh token available", + }; + } + + if (!token.tokenEndpoint) { + return { + success: false, + error: "No token endpoint available", + }; + } + + if (!token.clientId) { + return { + success: false, + error: "No client ID available", + }; + } + + try { + // Build the token request + const params = new URLSearchParams({ + grant_type: "refresh_token", + refresh_token: token.refreshToken, + client_id: token.clientId, + }); + + // Add client_secret if we have it (some servers require it) + if (token.clientSecret) { + params.set("client_secret", token.clientSecret); + } + + // Add scope if we have it + if (token.scope) { + params.set("scope", token.scope); + } + + // Make the token request + const response = await fetch(token.tokenEndpoint, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Accept: "application/json", + }, + body: params.toString(), + }); + + if (!response.ok) { + const errorBody = await response.text(); + console.error( + `[TokenRefresh] Failed to refresh token: ${response.status}`, + errorBody, + ); + + // Try to parse error response + try { + const errorJson = JSON.parse(errorBody); + return { + success: false, + error: + errorJson.error_description || + errorJson.error || + `Token refresh failed: ${response.status}`, + }; + } catch { + return { + success: false, + error: `Token refresh failed: ${response.status}`, + }; + } + } + + const data = (await response.json()) as { + access_token: string; + refresh_token?: string; + expires_in?: number; + token_type?: string; + scope?: string; + }; + + return { + success: true, + accessToken: data.access_token, + // Some servers return a new refresh token, some don't + refreshToken: data.refresh_token || token.refreshToken, + expiresIn: data.expires_in, + scope: data.scope, + }; + } catch (error) { + console.error("[TokenRefresh] Error refreshing token:", error); + return { + success: false, + error: error instanceof Error ? error.message : "Token refresh failed", + }; + } +} diff --git a/apps/mesh/src/shared/utils/generate-id.ts b/apps/mesh/src/shared/utils/generate-id.ts index 498d684e3..16462d67a 100644 --- a/apps/mesh/src/shared/utils/generate-id.ts +++ b/apps/mesh/src/shared/utils/generate-id.ts @@ -1,6 +1,6 @@ import { nanoid } from "nanoid"; -type IdPrefixes = "conn" | "audit" | "log" | "gw" | "gwc"; +type IdPrefixes = "conn" | "audit" | "log" | "gw" | "gwc" | "dtok"; export function generatePrefixedId(prefix: IdPrefixes) { return `${prefix}_${nanoid()}`; diff --git a/apps/mesh/src/storage/downstream-token.test.ts b/apps/mesh/src/storage/downstream-token.test.ts new file mode 100644 index 000000000..e95a1facd --- /dev/null +++ b/apps/mesh/src/storage/downstream-token.test.ts @@ -0,0 +1,80 @@ +import { describe, it, expect, beforeAll, afterAll } from "bun:test"; +import { createDatabase, closeDatabase, type MeshDatabase } from "../database"; +import { createTestSchema } from "./test-helpers"; +import { CredentialVault } from "../encryption/credential-vault"; +import { + DownstreamTokenStorage, + type DownstreamTokenData, +} from "./downstream-token"; + +describe("DownstreamTokenStorage", () => { + let database: MeshDatabase; + let storage: DownstreamTokenStorage; + + beforeAll(async () => { + database = createDatabase(":memory:"); + await createTestSchema(database.db); + + const vault = new CredentialVault(CredentialVault.generateKey()); + storage = new DownstreamTokenStorage(database.db, vault); + }); + + afterAll(async () => { + await closeDatabase(database); + }); + + it("should fail-safe invalid expiration date as expired", async () => { + const token = { + id: "test", + connectionId: "c1", + userId: "u1", + accessToken: "at", + refreshToken: null, + scope: null, + expiresAt: "invalid-date-string", // Invalid date + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + clientId: null, + clientSecret: null, + tokenEndpoint: null, + }; + + // Before fix: new Date("invalid").getTime() is NaN. NaN < Date.now() is false. + // After fix: should return true. + expect(storage.isExpired(token)).toBe(true); + }); + + it("should upsert token atomically", async () => { + const data: DownstreamTokenData = { + connectionId: "conn_atomic", + userId: "user_atomic", + accessToken: "access_1", + refreshToken: "refresh_1", + scope: "scope_1", + expiresAt: new Date(Date.now() + 3600000), + clientId: "client_1", + clientSecret: "secret_1", + tokenEndpoint: "https://example.com/token", + }; + + // First insert + const t1 = await storage.upsert(data); + expect(t1.accessToken).toBe("access_1"); + expect(t1.clientId).toBe("client_1"); + + // Update + const data2 = { ...data, accessToken: "access_2", clientId: "client_2" }; + const t2 = await storage.upsert(data2); + + expect(t2.id).toBe(t1.id); // Should update same record + expect(t2.accessToken).toBe("access_2"); + expect(t2.clientId).toBe("client_2"); + + // Check DB count + const count = await database.db + .selectFrom("downstream_tokens") + .select(database.db.fn.count("id").as("c")) + .executeTakeFirst(); + expect(Number(count?.c)).toBe(1); + }); +}); diff --git a/apps/mesh/src/storage/downstream-token.ts b/apps/mesh/src/storage/downstream-token.ts new file mode 100644 index 000000000..4e5568a05 --- /dev/null +++ b/apps/mesh/src/storage/downstream-token.ts @@ -0,0 +1,268 @@ +/** + * Downstream Token Storage Implementation + * + * Handles CRUD operations for downstream MCP OAuth tokens. + * Supports token caching and refresh for OAuth-enabled MCP connections. + */ + +import type { Kysely } from "kysely"; +import type { CredentialVault } from "../encryption/credential-vault"; +import type { Database, DownstreamToken } from "./types"; +import { generatePrefixedId } from "@/shared/utils/generate-id"; + +/** + * Data for creating/updating a downstream token + */ +export interface DownstreamTokenData { + connectionId: string; + userId: string | null; + accessToken: string; + refreshToken: string | null; + scope: string | null; + expiresAt: Date | null; + // Dynamic Client Registration info + clientId: string | null; + clientSecret: string | null; + tokenEndpoint: string | null; +} + +/** + * Port interface for downstream token storage + */ +export interface DownstreamTokenStoragePort { + /** + * Get cached token for a connection + user + */ + get( + connectionId: string, + userId: string | null, + ): Promise; + + /** + * Save or update a token + */ + upsert(data: DownstreamTokenData): Promise; + + /** + * Delete token for a connection + user + */ + delete(connectionId: string, userId: string | null): Promise; + + /** + * Delete all tokens for a connection + */ + deleteByConnection(connectionId: string): Promise; + + /** + * Check if token is expired or will expire within buffer time + */ + isExpired(token: DownstreamToken, bufferMs?: number): boolean; +} + +/** + * Downstream Token Storage Implementation + */ +export class DownstreamTokenStorage implements DownstreamTokenStoragePort { + constructor( + private db: Kysely, + private vault: CredentialVault, + ) {} + + async get( + connectionId: string, + userId: string | null, + ): Promise { + const query = this.db + .selectFrom("downstream_tokens") + .selectAll() + .where("connectionId", "=", connectionId); + + const row = await (userId + ? query.where("userId", "=", userId) + : query.where("userId", "is", null) + ).executeTakeFirst(); + + if (!row) return null; + + return this.decryptToken(row); + } + + async upsert(data: DownstreamTokenData): Promise { + const now = new Date().toISOString(); + + // Encrypt sensitive fields + const encryptedAccessToken = await this.vault.encrypt(data.accessToken); + const encryptedRefreshToken = data.refreshToken + ? await this.vault.encrypt(data.refreshToken) + : null; + const encryptedClientSecret = data.clientSecret + ? await this.vault.encrypt(data.clientSecret) + : null; + + // Use transaction to prevent race conditions during upsert + return await this.db.transaction().execute(async (trx) => { + // Check for existing token within transaction + const query = trx + .selectFrom("downstream_tokens") + .select(["id", "createdAt"]) + .where("connectionId", "=", data.connectionId); + + const existing = await (data.userId + ? query.where("userId", "=", data.userId) + : query.where("userId", "is", null) + ).executeTakeFirst(); + + if (existing) { + // Update existing token + await trx + .updateTable("downstream_tokens") + .set({ + accessToken: encryptedAccessToken, + refreshToken: encryptedRefreshToken, + scope: data.scope, + expiresAt: data.expiresAt?.toISOString() ?? null, + clientId: data.clientId, + clientSecret: encryptedClientSecret, + tokenEndpoint: data.tokenEndpoint, + updatedAt: now, + }) + .where("id", "=", existing.id) + .execute(); + + return { + id: existing.id, + connectionId: data.connectionId, + userId: data.userId, + accessToken: data.accessToken, + refreshToken: data.refreshToken, + scope: data.scope, + expiresAt: data.expiresAt, + createdAt: existing.createdAt as unknown as string, + updatedAt: now, + clientId: data.clientId, + clientSecret: data.clientSecret, + tokenEndpoint: data.tokenEndpoint, + }; + } + + // Create new token + const id = generatePrefixedId("dtok"); + + await trx + .insertInto("downstream_tokens") + .values({ + id, + connectionId: data.connectionId, + userId: data.userId, + accessToken: encryptedAccessToken, + refreshToken: encryptedRefreshToken, + scope: data.scope, + expiresAt: data.expiresAt?.toISOString() ?? null, + clientId: data.clientId, + clientSecret: encryptedClientSecret, + tokenEndpoint: data.tokenEndpoint, + createdAt: now as unknown as string, + updatedAt: now as unknown as string, + }) + .execute(); + + return { + id, + connectionId: data.connectionId, + userId: data.userId, + accessToken: data.accessToken, + refreshToken: data.refreshToken, + scope: data.scope, + expiresAt: data.expiresAt, + createdAt: now as unknown as string, + updatedAt: now as unknown as string, + clientId: data.clientId, + clientSecret: data.clientSecret, + tokenEndpoint: data.tokenEndpoint, + }; + }); + } + + async delete(connectionId: string, userId: string | null): Promise { + const query = this.db + .deleteFrom("downstream_tokens") + .where("connectionId", "=", connectionId); + + await (userId + ? query.where("userId", "=", userId) + : query.where("userId", "is", null) + ).execute(); + } + + async deleteByConnection(connectionId: string): Promise { + await this.db + .deleteFrom("downstream_tokens") + .where("connectionId", "=", connectionId) + .execute(); + } + + /** + * Check if token is expired or will expire within buffer time + * Default buffer is 5 minutes to account for clock skew and request time + */ + isExpired(token: DownstreamToken, bufferMs: number = 5 * 60 * 1000): boolean { + if (!token.expiresAt) { + // No expiry = never expires + return false; + } + + const expiresAt = + token.expiresAt instanceof Date + ? token.expiresAt + : new Date(token.expiresAt); + + const expiryTime = expiresAt.getTime(); + if (Number.isNaN(expiryTime)) { + // Fail-safe: if date is invalid, treat as expired + return true; + } + + return expiryTime - bufferMs < Date.now(); + } + + /** + * Decrypt sensitive fields from a database row + */ + private async decryptToken(row: { + id: string; + connectionId: string; + userId: string | null; + accessToken: string; + refreshToken: string | null; + scope: string | null; + expiresAt: Date | string | null; + createdAt: Date | string; + updatedAt: Date | string; + clientId: string | null; + clientSecret: string | null; + tokenEndpoint: string | null; + }): Promise { + const accessToken = await this.vault.decrypt(row.accessToken); + const refreshToken = row.refreshToken + ? await this.vault.decrypt(row.refreshToken) + : null; + const clientSecret = row.clientSecret + ? await this.vault.decrypt(row.clientSecret) + : null; + + return { + id: row.id, + connectionId: row.connectionId, + userId: row.userId, + accessToken, + refreshToken, + scope: row.scope, + expiresAt: row.expiresAt, + createdAt: row.createdAt, + updatedAt: row.updatedAt, + clientId: row.clientId, + clientSecret, + tokenEndpoint: row.tokenEndpoint, + }; + } +} diff --git a/apps/mesh/src/storage/types.ts b/apps/mesh/src/storage/types.ts index fae142451..0e8790cd8 100644 --- a/apps/mesh/src/storage/types.ts +++ b/apps/mesh/src/storage/types.ts @@ -268,9 +268,13 @@ export interface DownstreamTokenTable { accessToken: string; // Encrypted refreshToken: string | null; // Encrypted scope: string | null; - expiresAt: ColumnType | null; + expiresAt: ColumnType | null; createdAt: ColumnType; updatedAt: ColumnType; + // Dynamic Client Registration info (for token refresh) + clientId: string | null; + clientSecret: string | null; // Encrypted + tokenEndpoint: string | null; } // ============================================================================ @@ -333,6 +337,10 @@ export interface DownstreamToken { expiresAt: Date | string | null; createdAt: Date | string; updatedAt: Date | string; + // Dynamic Client Registration info (for token refresh) + clientId: string | null; + clientSecret: string | null; + tokenEndpoint: string | null; } // ============================================================================ diff --git a/apps/mesh/src/web/components/details/connection/settings-tab/index.tsx b/apps/mesh/src/web/components/details/connection/settings-tab/index.tsx index d73813368..e5d63c621 100644 --- a/apps/mesh/src/web/components/details/connection/settings-tab/index.tsx +++ b/apps/mesh/src/web/components/details/connection/settings-tab/index.tsx @@ -502,7 +502,7 @@ function SettingsTabContentImpl(props: SettingsTabContentImplProps) { }; const handleAuthenticate = async () => { - const { token, error } = await authenticateMcp({ + const { token, tokenInfo, error } = await authenticateMcp({ connectionId: connection.id, }); if (error || !token) { @@ -510,10 +510,49 @@ function SettingsTabContentImpl(props: SettingsTabContentImplProps) { return; } - await connectionActions.update.mutateAsync({ - id: connection.id, - data: { connection_token: token }, - }); + // Save token via new API (supports refresh tokens) + if (tokenInfo) { + try { + const response = await fetch( + `/api/connections/${connection.id}/oauth-token`, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + credentials: "include", + body: JSON.stringify({ + accessToken: tokenInfo.accessToken, + refreshToken: tokenInfo.refreshToken, + expiresIn: tokenInfo.expiresIn, + scope: tokenInfo.scope, + clientId: tokenInfo.clientId, + clientSecret: tokenInfo.clientSecret, + tokenEndpoint: tokenInfo.tokenEndpoint, + }), + }, + ); + if (!response.ok) { + console.error("Failed to save OAuth token:", await response.text()); + // Fall back to connection_token update + await connectionActions.update.mutateAsync({ + id: connection.id, + data: { connection_token: token }, + }); + } + } catch (err) { + console.error("Error saving OAuth token:", err); + // Fall back to connection_token update + await connectionActions.update.mutateAsync({ + id: connection.id, + data: { connection_token: token }, + }); + } + } else { + // No tokenInfo, fall back to legacy behavior + await connectionActions.update.mutateAsync({ + id: connection.id, + data: { connection_token: token }, + }); + } // Invalidate auth status query to trigger UI refresh const mcpProxyUrl = new URL( diff --git a/apps/mesh/src/web/lib/mcp-oauth.ts b/apps/mesh/src/web/lib/mcp-oauth.ts index 407c09ab9..d508ca775 100644 --- a/apps/mesh/src/web/lib/mcp-oauth.ts +++ b/apps/mesh/src/web/lib/mcp-oauth.ts @@ -170,14 +170,40 @@ class McpOAuthProvider implements OAuthClientProvider { } } +/** + * Full OAuth token info for persistence + */ +export interface OAuthTokenInfo { + accessToken: string; + refreshToken: string | null; + expiresIn: number | null; + scope: string | null; + // Dynamic Client Registration info + clientId: string | null; + clientSecret: string | null; + tokenEndpoint: string | null; +} + /** * Result from authenticateMcp */ export interface AuthenticateMcpResult { token: string | null; + /** Full token info for persistence (includes refresh token) */ + tokenInfo: OAuthTokenInfo | null; error: string | null; } +/** + * Extended token result with all info needed for persistence + */ +interface FullTokenResult { + tokens: OAuthTokens; + clientId: string | null; + clientSecret: string | null; + tokenEndpoint: string | null; +} + /** * Authenticate with an MCP server using OAuth * @param serverUrl - Full MCP server URL to authenticate with @@ -206,142 +232,179 @@ export async function authenticateMcp(params: { try { // Wait for OAuth callback message from popup and handle token exchange // Uses both postMessage (primary) and localStorage (fallback for when opener is lost) - const oauthCompletePromise = new Promise((resolve, reject) => { - const timeout = params.timeout || 120000; - let timeoutId: ReturnType; - let resolved = false; - // Use the OAuth state as the storage key - it's already unique per flow - // and will be available to the callback page via URL params - const oauthState = provider.state(); - const storageKey = `${OAUTH_CALLBACK_STORAGE_KEY}${oauthState}`; - - const cleanup = () => { - if (resolved) return; - resolved = true; - window.removeEventListener("message", handleMessage); - window.removeEventListener("storage", handleStorageEvent); - clearTimeout(timeoutId); - // Clean up storage key - try { - localStorage.removeItem(storageKey); - } catch { - // Ignore storage errors - } - }; - - const processCallback = async (data: { - success: boolean; - code?: string; - state?: string; - error?: string; - }) => { - if (resolved) return; + const oauthCompletePromise = new Promise( + (resolve, reject) => { + const timeout = params.timeout || 120000; + let timeoutId: ReturnType; + let resolved = false; + // Use the OAuth state as the storage key - it's already unique per flow + // and will be available to the callback page via URL params + const oauthState = provider.state(); + const storageKey = `${OAUTH_CALLBACK_STORAGE_KEY}${oauthState}`; + + const cleanup = () => { + if (resolved) return; + resolved = true; + window.removeEventListener("message", handleMessage); + window.removeEventListener("storage", handleStorageEvent); + clearTimeout(timeoutId); + // Clean up storage key + try { + localStorage.removeItem(storageKey); + } catch { + // Ignore storage errors + } + }; - if (!data.success) { - cleanup(); - reject(new Error(data.error || "OAuth authentication failed")); - return; - } + const processCallback = async (data: { + success: boolean; + code?: string; + state?: string; + error?: string; + }) => { + if (resolved) return; - const { code, state } = data; + if (!data.success) { + cleanup(); + reject(new Error(data.error || "OAuth authentication failed")); + return; + } - if (!code) { - cleanup(); - reject(new Error("Missing authorization code")); - return; - } + const { code, state } = data; - // Verify state matches - const storedState = provider.getStoredState(); - if (storedState !== state) { - cleanup(); - reject(new Error("OAuth state mismatch - possible CSRF attack")); - return; - } - - try { - // Do token exchange in parent window (we have provider in memory) - const resourceMetadata = - await discoverOAuthProtectedResourceMetadata(serverUrl); - const authServerUrl = - resourceMetadata?.authorization_servers?.[0] || serverUrl; - const authServerMetadata = - await discoverAuthorizationServerMetadata(authServerUrl); - - const clientInfo = provider.clientInformation(); - if (!clientInfo) { + if (!code) { cleanup(); - reject(new Error("Client information not found")); + reject(new Error("Missing authorization code")); return; } - const codeVerifier = provider.codeVerifier(); + // Verify state matches + const storedState = provider.getStoredState(); + if (storedState !== state) { + cleanup(); + reject(new Error("OAuth state mismatch - possible CSRF attack")); + return; + } - const tokens = await exchangeAuthorization(authServerUrl, { - metadata: authServerMetadata, - clientInformation: clientInfo, - authorizationCode: code, - codeVerifier, - redirectUri: provider.redirectUrl, - resource: new URL(serverUrl), - }); + try { + // Do token exchange in parent window (we have provider in memory) + const resourceMetadata = + await discoverOAuthProtectedResourceMetadata(serverUrl); + const authServerUrl = + resourceMetadata?.authorization_servers?.[0] || serverUrl; + const authServerMetadata = + await discoverAuthorizationServerMetadata(authServerUrl); + + const clientInfo = provider.clientInformation(); + if (!clientInfo) { + cleanup(); + reject(new Error("Client information not found")); + return; + } + + const codeVerifier = provider.codeVerifier(); + + const tokens = await exchangeAuthorization(authServerUrl, { + metadata: authServerMetadata, + clientInformation: clientInfo, + authorizationCode: code, + codeVerifier, + redirectUri: provider.redirectUrl, + resource: new URL(serverUrl), + }); - cleanup(); - resolve(tokens); - } catch (err) { - cleanup(); - reject(err); - } - }; + cleanup(); - // Primary: Listen for postMessage from popup - const handleMessage = async (event: MessageEvent) => { - if (event.origin !== window.location.origin) return; - if (event.data?.type === "mcp:oauth:callback") { - await processCallback(event.data); - } - }; + // Resolve with full result including client info for token refresh + resolve({ + tokens, + clientId: clientInfo.client_id ?? null, + clientSecret: + "client_secret" in clientInfo + ? (clientInfo.client_secret as string) + : null, + tokenEndpoint: authServerMetadata?.token_endpoint ?? null, + }); + } catch (err) { + cleanup(); + reject(err); + } + }; - // Fallback: Listen for localStorage events (when window.opener is lost) - const handleStorageEvent = async (event: StorageEvent) => { - if (event.key !== storageKey || !event.newValue) return; - try { - const data = JSON.parse(event.newValue); - await processCallback(data); - } catch { - // Ignore parse errors - } - }; + // Primary: Listen for postMessage from popup + const handleMessage = async (event: MessageEvent) => { + if (event.origin !== window.location.origin) return; + if (event.data?.type === "mcp:oauth:callback") { + await processCallback(event.data); + } + }; + + // Fallback: Listen for localStorage events (when window.opener is lost) + const handleStorageEvent = async (event: StorageEvent) => { + if (event.key !== storageKey || !event.newValue) return; + try { + const data = JSON.parse(event.newValue); + await processCallback(data); + } catch { + // Ignore parse errors + } + }; - window.addEventListener("message", handleMessage); - window.addEventListener("storage", handleStorageEvent); + window.addEventListener("message", handleMessage); + window.addEventListener("storage", handleStorageEvent); - timeoutId = setTimeout(() => { - cleanup(); - reject(new Error("OAuth authentication timeout")); - }, timeout); - }); + timeoutId = setTimeout(() => { + cleanup(); + reject(new Error("OAuth authentication timeout")); + }, timeout); + }, + ); // Start the auth flow const result: AuthResult = await auth(provider, { serverUrl }); if (result === "REDIRECT") { - const tokens = await oauthCompletePromise; + const fullResult = await oauthCompletePromise; return { - token: tokens.access_token, + token: fullResult.tokens.access_token, + tokenInfo: { + accessToken: fullResult.tokens.access_token, + refreshToken: fullResult.tokens.refresh_token ?? null, + expiresIn: fullResult.tokens.expires_in ?? null, + scope: fullResult.tokens.scope ?? null, + clientId: fullResult.clientId, + clientSecret: fullResult.clientSecret, + tokenEndpoint: fullResult.tokenEndpoint, + }, error: null, }; } // If we got here without redirect, check for tokens const tokens = provider.tokens(); + const clientInfo = provider.clientInformation(); return { token: tokens?.access_token || null, + tokenInfo: tokens + ? { + accessToken: tokens.access_token, + refreshToken: tokens.refresh_token ?? null, + expiresIn: tokens.expires_in ?? null, + scope: tokens.scope ?? null, + clientId: clientInfo?.client_id ?? null, + clientSecret: + clientInfo && "client_secret" in clientInfo + ? (clientInfo.client_secret as string) + : null, + tokenEndpoint: null, // Would need to be passed through + } + : null, error: null, }; } catch (error) { return { token: null, + tokenInfo: null, error: error instanceof Error ? error.message : String(error), }; } finally { diff --git a/apps/mesh/src/web/utils/constants.ts b/apps/mesh/src/web/utils/constants.ts index bc8cbcff5..2fc35e60f 100644 --- a/apps/mesh/src/web/utils/constants.ts +++ b/apps/mesh/src/web/utils/constants.ts @@ -24,5 +24,5 @@ export type JsonSchema = { */ export const BaseCollectionJsonSchema: JsonSchema = z.toJSONSchema( BaseCollectionEntitySchema, - { target: "draft-07" }, + { target: "draft-7" }, ) as JsonSchema;