Skip to content

Commit 31a0750

Browse files
RUST-1945 Add a with_type method to the Aggregate action (#1100)
1 parent e06744b commit 31a0750

File tree

3 files changed

+132
-14
lines changed

3 files changed

+132
-14
lines changed

src/action/aggregate.rs

+58-14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::time::Duration;
1+
use std::{marker::PhantomData, time::Duration};
22

33
use bson::Document;
44

@@ -24,15 +24,17 @@ impl Database {
2424
/// See the documentation [here](https://www.mongodb.com/docs/manual/aggregation/) for more
2525
/// information on aggregations.
2626
///
27-
/// `await` will return d[`Result<Cursor<Document>>`] or d[`Result<SessionCursor<Document>>`] if
28-
/// a `ClientSession` is provided.
27+
/// `await` will return d[`Result<Cursor<Document>>`]. If a [`ClientSession`] was provided, the
28+
/// returned cursor will be a [`SessionCursor`]. If [`with_type`](Aggregate::with_type) was
29+
/// called, the returned cursor will be generic over the `T` specified.
2930
#[deeplink]
3031
pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
3132
Aggregate {
3233
target: AggregateTargetRef::Database(self),
3334
pipeline: pipeline.into_iter().collect(),
3435
options: None,
3536
session: ImplicitSession,
37+
_phantom: PhantomData,
3638
}
3739
}
3840
}
@@ -46,15 +48,17 @@ where
4648
/// See the documentation [here](https://www.mongodb.com/docs/manual/aggregation/) for more
4749
/// information on aggregations.
4850
///
49-
/// `await` will return d[`Result<Cursor<Document>>`] or d[`Result<SessionCursor<Document>>`] if
50-
/// a [`ClientSession`] is provided.
51+
/// `await` will return d[`Result<Cursor<Document>>`]. If a [`ClientSession`] was provided, the
52+
/// returned cursor will be a [`SessionCursor`]. If [`with_type`](Aggregate::with_type) was
53+
/// called, the returned cursor will be generic over the `T` specified.
5154
#[deeplink]
5255
pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
5356
Aggregate {
5457
target: AggregateTargetRef::Collection(CollRef::new(self)),
5558
pipeline: pipeline.into_iter().collect(),
5659
options: None,
5760
session: ImplicitSession,
61+
_phantom: PhantomData,
5862
}
5963
}
6064
}
@@ -66,8 +70,10 @@ impl crate::sync::Database {
6670
/// See the documentation [here](https://www.mongodb.com/docs/manual/aggregation/) for more
6771
/// information on aggregations.
6872
///
69-
/// [`run`](Aggregate::run) will return d[`Result<crate::sync::Cursor<Document>>`] or
70-
/// d[`Result<crate::sync::SessionCursor<Document>>`] if a [`ClientSession`] is provided.
73+
/// [`run`](Aggregate::run) will return d[Result<crate::sync::Cursor<Document>>`]. If a
74+
/// [`crate::sync::ClientSession`] was provided, the returned cursor will be a
75+
/// [`crate::sync::SessionCursor`]. If [`with_type`](Aggregate::with_type) was called, the
76+
/// returned cursor will be generic over the `T` specified.
7177
#[deeplink]
7278
pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
7379
self.async_database.aggregate(pipeline)
@@ -84,8 +90,10 @@ where
8490
/// See the documentation [here](https://www.mongodb.com/docs/manual/aggregation/) for more
8591
/// information on aggregations.
8692
///
87-
/// [`run`](Aggregate::run) will return d[`Result<crate::sync::Cursor<Document>>`] or
88-
/// d[`Result<crate::sync::SessionCursor<Document>>`] if a `ClientSession` is provided.
93+
/// [`run`](Aggregate::run) will return d[Result<crate::sync::Cursor<Document>>`]. If a
94+
/// `crate::sync::ClientSession` was provided, the returned cursor will be a
95+
/// `crate::sync::SessionCursor`. If [`with_type`](Aggregate::with_type) was called, the
96+
/// returned cursor will be generic over the `T` specified.
8997
#[deeplink]
9098
pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
9199
self.async_collection.aggregate(pipeline)
@@ -95,14 +103,15 @@ where
95103
/// Run an aggregation operation. Construct with [`Database::aggregate`] or
96104
/// [`Collection::aggregate`].
97105
#[must_use]
98-
pub struct Aggregate<'a, Session = ImplicitSession> {
106+
pub struct Aggregate<'a, Session = ImplicitSession, T = Document> {
99107
target: AggregateTargetRef<'a>,
100108
pipeline: Vec<Document>,
101109
options: Option<AggregateOptions>,
102110
session: Session,
111+
_phantom: PhantomData<T>,
103112
}
104113

105-
impl<'a, Session> Aggregate<'a, Session> {
114+
impl<'a, Session, T> Aggregate<'a, Session, T> {
106115
option_setters!(options: AggregateOptions;
107116
allow_disk_use: bool,
108117
batch_size: u32,
@@ -130,15 +139,50 @@ impl<'a> Aggregate<'a, ImplicitSession> {
130139
pipeline: self.pipeline,
131140
options: self.options,
132141
session: ExplicitSession(value.into()),
142+
_phantom: PhantomData,
133143
}
134144
}
135145
}
136146

137-
#[action_impl(sync = crate::sync::Cursor<Document>)]
138-
impl<'a> Action for Aggregate<'a, ImplicitSession> {
147+
impl<'a, Session> Aggregate<'a, Session, Document> {
148+
/// Use the provided type for the returned cursor.
149+
///
150+
/// ```rust
151+
/// # use futures_util::TryStreamExt;
152+
/// # use mongodb::{bson::Document, error::Result, Cursor, Database};
153+
/// # use serde::Deserialize;
154+
/// # async fn run() -> Result<()> {
155+
/// # let database: Database = todo!();
156+
/// # let pipeline: Vec<Document> = todo!();
157+
/// #[derive(Deserialize)]
158+
/// struct PipelineOutput {
159+
/// len: usize,
160+
/// }
161+
///
162+
/// let aggregate_cursor = database
163+
/// .aggregate(pipeline)
164+
/// .with_type::<PipelineOutput>()
165+
/// .await?;
166+
/// let aggregate_results: Vec<PipelineOutput> = aggregate_cursor.try_collect().await?;
167+
/// # Ok(())
168+
/// # }
169+
/// ```
170+
pub fn with_type<T>(self) -> Aggregate<'a, Session, T> {
171+
Aggregate {
172+
target: self.target,
173+
pipeline: self.pipeline,
174+
options: self.options,
175+
session: self.session,
176+
_phantom: PhantomData,
177+
}
178+
}
179+
}
180+
181+
#[action_impl(sync = crate::sync::Cursor<T>)]
182+
impl<'a, T> Action for Aggregate<'a, ImplicitSession, T> {
139183
type Future = AggregateFuture;
140184

141-
async fn execute(mut self) -> Result<Cursor<Document>> {
185+
async fn execute(mut self) -> Result<Cursor<T>> {
142186
resolve_options!(
143187
self.target,
144188
self.options,

src/test/coll.rs

+43
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::{
2727
results::DeleteResult,
2828
test::{get_client_options, log_uncaptured, util::TestClient, EventClient},
2929
Collection,
30+
Cursor,
3031
IndexModel,
3132
};
3233

@@ -1306,3 +1307,45 @@ async fn insert_many_document_sequences() {
13061307
let second_batch_len = second_event.command.get_array("documents").unwrap().len();
13071308
assert_eq!(first_batch_len + second_batch_len, total_docs);
13081309
}
1310+
1311+
#[tokio::test]
1312+
async fn aggregate_with_generics() {
1313+
#[derive(Serialize)]
1314+
struct A {
1315+
str: String,
1316+
}
1317+
1318+
#[derive(Deserialize)]
1319+
struct B {
1320+
len: i32,
1321+
}
1322+
1323+
let client = TestClient::new().await;
1324+
let collection = client
1325+
.database("aggregate_with_generics")
1326+
.collection::<A>("aggregate_with_generics");
1327+
1328+
let a = A {
1329+
str: "hi".to_string(),
1330+
};
1331+
let len = a.str.len();
1332+
collection.insert_one(&a).await.unwrap();
1333+
1334+
// Assert at compile-time that the default cursor returned is a Cursor<Document>
1335+
let basic_pipeline = vec![doc! { "$match": { "a": 1 } }];
1336+
let _: Cursor<Document> = collection.aggregate(basic_pipeline).await.unwrap();
1337+
1338+
// Assert that data is properly deserialized when using with_type
1339+
let project_pipeline = vec![doc! { "$project": {
1340+
"str": 1,
1341+
"len": { "$strLenBytes": "$str" }
1342+
}
1343+
}];
1344+
let cursor = collection
1345+
.aggregate(project_pipeline)
1346+
.with_type::<B>()
1347+
.await
1348+
.unwrap();
1349+
let lens: Vec<B> = cursor.try_collect().await.unwrap();
1350+
assert_eq!(lens[0].len as usize, len);
1351+
}

src/test/db.rs

+31
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::cmp::Ord;
22

33
use futures::stream::TryStreamExt;
4+
use serde::Deserialize;
45

56
use crate::{
67
action::Action,
@@ -17,6 +18,7 @@ use crate::{
1718
results::{CollectionSpecification, CollectionType},
1819
test::util::TestClient,
1920
Client,
21+
Cursor,
2022
Database,
2123
};
2224

@@ -378,3 +380,32 @@ async fn clustered_index_list_collections() {
378380
.unwrap();
379381
assert!(clustered_index_collection.options.clustered_index.is_some());
380382
}
383+
384+
#[tokio::test]
385+
async fn aggregate_with_generics() {
386+
#[derive(Deserialize)]
387+
struct A {
388+
str: String,
389+
}
390+
391+
let client = TestClient::new().await;
392+
let database = client.database("aggregate_with_generics");
393+
394+
if client.server_version_lt(5, 1) {
395+
log_uncaptured(
396+
"skipping aggregate_with_generics: $documents agg stage only available on 5.1+",
397+
);
398+
return;
399+
}
400+
401+
// The cursor returned will contain these documents
402+
let pipeline = vec![doc! { "$documents": [ { "str": "hi" } ] }];
403+
404+
// Assert at compile-time that the default cursor returned is a Cursor<Document>
405+
let _: Cursor<Document> = database.aggregate(pipeline.clone()).await.unwrap();
406+
407+
// Assert that data is properly deserialized when using with_type
408+
let mut cursor = database.aggregate(pipeline).with_type::<A>().await.unwrap();
409+
assert!(cursor.advance().await.unwrap());
410+
assert_eq!(&cursor.deserialize_current().unwrap().str, "hi");
411+
}

0 commit comments

Comments
 (0)