From acf8c66e91ace257bfe8c90a8753fb7eecf9bb65 Mon Sep 17 00:00:00 2001 From: Kirill Sizov Date: Fri, 10 Jan 2025 12:19:57 +0300 Subject: [PATCH] Support for mutable structures as BufferWriter. --- .../MemoryPackSerializer.Serialize.cs | 20 ++++--- .../SerializerStructBufferWriterTest.cs | 59 +++++++++++++++++++ 2 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 tests/MemoryPack.Tests/SerializerStructBufferWriterTest.cs diff --git a/src/MemoryPack.Core/MemoryPackSerializer.Serialize.cs b/src/MemoryPack.Core/MemoryPackSerializer.Serialize.cs index 13fbf81..73f1d3f 100644 --- a/src/MemoryPack.Core/MemoryPackSerializer.Serialize.cs +++ b/src/MemoryPack.Core/MemoryPackSerializer.Serialize.cs @@ -92,11 +92,12 @@ public static unsafe void Serialize(in TBufferWriter bufferWri where TBufferWriter : class, IBufferWriter #endif { + ref var bufferWriterRef = ref Unsafe.AsRef(in bufferWriter); if (!RuntimeHelpers.IsReferenceOrContainsReferences()) { - var buffer = bufferWriter.GetSpan(Unsafe.SizeOf()); + var buffer = bufferWriterRef.GetSpan(Unsafe.SizeOf()); Unsafe.WriteUnaligned(ref MemoryMarshal.GetReference(buffer), value); - bufferWriter.Advance(Unsafe.SizeOf()); + bufferWriterRef.Advance(Unsafe.SizeOf()); return; } #if NET7_0_OR_GREATER @@ -105,9 +106,9 @@ public static unsafe void Serialize(in TBufferWriter bufferWri { if (value == null) { - var span = bufferWriter.GetSpan(4); + var span = bufferWriterRef.GetSpan(4); MemoryPackCode.NullCollectionData.CopyTo(span); - bufferWriter.Advance(4); + bufferWriterRef.Advance(4); return; } @@ -115,19 +116,20 @@ public static unsafe void Serialize(in TBufferWriter bufferWri var length = srcArray.Length; if (length == 0) { - var span = bufferWriter.GetSpan(4); + var span = bufferWriterRef.GetSpan(4); MemoryPackCode.ZeroCollectionData.CopyTo(span); - bufferWriter.Advance(4); + bufferWriterRef.Advance(4); return; } + var dataSize = elementSize * length; - var destSpan = bufferWriter.GetSpan(dataSize + 4); + var destSpan = bufferWriterRef.GetSpan(dataSize + 4); ref var head = ref MemoryMarshal.GetReference(destSpan); Unsafe.WriteUnaligned(ref head, length); Unsafe.CopyBlockUnaligned(ref Unsafe.Add(ref head, 4), ref MemoryMarshal.GetArrayDataReference(srcArray), (uint)dataSize); - bufferWriter.Advance(dataSize + 4); + bufferWriterRef.Advance(dataSize + 4); return; } #endif @@ -141,7 +143,7 @@ public static unsafe void Serialize(in TBufferWriter bufferWri try { - var writer = new MemoryPackWriter(ref Unsafe.AsRef(in bufferWriter), state); + var writer = new MemoryPackWriter(ref bufferWriterRef, state); Serialize(ref writer, value); } finally diff --git a/tests/MemoryPack.Tests/SerializerStructBufferWriterTest.cs b/tests/MemoryPack.Tests/SerializerStructBufferWriterTest.cs new file mode 100644 index 0000000..191da0a --- /dev/null +++ b/tests/MemoryPack.Tests/SerializerStructBufferWriterTest.cs @@ -0,0 +1,59 @@ +using System; +using System.Buffers; +using MemoryPack.Tests.Models; + +namespace MemoryPack.Tests; + +public class SerializerStructBufferWriterTest +{ + [Fact] + public void Serialize_ShouldSupportStructAsBufferWriter_WhenValueIsNotReferenceAndNotContainsReferences() + { + var writer = new TestBufferWriter(); + MemoryPackSerializer.Serialize(writer, 16); + Assert.Equal(4, writer.WrittenSize); + } + + [Fact] + public void Serialize_ShouldSupportStructAsBufferWriter_WhenValueIsUnmanagedSZArray() + { + var writer = new TestBufferWriter(); + MemoryPackSerializer.Serialize(writer, new UnmanagedStruct[] { new() { X = 1, Y = 2, Z = 3 } }); + Assert.Equal(16, writer.WrittenSize); + } + + [Fact] + public void Serialize_ShouldSupportStructAsBufferWriter_WhenFormatterRequired() + { + var writer = new TestBufferWriter(); + MemoryPackSerializer.Serialize(writer, new TestData(1)); + Assert.Equal(5, writer.WrittenSize); + } +} + +[MemoryPackable] +public partial record TestData(int A); + +public struct TestBufferWriter : IBufferWriter +{ + public int WrittenSize = 0; + + public TestBufferWriter() + { + } + + public void Advance(int count) + { + WrittenSize += count; + } + + public Memory GetMemory(int sizeHint = 0) + { + throw new InvalidOperationException(); + } + + public Span GetSpan(int sizeHint = 0) + { + return new byte[sizeHint]; + } +}