Skip to content

[SPARK-52444][SQL][CONNECT] Add support for Variant/Char/Varchar Literal #51215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 69 additions & 63 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

116 changes: 116 additions & 0 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,95 @@ class Expression(google.protobuf.message.Message):
| None
): ...

class Variant(google.protobuf.message.Message):
"""Binary representation of a semi-structured value (Spark VariantVal).
The format follows Spark's internal VariantVal encoding:
- See org.apache.spark.unsafe.types.VariantVal for details.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

VALUE_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
value: builtins.bytes
"""Encodes the value's type and data (without field names)."""
metadata: builtins.bytes
"""Metadata contains version identifier and field name information."""
def __init__(
self,
*,
value: builtins.bytes = ...,
metadata: builtins.bytes = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal["metadata", b"metadata", "value", b"value"],
) -> None: ...

class Char(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

VALUE_FIELD_NUMBER: builtins.int
LENGTH_FIELD_NUMBER: builtins.int
value: builtins.str
length: builtins.int
"""The fixed length for this Char type.
- If omitted, uses the actual length of `value`.
- If provided but smaller than `value`'s length, an error will be thrown.
- If provided and larger than `value`'s length, the `value` will be right-padded with spaces.
"""
def __init__(
self,
*,
value: builtins.str = ...,
length: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal["_length", b"_length", "length", b"length"],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_length", b"_length", "length", b"length", "value", b"value"
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_length", b"_length"]
) -> typing_extensions.Literal["length"] | None: ...

class Varchar(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

VALUE_FIELD_NUMBER: builtins.int
LENGTH_FIELD_NUMBER: builtins.int
value: builtins.str
length: builtins.int
"""Specifies the maximum length for this Varchar type.
- If omitted, uses the actual length of `value`.
- If provided but smaller than `value`'s length, an error will be thrown.
- If provided and larger than `value`'s length, stores exact value without padding.
"""
def __init__(
self,
*,
value: builtins.str = ...,
length: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal["_length", b"_length", "length", b"length"],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_length", b"_length", "length", b"length", "value", b"value"
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_length", b"_length"]
) -> typing_extensions.Literal["length"] | None: ...

NULL_FIELD_NUMBER: builtins.int
BINARY_FIELD_NUMBER: builtins.int
BOOLEAN_FIELD_NUMBER: builtins.int
Expand All @@ -675,6 +764,9 @@ class Expression(google.protobuf.message.Message):
MAP_FIELD_NUMBER: builtins.int
STRUCT_FIELD_NUMBER: builtins.int
SPECIALIZED_ARRAY_FIELD_NUMBER: builtins.int
VARIANT_FIELD_NUMBER: builtins.int
CHAR_FIELD_NUMBER: builtins.int
VARCHAR_FIELD_NUMBER: builtins.int
@property
def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
binary: builtins.bytes
Expand Down Expand Up @@ -706,6 +798,12 @@ class Expression(google.protobuf.message.Message):
def struct(self) -> global___Expression.Literal.Struct: ...
@property
def specialized_array(self) -> global___Expression.Literal.SpecializedArray: ...
@property
def variant(self) -> global___Expression.Literal.Variant: ...
@property
def char(self) -> global___Expression.Literal.Char: ...
@property
def varchar(self) -> global___Expression.Literal.Varchar: ...
def __init__(
self,
*,
Expand All @@ -730,6 +828,9 @@ class Expression(google.protobuf.message.Message):
map: global___Expression.Literal.Map | None = ...,
struct: global___Expression.Literal.Struct | None = ...,
specialized_array: global___Expression.Literal.SpecializedArray | None = ...,
variant: global___Expression.Literal.Variant | None = ...,
char: global___Expression.Literal.Char | None = ...,
varchar: global___Expression.Literal.Varchar | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -744,6 +845,8 @@ class Expression(google.protobuf.message.Message):
b"byte",
"calendar_interval",
b"calendar_interval",
"char",
b"char",
"date",
b"date",
"day_time_interval",
Expand Down Expand Up @@ -776,6 +879,10 @@ class Expression(google.protobuf.message.Message):
b"timestamp",
"timestamp_ntz",
b"timestamp_ntz",
"varchar",
b"varchar",
"variant",
b"variant",
"year_month_interval",
b"year_month_interval",
],
Expand All @@ -793,6 +900,8 @@ class Expression(google.protobuf.message.Message):
b"byte",
"calendar_interval",
b"calendar_interval",
"char",
b"char",
"date",
b"date",
"day_time_interval",
Expand Down Expand Up @@ -825,6 +934,10 @@ class Expression(google.protobuf.message.Message):
b"timestamp",
"timestamp_ntz",
b"timestamp_ntz",
"varchar",
b"varchar",
"variant",
b"variant",
"year_month_interval",
b"year_month_interval",
],
Expand Down Expand Up @@ -854,6 +967,9 @@ class Expression(google.protobuf.message.Message):
"map",
"struct",
"specialized_array",
"variant",
"char",
"varchar",
]
| None
): ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.protobuf.{functions => pbFn}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
import org.apache.spark.util.SparkFileUtils

// scalastyle:off
Expand Down Expand Up @@ -3319,7 +3319,8 @@ class PlanGenerationTestSuite
fn.lit(java.sql.Date.valueOf("2023-02-23")),
fn.lit(java.time.Duration.ofSeconds(200L)),
fn.lit(java.time.Period.ofDays(100)),
fn.lit(new CalendarInterval(2, 20, 100L)))
fn.lit(new CalendarInterval(2, 20, 100L)),
fn.lit(new VariantVal(Array[Byte](1), Array[Byte](1))))
}

test("function lit array") {
Expand Down Expand Up @@ -3390,6 +3391,7 @@ class PlanGenerationTestSuite
fn.typedLit(java.time.Duration.ofSeconds(200L)),
fn.typedLit(java.time.Period.ofDays(100)),
fn.typedLit(new CalendarInterval(2, 20, 100L)),
fn.typedLit(new VariantVal(Array[Byte](1), Array[Byte](1))),

// Handle parameterized scala types e.g.: List, Seq and Map.
fn.typedLit(Some(1)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ message Expression {
Struct struct = 24;

SpecializedArray specialized_array = 25;
Variant variant = 26;
Char char = 27;
Varchar varchar = 28;
}

message Decimal {
Expand Down Expand Up @@ -240,6 +243,35 @@ message Expression {
Strings strings = 6;
}
}

// Binary representation of a semi-structured value (Spark VariantVal).
// The format follows Spark's internal VariantVal encoding:
// - See org.apache.spark.unsafe.types.VariantVal for details.
message Variant {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure to provide a reference to the format that is used here.

// Encodes the value's type and data (without field names).
bytes value = 1;

// Metadata contains version identifier and field name information.
bytes metadata = 2;
}

message Char {
string value = 1;
// The fixed length for this Char type.
// - If omitted, uses the actual length of `value`.
// - If provided but smaller than `value`'s length, an error will be thrown.
// - If provided and larger than `value`'s length, the `value` will be right-padded with spaces.
optional int32 length = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only needed when the length of the value and the intended data type do not match right? If so please document this.

}

message Varchar {
string value = 1;
// Specifies the maximum length for this Varchar type.
// - If omitted, uses the actual length of `value`.
// - If provided but smaller than `value`'s length, an error will be thrown.
// - If provided and larger than `value`'s length, stores exact value without padding.
optional int32 length = 2;
}
}

// An unresolved attribute that is not explicitly bound to a specific column, but the column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
import org.apache.spark.util.SparkClassUtils

object LiteralValueProtoConverter {
Expand Down Expand Up @@ -99,6 +99,11 @@ object LiteralValueProtoConverter {
case v: Array[_] => builder.setArray(arrayBuilder(v))
case v: CalendarInterval =>
builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds))
case v: VariantVal =>
builder.setVariant(
builder.getVariantBuilder
.setValue(ByteString.copyFrom(v.getValue))
.setMetadata(ByteString.copyFrom(v.getMetadata)))
case null => builder.setNull(ProtoDataTypes.NullType)
case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [id#0L, id#0L, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, ... 2 more fields]
Project [id#0L, id#0L, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, ... 3 more fields]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 18 more fields]
Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 19 more fields]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,30 @@
}
}
}
}, {
"literal": {
"variant": {
"value": "AQ==",
"metadata": "AQ=="
}
},
"common": {
"origin": {
"jvmOrigin": {
"stackTrace": [{
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.functions$",
"methodName": "lit",
"fileName": "functions.scala"
}, {
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
"methodName": "~~trimmed~anonfun~~",
"fileName": "PlanGenerationTestSuite.scala"
}]
}
}
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,30 @@
}
}
}
}, {
"literal": {
"variant": {
"value": "AQ==",
"metadata": "AQ=="
}
},
"common": {
"origin": {
"jvmOrigin": {
"stackTrace": [{
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.functions$",
"methodName": "typedLit",
"fileName": "functions.scala"
}, {
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
"methodName": "~~trimmed~anonfun~~",
"fileName": "PlanGenerationTestSuite.scala"
}]
}
}
}
}, {
"literal": {
"integer": 1
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}

object LiteralExpressionProtoConverter {

Expand Down Expand Up @@ -75,6 +75,27 @@ object LiteralExpressionProtoConverter {
case proto.Expression.Literal.LiteralTypeCase.STRING =>
expressions.Literal(UTF8String.fromString(lit.getString), StringType)

case proto.Expression.Literal.LiteralTypeCase.CHAR =>
var length = lit.getChar.getValue.length
if (lit.getChar.hasLength) {
length = lit.getChar.getLength
}
expressions.Literal(UTF8String.fromString(lit.getChar.getValue), CharType(length))

case proto.Expression.Literal.LiteralTypeCase.VARCHAR =>
var length = lit.getVarchar.getValue.length
if (lit.getVarchar.hasLength) {
length = lit.getVarchar.getLength
}
expressions.Literal(UTF8String.fromString(lit.getVarchar.getValue), VarcharType(length))

case proto.Expression.Literal.LiteralTypeCase.VARIANT =>
expressions.Literal(
new VariantVal(
lit.getVariant.getValue.toByteArray,
lit.getVariant.getMetadata.toByteArray),
VariantType)

case proto.Expression.Literal.LiteralTypeCase.DATE =>
expressions.Literal(lit.getDate, DateType)

Expand Down
Loading