Skip to content

Commit bce3c98

Browse files
feat: add distributed integration test
1 parent b97e18e commit bce3c98

File tree

10 files changed

+1249
-24
lines changed

10 files changed

+1249
-24
lines changed

Cargo.lock

Lines changed: 477 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/otlp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async fn main() -> Result<()> {
8080
#[instrument(level = "info")]
8181
async fn run_otlp_example() -> Result<()> {
8282
// Initialize the DataFusion session context.
83-
let ctx = init_session(true, true, 5, true).await?;
83+
let ctx = init_session(true, true, 5, true, false).await?;
8484

8585
// Run the SQL query with tracing enabled.
8686
run_traced_query(&ctx, QUERY_NAME).await?;

integration-utils/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,19 @@ rust-version = { workspace = true }
3131
workspace = true
3232

3333
[dependencies]
34+
arrow = "56.1"
35+
arrow-flight = "56.1"
36+
async-trait = "0.1"
3437
datafusion = { workspace = true, features = ["parquet", "nested_expressions"] }
38+
datafusion-distributed = { git = "https://github.com/datafusion-contrib/datafusion-distributed", branch = "main" }
3539
datafusion-tracing = { workspace = true }
40+
futures = { workspace = true }
41+
hyper-util = "0.1"
3642
instrumented-object-store = { workspace = true }
3743
object_store = { version = "0.12.1", default-features = false }
44+
tokio = { workspace = true, features = ["full"] }
45+
tokio-stream = "0.1"
46+
tonic = { version = "0.13", features = ["transport"] }
47+
tower = "0.5"
3848
tracing = { workspace = true }
3949
url = { version = "2.5" }
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
SELECT
2+
count(*),
3+
"MinTemp"
4+
FROM weather
5+
GROUP BY "MinTemp"
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
//
18+
// This product includes software developed at Datadog (https://www.datadoghq.com/) Copyright 2025 Datadog, Inc.
19+
20+
use arrow_flight::flight_service_client::FlightServiceClient;
21+
use arrow_flight::flight_service_server::FlightServiceServer;
22+
use async_trait::async_trait;
23+
use datafusion::common::DataFusionError;
24+
use datafusion::execution::SessionStateBuilder;
25+
use datafusion_distributed::{
26+
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
27+
DistributedSessionBuilderContext,
28+
};
29+
use hyper_util::rt::TokioIo;
30+
use tonic::transport::{Endpoint, Server};
31+
32+
const DUMMY_URL: &str = "http://localhost:50051";
33+
34+
/// [ChannelResolver] implementation that returns gRPC clients baked by an in-memory
35+
/// tokio duplex rather than a TCP connection.
36+
#[derive(Clone)]
37+
pub(crate) struct InMemoryChannelResolver {
38+
channel: FlightServiceClient<BoxCloneSyncChannel>,
39+
}
40+
41+
impl InMemoryChannelResolver {
42+
pub fn new() -> Self {
43+
let (client, server) = tokio::io::duplex(1024 * 1024);
44+
45+
let mut client = Some(client);
46+
let channel = Endpoint::try_from(DUMMY_URL)
47+
.expect(
48+
"Invalid dummy URL for building an endpoint. This should never happen",
49+
)
50+
.connect_with_connector_lazy(tower::service_fn(move |_| {
51+
let client = client
52+
.take()
53+
.expect("Client taken twice. This should never happen");
54+
async move { Ok::<_, std::io::Error>(TokioIo::new(client)) }
55+
}));
56+
57+
let this = Self {
58+
channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)),
59+
};
60+
let this_clone = this.clone();
61+
62+
let endpoint =
63+
ArrowFlightEndpoint::try_new(move |ctx: DistributedSessionBuilderContext| {
64+
let this = this.clone();
65+
async move {
66+
let builder = SessionStateBuilder::new()
67+
.with_default_features()
68+
.with_distributed_channel_resolver(this)
69+
.with_runtime_env(ctx.runtime_env.clone());
70+
Ok(builder.build())
71+
}
72+
})
73+
.unwrap();
74+
75+
tokio::spawn(async move {
76+
Server::builder()
77+
.add_service(FlightServiceServer::new(endpoint))
78+
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
79+
.await
80+
});
81+
82+
this_clone
83+
}
84+
}
85+
86+
#[async_trait]
87+
impl ChannelResolver for InMemoryChannelResolver {
88+
fn get_urls(&self) -> Result<Vec<url::Url>, DataFusionError> {
89+
Ok(vec![url::Url::parse(DUMMY_URL).unwrap()])
90+
}
91+
92+
async fn get_flight_client_for_url(
93+
&self,
94+
_: &url::Url,
95+
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
96+
Ok(self.channel.clone())
97+
}
98+
}

integration-utils/src/lib.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,17 @@ use datafusion::{
6161
error::Result, execution::SessionStateBuilder,
6262
physical_optimizer::PhysicalOptimizerRule, prelude::*,
6363
};
64+
use datafusion_distributed::{DistributedExt, DistributedPhysicalOptimizerRule};
6465
use datafusion_tracing::{
6566
InstrumentationOptions, instrument_with_info_spans, pretty_format_compact_batch,
6667
};
6768
use instrumented_object_store::instrument_object_store;
6869
use tracing::{field, info, instrument};
6970
use url::Url;
7071

72+
mod channel_resolver;
73+
use channel_resolver::InMemoryChannelResolver;
74+
7175
/// Executes the SQL query with instrumentation enabled, providing detailed tracing output.
7276
#[instrument(level = "info", skip(ctx))]
7377
pub async fn run_traced_query(ctx: &SessionContext, query_name: &str) -> Result<()> {
@@ -95,17 +99,23 @@ pub async fn init_session(
9599
record_metrics: bool,
96100
preview_limit: usize,
97101
compact_preview: bool,
102+
distributed: bool,
98103
) -> Result<SessionContext> {
99104
// Configure the session state with instrumentation for query execution.
100-
let session_state = SessionStateBuilder::new()
105+
let mut session_state_builder = SessionStateBuilder::new()
101106
.with_default_features()
102-
.with_config(SessionConfig::default().with_target_partitions(8)) // Enforce target partitions to ensure consistent test results regardless of the number of CPU cores.
103-
.with_physical_optimizer_rule(create_instrumentation_rule(
104-
record_metrics,
105-
preview_limit,
106-
compact_preview,
107-
))
108-
.build();
107+
.with_config(SessionConfig::default().with_target_partitions(8)); // Enforce target partitions to ensure consistent test results regardless of the number of CPU cores.
108+
if distributed {
109+
session_state_builder = session_state_builder
110+
.with_distributed_channel_resolver(InMemoryChannelResolver::new())
111+
.with_physical_optimizer_rule(Arc::new(
112+
DistributedPhysicalOptimizerRule::new(),
113+
));
114+
}
115+
session_state_builder = session_state_builder.with_physical_optimizer_rule(
116+
create_instrumentation_rule(record_metrics, preview_limit, compact_preview),
117+
);
118+
let session_state = session_state_builder.build();
109119

110120
let ctx = SessionContext::new_with_state(session_state);
111121

@@ -153,16 +163,24 @@ pub fn create_instrumentation_rule(
153163
)
154164
}
155165

156-
/// Returns the path to the directory containing the TPCH Parquet tables.
157-
pub fn tpch_tables_dir() -> PathBuf {
166+
/// Returns the path to the directory containing the Parquet tables.
167+
pub fn data_dir() -> PathBuf {
158168
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("data")
159169
}
160170

161171
/// Registers all TPCH Parquet tables required for executing the queries.
162172
#[instrument(level = "info", skip(ctx))]
163173
async fn register_tpch_tables(ctx: &SessionContext) -> Result<()> {
164174
// Construct the path to the directory containing Parquet data.
165-
let data_dir = tpch_tables_dir();
175+
let data_dir = data_dir();
176+
177+
// Register the weather table.
178+
ctx.register_parquet(
179+
"weather",
180+
data_dir.join("weather").to_string_lossy(),
181+
ParquetReadOptions::default(),
182+
)
183+
.await?;
166184

167185
// Generate and register each table from Parquet files.
168186
// This includes all standard TPCH tables so examples/tests can rely on them.

tests/integration_tests.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ struct QueryTestCase<'a> {
5656
ignored_preview_spans: &'a [usize],
5757
/// Whether to ignore the full trace in assertions.
5858
ignore_full_trace: bool,
59+
/// Whether to run the test in distributed mode.
60+
distributed: bool,
5961
}
6062

6163
impl<'a> QueryTestCase<'a> {
@@ -95,6 +97,11 @@ impl<'a> QueryTestCase<'a> {
9597
self.ignore_full_trace = true;
9698
self
9799
}
100+
101+
fn distributed(mut self) -> Self {
102+
self.distributed = true;
103+
self
104+
}
98105
}
99106

100107
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
@@ -196,6 +203,20 @@ async fn test_topk_lineitem() -> Result<()> {
196203
execute_test_case("10_topk_lineitem", &QueryTestCase::new("topk_lineitem")).await
197204
}
198205

206+
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
207+
async fn test_weather() -> Result<()> {
208+
execute_test_case("11_weather", &QueryTestCase::new("weather")).await
209+
}
210+
211+
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
212+
async fn test_distributed_weather() -> Result<()> {
213+
execute_test_case(
214+
"12_distributed_weather",
215+
&QueryTestCase::new("weather").distributed(),
216+
)
217+
.await
218+
}
219+
199220
/// Executes the provided [`QueryTestCase`], setting up tracing and verifying
200221
/// log output according to its parameters.
201222
async fn execute_test_case(test_name: &str, test_case: &QueryTestCase<'_>) -> Result<()> {
@@ -208,6 +229,7 @@ async fn execute_test_case(test_name: &str, test_case: &QueryTestCase<'_>) -> Re
208229
test_case.should_record_metrics,
209230
test_case.row_limit,
210231
test_case.use_compact_preview,
232+
test_case.distributed,
211233
)
212234
.await?;
213235

0 commit comments

Comments
 (0)