diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs index 52d82e20..d4d0ef57 100644 --- a/src/ModelContextProtocol/Diagnostics.cs +++ b/src/ModelContextProtocol/Diagnostics.cs @@ -50,22 +50,19 @@ private static void ExtractContext(object? message, string fieldName, out string fieldValues = null; fieldValue = null; - JsonNode? parameters = null; + JsonNode? meta = null; switch (message) { case JsonRpcRequest request: - parameters = request.Params; + meta = request.Params?["_meta"]; break; case JsonRpcNotification notification: - parameters = notification.Params; - break; - - default: + meta = notification.Params?["_meta"]; break; } - if (parameters?[fieldName] is JsonValue value && value.GetValueKind() == JsonValueKind.String) + if (meta?[fieldName] is JsonValue value && value.GetValueKind() == JsonValueKind.String) { fieldValue = value.GetValue(); } @@ -89,14 +86,17 @@ private static void InjectContext(object? message, string key, string value) case JsonRpcNotification notification: parameters = notification.Params; break; - - default: - break; } - if (parameters is JsonObject jsonObject && jsonObject[key] == null) + // Replace any params._meta with the current value + if (parameters is JsonObject jsonObject) { - jsonObject[key] = value; + if (jsonObject["_meta"] is not JsonObject meta) + { + meta = new JsonObject(); + jsonObject["_meta"] = meta; + } + meta[key] = value; } } diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index b6355f79..3fe69275 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -4,6 +4,8 @@ using OpenTelemetry.Trace; using System.Diagnostics; using System.IO.Pipelines; +using System.Text; +using System.Text.Json; namespace ModelContextProtocol.Tests; @@ -14,6 +16,7 @@ public class DiagnosticTests public async Task Session_TracksActivities() { var activities = new List(); + var clientToServerLog = new List(); using (var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() .AddSource("Experimental.ModelContextProtocol") @@ -28,7 +31,7 @@ await RunConnected(async (client, server) => var tool = tools.First(t => t.Name == "DoubleValue"); await tool.InvokeAsync(new() { ["amount"] = 42 }, TestContext.Current.CancellationToken); - }); + }, clientToServerLog); } Assert.NotEmpty(activities); @@ -64,6 +67,11 @@ await RunConnected(async (client, server) => Assert.Equal(clientListToolsCall.SpanId, serverListToolsCall.ParentSpanId); Assert.Equal(clientListToolsCall.TraceId, serverListToolsCall.TraceId); + + // Validate that the client trace context encoded to request.params._meta[traceparent] + using var listToolsJson = JsonDocument.Parse(clientToServerLog.First(s => s.Contains("\"method\":\"tools/list\""))); + var metaJson = listToolsJson.RootElement.GetProperty("params").GetProperty("_meta").GetRawText(); + Assert.Equal($$"""{"traceparent":"00-{{clientListToolsCall.TraceId}}-{{clientListToolsCall.SpanId}}-01"}""", metaJson); } [Fact] @@ -80,7 +88,7 @@ await RunConnected(async (client, server) => { await client.CallToolAsync("Throw", cancellationToken: TestContext.Current.CancellationToken); await Assert.ThrowsAsync(() => client.CallToolAsync("does-not-exist", cancellationToken: TestContext.Current.CancellationToken)); - }); + }, new List()); } Assert.NotEmpty(activities); @@ -120,11 +128,12 @@ await RunConnected(async (client, server) => Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); } - private static async Task RunConnected(Func action) + private static async Task RunConnected(Func action, List clientToServerLog) { Pipe clientToServerPipe = new(), serverToClientPipe = new(); StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()); - StreamClientTransport clientTransport = new(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream()); + StreamClientTransport clientTransport = new(new LoggingStream( + clientToServerPipe.Writer.AsStream(), clientToServerLog.Add), serverToClientPipe.Reader.AsStream()); Task serverTask; @@ -155,3 +164,32 @@ private static async Task RunConnected(Func action await serverTask; } } + +public class LoggingStream : Stream +{ + private readonly Stream _innerStream; + private readonly Action _logAction; + + public LoggingStream(Stream innerStream, Action logAction) + { + _innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream)); + _logAction = logAction ?? throw new ArgumentNullException(nameof(logAction)); + } + + public override void Write(byte[] buffer, int offset, int count) + { + var data = Encoding.UTF8.GetString(buffer, offset, count); + _logAction(data); + _innerStream.Write(buffer, offset, count); + } + + public override bool CanRead => _innerStream.CanRead; + public override bool CanSeek => _innerStream.CanSeek; + public override bool CanWrite => _innerStream.CanWrite; + public override long Length => _innerStream.Length; + public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; } + public override void Flush() => _innerStream.Flush(); + public override int Read(byte[] buffer, int offset, int count) => _innerStream.Read(buffer, offset, count); + public override long Seek(long offset, SeekOrigin origin) => _innerStream.Seek(offset, origin); + public override void SetLength(long value) => _innerStream.SetLength(value); +}