diff --git a/src/Microsoft.AspNetCore.OData/Microsoft.AspNetCore.OData.csproj b/src/Microsoft.AspNetCore.OData/Microsoft.AspNetCore.OData.csproj index 507f2d72..7b06a384 100644 --- a/src/Microsoft.AspNetCore.OData/Microsoft.AspNetCore.OData.csproj +++ b/src/Microsoft.AspNetCore.OData/Microsoft.AspNetCore.OData.csproj @@ -1,4 +1,4 @@ - + net8.0 @@ -35,6 +35,7 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/src/Microsoft.AspNetCore.OData/Query/Container/TruncatedAsyncEnumerableOfT.cs b/src/Microsoft.AspNetCore.OData/Query/Container/TruncatedAsyncEnumerableOfT.cs new file mode 100644 index 00000000..0ce6a4fb --- /dev/null +++ b/src/Microsoft.AspNetCore.OData/Query/Container/TruncatedAsyncEnumerableOfT.cs @@ -0,0 +1,71 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) .NET Foundation and Contributors. All rights reserved. +// See License.txt in the project root for license information. +// +//------------------------------------------------------------------------------ + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.OData.Query.Container; + +public class TruncatedAsyncEnumerable : IAsyncEnumerable +{ + private readonly IAsyncEnumerable _source; + private readonly int _pageSize; + private readonly TruncationState _state; + + /// + /// Initializes a new instance of the class, which provides an + /// asynchronous enumerable that limits the number of items returned per page and tracks truncation state. + /// + /// The source asynchronous enumerable to be paginated and truncated. + /// The maximum number of items to include in each page. Must be greater than zero. + /// The truncation state object used to track whether the enumeration was truncated. + public TruncatedAsyncEnumerable(IAsyncEnumerable source, int pageSize, TruncationState state) + { + _source = source; + _pageSize = pageSize; + _state = state; + } + + /// + /// Returns an asynchronous enumerator that iterates through the items in the source collection, up to a specified page size. + /// + /// The enumerator yields items from the source collection until the specified page size is reached. + /// If the number of items exceeds the page size, the enumeration is truncated, and the state is updated to true. Otherwise, the state is updated to false. + /// A token to monitor for cancellation requests. If the token is canceled, the enumeration is stopped. + /// An asynchronous enumerator that yields items from the source collection, up to the specified page size. + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + int count = 0; + await foreach (var item in _source.WithCancellation(cancellationToken)) + { + if (count < _pageSize) + { + yield return item; + count++; + } + else + { + // More items exist than pageSize, so mark as truncated and stop yielding. + _state.IsTruncated = true; + yield break; + } + } + + // If we didn't hit the limit, not truncated. + _state.IsTruncated = false; + } +} + + +/// +/// Used to track the truncation state of an async enumerable. +/// +public class TruncationState +{ + public bool IsTruncated { get; set; } +} diff --git a/src/Microsoft.AspNetCore.OData/Query/Container/TruncatedCollectionOfT.cs b/src/Microsoft.AspNetCore.OData/Query/Container/TruncatedCollectionOfT.cs index fc62d83f..59590304 100644 --- a/src/Microsoft.AspNetCore.OData/Query/Container/TruncatedCollectionOfT.cs +++ b/src/Microsoft.AspNetCore.OData/Query/Container/TruncatedCollectionOfT.cs @@ -6,8 +6,11 @@ //------------------------------------------------------------------------------ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.OData.Query.Container; @@ -15,16 +18,49 @@ namespace Microsoft.AspNetCore.OData.Query.Container; /// Represents a class that truncates a collection to a given page size. /// /// The collection element type. -public class TruncatedCollection : List, ITruncatedCollection, IEnumerable, ICountOptionCollection +public class TruncatedCollection : IReadOnlyList, ITruncatedCollection, ICountOptionCollection, IAsyncEnumerable { - // The default capacity of the list. - // https://github.com/dotnet/runtime/blob/main/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/List.cs#L23 - private const int DefaultCapacity = 4; private const int MinPageSize = 1; + private const int DefaultCapacity = 4; + + private readonly List _items; + private readonly IAsyncEnumerable _asyncSource; + + private readonly bool _isTruncated; + + private readonly TruncationState _isTruncatedState; + + /// + /// Private constructor used by static Create methods and public constructors. + /// + /// The list of items in the collection. + /// The maximum number of items per page. + /// The total number of items in the source collection, if known. + /// Indicates whether the collection is truncated. + private TruncatedCollection(List items, int pageSize, long? totalCount, bool isTruncated) + { + _items = items; + _isTruncated = isTruncated; + PageSize = pageSize; + TotalCount = totalCount; + } + + /// + /// Private constructor used by static Create methods and public constructors. + /// + /// The asynchronous source of items (pageSize + 1) in the collection. + /// The maximum number of items per page. + /// The total number of items in the source collection, if known. + /// State to indicate whether the collection is truncated. + private TruncatedCollection(IAsyncEnumerable asyncSource, int pageSize, long? totalCount, TruncationState isTruncatedState) + { + _asyncSource = asyncSource; + _isTruncatedState = isTruncatedState; + PageSize = pageSize; + TotalCount = totalCount; + } - private bool _isTruncated; - private int _pageSize; - private long? _totalCount; + #region Constructors for Backward Compatibility /// /// Initializes a new instance of the class. @@ -32,23 +68,24 @@ public class TruncatedCollection : List, ITruncatedCollection, IEnumerable /// The collection to be truncated. /// The page size. public TruncatedCollection(IEnumerable source, int pageSize) - : base(checked(pageSize + 1)) - { - var items = source.Take(Capacity); - AddRange(items); - Initialize(pageSize); - } + : this(CreateInternal(source, pageSize, totalCount: null)) { } /// /// Initializes a new instance of the class. /// /// The queryable collection to be truncated. /// The page size. - // NOTE: The queryable version calls Queryable.Take which actually gets translated to the backend query where as - // the enumerable version just enumerates and is inefficient. - public TruncatedCollection(IQueryable source, int pageSize) : this(source, pageSize, false) - { - } + /// The total count. + public TruncatedCollection(IEnumerable source, int pageSize, long? totalCount) + : this(CreateInternal(source, pageSize, totalCount)) { } + + /// + /// Initializes a new instance of the class. + /// + /// The queryable collection to be truncated. + /// The page size. + public TruncatedCollection(IQueryable source, int pageSize) + : this(CreateInternal(source, pageSize, false)) { } /// /// Initializes a new instance of the class. @@ -56,120 +93,233 @@ public TruncatedCollection(IQueryable source, int pageSize) : this(source, pa /// The queryable collection to be truncated. /// The page size. /// Flag indicating whether constants should be parameterized - // NOTE: The queryable version calls Queryable.Take which actually gets translated to the backend query where as - // the enumerable version just enumerates and is inefficient. public TruncatedCollection(IQueryable source, int pageSize, bool parameterize) - : base(checked(pageSize + 1)) + : this(CreateInternal(source, pageSize, parameterize)) { } + + /// + /// Wrapper used internally by the backward-compatible constructors. + /// + /// An instance of . + private TruncatedCollection(TruncatedCollection other) + : this(other._items, other.PageSize, other.TotalCount, other._isTruncated) { - var items = Take(source, pageSize, parameterize); - AddRange(items); - Initialize(pageSize); } + #endregion + + #region Static Create Methods + /// - /// Initializes a new instance of the class. + /// Create a truncated collection from an . /// - /// The queryable collection to be truncated. + /// The collection to be truncated. /// The page size. - /// The total count. - public TruncatedCollection(IEnumerable source, int pageSize, long? totalCount) - : base(pageSize > 0 - ? checked(pageSize + 1) - : (totalCount > 0 ? (totalCount < int.MaxValue ? (int)totalCount : int.MaxValue) : DefaultCapacity)) + /// An instance of the + public static TruncatedCollection Create(IEnumerable source, int pageSize) { - if (pageSize > 0) - { - AddRange(source.Take(Capacity)); - } - else - { - AddRange(source); - } + return CreateInternal(source, pageSize, null); + } - if (pageSize > 0) - { - Initialize(pageSize); - } + /// + /// Create a truncated collection from an . + /// + /// The collection to be truncated. + /// The page size. + /// The total count. + /// An instance of the + public static TruncatedCollection Create(IEnumerable source, int pageSize, long? totalCount) + { + return CreateInternal(source, pageSize, totalCount); + } - _totalCount = totalCount; + /// + /// Create a truncated collection from an . + /// + /// The collection to be truncated. + /// The page size. + /// An instance of the + public static TruncatedCollection Create(IQueryable source, int pageSize) + { + return CreateInternal(source, pageSize, false, null); } /// - /// Initializes a new instance of the class. + /// Create a truncated collection from an . /// - /// The queryable collection to be truncated. + /// The collection to be truncated. /// The page size. /// The total count. - // NOTE: The queryable version calls Queryable.Take which actually gets translated to the backend query where as - // the enumerable version just enumerates and is inefficient. - [Obsolete("should not be used, will be marked internal in the next major version")] - public TruncatedCollection(IQueryable source, int pageSize, long? totalCount) : this(source, pageSize, - totalCount, false) + /// An instance of the + public static TruncatedCollection Create(IQueryable source, int pageSize, long? totalCount) { + return CreateInternal(source, pageSize, false, totalCount); } /// - /// Initializes a new instance of the class. + /// Create a truncated collection from an . /// - /// The queryable collection to be truncated. + /// The collection to be truncated. /// The page size. - /// The total count. /// Flag indicating whether constants should be parameterized - // NOTE: The queryable version calls Queryable.Take which actually gets translated to the backend query where as - // the enumerable version just enumerates and is inefficient. + /// An instance of the + public static TruncatedCollection Create(IQueryable source, int pageSize, bool parameterize) + { + return CreateInternal(source, pageSize, parameterize); + } + + /// + /// Create a truncated collection from an . + /// + /// The collection to be truncated. + /// The page size. + /// The total count. Default is null. + /// Flag indicating whether constants should be parameterized + /// An instance of the [Obsolete("should not be used, will be marked internal in the next major version")] - public TruncatedCollection(IQueryable source, int pageSize, long? totalCount, bool parameterize) - : base(pageSize > 0 ? Take(source, pageSize, parameterize) : source) + public static TruncatedCollection Create(IQueryable source, int pageSize, long? totalCount, bool parameterize) { - if (pageSize > 0) - { - Initialize(pageSize); - } + return CreateInternal(source, pageSize, parameterize, totalCount); + } - _totalCount = totalCount; + /// + /// Create an async truncated collection from an . + /// + /// The AsyncEnumerable to be truncated. + /// The page size. + /// /// The total count. Default null. + /// Cancellation token for async operations. Default. + /// An instance of the + public static TruncatedCollection CreateForAsync(IAsyncEnumerable source, int pageSize, long? totalCount = null, CancellationToken cancellationToken = default) + { + return CreateInternal(source, pageSize, totalCount, cancellationToken); } - private void Initialize(int pageSize) + #endregion + + #region Core Internal (Sync/Async) + + private static TruncatedCollection CreateInternal(IEnumerable source, int pageSize, long? totalCount) { - if (pageSize < MinPageSize) + ValidateArgs(source, pageSize); + + int capacity = pageSize > 0 ? checked(pageSize + 1) : (totalCount > 0 ? (totalCount < int.MaxValue ? (int)totalCount : int.MaxValue) : DefaultCapacity); + var items = source.Take(capacity); + + var smallPossibleCount = capacity < items.Count() ? items.Count() : capacity; + var buffer = new List(smallPossibleCount); + buffer.AddRange(items); + + bool isTruncated = buffer.Count > pageSize; + if (isTruncated) { - throw Error.ArgumentMustBeGreaterThanOrEqualTo("pageSize", pageSize, MinPageSize); + buffer.RemoveAt(buffer.Count - 1); } - _pageSize = pageSize; + return new TruncatedCollection(buffer, pageSize, totalCount, isTruncated: isTruncated); + } + + private static TruncatedCollection CreateInternal(IQueryable source, int pageSize, bool parameterize = false, long? totalCount = null) + { + ValidateArgs(source, pageSize); + + int capacity = pageSize > 0 ? pageSize : (totalCount > 0 ? (totalCount < int.MaxValue ? (int)totalCount : int.MaxValue) : DefaultCapacity); + var items = Take(source, capacity, parameterize); - if (Count > pageSize) + int count = 0; + var buffer = new List(pageSize); + using IEnumerator enumerator = items.GetEnumerator(); + while (count < pageSize && enumerator.MoveNext()) { - _isTruncated = true; - RemoveAt(Count - 1); + buffer.Add(enumerator.Current); + count++; } + + return new TruncatedCollection(buffer, pageSize, totalCount, isTruncated: enumerator.MoveNext()); + } + + private static TruncatedCollection CreateInternal(IAsyncEnumerable source, int pageSize, long? totalCount, CancellationToken cancellationToken = default) + { + ValidateArgs(source, pageSize); + + int capacity = pageSize > 0 ? pageSize : (totalCount > 0 ? (totalCount < int.MaxValue ? (int)totalCount : int.MaxValue) : DefaultCapacity); + + var state = new TruncationState(); + var truncatedSource = new TruncatedAsyncEnumerable(source, capacity, state); + return new TruncatedCollection(truncatedSource, pageSize, totalCount, state); } private static IQueryable Take(IQueryable source, int pageSize, bool parameterize) { - if (source == null) + // This uses existing ExpressionHelpers from OData to apply Take(pageSize + 1) + return (IQueryable)ExpressionHelpers.Take(source, checked(pageSize + 1), typeof(T), parameterize); + } + + private static void ValidateArgs(object source, int pageSize) + { + ArgumentNullException.ThrowIfNull(source); + + if (pageSize < MinPageSize) { - throw Error.ArgumentNull("source"); + throw Error.ArgumentMustBeGreaterThanOrEqualTo("pageSize", pageSize, MinPageSize); } - - return ExpressionHelpers.Take(source, checked(pageSize + 1), typeof(T), parameterize) as IQueryable; } - /// - public int PageSize + #endregion + + /// + public int PageSize { get; } + /// + public long? TotalCount { get; } + /// + public bool IsTruncated => _isTruncatedState?.IsTruncated ?? _isTruncated; + + /// + public int Count { - get { return _pageSize; } + get + { + if (_items != null) + { + return _items.Count; + } + else if (_asyncSource != null) + { + throw Error.InvalidOperation("Count cannot be accessed synchronously for an asynchronous source. Use CountAsync instead."); + } + + return 0; + } } - /// - public bool IsTruncated + /// + public async Task CountAsync() { - get { return _isTruncated; } + if (_items != null) + { + return await Task.FromResult(_items.Count); + } + else if (_asyncSource != null) + { + return await _asyncSource.CountAsync().ConfigureAwait(false); + } + + return 0; } - /// - public long? TotalCount + /// + public T this[int index] => _items[index]; + + /// + public IEnumerator GetEnumerator() => _items?.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => _items?.GetEnumerator(); + + /// + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - get { return _totalCount; } + if (_asyncSource == null) + { + throw new InvalidOperationException("Async enumeration is not supported for sync-only instances."); + } + return _asyncSource.GetAsyncEnumerator(cancellationToken); } } diff --git a/test/Microsoft.AspNetCore.OData.E2E.Tests/IAsyncEnumerableTests/IAsyncEnumerableController.cs b/test/Microsoft.AspNetCore.OData.E2E.Tests/IAsyncEnumerableTests/IAsyncEnumerableController.cs index e73e29c1..dced0d04 100644 --- a/test/Microsoft.AspNetCore.OData.E2E.Tests/IAsyncEnumerableTests/IAsyncEnumerableController.cs +++ b/test/Microsoft.AspNetCore.OData.E2E.Tests/IAsyncEnumerableTests/IAsyncEnumerableController.cs @@ -55,7 +55,7 @@ public ActionResult> CustomersDataNew() return Ok(_context.Customers.AsAsyncEnumerable()); } - [EnableQuery] + [EnableQuery(PageSize = 2)] [HttpGet("v3/Customers")] public IActionResult SearchCustomersForV3Route([FromQuery] Variant variant = Variant.None) { diff --git a/test/Microsoft.AspNetCore.OData.Tests/Query/Container/TruncatedCollectionOfTTest.cs b/test/Microsoft.AspNetCore.OData.Tests/Query/Container/TruncatedCollectionOfTTest.cs index 2868e724..1ac59cb2 100644 --- a/test/Microsoft.AspNetCore.OData.Tests/Query/Container/TruncatedCollectionOfTTest.cs +++ b/test/Microsoft.AspNetCore.OData.Tests/Query/Container/TruncatedCollectionOfTTest.cs @@ -6,10 +6,11 @@ //------------------------------------------------------------------------------ using System; -using Microsoft.AspNetCore.OData.Query.Container; -using Microsoft.AspNetCore.OData.Tests.Commons; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.OData.Query.Container; +using Microsoft.AspNetCore.OData.Tests.Commons; using Xunit; namespace Microsoft.AspNetCore.OData.Tests.Query.Container; @@ -50,6 +51,51 @@ public void CtorTruncatedCollection_SetsProperties() Assert.Equal(new[] { 1, 2, 3 }, collection); } + [Fact] + public void TruncatedCollectionCreateForIEnumerable_SetsProperties() + { + // Arrange & Act + IEnumerable source = new[] { 1, 2, 3, 5, 7 }; + var collection = TruncatedCollection.Create(source, 3, 5); + + // Assert + Assert.Equal(3, collection.PageSize); + Assert.Equal(5, collection.TotalCount); + Assert.True(collection.IsTruncated); + Assert.Equal(3, collection.Count); + Assert.Equal(new[] { 1, 2, 3 }, collection); + } + + [Fact] + public async Task TruncatedCollectionCreateForIAsyncEnumerable_SetsProperties() + { + // Arrange & Act + IAsyncEnumerable source = new[] { 1, 2, 3, 5, 7 }.ToAsyncEnumerable(); + var collection = TruncatedCollection.CreateForAsyncSource(source, 3, 5); + + // Assert + Assert.Equal(3, collection.PageSize); + Assert.Equal(5, collection.TotalCount); + Assert.Equal(3, await collection.CountAsync()); + Assert.Equal(new[] { 1, 2, 3 }, await collection.ToArrayAsync()); + Assert.True(collection.IsTruncated); + } + + [Fact] + public void TruncatedCollectionCreateForIQueryable_SetsProperties() + { + // Arrange & Act + IQueryable source = new[] { 1, 2, 3, 5, 7 }.AsQueryable(); + var collection = TruncatedCollection.Create(source, 3, 5); + + // Assert + Assert.Equal(3, collection.PageSize); + Assert.Equal(5, collection.TotalCount); + Assert.True(collection.IsTruncated); + Assert.Equal(3, collection.Count); + Assert.Equal(new[] { 1, 2, 3 }, collection); + } + [Fact] [Obsolete] public void CtorTruncatedCollection_WithQueryable_SetsProperties()