From 4abb2f58cc017bd561ee593f0a6c872eadca7807 Mon Sep 17 00:00:00 2001 From: Matthew Patton Date: Tue, 23 Jun 2026 13:42:16 -0400 Subject: [PATCH] refactor: Migrate ScalarSubqueryExpr to self-serialization proto pattern Move ScalarSubqueryExpr proto encode/decode into the expression itself via the try_to_proto / try_from_proto hooks, removing the centralized downcast branch in to_proto.rs and the inline construction in from_proto.rs. Because the ScalarSubqueryResults container is runtime-only shared state and not part of the wire format, from_proto.rs still fetches it from the decode context and threads it into try_from_proto. Follows the pattern established in #21929. --- .../physical-expr/src/scalar_subquery.rs | 191 ++++++++++++++++-- .../proto/src/physical_plan/from_proto.rs | 17 +- .../proto/src/physical_plan/to_proto.rs | 12 -- 3 files changed, 180 insertions(+), 40 deletions(-) diff --git a/datafusion/physical-expr/src/scalar_subquery.rs b/datafusion/physical-expr/src/scalar_subquery.rs index ea00847151e66..40702f5095ba0 100644 --- a/datafusion/physical-expr/src/scalar_subquery.rs +++ b/datafusion/physical-expr/src/scalar_subquery.rs @@ -59,19 +59,6 @@ impl ScalarSubqueryExpr { } } - pub fn data_type(&self) -> &DataType { - &self.data_type - } - - pub fn nullable(&self) -> bool { - self.nullable - } - - /// Returns the index of this subquery in the shared results container. - pub fn index(&self) -> SubqueryIndex { - self.index - } - pub fn results(&self) -> &ScalarSubqueryResults { &self.results } @@ -139,6 +126,64 @@ impl PhysicalExpr for ScalarSubqueryExpr { fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(scalar subquery)") } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarSubquery( + protobuf::PhysicalScalarSubqueryExprNode { + data_type: Some((&self.data_type).try_into()?), + nullable: self.nullable, + index: self.index.as_usize() as u32, + }, + )), + })) + } +} + +#[cfg(feature = "proto")] +impl ScalarSubqueryExpr { + /// Reconstruct a [`ScalarSubqueryExpr`] from its protobuf representation. + /// + /// Unlike other expressions, this takes a third argument: the shared + /// [`ScalarSubqueryResults`] container. That container is a runtime-only + /// `Arc` shared with the surrounding `ScalarSubqueryExec` and is not part of + /// the wire format, so it cannot be reconstructed here or carried on the + /// decode context (which lives in a crate that cannot depend on + /// `datafusion-expr`). The match arm in `from_proto.rs` fetches it from the + /// plan-level decode context and passes it in. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + results: &ScalarSubqueryResults, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_physical_expr_common::physical_expr::proto_decode::require_proto_field; + use datafusion_proto_models::protobuf; + + let sq = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::ScalarSubquery, + "ScalarSubqueryExpr", + ); + let data_type = require_proto_field( + sq.data_type.as_ref(), + "ScalarSubqueryExpr", + "data_type", + )? + .try_into()?; + Ok(Arc::new(ScalarSubqueryExpr::new( + data_type, + sq.nullable, + SubqueryIndex::new(sq.index as usize), + results.clone(), + ))) + } } #[cfg(test)] @@ -238,3 +283,123 @@ mod tests { assert_ne!(e1a, e3); } } + +/// Tests for the `try_to_proto` / `try_from_proto` hooks. +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::proto_test_util::{StubEncoder, UnreachableDecoder, column_node}; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalScalarSubqueryExprNode, physical_expr_node, + }; + + /// Build a `ScalarSubquery` proto node directly, with control over each + /// field, so the decode error paths can be exercised independently. + fn proto_scalar_subquery_node( + data_type: Option, + nullable: bool, + index: u32, + ) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::ScalarSubquery( + PhysicalScalarSubqueryExprNode { + data_type, + nullable, + index, + }, + )), + } + } + + #[test] + fn round_trips_through_proto() { + // A three-slot results container so index 2 is meaningful. + let results = ScalarSubqueryResults::new(3); + let expr = ScalarSubqueryExpr::new( + DataType::Int32, + true, + SubqueryIndex::new(2), + results.clone(), + ); + + // Encode: the expression serializes itself via try_to_proto. + let encoder = StubEncoder::ok(); + let enc_ctx = PhysicalExprEncodeCtx::new(&encoder); + let node = expr + .try_to_proto(&enc_ctx) + .unwrap() + .expect("ScalarSubqueryExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let sq = match &node.expr_type { + Some(physical_expr_node::ExprType::ScalarSubquery(sq)) => sq, + other => panic!("expected a ScalarSubquery node, got {other:?}"), + }; + assert!(sq.nullable); + assert_eq!(sq.index, 2); + let encoded_type: DataType = sq + .data_type + .as_ref() + .expect("data_type encoded") + .try_into() + .unwrap(); + assert_eq!(encoded_type, DataType::Int32); + + // Decode: reconstruct from the proto node, threading in the shared + // results container the surrounding exec would provide. + let decoder = UnreachableDecoder; + let schema = Schema::empty(); + let dec_ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let decoded = + ScalarSubqueryExpr::try_from_proto(&node, &dec_ctx, &results).unwrap(); + let decoded = decoded + .downcast_ref::() + .expect("decoded expr should be a ScalarSubqueryExpr"); + + // data_type + nullable survive the round-trip (observed via return_field). + let field = decoded.return_field(&Schema::empty()).unwrap(); + assert_eq!(field.data_type(), &DataType::Int32); + assert!(field.is_nullable()); + + // Same shared container + same index → equal to the original. + assert_eq!(decoded, &expr); + } + + #[test] + fn rejects_non_scalar_subquery_node() { + let node = column_node("a"); + let results = ScalarSubqueryResults::new(1); + let decoder = UnreachableDecoder; + let schema = Schema::empty(); + let dec_ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = + ScalarSubqueryExpr::try_from_proto(&node, &dec_ctx, &results).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) + if msg.contains("PhysicalExprNode is not a ScalarSubqueryExpr") + )); + } + + #[test] + fn rejects_missing_data_type() { + let node = proto_scalar_subquery_node(None, false, 0); + let results = ScalarSubqueryResults::new(1); + let decoder = UnreachableDecoder; + let schema = Schema::empty(); + let dec_ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = + ScalarSubqueryExpr::try_from_proto(&node, &dec_ctx, &results).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) + if msg.contains("ScalarSubqueryExpr is missing required field 'data_type'") + )); + } +} diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 7bd1a3dd66d01..c96933e67e220 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -40,7 +40,6 @@ use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; use datafusion_expr::dml::InsertOp; -use datafusion_expr::execution_props::SubqueryIndex; use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; @@ -336,26 +335,14 @@ pub fn parse_physical_expr_with_converter( } ExprType::LikeExpr(_) => LikeExpr::try_from_proto(proto, &decode_ctx)?, ExprType::HashExpr(_) => HashExpr::try_from_proto(proto, &decode_ctx)?, - ExprType::ScalarSubquery(sq) => { - let data_type: arrow::datatypes::DataType = sq - .data_type - .as_ref() - .ok_or_else(|| { - proto_error("Missing data_type in PhysicalScalarSubqueryExprNode") - })? - .try_into()?; + ExprType::ScalarSubquery(_) => { let results = ctx.scalar_subquery_results().ok_or_else(|| { proto_error( "ScalarSubqueryExpr can only be deserialized as part \ of a surrounding ScalarSubqueryExec", ) })?; - Arc::new(ScalarSubqueryExpr::new( - data_type, - sq.nullable, - SubqueryIndex::new(sq.index as usize), - results.clone(), - )) + ScalarSubqueryExpr::try_from_proto(proto, &decode_ctx, results)? } ExprType::DynamicFilter(_) => { DynamicFilterPhysicalExpr::try_from_proto(proto, &decode_ctx)? diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 7310c0928eee4..a92d75547bfbc 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -32,7 +32,6 @@ use datafusion_datasource_json::file_format::JsonSink; use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_expr::WindowFrame; use datafusion_physical_expr::ScalarFunctionExpr; -use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::udaf::AggregateFunctionExpr; @@ -318,17 +317,6 @@ pub fn serialize_physical_expr_with_converter( }, )), }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_id, - expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarSubquery( - protobuf::PhysicalScalarSubqueryExprNode { - data_type: Some(expr.data_type().try_into()?), - nullable: expr.nullable(), - index: expr.index().as_usize() as u32, - }, - )), - }) } else { let mut buf: Vec = vec![]; match codec.try_encode_expr(value, &mut buf) {