From 05d4a5d573522ee44a7af352cc9c924ea2cf14ca Mon Sep 17 00:00:00 2001 From: Adrian Cole Date: Mon, 28 Apr 2025 12:40:36 +1200 Subject: [PATCH 1/2] Adds test to show where trace IDs are encoded Signed-off-by: Adrian Cole --- .../DiagnosticTests.cs | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index b6355f79..d94eef7b 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -4,6 +4,8 @@ using OpenTelemetry.Trace; using System.Diagnostics; using System.IO.Pipelines; +using System.Text; +using System.Text.Json; namespace ModelContextProtocol.Tests; @@ -14,6 +16,7 @@ public class DiagnosticTests public async Task Session_TracksActivities() { var activities = new List(); + var clientToServerLog = new List(); using (var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() .AddSource("Experimental.ModelContextProtocol") @@ -28,7 +31,7 @@ await RunConnected(async (client, server) => var tool = tools.First(t => t.Name == "DoubleValue"); await tool.InvokeAsync(new() { ["amount"] = 42 }, TestContext.Current.CancellationToken); - }); + }, clientToServerLog); } Assert.NotEmpty(activities); @@ -64,6 +67,15 @@ await RunConnected(async (client, server) => Assert.Equal(clientListToolsCall.SpanId, serverListToolsCall.ParentSpanId); Assert.Equal(clientListToolsCall.TraceId, serverListToolsCall.TraceId); + + // Validate that the client trace context encoded to request.params.traceparent + using var listToolsJson = JsonDocument.Parse(clientToServerLog.First(s => s.Contains("\"method\":\"tools/list\""))); + var traceparent = listToolsJson.RootElement.GetProperty("params").GetProperty("traceparent").GetString(); + Assert.NotNull(traceparent); + + var parts = traceparent.Split('-'); + Assert.Equal(clientListToolsCall.TraceId.ToString(), parts[1]); + Assert.Equal(clientListToolsCall.SpanId.ToString(), parts[2]); } [Fact] @@ -80,7 +92,7 @@ await RunConnected(async (client, server) => { await client.CallToolAsync("Throw", cancellationToken: TestContext.Current.CancellationToken); await Assert.ThrowsAsync(() => client.CallToolAsync("does-not-exist", cancellationToken: TestContext.Current.CancellationToken)); - }); + }, new List()); } Assert.NotEmpty(activities); @@ -120,11 +132,12 @@ await RunConnected(async (client, server) => Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); } - private static async Task RunConnected(Func action) + private static async Task RunConnected(Func action, List clientToServerLog) { Pipe clientToServerPipe = new(), serverToClientPipe = new(); StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()); - StreamClientTransport clientTransport = new(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream()); + StreamClientTransport clientTransport = new(new LoggingStream( + clientToServerPipe.Writer.AsStream(), clientToServerLog.Add), serverToClientPipe.Reader.AsStream()); Task serverTask; @@ -155,3 +168,32 @@ private static async Task RunConnected(Func action await serverTask; } } + +public class LoggingStream : Stream +{ + private readonly Stream _innerStream; + private readonly Action _logAction; + + public LoggingStream(Stream innerStream, Action logAction) + { + _innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream)); + _logAction = logAction ?? throw new ArgumentNullException(nameof(logAction)); + } + + public override void Write(byte[] buffer, int offset, int count) + { + var data = Encoding.UTF8.GetString(buffer, offset, count); + _logAction(data); + _innerStream.Write(buffer, offset, count); + } + + public override bool CanRead => _innerStream.CanRead; + public override bool CanSeek => _innerStream.CanSeek; + public override bool CanWrite => _innerStream.CanWrite; + public override long Length => _innerStream.Length; + public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; } + public override void Flush() => _innerStream.Flush(); + public override int Read(byte[] buffer, int offset, int count) => _innerStream.Read(buffer, offset, count); + public override long Seek(long offset, SeekOrigin origin) => _innerStream.Seek(offset, origin); + public override void SetLength(long value) => _innerStream.SetLength(value); +} From 1d6585098ad460c009517852f562a82f1857dead Mon Sep 17 00:00:00 2001 From: Adrian Cole Date: Mon, 28 Apr 2025 13:59:24 +1200 Subject: [PATCH 2/2] update impl to params._meta Signed-off-by: Adrian Cole --- src/ModelContextProtocol/Diagnostics.cs | 24 +++++++++---------- .../DiagnosticTests.cs | 10 +++----- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs index 52d82e20..d4d0ef57 100644 --- a/src/ModelContextProtocol/Diagnostics.cs +++ b/src/ModelContextProtocol/Diagnostics.cs @@ -50,22 +50,19 @@ private static void ExtractContext(object? message, string fieldName, out string fieldValues = null; fieldValue = null; - JsonNode? parameters = null; + JsonNode? meta = null; switch (message) { case JsonRpcRequest request: - parameters = request.Params; + meta = request.Params?["_meta"]; break; case JsonRpcNotification notification: - parameters = notification.Params; - break; - - default: + meta = notification.Params?["_meta"]; break; } - if (parameters?[fieldName] is JsonValue value && value.GetValueKind() == JsonValueKind.String) + if (meta?[fieldName] is JsonValue value && value.GetValueKind() == JsonValueKind.String) { fieldValue = value.GetValue(); } @@ -89,14 +86,17 @@ private static void InjectContext(object? message, string key, string value) case JsonRpcNotification notification: parameters = notification.Params; break; - - default: - break; } - if (parameters is JsonObject jsonObject && jsonObject[key] == null) + // Replace any params._meta with the current value + if (parameters is JsonObject jsonObject) { - jsonObject[key] = value; + if (jsonObject["_meta"] is not JsonObject meta) + { + meta = new JsonObject(); + jsonObject["_meta"] = meta; + } + meta[key] = value; } } diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index d94eef7b..3fe69275 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -68,14 +68,10 @@ await RunConnected(async (client, server) => Assert.Equal(clientListToolsCall.SpanId, serverListToolsCall.ParentSpanId); Assert.Equal(clientListToolsCall.TraceId, serverListToolsCall.TraceId); - // Validate that the client trace context encoded to request.params.traceparent + // Validate that the client trace context encoded to request.params._meta[traceparent] using var listToolsJson = JsonDocument.Parse(clientToServerLog.First(s => s.Contains("\"method\":\"tools/list\""))); - var traceparent = listToolsJson.RootElement.GetProperty("params").GetProperty("traceparent").GetString(); - Assert.NotNull(traceparent); - - var parts = traceparent.Split('-'); - Assert.Equal(clientListToolsCall.TraceId.ToString(), parts[1]); - Assert.Equal(clientListToolsCall.SpanId.ToString(), parts[2]); + var metaJson = listToolsJson.RootElement.GetProperty("params").GetProperty("_meta").GetRawText(); + Assert.Equal($$"""{"traceparent":"00-{{clientListToolsCall.TraceId}}-{{clientListToolsCall.SpanId}}-01"}""", metaJson); } [Fact]