Skip to content

Commit e4df78b

Browse files
lp6moonrinormaloku
andauthored
add mcpServer and mcpRequest to context (#45)
* feat: add mcpServer and mcpRequest to context * adds testing --------- Co-authored-by: rinormaloku <[email protected]>
1 parent fbfe561 commit e4df78b

File tree

3 files changed

+100
-29
lines changed

3 files changed

+100
-29
lines changed

src/interfaces/mcp-tool.interface.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
import { Progress } from '@modelcontextprotocol/sdk/types.js';
1+
import {
2+
CallToolRequestSchema,
3+
GetPromptRequestSchema,
4+
Progress,
5+
ReadResourceRequestSchema,
6+
} from '@modelcontextprotocol/sdk/types.js';
7+
import { z } from 'zod';
8+
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
29

310
export type Literal = boolean | null | number | string | undefined;
411

@@ -7,6 +14,13 @@ export type SerializableValue =
714
| SerializableValue[]
815
| { [key: string]: SerializableValue };
916

17+
export type McpRequestSchema =
18+
| typeof CallToolRequestSchema
19+
| typeof ReadResourceRequestSchema
20+
| typeof GetPromptRequestSchema;
21+
22+
export type McpRequest = z.infer<McpRequestSchema>;
23+
1024
/**
1125
* Enhanced execution context that includes user information
1226
*/
@@ -18,4 +32,6 @@ export type Context = {
1832
info: (message: string, data?: SerializableValue) => void;
1933
warn: (message: string, data?: SerializableValue) => void;
2034
};
35+
mcpServer: McpServer;
36+
mcpRequest: McpRequest;
2137
};

src/services/handlers/mcp-handler.base.ts

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
11
import { Logger } from '@nestjs/common';
22
import { ModuleRef } from '@nestjs/core';
33
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
4-
import { z } from 'zod';
5-
import {
6-
CallToolRequestSchema,
7-
GetPromptRequestSchema,
8-
Progress,
9-
ReadResourceRequestSchema,
10-
} from '@modelcontextprotocol/sdk/types.js';
11-
import {
12-
Context,
13-
SerializableValue,
14-
} from '../../interfaces/mcp-tool.interface';
4+
import { Progress } from '@modelcontextprotocol/sdk/types.js';
5+
import { Context, McpRequest, SerializableValue } from '../../interfaces';
156
import { McpRegistryService } from '../mcp-registry.service';
167

178
export abstract class McpHandlerBase {
@@ -27,15 +18,11 @@ export abstract class McpHandlerBase {
2718

2819
protected createContext(
2920
mcpServer: McpServer,
30-
mcpRequest: z.infer<
31-
| typeof CallToolRequestSchema
32-
| typeof ReadResourceRequestSchema
33-
| typeof GetPromptRequestSchema
34-
>,
21+
mcpRequest: McpRequest,
3522
): Context {
3623
// handless stateless traffic where notifications and progress are not supported
3724
if ((mcpServer.server.transport as any).sessionId === undefined) {
38-
return this.createStatelessContext();
25+
return this.createStatelessContext(mcpServer, mcpRequest);
3926
}
4027

4128
const progressToken = mcpRequest.params?._meta?.progressToken;
@@ -77,10 +64,15 @@ export abstract class McpHandlerBase {
7764
});
7865
},
7966
},
67+
mcpServer,
68+
mcpRequest,
8069
};
8170
}
8271

83-
protected createStatelessContext(): Context {
72+
protected createStatelessContext(
73+
mcpServer: McpServer,
74+
mcpRequest: McpRequest,
75+
): Context {
8476
const warn = (fn: string) => {
8577
this.logger.warn(`Stateless context: '${fn}' is not supported.`);
8678
};
@@ -107,6 +99,8 @@ export abstract class McpHandlerBase {
10799
warn('server report logging not supported in stateless');
108100
},
109101
},
102+
mcpServer,
103+
mcpRequest,
110104
};
111105
}
112106
}

tests/mcp-tool.e2e.spec.ts

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ export class GreetingTool {
4343
}),
4444
})
4545
async sayHello({ name }, context: Context) {
46+
// Validate that mcpServer and mcpRequest properties exist
47+
if (!context.mcpServer) {
48+
throw new Error('mcpServer is not defined in the context');
49+
}
50+
if (!context.mcpRequest) {
51+
throw new Error('mcpRequest is not defined in the context');
52+
}
53+
4654
const user = await this.userRepository.findByName(name);
4755
for (let i = 0; i < 5; i++) {
4856
await new Promise((resolve) => setTimeout(resolve, 50));
@@ -125,12 +133,15 @@ export class ToolRequestScoped {
125133

126134
describe('E2E: MCP ToolServer', () => {
127135
let app: INestApplication;
128-
let testPort: number;
136+
let statelessApp: INestApplication;
137+
let statefulServerPort: number;
138+
let statelessServerPort: number;
129139

130140
// Set timeout for all tests in this describe block to 15000ms
131141
jest.setTimeout(15000);
132142

133143
beforeAll(async () => {
144+
// Create stateful server (original)
134145
const moduleFixture: TestingModule = await Test.createTestingModule({
135146
imports: [
136147
McpModule.forRoot({
@@ -159,21 +170,63 @@ describe('E2E: MCP ToolServer', () => {
159170
if (!server.address()) {
160171
throw new Error('Server address not found after listen');
161172
}
162-
testPort = (server.address() as import('net').AddressInfo).port;
173+
statefulServerPort = (server.address() as import('net').AddressInfo).port;
174+
175+
// Create stateless server
176+
const statelessModuleFixture: TestingModule =
177+
await Test.createTestingModule({
178+
imports: [
179+
McpModule.forRoot({
180+
name: 'test-stateless-mcp-server',
181+
version: '0.0.1',
182+
guards: [],
183+
transport: McpTransportType.STREAMABLE_HTTP,
184+
streamableHttp: {
185+
enableJsonResponse: true,
186+
sessionIdGenerator: undefined,
187+
statelessMode: true,
188+
},
189+
}),
190+
],
191+
providers: [
192+
GreetingTool,
193+
GreetingToolRequestScoped,
194+
MockUserRepository,
195+
ToolRequestScoped,
196+
],
197+
}).compile();
198+
199+
statelessApp = statelessModuleFixture.createNestApplication();
200+
await statelessApp.listen(0);
201+
202+
const statelessServer = statelessApp.getHttpServer();
203+
if (!statelessServer.address()) {
204+
throw new Error('Stateless server address not found after listen');
205+
}
206+
statelessServerPort = (
207+
statelessServer.address() as import('net').AddressInfo
208+
).port;
163209
});
164210

165211
afterAll(async () => {
166212
await app.close();
213+
await statelessApp.close();
167214
});
168215

169216
const runClientTests = (
170217
clientType: 'http+sse' | 'streamable http' | 'stdio',
171218
clientCreator: (port: number, options?: any) => Promise<Client>,
172219
requestScopedHeaderValue: string,
220+
stateless = false,
173221
) => {
174222
describe(`using ${clientType} client (${clientCreator.name})`, () => {
223+
let port: number;
224+
225+
beforeAll(async () => {
226+
port = stateless ? statelessServerPort : statefulServerPort;
227+
});
175228
it('should list tools', async () => {
176-
const client = await clientCreator(testPort);
229+
const client = await clientCreator(port);
177230
try {
178231
const tools = await client.listTools();
179232
expect(tools.tools.length).toBeGreaterThan(0);
@@ -194,7 +247,7 @@ describe('E2E: MCP ToolServer', () => {
194247
it.each([{ tool: 'hello-world' }, { tool: 'hello-world-scoped' }])(
195248
'should call the tool $tool and receive results',
196249
async ({ tool }) => {
197-
const client = await clientCreator(testPort);
250+
const client = await clientCreator(port);
198251
try {
199252
let progressCount = 1;
200253
const result: any = await client.callTool(
@@ -209,7 +262,7 @@ describe('E2E: MCP ToolServer', () => {
209262
},
210263
);
211264

212-
if (clientType != 'stdio') {
265+
if (clientType != 'stdio' && !stateless) {
213266
// stdio has no support for progress
214267
expect(progressCount).toBe(5);
215268
}
@@ -224,7 +277,7 @@ describe('E2E: MCP ToolServer', () => {
224277
);
225278

226279
it('should call the tool get-request-scoped and receive header', async () => {
227-
const client = await clientCreator(testPort, {
280+
const client = await clientCreator(port, {
228281
requestInit: {
229282
headers: { 'any-header': requestScopedHeaderValue },
230283
},
@@ -243,7 +296,7 @@ describe('E2E: MCP ToolServer', () => {
243296
});
244297

245298
it('should reject invalid arguments for hello-world', async () => {
246-
const client = await clientCreator(testPort);
299+
const client = await clientCreator(port);
247300

248301
try {
249302
await client.callTool({
@@ -259,7 +312,7 @@ describe('E2E: MCP ToolServer', () => {
259312
});
260313

261314
it('should reject missing arguments for hello-world', async () => {
262-
const client = await clientCreator(testPort);
315+
const client = await clientCreator(port);
263316

264317
try {
265318
await client.callTool({
@@ -275,7 +328,7 @@ describe('E2E: MCP ToolServer', () => {
275328
});
276329

277330
it('should call the tool and receive an error', async () => {
278-
const client = await clientCreator(testPort);
331+
const client = await clientCreator(port);
279332
try {
280333
const result: any = await client.callTool({
281334
name: 'hello-world-error',
@@ -297,9 +350,17 @@ describe('E2E: MCP ToolServer', () => {
297350
// Run tests using the HTTP+SSE MCP client
298351
runClientTests('http+sse', createSseClient, 'any-value');
299352

300-
// Run tests using the Streamable HTTP MCP client
353+
// Run tests using the [Stateful] Streamable HTTP MCP client
301354
runClientTests('streamable http', createStreamableClient, 'streamable-value');
302355

356+
// Run tests using the [Stateless] Streamable HTTP MCP client
357+
runClientTests(
358+
'streamable http',
359+
createStreamableClient,
360+
'stateless-streamable-value',
361+
true,
362+
);
363+
303364
runClientTests(
304365
'stdio',
305366
() =>

0 commit comments

Comments
 (0)