Skip to content

Commit 17a9427

Browse files
Backfills trace injection test and migrates from request.params to request.params._meta (#360)
* Adds test to show where trace IDs are encoded Signed-off-by: Adrian Cole <[email protected]> * update impl to params._meta Signed-off-by: Adrian Cole <[email protected]> --------- Signed-off-by: Adrian Cole <[email protected]>
1 parent 2eeb61f commit 17a9427

File tree

2 files changed

+54
-16
lines changed

2 files changed

+54
-16
lines changed

src/ModelContextProtocol/Diagnostics.cs

+12-12
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,19 @@ private static void ExtractContext(object? message, string fieldName, out string
5050
fieldValues = null;
5151
fieldValue = null;
5252

53-
JsonNode? parameters = null;
53+
JsonNode? meta = null;
5454
switch (message)
5555
{
5656
case JsonRpcRequest request:
57-
parameters = request.Params;
57+
meta = request.Params?["_meta"];
5858
break;
5959

6060
case JsonRpcNotification notification:
61-
parameters = notification.Params;
62-
break;
63-
64-
default:
61+
meta = notification.Params?["_meta"];
6562
break;
6663
}
6764

68-
if (parameters?[fieldName] is JsonValue value && value.GetValueKind() == JsonValueKind.String)
65+
if (meta?[fieldName] is JsonValue value && value.GetValueKind() == JsonValueKind.String)
6966
{
7067
fieldValue = value.GetValue<string>();
7168
}
@@ -89,14 +86,17 @@ private static void InjectContext(object? message, string key, string value)
8986
case JsonRpcNotification notification:
9087
parameters = notification.Params;
9188
break;
92-
93-
default:
94-
break;
9589
}
9690

97-
if (parameters is JsonObject jsonObject && jsonObject[key] == null)
91+
// Replace any params._meta with the current value
92+
if (parameters is JsonObject jsonObject)
9893
{
99-
jsonObject[key] = value;
94+
if (jsonObject["_meta"] is not JsonObject meta)
95+
{
96+
meta = new JsonObject();
97+
jsonObject["_meta"] = meta;
98+
}
99+
meta[key] = value;
100100
}
101101
}
102102

tests/ModelContextProtocol.Tests/DiagnosticTests.cs

+42-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
using OpenTelemetry.Trace;
55
using System.Diagnostics;
66
using System.IO.Pipelines;
7+
using System.Text;
8+
using System.Text.Json;
79

810
namespace ModelContextProtocol.Tests;
911

@@ -14,6 +16,7 @@ public class DiagnosticTests
1416
public async Task Session_TracksActivities()
1517
{
1618
var activities = new List<Activity>();
19+
var clientToServerLog = new List<string>();
1720

1821
using (var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()
1922
.AddSource("Experimental.ModelContextProtocol")
@@ -28,7 +31,7 @@ await RunConnected(async (client, server) =>
2831

2932
var tool = tools.First(t => t.Name == "DoubleValue");
3033
await tool.InvokeAsync(new() { ["amount"] = 42 }, TestContext.Current.CancellationToken);
31-
});
34+
}, clientToServerLog);
3235
}
3336

3437
Assert.NotEmpty(activities);
@@ -64,6 +67,11 @@ await RunConnected(async (client, server) =>
6467

6568
Assert.Equal(clientListToolsCall.SpanId, serverListToolsCall.ParentSpanId);
6669
Assert.Equal(clientListToolsCall.TraceId, serverListToolsCall.TraceId);
70+
71+
// Validate that the client trace context encoded to request.params._meta[traceparent]
72+
using var listToolsJson = JsonDocument.Parse(clientToServerLog.First(s => s.Contains("\"method\":\"tools/list\"")));
73+
var metaJson = listToolsJson.RootElement.GetProperty("params").GetProperty("_meta").GetRawText();
74+
Assert.Equal($$"""{"traceparent":"00-{{clientListToolsCall.TraceId}}-{{clientListToolsCall.SpanId}}-01"}""", metaJson);
6775
}
6876

6977
[Fact]
@@ -80,7 +88,7 @@ await RunConnected(async (client, server) =>
8088
{
8189
await client.CallToolAsync("Throw", cancellationToken: TestContext.Current.CancellationToken);
8290
await Assert.ThrowsAsync<McpException>(() => client.CallToolAsync("does-not-exist", cancellationToken: TestContext.Current.CancellationToken));
83-
});
91+
}, new List<string>());
8492
}
8593

8694
Assert.NotEmpty(activities);
@@ -120,11 +128,12 @@ await RunConnected(async (client, server) =>
120128
Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value);
121129
}
122130

123-
private static async Task RunConnected(Func<IMcpClient, IMcpServer, Task> action)
131+
private static async Task RunConnected(Func<IMcpClient, IMcpServer, Task> action, List<string> clientToServerLog)
124132
{
125133
Pipe clientToServerPipe = new(), serverToClientPipe = new();
126134
StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream());
127-
StreamClientTransport clientTransport = new(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream());
135+
StreamClientTransport clientTransport = new(new LoggingStream(
136+
clientToServerPipe.Writer.AsStream(), clientToServerLog.Add), serverToClientPipe.Reader.AsStream());
128137

129138
Task serverTask;
130139

@@ -155,3 +164,32 @@ private static async Task RunConnected(Func<IMcpClient, IMcpServer, Task> action
155164
await serverTask;
156165
}
157166
}
167+
168+
public class LoggingStream : Stream
169+
{
170+
private readonly Stream _innerStream;
171+
private readonly Action<string> _logAction;
172+
173+
public LoggingStream(Stream innerStream, Action<string> logAction)
174+
{
175+
_innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream));
176+
_logAction = logAction ?? throw new ArgumentNullException(nameof(logAction));
177+
}
178+
179+
public override void Write(byte[] buffer, int offset, int count)
180+
{
181+
var data = Encoding.UTF8.GetString(buffer, offset, count);
182+
_logAction(data);
183+
_innerStream.Write(buffer, offset, count);
184+
}
185+
186+
public override bool CanRead => _innerStream.CanRead;
187+
public override bool CanSeek => _innerStream.CanSeek;
188+
public override bool CanWrite => _innerStream.CanWrite;
189+
public override long Length => _innerStream.Length;
190+
public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; }
191+
public override void Flush() => _innerStream.Flush();
192+
public override int Read(byte[] buffer, int offset, int count) => _innerStream.Read(buffer, offset, count);
193+
public override long Seek(long offset, SeekOrigin origin) => _innerStream.Seek(offset, origin);
194+
public override void SetLength(long value) => _innerStream.SetLength(value);
195+
}

0 commit comments

Comments
 (0)