diff --git a/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Server/BatchRequestTestsController.cs b/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Server/BatchRequestTestsController.cs new file mode 100644 index 0000000000..2bae264cbe --- /dev/null +++ b/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Server/BatchRequestTestsController.cs @@ -0,0 +1,135 @@ +//--------------------------------------------------------------------- +// +// Copyright (C) Microsoft Corporation. All rights reserved. See License.txt in the project root for license information. +// +//--------------------------------------------------------------------- + +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.OData.Deltas; +using Microsoft.AspNetCore.OData.Query; +using Microsoft.AspNetCore.OData.Routing.Controllers; +using Microsoft.OData.Client.E2E.Tests.Common.Server.Default; + +namespace Microsoft.OData.Client.E2E.Tests.BatchRequestTests.Server; + +public class BatchRequestTestsController : ODataController +{ + private static DefaultDataSource _dataSource; + + [EnableQuery] + [HttpGet("odata/Accounts")] + public IActionResult GetAccounts() + { + var result = _dataSource.Accounts; + + return Ok(result); + } + + [EnableQuery] + [HttpGet("odata/Accounts({key})")] + public IActionResult GetAccount([FromRoute] int key) + { + var result = _dataSource.Accounts?.SingleOrDefault(a => a.AccountID == key); + + if (result == null) + { + return NotFound(); + } + + return Ok(result); + } + + [EnableQuery] + [HttpGet("odata/Accounts({key})/MyPaymentInstruments")] + public IActionResult GetAccountMyPaymentInstruments([FromRoute] int key) + { + var result = _dataSource.Accounts?.SingleOrDefault(a => a.AccountID == key); + + if (result == null) + { + return NotFound(); + } + + return Ok(result.MyPaymentInstruments); + } + + [EnableQuery] + [HttpGet("odata/Accounts({accountKey})/MyPaymentInstruments({myPaymentInstrumentKey})/BillingStatements({billingStatementKey})")] + public IActionResult GetAccountBillingStatement([FromRoute] int accountKey, [FromRoute] int myPaymentInstrumentKey, [FromRoute] int billingStatementKey) + { + var account = _dataSource.Accounts?.SingleOrDefault(a => a.AccountID == accountKey); + var myPaymentInstrument = account?.MyPaymentInstruments?.SingleOrDefault(a => a.PaymentInstrumentID == myPaymentInstrumentKey); + var billingStatement = myPaymentInstrument?.BillingStatements?.SingleOrDefault(a => a.StatementID == billingStatementKey); + if (billingStatement == null) + { + return NotFound(); + } + + return Ok(billingStatement); + } + + [HttpPatch("odata/Accounts({key})")] + public IActionResult AddAccount([FromRoute] int key, [FromBody] Delta delta) + { + var account = _dataSource.Accounts?.SingleOrDefault(a => a.AccountID == key); + if (account == null) + { + return NotFound(); + } + + var updated = delta.Patch(account); + return Ok(updated); + } + + [HttpPost("odata/Accounts")] + public IActionResult AddAccount([FromBody] Account account) + { + _dataSource.Accounts?.Add(account); + + return Created(account); + } + + [HttpPost("odata/Accounts({key})/MyPaymentInstruments")] + public IActionResult GetAccountBillingStatement([FromRoute] int key, [FromBody] PaymentInstrument paymentInstrument) + { + var account = _dataSource.Accounts?.SingleOrDefault(a => a.AccountID == key); + if (account == null) + { + return NotFound(); + } + + account.MyPaymentInstruments ??= []; + account.MyPaymentInstruments.Add(paymentInstrument); + + return Created(paymentInstrument); + } + + [HttpPost("odata/Accounts({accountKey})/MyPaymentInstruments({myPaymentInstrumentKey})/BillingStatements")] + public IActionResult GetAccountBillingStatement([FromRoute] int accountKey, [FromRoute] int myPaymentInstrumentKey, [FromBody] Statement statement) + { + var account = _dataSource.Accounts?.SingleOrDefault(a => a.AccountID == accountKey); + if (account == null) + { + return NotFound(); + } + + var paymentInstrument = account.MyPaymentInstruments?.SingleOrDefault(a => a.PaymentInstrumentID == myPaymentInstrumentKey); + if (paymentInstrument == null) + { + return NotFound(); + } + + paymentInstrument.BillingStatements ??= []; + paymentInstrument.BillingStatements.Add(statement); + + return Created(statement); + } + + [HttpPost("odata/batchrequests/Default.ResetDefaultDataSource")] + public IActionResult ResetDefaultDataSource() + { + _dataSource = DefaultDataSource.CreateInstance(); + + return Ok(); + } +} diff --git a/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Tests/BatchRequestClientTests.cs b/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Tests/BatchRequestClientTests.cs new file mode 100644 index 0000000000..6ab4fe4cab --- /dev/null +++ b/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Tests/BatchRequestClientTests.cs @@ -0,0 +1,203 @@ +//--------------------------------------------------------------------- +// +// Copyright (C) Microsoft Corporation. All rights reserved. See License.txt in the project root for license information. +// +//--------------------------------------------------------------------- + +using Microsoft.AspNetCore.OData; +using Microsoft.AspNetCore.OData.Batch; +using Microsoft.AspNetCore.OData.Routing.Controllers; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.OData.Client.E2E.TestCommon; +using Microsoft.OData.Client.E2E.Tests.BatchRequestTests.Server; +using Microsoft.OData.Client.E2E.Tests.Common.Client.Default.Default; +using Microsoft.OData.Client.E2E.Tests.Common.Server.Default; +using Xunit; +using Account = Microsoft.OData.Client.E2E.Tests.Common.Client.Default.Account; +using AccountInfo = Microsoft.OData.Client.E2E.Tests.Common.Client.Default.AccountInfo; +using PaymentInstrument = Microsoft.OData.Client.E2E.Tests.Common.Client.Default.PaymentInstrument; +using Statement = Microsoft.OData.Client.E2E.Tests.Common.Client.Default.Statement; + +namespace Microsoft.OData.Client.E2E.Tests.BatchRequestTests.Tests; + +public class BatchRequestClientTests : EndToEndTestBase +{ + private readonly Uri _baseUri; + private readonly Container _context; + + public class TestsStartup : TestStartupBase + { + public override void ConfigureServices(IServiceCollection services) + { + services.ConfigureControllers(typeof(BatchRequestTestsController), typeof(MetadataController)); + + services.AddControllers().AddOData(opt => + { + opt.EnableQueryFeatures().AddRouteComponents("odata", DefaultEdmModel.GetEdmModel(), new DefaultODataBatchHandler()); + opt.RouteOptions.EnableNonParenthesisForEmptyParameterFunction = true; + }); + } + } + + public BatchRequestClientTests(TestWebApplicationFactory fixture) : base(fixture) + { + if (Client.BaseAddress == null) + { + throw new ArgumentNullException(nameof(Client.BaseAddress), "Base address cannot be null"); + } + + _baseUri = new Uri(Client.BaseAddress, "odata/"); + + _context = new Container(_baseUri) + { + HttpClientFactory = HttpClientFactory + }; + + ResetDefaultDataSource(); + } + + [Fact] + public async Task AddAndModifiedBatchRequestsTest() + { + // Arrange + + // POST Requests + var account = new Account + { + AccountID = 110, + CountryRegion = "US" + }; + _context.AddToAccounts(account); + + var paymentInstrument = new PaymentInstrument + { + PaymentInstrumentID = 102910, + FriendlyName = "102 batch new PI", + CreatedDate = new DateTimeOffset(new DateTime(2013, 12, 29, 11, 11, 57)) + }; + _context.AddRelatedObject(account, "MyPaymentInstruments", paymentInstrument); + + var billingStatement = new Statement + { + StatementID = 102910010, + TransactionDescription = "Digital goods: PC", + Amount = 1000 + }; + _context.AddRelatedObject(paymentInstrument, "BillingStatements", billingStatement); + + // PATCH Request + var accountToUpdate = _context.Accounts.Where(a => a.AccountID == 107).Single(); + Assert.NotNull(accountToUpdate); + var now = DateTimeOffset.Now; + accountToUpdate.UpdatedTime = now; + accountToUpdate.AccountInfo = new AccountInfo + { + FirstName = "John", + LastName = "Doe" + }; + _context.UpdateObject(accountToUpdate); + + // Act + var response = await _context.SaveChangesAsync(SaveChangesOptions.BatchWithSingleChangeset | SaveChangesOptions.UseRelativeUri); + + // Assert + Assert.Equal(4, response.Count()); + + var changeResponses = response.Cast(); + + Assert.All(changeResponses, changeResponse => Assert.True(changeResponse.StatusCode == 200 || changeResponse.StatusCode == 201)); + + var accountCreatedResponse = changeResponses.FirstOrDefault(r => r.StatusCode == 201 && r.Descriptor is EntityDescriptor descriptor && descriptor.Entity is Account); + var accountUpdatedResponse = changeResponses.FirstOrDefault(r => r.StatusCode == 200 && r.Descriptor is EntityDescriptor descriptor && descriptor.Entity is Account); + var paymentInstrumentCreatedResponse = changeResponses.FirstOrDefault(r => r.Descriptor is EntityDescriptor descriptor && descriptor.Entity is PaymentInstrument); + var statementCreatedResponse = changeResponses.FirstOrDefault(r => r.Descriptor is EntityDescriptor descriptor && descriptor.Entity is Statement); + + // Account 110 is created + Assert.NotNull(accountCreatedResponse); + var accountCreated = Assert.IsType((accountCreatedResponse.Descriptor as EntityDescriptor)?.Entity); + Assert.NotNull(accountCreated); + Assert.Equal(110, accountCreated.AccountID); + Assert.Equal("US", accountCreated.CountryRegion); + Assert.Null(accountCreated.AccountInfo); + + // Account 107 is updated + Assert.NotNull(accountUpdatedResponse); + var accountUpdated = Assert.IsType((accountUpdatedResponse.Descriptor as EntityDescriptor)?.Entity); + Assert.NotNull(accountUpdated); + Assert.Equal(107, accountUpdated.AccountID); + Assert.Equal("FR", accountUpdated.CountryRegion); + Assert.Equal(now, accountUpdated.UpdatedTime); + Assert.Equal("John", accountUpdated.AccountInfo.FirstName); + Assert.Equal("Doe", accountUpdated.AccountInfo.LastName); + + // PaymentInstrument 102910 is created + Assert.NotNull(paymentInstrumentCreatedResponse); + var paymentInstrumentCreated = Assert.IsType((paymentInstrumentCreatedResponse.Descriptor as EntityDescriptor)?.Entity); + Assert.NotNull(paymentInstrumentCreated); + Assert.Equal(102910, paymentInstrumentCreated.PaymentInstrumentID); + Assert.Equal("102 batch new PI", paymentInstrumentCreated.FriendlyName); + + // Statement 102910010 is created + Assert.NotNull(statementCreatedResponse); + var statementCreated = Assert.IsType((statementCreatedResponse.Descriptor as EntityDescriptor)?.Entity); + Assert.NotNull(statementCreated); + Assert.Equal(102910010, statementCreated.StatementID); + Assert.Equal(1000, statementCreated.Amount); + Assert.Equal("Digital goods: PC", statementCreated.TransactionDescription); + } + + [Fact] + public async Task QueryBatchRequestsTest() + { + // Arrange + var batchRequest = new DataServiceRequest[] + { + new DataServiceRequest(new Uri(_baseUri + "Accounts(102)")), + new DataServiceRequest(new Uri(_baseUri + "Accounts(101)/MyPaymentInstruments")), + new DataServiceRequest(new Uri(_baseUri + "Accounts(103)/MyPaymentInstruments(103901)/BillingStatements(103901001)")) + }; + + // Act + var response = await _context.ExecuteBatchAsync(SaveChangesOptions.BatchWithSingleChangeset | SaveChangesOptions.UseRelativeUri, batchRequest); + + // Assert + foreach (var operationResponse in response) + { + if (operationResponse is QueryOperationResponse paymentInstrumentResponse) + { + Assert.Equal(200, paymentInstrumentResponse.StatusCode); + var paymentInstruments = paymentInstrumentResponse.ToList(); + Assert.Equal(3, paymentInstruments.Count); + } + + else if (operationResponse is QueryOperationResponse statementResponse) + { + Assert.Equal(200, statementResponse.StatusCode); + var statements = statementResponse.ToList(); + Assert.Single(statements); + Assert.Equal(103901001, statements[0].StatementID); + Assert.Equal(100, statements[0].Amount); + Assert.Equal("Digital goods: App", statements[0].TransactionDescription); + } + + else if (operationResponse is QueryOperationResponse accountResponse) + { + Assert.Equal(200, accountResponse.StatusCode); + var accounts = accountResponse.ToList(); + Assert.Single(accounts); + Assert.Equal(102, accounts[0].AccountID); + Assert.Equal("GB", accounts[0].CountryRegion); + } + } + } + + #region Private methods + + private void ResetDefaultDataSource() + { + var actionUri = new Uri(_baseUri + "batchrequests/Default.ResetDefaultDataSource", UriKind.Absolute); + _context.Execute(actionUri, "POST"); + } + + #endregion +} diff --git a/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Tests/BatchRequestWithRelativeUriTests.cs b/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Tests/BatchRequestWithRelativeUriTests.cs new file mode 100644 index 0000000000..5a167e6e01 --- /dev/null +++ b/test/EndToEndTests/Tests/Client/Microsoft.OData.Client.E2E.Tests/BatchRequestTests/Tests/BatchRequestWithRelativeUriTests.cs @@ -0,0 +1,265 @@ +//--------------------------------------------------------------------- +// +// Copyright (C) Microsoft Corporation. All rights reserved. See License.txt in the project root for license information. +// +//--------------------------------------------------------------------- + +using Microsoft.AspNetCore.OData; +using Microsoft.AspNetCore.OData.Batch; +using Microsoft.AspNetCore.OData.Routing.Controllers; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.OData.Client.E2E.TestCommon; +using Microsoft.OData.Client.E2E.TestCommon.Common; +using Microsoft.OData.Client.E2E.Tests.BatchRequestTests.Server; +using Microsoft.OData.Client.E2E.Tests.Common.Client.Default.Default; +using Microsoft.OData.Client.E2E.Tests.Common.Server.Default; +using Microsoft.OData.Edm; +using Xunit; + +namespace Microsoft.OData.Client.E2E.Tests.BatchRequestTests.Tests; + +public class BatchRequestWithRelativeUriTests : EndToEndTestBase +{ + private const string NameSpacePrefix = "Microsoft.OData.Client.E2E.Tests.Common.Server.Default"; + + private readonly Uri _baseUri; + private readonly Container _context; + private readonly IEdmModel _model; + + public class TestsStartup : TestStartupBase + { + public override void ConfigureServices(IServiceCollection services) + { + services.ConfigureControllers(typeof(BatchRequestTestsController), typeof(MetadataController)); + + services.AddControllers().AddOData(opt => + opt.EnableQueryFeatures().AddRouteComponents("odata", DefaultEdmModel.GetEdmModel(), new DefaultODataBatchHandler())); + } + } + + public BatchRequestWithRelativeUriTests(TestWebApplicationFactory fixture) + : base(fixture) + { + if (Client.BaseAddress == null) + { + throw new ArgumentNullException(nameof(Client.BaseAddress), "Base address cannot be null"); + } + + _baseUri = new Uri(Client.BaseAddress, "odata/"); + + _context = new Container(_baseUri) + { + HttpClientFactory = HttpClientFactory + }; + + _model = DefaultEdmModel.GetEdmModel(); + ResetDefaultDataSource(); + } + + [Theory] + [InlineData(BatchPayloadUriOption.AbsoluteUri)] + [InlineData(BatchPayloadUriOption.RelativeUri)] + [InlineData(BatchPayloadUriOption.AbsoluteUriUsingHostHeader)] + public async Task BatchRequestWithResourcePathTest(BatchPayloadUriOption option) + { + // Arrange + var writerSettings = new ODataMessageWriterSettings + { + BaseUri = _baseUri, + EnableMessageStreamDisposal = false, // Ensure the stream is not disposed of prematurely + }; + + var accountType = _model.FindDeclaredType($"{NameSpacePrefix}.Account") as IEdmEntityType; + Assert.NotNull(accountType); + + var accountSet = _model.EntityContainer.FindEntitySet("Accounts"); + Assert.NotNull(accountSet); + + var paymentInstrumentType = _model.FindDeclaredType($"{NameSpacePrefix}.PaymentInstrument") as IEdmEntityType; + Assert.NotNull(paymentInstrumentType); + + var navProp = accountType.FindProperty("MyPaymentInstruments") as IEdmNavigationProperty; + Assert.NotNull(navProp); + + var myPaymentInstrumentSet = accountSet.FindNavigationTarget(navProp); + Assert.NotNull(myPaymentInstrumentSet); + + var requestUrl = new Uri(_baseUri.AbsoluteUri + "$batch", UriKind.Absolute); + var requestMessage = new TestHttpClientRequestMessage(requestUrl, Client); + requestMessage.SetHeader("Content-Type", "multipart/mixed;boundary=batch_01AD6766-4A45-47CC-9463-94D4591D8DA9"); + requestMessage.SetHeader("OData-Version", "4.0"); + requestMessage.Method = "POST"; + + await using (var messageWriter = new ODataMessageWriter(requestMessage, writerSettings, _model)) + { + var batchWriter = await messageWriter.CreateODataBatchWriterAsync(); + + // Batch start. + await batchWriter.WriteStartBatchAsync(); + + // A Get request. + requestUrl = new Uri(_baseUri + "Accounts(101)/MyPaymentInstruments"); + var batchOperation1 = await batchWriter.CreateOperationRequestMessageAsync("GET", requestUrl, null, option); + batchOperation1.SetHeader("Accept", "application/json;odata.metadata=full"); + // Get request ends. + + // Changeset start. + await batchWriter.WriteStartChangesetAsync(); + + // The first operation in changeset is a Create request. + requestUrl = new Uri(_baseUri + "Accounts(102)/MyPaymentInstruments"); + var batchChangesetOperation1 = await batchWriter.CreateOperationRequestMessageAsync("POST", requestUrl, "1", option); + batchChangesetOperation1.SetHeader("Content-Type", "application/json;odata.metadata=full"); + batchChangesetOperation1.SetHeader("Accept", "application/json;odata.metadata=full"); + + var paymentInstrumentEntry = new ODataResource() { TypeName = $"{NameSpacePrefix}.PaymentInstrument" }; + var paymentInstrumentEntryP1 = new ODataProperty { Name = "PaymentInstrumentID", Value = 102910 }; + var paymentInstrumentEntryP2 = new ODataProperty { Name = "FriendlyName", Value = "102 batch new PI" }; + var paymentInstrumentEntryP3 = new ODataProperty { Name = "CreatedDate", Value = new DateTimeOffset(new DateTime(2013, 12, 29, 11, 11, 57)) }; + paymentInstrumentEntry.Properties = [paymentInstrumentEntryP1, paymentInstrumentEntryP2, paymentInstrumentEntryP3]; + + await using (var entryMessageWriter = new ODataMessageWriter(batchChangesetOperation1)) + { + var odataEntryWriter = await entryMessageWriter.CreateODataResourceWriterAsync(myPaymentInstrumentSet, paymentInstrumentType); + await odataEntryWriter.WriteStartAsync(paymentInstrumentEntry); + await odataEntryWriter.WriteEndAsync(); + } + + // Changeset end. + await batchWriter.WriteEndChangesetAsync(); + + // Another Get request. + requestUrl = new Uri(_baseUri + "Accounts(103)/MyPaymentInstruments(103901)/BillingStatements(103901001)"); + var batchOperation2 = await batchWriter.CreateOperationRequestMessageAsync("GET", requestUrl, null, option); + batchOperation2.SetHeader("Accept", "application/json;odata.metadata=full"); + + // Batch end. + await batchWriter.WriteEndBatchAsync(); + } + + // Act + var responseMessage = await requestMessage.GetResponseAsync(); + + // Assert + Assert.Equal(200, responseMessage.StatusCode); + + await ProcessBatchResponseAsync(responseMessage); + } + + #region Private methods + + private async Task ProcessBatchResponseAsync(IODataResponseMessageAsync responseMessage) + { + ODataMessageReaderSettings readerSettings = new() { BaseUri = _baseUri }; + + using (var innerMessageReader = new ODataMessageReader(responseMessage, readerSettings, _model)) + { + var batchReader = await innerMessageReader.CreateODataBatchReaderAsync(); + int batchOperationId = 0; + + while (await batchReader.ReadAsync()) + { + switch (batchReader.State) + { + case ODataBatchReaderState.Initial: + break; + case ODataBatchReaderState.ChangesetStart: + break; + case ODataBatchReaderState.ChangesetEnd: + break; + case ODataBatchReaderState.Operation: + var operationResponse = await batchReader.CreateOperationResponseMessageAsync(); + + using (var operationResponseReader = new ODataMessageReader(operationResponse, readerSettings, _model)) + { + if (batchOperationId == 0) + { + // the first response message is a feed + var feedReader = await operationResponseReader.CreateODataResourceSetReaderAsync(); + + Assert.Equal(200, operationResponse.StatusCode); + + var pis = new List(); + while (await feedReader.ReadAsync()) + { + if (feedReader.State == ODataReaderState.ResourceEnd) + { + var entry = feedReader.Item as ODataResource; + Assert.NotNull(entry); + pis.Add(entry); + } + } + + Assert.Equal(ODataReaderState.Completed, feedReader.State); + + Assert.Equal(3, pis.Count); + } + else if (batchOperationId == 1) + { + // the second response message is a creation response + var entryReader = await operationResponseReader.CreateODataResourceReaderAsync(); + + Assert.Equal(201, operationResponse.StatusCode); + + var pis = new List(); + while (await entryReader.ReadAsync()) + { + if (entryReader.State == ODataReaderState.ResourceEnd) + { + var entry = entryReader.Item as ODataResource; + Assert.NotNull(entry); + pis.Add(entry); + } + } + + Assert.Equal(ODataReaderState.Completed, entryReader.State); + + Assert.Single(pis); + var paymentInstrumentIDProperty = pis[0].Properties.Single(p => p.Name == "PaymentInstrumentID") as ODataProperty; + Assert.NotNull(paymentInstrumentIDProperty); + Assert.Equal(102910, paymentInstrumentIDProperty.Value); + } + else if (batchOperationId == 2) + { + // the third response message is an entry + var entryReader = await operationResponseReader.CreateODataResourceReaderAsync(); + + Assert.Equal(200, operationResponse.StatusCode); + + var statements = new List(); + while (await entryReader.ReadAsync()) + { + + if (entryReader.State == ODataReaderState.ResourceEnd) + { + var entry = entryReader.Item as ODataResource; + Assert.NotNull(entry); + statements.Add(entry); + } + } + + Assert.Equal(ODataReaderState.Completed, entryReader.State); + + Assert.Single(statements); + var statementIDProperty = statements[0].Properties.Single(p => p.Name == "StatementID") as ODataProperty; + Assert.NotNull(statementIDProperty); + Assert.Equal(103901001, statementIDProperty.Value); + } + } + + batchOperationId++; + break; + } + } + Assert.Equal(ODataBatchReaderState.Completed, batchReader.State); + } + } + + private void ResetDefaultDataSource() + { + var actionUri = new Uri(_baseUri + "batchrequests/Default.ResetDefaultDataSource", UriKind.Absolute); + _context.Execute(actionUri, "POST"); + } + + #endregion +}