diff --git a/sdk/core/azure_core/src/lib.rs b/sdk/core/azure_core/src/lib.rs index b0dd236578..719021f226 100644 --- a/sdk/core/azure_core/src/lib.rs +++ b/sdk/core/azure_core/src/lib.rs @@ -51,8 +51,8 @@ pub use typespec_client_core::xml; pub use typespec_client_core::{ base64, date, http::{ - headers::Header, new_http_client, AppendToUrlQuery, Body, Context, Continuable, HttpClient, - Method, Pageable, Request, RequestContent, StatusCode, Url, + headers::Header, new_http_client, AppendToUrlQuery, Body, Context, HttpClient, Method, + Pager, Request, RequestContent, StatusCode, Url, }, json, parsing, sleep::{self, sleep}, diff --git a/sdk/cosmos/azure_data_cosmos/examples/cosmos/query.rs b/sdk/cosmos/azure_data_cosmos/examples/cosmos/query.rs index c084392624..43383d1a76 100644 --- a/sdk/cosmos/azure_data_cosmos/examples/cosmos/query.rs +++ b/sdk/cosmos/azure_data_cosmos/examples/cosmos/query.rs @@ -34,10 +34,8 @@ impl QueryCommand { container_client.query_items::(&self.query, pk, None)?; while let Some(page) = items_pager.next().await { - let response = page?; + let response = page?.deserialize_body().await?; println!("Results Page"); - println!(" Query Metrics: {:?}", response.query_metrics); - println!(" Index Metrics: {:?}", response.index_metrics); println!(" Items:"); for item in response.items { println!(" * {:#?}", item); diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs index e2af9f2848..c9f008e5bd 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs @@ -10,8 +10,9 @@ use crate::{ ItemOptions, PartitionKey, Query, QueryPartitionStrategy, }; -use azure_core::{Context, Request, Response}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use azure_core::{Context, Pager, Request, Response}; +use serde::{de::DeserializeOwned, Serialize}; +use typespec_client_core::http::PagerResult; use url::Url; #[cfg(doc)] @@ -44,13 +45,13 @@ pub trait ContainerClientMethods { async fn read( &self, options: Option, - ) -> azure_core::Result>; + ) -> azure_core::Result>; /// Creates a new item in the container. /// /// # Arguments /// * `partition_key` - The partition key of the new item. - /// * `item` - The item to create. The type must implement [`Serialize`] and [`Deserialize`] + /// * `item` - The item to create. The type must implement [`Serialize`] and [`Deserialize`](serde::Deserialize) /// * `options` - Optional parameters for the request /// /// # Examples @@ -87,14 +88,14 @@ pub trait ContainerClientMethods { partition_key: impl Into, item: T, options: Option, - ) -> azure_core::Result>>; + ) -> azure_core::Result>>; /// Replaces an existing item in the container. /// /// # Arguments /// * `partition_key` - The partition key of the item to replace. /// * `item_id` - The id of the item to replace. - /// * `item` - The item to create. The type must implement [`Serialize`] and [`Deserialize`] + /// * `item` - The item to create. The type must implement [`Serialize`] and [`Deserialize`](serde::Deserialize) /// * `options` - Optional parameters for the request /// /// # Examples @@ -132,7 +133,7 @@ pub trait ContainerClientMethods { item_id: impl AsRef, item: T, options: Option, - ) -> azure_core::Result>>; + ) -> azure_core::Result>>; /// Creates or replaces an item in the container. /// @@ -141,7 +142,7 @@ pub trait ContainerClientMethods { /// /// # Arguments /// * `partition_key` - The partition key of the item to create or replace. - /// * `item` - The item to create. The type must implement [`Serialize`] and [`Deserialize`] + /// * `item` - The item to create. The type must implement [`Serialize`] and [`Deserialize`](serde::Deserialize) /// * `options` - Optional parameters for the request /// /// # Examples @@ -178,7 +179,7 @@ pub trait ContainerClientMethods { partition_key: impl Into, item: T, options: Option, - ) -> azure_core::Result>>; + ) -> azure_core::Result>>; /// Reads a specific item from the container. /// @@ -216,7 +217,7 @@ pub trait ContainerClientMethods { partition_key: impl Into, item_id: impl AsRef, options: Option, - ) -> azure_core::Result>>; + ) -> azure_core::Result>>; /// Deletes an item from the container. /// @@ -243,7 +244,7 @@ pub trait ContainerClientMethods { partition_key: impl Into, item_id: impl AsRef, options: Option, - ) -> azure_core::Result; + ) -> azure_core::Result; /// Executes a single-partition query against items in the container. /// @@ -304,7 +305,7 @@ pub trait ContainerClientMethods { query: impl Into, partition_key: impl Into, options: Option, - ) -> azure_core::Result, azure_core::Error>>; + ) -> azure_core::Result>>; } /// A client for working with a specific container in a Cosmos DB account. @@ -333,7 +334,7 @@ impl ContainerClientMethods for ContainerClient { #[allow(unused_variables)] // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, - ) -> azure_core::Result> { + ) -> azure_core::Result> { let mut req = Request::new(self.container_url.clone(), azure_core::Method::Get); self.pipeline .send(Context::new(), &mut req, ResourceType::Containers) @@ -348,7 +349,7 @@ impl ContainerClientMethods for ContainerClient { #[allow(unused_variables)] // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, - ) -> azure_core::Result>> { + ) -> azure_core::Result>> { let url = self.container_url.with_path_segments(["docs"]); let mut req = Request::new(url, azure_core::Method::Post); req.insert_headers(&partition_key.into())?; @@ -367,7 +368,7 @@ impl ContainerClientMethods for ContainerClient { #[allow(unused_variables)] // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, - ) -> azure_core::Result>> { + ) -> azure_core::Result>> { let url = self .container_url .with_path_segments(["docs", item_id.as_ref()]); @@ -387,7 +388,7 @@ impl ContainerClientMethods for ContainerClient { #[allow(unused_variables)] // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, - ) -> azure_core::Result>> { + ) -> azure_core::Result>> { let url = self.container_url.with_path_segments(["docs"]); let mut req = Request::new(url, azure_core::Method::Post); req.insert_header(constants::IS_UPSERT, "true"); @@ -406,7 +407,7 @@ impl ContainerClientMethods for ContainerClient { #[allow(unused_variables)] // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, - ) -> azure_core::Result>> { + ) -> azure_core::Result>> { let url = self .container_url .with_path_segments(["docs", item_id.as_ref()]); @@ -425,7 +426,7 @@ impl ContainerClientMethods for ContainerClient { #[allow(unused_variables)] // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, - ) -> azure_core::Result { + ) -> azure_core::Result { let url = self .container_url .with_path_segments(["docs", item_id.as_ref()]); @@ -444,26 +445,9 @@ impl ContainerClientMethods for ContainerClient { #[allow(unused_variables)] // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, - ) -> azure_core::Result, azure_core::Error>> { - // Represents the raw response model from the server. - // We'll use this to deserialize the response body and then convert it to a more user-friendly model. - #[derive(Deserialize)] - struct QueryResponseModel { - #[serde(rename = "Documents")] - documents: Vec, - } - - // We have to manually implement Model, because the derive macro doesn't support auto-inferring type and lifetime bounds. - // See https://github.com/Azure/azure-sdk-for-rust/issues/1803 - impl azure_core::Model for QueryResponseModel { - async fn from_response_body( - body: azure_core::ResponseBody, - ) -> typespec_client_core::Result { - body.json().await - } - } - - let url = self.container_url.with_path_segments(["docs"]); + ) -> azure_core::Result>> { + let mut url = self.container_url.clone(); + url.append_path_segments(["docs"]); let mut base_req = Request::new(url, azure_core::Method::Post); base_req.insert_header(constants::QUERY, "True"); @@ -477,7 +461,7 @@ impl ContainerClientMethods for ContainerClient { // We have to double-clone here. // First we clone the pipeline to pass it in to the closure let pipeline = self.pipeline.clone(); - Ok(azure_core::Pageable::new(move |continuation| { + Ok(Pager::from_callback(move |continuation| { // Then we have to clone it again to pass it in to the async block. // This is because Pageable can't borrow any data, it has to own it all. // That's probably good, because it means a Pageable can outlive the client that produced it, but it requires some extra cloning. @@ -488,29 +472,13 @@ impl ContainerClientMethods for ContainerClient { req.insert_header(constants::CONTINUATION, continuation); } - let resp: Response> = pipeline + let response = pipeline .send(Context::new(), &mut req, ResourceType::Items) .await?; - - let query_metrics = resp - .headers() - .get_optional_string(&constants::QUERY_METRICS); - let index_metrics = resp - .headers() - .get_optional_string(&constants::INDEX_METRICS); - let continuation_token = - resp.headers().get_optional_string(&constants::CONTINUATION); - - let query_response: QueryResponseModel = resp.deserialize_body().await?; - - let query_results = QueryResults { - items: query_response.documents, - query_metrics, - index_metrics, - continuation_token, - }; - - Ok(query_results) + Ok(PagerResult::from_response_header( + response, + &constants::CONTINUATION, + )) } })) } diff --git a/sdk/cosmos/azure_data_cosmos/src/lib.rs b/sdk/cosmos/azure_data_cosmos/src/lib.rs index 50639d00fd..945847fabe 100644 --- a/sdk/cosmos/azure_data_cosmos/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos/src/lib.rs @@ -11,7 +11,7 @@ #![cfg_attr(docsrs, feature(doc_cfg_hide))] pub mod clients; -pub(crate) mod constants; +pub mod constants; mod options; mod partition_key; pub(crate) mod pipeline; diff --git a/sdk/cosmos/azure_data_cosmos/src/models/mod.rs b/sdk/cosmos/azure_data_cosmos/src/models/mod.rs index 0c55691749..63e191e6fc 100644 --- a/sdk/cosmos/azure_data_cosmos/src/models/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/models/mod.rs @@ -3,8 +3,8 @@ //! Model types sent to and received from the Cosmos DB API. -use azure_core::{date::OffsetDateTime, Continuable, Model}; -use serde::{Deserialize, Deserializer}; +use azure_core::{date::OffsetDateTime, Model}; +use serde::{de::DeserializeOwned, Deserialize, Deserializer}; #[cfg(doc)] use crate::{ @@ -37,19 +37,17 @@ where /// A page of query results, where each item is a document of type `T`. #[non_exhaustive] -#[derive(Clone, Default, Debug)] +#[derive(Clone, Default, Debug, Deserialize)] pub struct QueryResults { + #[serde(rename = "Documents")] pub items: Vec, - pub query_metrics: Option, - pub index_metrics: Option, - pub continuation_token: Option, } -impl Continuable for QueryResults { - type Continuation = String; - - fn continuation(&self) -> Option { - self.continuation_token.clone() +impl azure_core::Model for QueryResults { + async fn from_response_body( + body: azure_core::ResponseBody, + ) -> typespec_client_core::Result { + body.json().await } } diff --git a/sdk/typespec/typespec_client_core/src/http/mod.rs b/sdk/typespec/typespec_client_core/src/http/mod.rs index 2e906b06e4..cf8ea37c94 100644 --- a/sdk/typespec/typespec_client_core/src/http/mod.rs +++ b/sdk/typespec/typespec_client_core/src/http/mod.rs @@ -8,7 +8,7 @@ mod context; pub mod headers; mod models; mod options; -mod pageable; +mod pager; mod pipeline; pub mod policies; pub mod request; @@ -19,7 +19,7 @@ pub use context::*; pub use headers::Header; pub use models::*; pub use options::*; -pub use pageable::*; +pub use pager::*; pub use pipeline::*; pub use request::{Body, Request, RequestContent}; pub use response::{Model, Response}; diff --git a/sdk/typespec/typespec_client_core/src/http/pageable.rs b/sdk/typespec/typespec_client_core/src/http/pageable.rs deleted file mode 100644 index 398648d36d..0000000000 --- a/sdk/typespec/typespec_client_core/src/http/pageable.rs +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -use futures::{stream::unfold, Stream}; - -/// Helper macro for unwrapping `Result`s into the right types that `futures::stream::unfold` expects. -macro_rules! r#try { - ($expr:expr $(,)?) => { - match $expr { - ::std::result::Result::Ok(val) => val, - ::std::result::Result::Err(err) => { - return Some((Err(err.into()), State::Done)); - } - } - }; -} - -/// Helper macro for declaring the `Pageable` and `Continuable` types which easily allows -/// for conditionally compiling with a `Send` constraint or not. -macro_rules! declare { - ($($extra:tt)*) => { - // The use of a module here is a hack to get around the fact that `pin_project` - // generates a method `project_ref` which is never used and generates a warning. - // The module allows us to declare that `dead_code` is allowed but only for - // the `Pageable` type. - mod pageable { - #![allow(dead_code)] - - /// A pageable stream that yields items of type `T` - /// - /// Internally uses a specific continuation header to - /// make repeated requests to the service yielding a new page each time. - #[pin_project::pin_project] - // This is to suppress the unused `project_ref` warning - pub struct Pageable { - #[pin] - pub(crate) stream: ::std::pin::Pin> $($extra)*>>, - } - } - pub use pageable::Pageable; - - impl Pageable - where - T: Continuable, - { - pub fn new( - make_request: impl Fn(Option) -> F + Clone $($extra)* + 'static, - ) -> Self - where - F: ::std::future::Future> $($extra)* + 'static, - { - let stream = unfold(State::Init, move |state: State| { - let make_request = make_request.clone(); - async move { - let response = match state { - State::Init => { - let request = make_request(None); - r#try!(request.await) - } - State::Continuation(token) => { - let request = make_request(Some(token)); - r#try!(request.await) - } - State::Done => { - return None; - } - }; - - let next_state = response - .continuation() - .map_or(State::Done, State::Continuation); - - Some((Ok(response), next_state)) - } - }); - Self { - stream: Box::pin(stream), - } - } - } - - /// A type that can yield an optional continuation token - pub trait Continuable { - type Continuation: 'static $($extra)*; - fn continuation(&self) -> Option; - } - }; -} - -#[cfg(not(target_arch = "wasm32"))] -declare!(+ Send); -#[cfg(target_arch = "wasm32")] -declare!(); - -impl Stream for Pageable { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); - this.stream.poll_next(cx) - } -} - -impl std::fmt::Debug for Pageable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Pageable").finish_non_exhaustive() - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -enum State { - Init, - Continuation(T), - Done, -} diff --git a/sdk/typespec/typespec_client_core/src/http/pager.rs b/sdk/typespec/typespec_client_core/src/http/pager.rs new file mode 100644 index 0000000000..e21ff34467 --- /dev/null +++ b/sdk/typespec/typespec_client_core/src/http/pager.rs @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::{future::Future, pin::Pin}; + +use futures::{stream::unfold, Stream}; +use typespec::Error; + +use crate::http::{headers::HeaderName, Response}; + +pub enum PagerResult { + Continue { + response: Response, + continuation: C, + }, + Complete { + response: Response, + }, +} + +impl PagerResult { + /// Creates a [`PagerResult`] from the provided response, extracting the continuation value from the provided header. + /// + /// If the provided response has a header with the matching name, this returns [`PagerResult::Continue`], using the value from the header as the continuation. + /// If the provided response does not have a header with the matching name, this returns [`PagerResult::Complete`]. + pub fn from_response_header(response: Response, header_name: &HeaderName) -> Self { + match response.headers().get_optional_string(header_name) { + Some(continuation) => PagerResult::Continue { + response, + continuation, + }, + None => PagerResult::Complete { response }, + } + } +} + +/// Represents a paginated result across multiple requests. +#[pin_project::pin_project] +pub struct Pager { + #[pin] + #[cfg(not(target_arch = "wasm32"))] + stream: Pin, Error>> + Send>>, + + #[pin] + #[cfg(target_arch = "wasm32")] + stream: Pin, Error>>>>, +} + +impl Pager { + /// Creates a [`Pager`] from a callback that will be called repeatedly to request each page. + /// + /// This method expect a callback that accepts a single `Option` parameter, and returns a `(Response, Option)` tuple, asynchronously. + /// The `C` type parameter is the type of the continuation. It may be any [`Send`]able type. + /// The result will be an asynchronous stream of [`Result>`](typespec::Result>) values. + /// + /// The first time your callback is called, it will be called with [`Option::None`], indicating no continuation value is present. + /// Your callback must return one of: + /// * `Ok((response, Some(continuation)))` - The request succeeded, and return a response `response` and a continuation value `continuation`. The response will be yielded to the stream and the callback will be called again immediately with `Some(continuation)`. + /// * `Ok((response, None))` - The request succeeded, and there are no more pages. The response will be yielded to the stream, the stream will end, and the callback will not be called again. + /// * `Err(..)` - The request failed. The error will be yielded to the stream, the stream will end, and the callback will not be called again. + /// + /// ## Examples + /// + /// ```rust,no_run + /// # use typespec_client_core::http::{Context, Pager, PagerResult, Pipeline, Request, Response, Method, headers::HeaderName}; + /// # let pipeline: Pipeline = panic!("Not a runnable example"); + /// # struct MyModel; + /// let url = "https://example.com/my_paginated_api".parse().unwrap(); + /// let mut base_req = Request::new(url, Method::Get); + /// let pager = Pager::from_callback(move |continuation| { + /// // The callback must be 'static, so you have to clone and move any values you want to use. + /// let pipeline = pipeline.clone(); + /// let mut req = base_req.clone(); + /// async move { + /// if let Some(continuation) = continuation { + /// req.insert_header("x-continuation", continuation); + /// } + /// let resp: Response = pipeline + /// .send(&Context::new(), &mut req) + /// .await?; + /// Ok(PagerResult::from_response_header(resp, &HeaderName::from_static("x-next-continuation"))) + /// } + /// }); + /// ``` + pub fn from_callback< + // This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds. + #[cfg(not(target_arch = "wasm32"))] C: Send + 'static, + #[cfg(not(target_arch = "wasm32"))] F: Fn(Option) -> Fut + Send + 'static, + #[cfg(not(target_arch = "wasm32"))] Fut: Future, typespec::Error>> + Send + 'static, + #[cfg(target_arch = "wasm32")] C: 'static, + #[cfg(target_arch = "wasm32")] F: Fn(Option) -> Fut + 'static, + #[cfg(target_arch = "wasm32")] Fut: Future, typespec::Error>> + 'static, + >( + make_request: F, + ) -> Self { + let stream = unfold( + // We flow the `make_request` callback through the state value so that we can avoid cloning. + (State::Init, make_request), + |(state, make_request)| async move { + let result = match state { + State::Init => make_request(None).await, + State::Continuation(c) => make_request(Some(c)).await, + State::Done => return None, + }; + let (response, next_state) = match result { + Err(e) => return Some((Err(e), (State::Done, make_request))), + Ok(PagerResult::Continue { + response, + continuation, + }) => (Ok(response), State::Continuation(continuation)), + Ok(PagerResult::Complete { response }) => (Ok(response), State::Done), + }; + + // Flow 'make_request' through to avoid cloning + Some((response, (next_state, make_request))) + }, + ); + Self { + stream: Box::pin(stream), + } + } +} + +impl futures::Stream for Pager { + type Item = Result, Error>; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().stream.poll_next(cx) + } +} + +impl std::fmt::Debug for Pager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Pager").finish_non_exhaustive() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum State { + Init, + Continuation(T), + Done, +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use futures::StreamExt; + use serde::Deserialize; + use typespec_derive::Model; + + use crate::http::{ + headers::{HeaderName, HeaderValue}, + Pager, PagerResult, Response, StatusCode, + }; + + #[tokio::test] + pub async fn standard_pagination() { + #[derive(Model, Deserialize, Debug, PartialEq, Eq)] + #[typespec(crate = "crate")] + struct Page { + pub page: usize, + } + + let pager: Pager = Pager::from_callback(|continuation| async move { + match continuation { + None => Ok(PagerResult::Continue { + response: Response::from_bytes( + StatusCode::Ok, + HashMap::from([( + HeaderName::from_static("x-test-header"), + HeaderValue::from_static("page-1"), + )]) + .into(), + r#"{"page":1}"#, + ), + continuation: "1", + }), + Some("1") => Ok(PagerResult::Continue { + response: Response::from_bytes( + StatusCode::Ok, + HashMap::from([( + HeaderName::from_static("x-test-header"), + HeaderValue::from_static("page-2"), + )]) + .into(), + r#"{"page":2}"#, + ), + continuation: "2", + }), + Some("2") => Ok(PagerResult::Complete { + response: Response::from_bytes( + StatusCode::Ok, + HashMap::from([( + HeaderName::from_static("x-test-header"), + HeaderValue::from_static("page-3"), + )]) + .into(), + r#"{"page":3}"#, + ), + }), + _ => { + panic!("Unexpected continuation value") + } + } + }); + let pages: Vec<(String, Page)> = pager + .then(|r| async move { + let r = r.unwrap(); + let header = r + .headers() + .get_optional_string(&HeaderName::from_static("x-test-header")) + .unwrap(); + let body = r.deserialize_body().await.unwrap(); + (header, body) + }) + .collect() + .await; + assert_eq!( + &[ + ("page-1".to_string(), Page { page: 1 }), + ("page-2".to_string(), Page { page: 2 }), + ("page-3".to_string(), Page { page: 3 }), + ], + pages.as_slice() + ) + } + + #[tokio::test] + pub async fn error_stops_pagination() { + #[derive(Model, Deserialize, Debug, PartialEq, Eq)] + #[typespec(crate = "crate")] + struct Page { + pub page: usize, + } + + let pager: Pager = Pager::from_callback(|continuation| async move { + match continuation { + None => Ok(PagerResult::Continue { + response: Response::from_bytes( + StatusCode::Ok, + HashMap::from([( + HeaderName::from_static("x-test-header"), + HeaderValue::from_static("page-1"), + )]) + .into(), + r#"{"page":1}"#, + ), + continuation: "1", + }), + Some("1") => Err(typespec::Error::message( + typespec::error::ErrorKind::Other, + "yon request didst fail", + )), + _ => { + panic!("Unexpected continuation value") + } + } + }); + let pages: Vec> = pager + .then(|r| async move { + let r = r?; + let header = r + .headers() + .get_optional_string(&HeaderName::from_static("x-test-header")) + .unwrap(); + let body = r.deserialize_body().await?; + Ok((header, body)) + }) + .collect() + .await; + assert_eq!(2, pages.len()); + assert_eq!( + &("page-1".to_string(), Page { page: 1 }), + pages[0].as_ref().unwrap() + ); + + let err = pages[1].as_ref().unwrap_err(); + assert_eq!(&typespec::error::ErrorKind::Other, err.kind()); + assert_eq!("yon request didst fail", format!("{}", err)); + } +}