Skip to content

Commit bf88d9b

Browse files
fix(pinot): restrict types in dialect (#35337)
1 parent d51b35f commit bf88d9b

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

superset/sql/dialects/pinot.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
from __future__ import annotations
2626

27+
from sqlglot import exp
2728
from sqlglot.dialects.mysql import MySQL
29+
from sqlglot.tokens import TokenType
2830

2931

3032
class Pinot(MySQL):
@@ -41,3 +43,55 @@ class Tokenizer(MySQL.Tokenizer):
4143
QUOTES = ["'"] # Only single quotes for strings
4244
IDENTIFIERS = ['"', "`"] # Backticks and double quotes for identifiers
4345
STRING_ESCAPES = ["'", "\\"] # Remove double quote from string escapes
46+
KEYWORDS = {
47+
**MySQL.Tokenizer.KEYWORDS,
48+
"STRING": TokenType.TEXT,
49+
"LONG": TokenType.BIGINT,
50+
"BYTES": TokenType.VARBINARY,
51+
}
52+
53+
class Generator(MySQL.Generator):
54+
TYPE_MAPPING = {
55+
**MySQL.Generator.TYPE_MAPPING,
56+
exp.DataType.Type.TINYINT: "INT",
57+
exp.DataType.Type.SMALLINT: "INT",
58+
exp.DataType.Type.INT: "INT",
59+
exp.DataType.Type.BIGINT: "LONG",
60+
exp.DataType.Type.FLOAT: "FLOAT",
61+
exp.DataType.Type.DOUBLE: "DOUBLE",
62+
exp.DataType.Type.BOOLEAN: "BOOLEAN",
63+
exp.DataType.Type.TIMESTAMP: "TIMESTAMP",
64+
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
65+
exp.DataType.Type.VARCHAR: "STRING",
66+
exp.DataType.Type.CHAR: "STRING",
67+
exp.DataType.Type.TEXT: "STRING",
68+
exp.DataType.Type.BINARY: "BYTES",
69+
exp.DataType.Type.VARBINARY: "BYTES",
70+
exp.DataType.Type.JSON: "JSON",
71+
}
72+
73+
# Override MySQL's CAST_MAPPING - don't convert integer or string types
74+
CAST_MAPPING = {
75+
exp.DataType.Type.LONGBLOB: exp.DataType.Type.VARBINARY,
76+
exp.DataType.Type.MEDIUMBLOB: exp.DataType.Type.VARBINARY,
77+
exp.DataType.Type.TINYBLOB: exp.DataType.Type.VARBINARY,
78+
exp.DataType.Type.UBIGINT: "UNSIGNED",
79+
}
80+
81+
def datatype_sql(self, expression: exp.DataType) -> str:
82+
# Don't use MySQL's VARCHAR size requirement logic
83+
# Just use TYPE_MAPPING for all types
84+
type_value = expression.this
85+
type_sql = (
86+
self.TYPE_MAPPING.get(type_value, type_value.value)
87+
if isinstance(type_value, exp.DataType.Type)
88+
else type_value
89+
)
90+
91+
interior = self.expressions(expression, flat=True)
92+
nested = f"({interior})" if interior else ""
93+
94+
if expression.this in self.UNSIGNED_TYPE_MAPPING:
95+
return f"{type_sql} UNSIGNED{nested}"
96+
97+
return f"{type_sql}{nested}"

tests/unit_tests/sql/dialects/pinot_tests.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,78 @@ def test_distinct() -> None:
346346
FROM "products"
347347
""".strip()
348348
)
349+
350+
351+
def test_cast_to_string() -> None:
352+
"""
353+
Test that CAST to STRING is preserved (not converted to CHAR).
354+
"""
355+
sql = "SELECT CAST(cohort_size AS STRING) FROM table"
356+
ast = sqlglot.parse_one(sql, Pinot)
357+
generated = Pinot().generate(expression=ast)
358+
359+
assert "STRING" in generated
360+
assert "CHAR" not in generated
361+
362+
363+
def test_concat_with_cast_string() -> None:
364+
"""
365+
Test CONCAT with CAST to STRING - verifies the original issue is fixed.
366+
"""
367+
sql = """
368+
SELECT concat(a, cast(b AS string), ' - ')
369+
FROM "default".c"""
370+
ast = sqlglot.parse_one(sql, Pinot)
371+
generated = Pinot().generate(expression=ast)
372+
373+
# Verify STRING type is preserved (not converted to CHAR)
374+
assert "STRING" in generated or "string" in generated.lower()
375+
assert "CHAR" not in generated
376+
377+
378+
@pytest.mark.parametrize(
379+
"cast_type, expected_type",
380+
[
381+
("INT", "INT"),
382+
("TINYINT", "INT"),
383+
("SMALLINT", "INT"),
384+
("BIGINT", "LONG"),
385+
("LONG", "LONG"),
386+
("FLOAT", "FLOAT"),
387+
("DOUBLE", "DOUBLE"),
388+
("BOOLEAN", "BOOLEAN"),
389+
("TIMESTAMP", "TIMESTAMP"),
390+
("STRING", "STRING"),
391+
("VARCHAR", "STRING"),
392+
("CHAR", "STRING"),
393+
("TEXT", "STRING"),
394+
("BYTES", "BYTES"),
395+
("BINARY", "BYTES"),
396+
("VARBINARY", "BYTES"),
397+
("JSON", "JSON"),
398+
],
399+
)
400+
def test_type_mappings(cast_type: str, expected_type: str) -> None:
401+
"""
402+
Test that Pinot type mappings work correctly for all basic types.
403+
"""
404+
sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608
405+
ast = sqlglot.parse_one(sql, Pinot)
406+
generated = Pinot().generate(expression=ast)
407+
408+
assert expected_type in generated
409+
410+
411+
def test_unsigned_type() -> None:
412+
"""
413+
Test that unsigned integer types are handled correctly.
414+
Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method.
415+
"""
416+
from sqlglot import exp
417+
418+
# Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING
419+
dt = exp.DataType(this=exp.DataType.Type.UBIGINT)
420+
result = Pinot.Generator().datatype_sql(dt)
421+
422+
assert "UNSIGNED" in result
423+
assert "BIGINT" in result

0 commit comments

Comments
 (0)