Skip to content

Commit

Permalink
Merge pull request #352 from nathanAjacobs/ReadOnlySequenceDeserializ…
Browse files Browse the repository at this point in the history
…ationFix

Fix for generic struct deserialization when using ReadOnlySequence
  • Loading branch information
neuecc authored Feb 6, 2025
2 parents c42d4a0 + 1cc9c4e commit baf6c26
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 1 deletion.
45 changes: 45 additions & 0 deletions src/MemoryPack.Core/MemoryPackSerializer.Deserialize.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,51 @@ public static int Deserialize<
#endif
T>(in ReadOnlySequence<byte> buffer, ref T? value, MemoryPackSerializerOptions? options = default)
{
if (!RuntimeHelpers.IsReferenceOrContainsReferences<T>())
{
int sizeOfT = Unsafe.SizeOf<T>();
if (buffer.Length < sizeOfT)
{
MemoryPackSerializationException.ThrowInvalidRange(Unsafe.SizeOf<T>(), (int)buffer.Length);
}

ReadOnlySequence<byte> sliced = buffer.Slice(0, sizeOfT);

if (sliced.IsSingleSegment)
{
value = Unsafe.ReadUnaligned<T>(ref MemoryMarshal.GetReference(sliced.FirstSpan));
return sizeOfT;
}
else
{
// We can't read directly from ReadOnlySequence<byte> to T, so we copy to a temp array.
// if less than 512 bytes, use stackalloc, otherwise use MemoryPool<byte>
byte[]? tempArray = null;

Span<byte> tempSpan = sizeOfT <= 512 ? stackalloc byte[sizeOfT] : default;

try
{
if (sizeOfT > 512)
{
tempArray = ArrayPool<byte>.Shared.Rent(sizeOfT);
tempSpan = tempArray;
}

sliced.CopyTo(tempSpan);
value = Unsafe.ReadUnaligned<T>(ref MemoryMarshal.GetReference(tempSpan));
return sizeOfT;
}
finally
{
if (tempArray is not null)
{
ArrayPool<byte>.Shared.Return(tempArray);
}
}
}
}

var state = threadStaticReaderOptionalState;
if (state == null)
{
Expand Down
104 changes: 103 additions & 1 deletion tests/MemoryPack.Tests/DeserializeTest.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Dynamic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;

namespace MemoryPack.Tests;

public class DeserializeTest
public partial class DeserializeTest
{
[Fact]
public async Task StreamTest()
Expand All @@ -30,6 +32,106 @@ public async Task StreamTest()
result.Should().Equal(expected);
}

[Fact]
public void GenericValueStructTest()
{
GenericStruct<int> value = new() { Id = 75, Value = 23 };

RunMultiSegmentTest(value);
}

[Fact]
public void LargeGenericValueStructTest()
{
GenericStruct<PrePaddedInt> value = new() { Id = 75, Value = new PrePaddedInt() { Value = 23 } };

RunMultiSegmentTest(value);
}

[Fact]
public void GenericReferenceStructTest()
{
GenericStruct<string> value = new GenericStruct<string>() { Id = 75, Value = "Hello World!" };

RunMultiSegmentTest(value);
}

[Fact]
public void LargeGenericReferenceStructTest()
{
GenericStruct<PrePaddedString> value = new() { Id = 75, Value = new PrePaddedString() { Value = "Hello World!" } };

RunMultiSegmentTest(value);
}

private void RunMultiSegmentTest<T>(T value)
{
byte[] bytes = MemoryPackSerializer.Serialize(value);

byte[] firstHalf = new byte[bytes.Length / 2];
Array.Copy(bytes, 0, firstHalf, 0, firstHalf.Length);

int secondHalfLength = bytes.Length / 2;
if (bytes.Length % 2 != 0)
{
secondHalfLength++;
}

byte[] secondHalf = new byte[secondHalfLength];

Array.Copy(bytes, firstHalf.Length, secondHalf, 0, secondHalfLength);

ReadOnlySequence<byte> sequence = ReadOnlySequenceBuilder.Create(firstHalf, secondHalf);

T? result = MemoryPackSerializer.Deserialize<T>(sequence);
result.Should().Be(value);
}

[MemoryPackable]
public partial struct GenericStruct<T>
{
public int Id;
public T Value;

public override string ToString()
{
return $"{Id}, {Value}";
}
}

[StructLayout(LayoutKind.Explicit, Size = 516)]
struct PrePaddedInt
{
[FieldOffset(512)]
public int Value;
}

[MemoryPackable]
private partial class PrePaddedString : IEquatable<PrePaddedString>
{
private PrePaddedInt _padding;
public string Value { get; set; } = "";

public bool Equals(PrePaddedString? other)
{
if (other is null)
return false;

return Value.Equals(other.Value);
}

public override bool Equals(object? obj)
{
if (obj is PrePaddedString other)
return Equals(other);
return false;
}

public override int GetHashCode()
{
return Value.GetHashCode();
}
}

class RandomStream : Stream
{
Expand Down

0 comments on commit baf6c26

Please sign in to comment.