Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions crates/lance-graph-python/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,9 @@ impl CypherQuery {
/// ----------
/// datasets : dict
/// Dictionary mapping table names to Lance datasets
/// dialect : str, optional
/// SQL dialect to use. One of "default", "spark", "postgresql", "mysql", "sqlite".
/// Defaults to "default" (generic DataFusion SQL).
///
/// Returns
/// -------
Expand All @@ -504,7 +507,29 @@ impl CypherQuery {
/// ------
/// RuntimeError
/// If SQL generation fails
fn to_sql(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
/// ValueError
/// If an invalid dialect is specified
#[pyo3(signature = (datasets, dialect=None))]
fn to_sql(
&self,
py: Python,
datasets: &Bound<'_, PyDict>,
dialect: Option<&str>,
Comment thread
yuchen-pipi marked this conversation as resolved.
Outdated
) -> PyResult<String> {
let sql_dialect = match dialect {
None | Some("default") => None,
Some("spark") => Some(lance_graph::SqlDialect::Spark),
Some("postgresql") | Some("postgres") => Some(lance_graph::SqlDialect::PostgreSql),
Some("mysql") => Some(lance_graph::SqlDialect::MySql),
Some("sqlite") => Some(lance_graph::SqlDialect::Sqlite),
Some(other) => {
return Err(PyValueError::new_err(format!(
"Unknown SQL dialect: '{}'. Valid options: 'default', 'spark', 'postgresql', 'mysql', 'sqlite'",
other
)));
}
};

// Convert datasets to Arrow RecordBatch map
let arrow_datasets = python_datasets_to_batches(datasets)?;

Expand All @@ -513,7 +538,7 @@ impl CypherQuery {

// Execute via runtime
let sql = RT
.block_on(Some(py), inner_query.to_sql(arrow_datasets))?
.block_on(Some(py), inner_query.to_sql(arrow_datasets, sql_dialect))?
.map_err(graph_error_to_pyerr)?;

Ok(sql)
Expand Down
2 changes: 2 additions & 0 deletions crates/lance-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
pub mod parser;
pub mod query;
pub mod semantic;
pub mod spark_dialect;
pub mod sql_catalog;
pub mod sql_query;
pub mod table_readers;
Expand All @@ -65,9 +66,10 @@
DataSourceFormat, SchemaInfo, TableInfo, TableReader, TableType,
};
#[cfg(feature = "unity-catalog")]
pub use lance_graph_catalog::{UnityCatalogConfig, UnityCatalogProvider};

Check warning on line 69 in crates/lance-graph/src/lib.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance-graph/lance-graph/crates/lance-graph/src/lib.rs
pub use lance_vector_search::VectorSearch;
pub use query::{CypherQuery, ExecutionStrategy};
pub use spark_dialect::{SqlDialect, SparkDialect};
pub use sql_query::SqlQuery;
#[cfg(feature = "delta")]
pub use table_readers::DeltaTableReader;
Expand Down
29 changes: 19 additions & 10 deletions crates/lance-graph/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,10 @@
self.explain_internal(Arc::new(catalog), ctx).await
}

/// Convert the Cypher query to a DataFusion SQL string
/// Convert the Cypher query to a SQL string in the specified dialect.
///
/// This method generates a SQL string that corresponds to the DataFusion logical plan
/// derived from the Cypher query. It uses the `datafusion-sql` unparser.
/// derived from the Cypher query, using the specified SQL dialect for unparsing.
///
/// **WARNING**: This method is experimental and the generated SQL dialect may change.
///
Expand All @@ -293,16 +293,20 @@
///
/// # Arguments
/// * `datasets` - HashMap of table name to RecordBatch (nodes and relationships)
/// * `dialect` - The SQL dialect to use for generating the output SQL.
/// Defaults to `SqlDialect::Default` (generic DataFusion SQL).
/// Use `SqlDialect::Spark` for Spark SQL, `SqlDialect::PostgreSql`, etc.
///
/// # Returns
/// A SQL string representing the query
/// A SQL string representing the query in the specified dialect
pub async fn to_sql(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
dialect: Option<crate::spark_dialect::SqlDialect>,
) -> Result<String> {
use datafusion_sql::unparser::plan_to_sql;
use std::sync::Arc;

let dialect = dialect.unwrap_or_default();
let _config = self.require_config()?;

// Build catalog and context from datasets using the helper
Expand All @@ -323,11 +327,16 @@
location: snafu::Location::new(file!(), line!(), column!()),
})?;

// Unparse to SQL
let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError {
message: format!("Failed to unparse plan to SQL: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
// Unparse to SQL using the specified dialect

Check warning on line 330 in crates/lance-graph/src/query.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance-graph/lance-graph/crates/lance-graph/src/query.rs
let dialect_unparser = dialect.unparser();
let unparser = dialect_unparser.as_unparser();
let sql_ast =
unparser
.plan_to_sql(&optimized_plan)
.map_err(|e| GraphError::PlanError {
message: format!("Failed to unparse plan to SQL: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;

Ok(sql_ast.to_string())
}
Expand Down Expand Up @@ -1852,7 +1861,7 @@
.unwrap()
.with_config(cfg);

let sql = query.to_sql(datasets).await.unwrap();
let sql = query.to_sql(datasets, None).await.unwrap();
println!("Generated SQL: {}", sql);

assert!(sql.contains("SELECT"));
Expand Down
218 changes: 218 additions & 0 deletions crates/lance-graph/src/spark_dialect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// SPDX-License-Identifier: Apache-2.0
Comment thread
yuchen-pipi marked this conversation as resolved.
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! SQL dialect support for the DataFusion unparser.
//!
//! This module provides a [`SqlDialect`] enum for selecting which SQL dialect
//! to use when unparsing DataFusion logical plans to SQL strings, and includes
//! a [`SparkDialect`] implementation for Spark SQL.
//!
//! Key Spark SQL differences from standard SQL:
//! - Backtick (`` ` ``) identifier quoting
//! - `EXTRACT(field FROM expr)` for date field extraction
//! - `STRING` type for casting (not `VARCHAR`)
//! - `BIGINT`/`INT` for integer types
//! - `TIMESTAMP` for all timestamp types (no timezone info in cast)
//! - `LENGTH()` instead of `CHARACTER_LENGTH()`
//! - Subqueries in FROM require aliases

use std::sync::Arc;

use arrow::datatypes::TimeUnit;

Check warning on line 21 in crates/lance-graph/src/spark_dialect.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance-graph/lance-graph/crates/lance-graph/src/spark_dialect.rs
use datafusion_common::Result;
use datafusion_expr::Expr;
use datafusion_sql::unparser::dialect::{
CharacterLengthStyle, DateFieldExtractStyle, DefaultDialect, Dialect, IntervalStyle,
MySqlDialect, PostgreSqlDialect, SqliteDialect,
};

Check warning on line 27 in crates/lance-graph/src/spark_dialect.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance-graph/lance-graph/crates/lance-graph/src/spark_dialect.rs
use datafusion_sql::unparser::Unparser;
use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo};

/// SQL dialect to use when generating SQL from Cypher queries.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SqlDialect {
Comment thread
yuchen-pipi marked this conversation as resolved.
Outdated
/// Generic SQL (DataFusion default dialect)
#[default]
Default,
/// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.)
Spark,
/// PostgreSQL dialect
PostgreSql,
/// MySQL dialect
MySql,
/// SQLite dialect
Sqlite,
}

impl SqlDialect {
/// Create a DataFusion `Unparser` configured for this dialect.
pub fn unparser(&self) -> DialectUnparser {
match self {
SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}),
SqlDialect::Spark => DialectUnparser::Spark(SparkDialect),
SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}),
SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}),
SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}),
}
}
}

/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference.
pub enum DialectUnparser {
Default(DefaultDialect),
Spark(SparkDialect),
PostgreSql(PostgreSqlDialect),
MySql(MySqlDialect),
Sqlite(SqliteDialect),
}

impl DialectUnparser {
pub fn as_unparser(&self) -> Unparser<'_> {
match self {
DialectUnparser::Default(d) => Unparser::new(d),
DialectUnparser::Spark(d) => Unparser::new(d),
DialectUnparser::PostgreSql(d) => Unparser::new(d),
DialectUnparser::MySql(d) => Unparser::new(d),
DialectUnparser::Sqlite(d) => Unparser::new(d),
}
}
}

/// A Spark SQL dialect for unparsing DataFusion logical plans to Spark-compatible SQL.
pub struct SparkDialect;

impl Dialect for SparkDialect {
fn identifier_quote_style(&self, _identifier: &str) -> Option<char> {
Some('`')
}

fn supports_nulls_first_in_sort(&self) -> bool {
true
}

fn use_timestamp_for_date64(&self) -> bool {
true
}

fn interval_style(&self) -> IntervalStyle {
IntervalStyle::SQLStandard
}

fn float64_ast_dtype(&self) -> ast::DataType {
ast::DataType::Double(ast::ExactNumberInfo::None)
}

Check warning on line 103 in crates/lance-graph/src/spark_dialect.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance-graph/lance-graph/crates/lance-graph/src/spark_dialect.rs

fn utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Custom(
ObjectName::from(vec![Ident::new("STRING")]),
vec![],
)
}

fn large_utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Custom(

Check warning on line 113 in crates/lance-graph/src/spark_dialect.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance-graph/lance-graph/crates/lance-graph/src/spark_dialect.rs
ObjectName::from(vec![Ident::new("STRING")]),
vec![],
)
}

fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::Extract
}

fn character_length_style(&self) -> CharacterLengthStyle {
CharacterLengthStyle::Length
}

fn int64_cast_dtype(&self) -> ast::DataType {
ast::DataType::BigInt(None)
}

fn int32_cast_dtype(&self) -> ast::DataType {
ast::DataType::Int(None)

Check warning on line 132 in crates/lance-graph/src/spark_dialect.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance-graph/lance-graph/crates/lance-graph/src/spark_dialect.rs
}

fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
_tz: &Option<Arc<str>>,
) -> ast::DataType {
ast::DataType::Timestamp(None, TimezoneInfo::None)
}

fn date32_cast_dtype(&self) -> ast::DataType {
ast::DataType::Date
}

fn supports_column_alias_in_table_alias(&self) -> bool {
true
}

fn requires_derived_table_alias(&self) -> bool {
true
}

fn full_qualified_col(&self) -> bool {
false
}

fn unnest_as_table_factor(&self) -> bool {
false
}

fn scalar_function_to_sql_overrides(
&self,
_unparser: &Unparser,
_func_name: &str,
_args: &[Expr],
) -> Result<Option<ast::Expr>> {
// character_length -> length is handled by CharacterLengthStyle::Length
// Additional Spark-specific function mappings can be added here as needed
Ok(None)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_spark_dialect_identifier_quoting() {
let dialect = SparkDialect;
assert_eq!(dialect.identifier_quote_style("table_name"), Some('`'));
assert_eq!(dialect.identifier_quote_style("column"), Some('`'));
}

#[test]

Check warning on line 186 in crates/lance-graph/src/spark_dialect.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance-graph/lance-graph/crates/lance-graph/src/spark_dialect.rs
fn test_spark_dialect_type_mappings() {
let dialect = SparkDialect;
assert!(matches!(dialect.utf8_cast_dtype(), ast::DataType::Custom(..)));
assert!(matches!(dialect.int64_cast_dtype(), ast::DataType::BigInt(None)));
assert!(matches!(dialect.int32_cast_dtype(), ast::DataType::Int(None)));
assert!(matches!(dialect.date32_cast_dtype(), ast::DataType::Date));
}

#[test]
fn test_spark_dialect_requires_derived_table_alias() {
let dialect = SparkDialect;
assert!(dialect.requires_derived_table_alias());
}

#[test]
fn test_spark_dialect_extract_style() {
let dialect = SparkDialect;
assert!(matches!(
dialect.date_field_extract_style(),
DateFieldExtractStyle::Extract
));
}

#[test]
fn test_spark_dialect_character_length_style() {
let dialect = SparkDialect;
assert!(matches!(
dialect.character_length_style(),
CharacterLengthStyle::Length
));
}
}
Loading
Loading