diff --git a/src/client/cross-spawn.test.ts b/src/client/cross-spawn.test.ts index 11e81bf6..724ec706 100644 --- a/src/client/cross-spawn.test.ts +++ b/src/client/cross-spawn.test.ts @@ -1,4 +1,4 @@ -import { StdioClientTransport } from "./stdio.js"; +import { getDefaultEnvironment, StdioClientTransport } from "./stdio.js"; import spawn from "cross-spawn"; import { JSONRPCMessage } from "../types.js"; import { ChildProcess } from "node:child_process"; @@ -72,7 +72,10 @@ describe("StdioClientTransport using cross-spawn", () => { "test-command", [], expect.objectContaining({ - env: customEnv + env: { + ...getDefaultEnvironment(), + ...customEnv, + } }) ); }); diff --git a/src/client/stdio.test.ts b/src/client/stdio.test.ts index b2132446..0e92eac1 100644 --- a/src/client/stdio.test.ts +++ b/src/client/stdio.test.ts @@ -1,10 +1,24 @@ import { JSONRPCMessage } from "../types.js"; -import { StdioClientTransport, StdioServerParameters } from "./stdio.js"; +import { StdioClientTransport, StdioServerParameters, DEFAULT_INHERITED_ENV_VARS, getDefaultEnvironment } from "./stdio.js"; +import { AsyncLocalStorage } from "node:async_hooks"; const serverParameters: StdioServerParameters = { command: "/usr/bin/tee", }; +const envAsyncLocalStorage = new AsyncLocalStorage<{ env: Record }>(); + +jest.mock('cross-spawn', () => { + const originalSpawn = jest.requireActual('cross-spawn'); + return jest.fn((command, args, options) => { + const env = envAsyncLocalStorage.getStore(); + if (env) { + env.env = options.env; + } + return originalSpawn(command, args, options); + }); +}); + test("should start then close cleanly", async () => { const client = new StdioClientTransport(serverParameters); client.onerror = (error) => { @@ -60,6 +74,62 @@ test("should read messages", async () => { await client.close(); }); + +test("should properly set default environment variables in spawned process", async () => { + await envAsyncLocalStorage.run({ env: {} }, async () => { + const client = new StdioClientTransport(serverParameters); + + await client.start(); + await client.close(); + + // Get the default environment variables + const defaultEnv = getDefaultEnvironment(); + const spawnEnv = envAsyncLocalStorage.getStore()?.env; + expect(spawnEnv).toBeDefined(); + // Verify that all default environment variables are present + for (const key of DEFAULT_INHERITED_ENV_VARS) { + if (process.env[key] && !process.env[key].startsWith("()")) { + expect(spawnEnv).toHaveProperty(key); + expect(spawnEnv![key]).toBe(process.env[key]); + expect(spawnEnv![key]).toBe(defaultEnv[key]); + } + } + }); +}); + +test("should override default environment variables with custom ones", async () => { + await envAsyncLocalStorage.run({ env: {} }, async () => { + const customEnv = { + HOME: "/custom/home", + PATH: "/custom/path", + USER: "custom_user" + }; + + const client = new StdioClientTransport({ + ...serverParameters, + env: customEnv + }); + + await client.start(); + await client.close(); + + const spawnEnv = envAsyncLocalStorage.getStore()?.env; + expect(spawnEnv).toBeDefined(); + // Verify that custom environment variables override default ones + for (const [key, value] of Object.entries(customEnv)) { + expect(spawnEnv).toHaveProperty(key); + expect(spawnEnv![key]).toBe(value); + } + + // Verify that other default environment variables are still present + for (const key of DEFAULT_INHERITED_ENV_VARS) { + if (!(key in customEnv) && process.env[key] && !process.env[key].startsWith("()")) { + expect(spawnEnv).toHaveProperty(key); + expect(spawnEnv![key]).toBe(process.env[key]); + } + } + }); + test("should return child process pid", async () => { const client = new StdioClientTransport(serverParameters); @@ -67,4 +137,5 @@ test("should return child process pid", async () => { expect(client.pid).not.toBeNull(); await client.close(); expect(client.pid).toBeNull(); + }); diff --git a/src/client/stdio.ts b/src/client/stdio.ts index e9c9fa8f..62292ce1 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -122,7 +122,11 @@ export class StdioClientTransport implements Transport { this._serverParams.command, this._serverParams.args ?? [], { - env: this._serverParams.env ?? getDefaultEnvironment(), + // merge default env with server env because mcp server needs some env vars + env: { + ...getDefaultEnvironment(), + ...this._serverParams.env, + }, stdio: ["pipe", "pipe", this._serverParams.stderr ?? "inherit"], shell: false, signal: this._abortController.signal,