Skip to content

Commit

Permalink
[Bug Fix] Allow HiveCatalog to create table with TimestamptzType (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
HonahX authored Apr 8, 2024
1 parent 1016b19 commit 07442cc
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 31 deletions.
9 changes: 9 additions & 0 deletions mkdocs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ catalog:
s3.secret-access-key: password
```

When using Hive 2.x, make sure to set the compatibility flag:

```yaml
catalog:
default:
...
hive.hive2-compatible: true
```

## Glue Catalog

Your AWS credentials can be passed directly through the Python API.
Expand Down
5 changes: 3 additions & 2 deletions pyiceberg/catalog/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from pyiceberg.table import (
CommitTableRequest,
CommitTableResponse,
PropertyUtil,
Table,
update_table_metadata,
)
Expand Down Expand Up @@ -162,7 +163,7 @@ def primitive(self, primitive: PrimitiveType) -> str:
if isinstance(primitive, DecimalType):
return f"decimal({primitive.precision},{primitive.scale})"
if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES:
return str(primitive_type.root)
return str(primitive)
return GLUE_PRIMITIVE_TYPES[primitive_type]


Expand Down Expand Up @@ -344,7 +345,7 @@ def _update_glue_table(self, database_name: str, table_name: str, table_input: T
self.glue.update_table(
DatabaseName=database_name,
TableInput=table_input,
SkipArchive=self.properties.get(GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT),
SkipArchive=PropertyUtil.property_as_bool(self.properties, GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT),
VersionId=version_id,
)
except self.glue.exceptions.EntityNotFoundException as e:
Expand Down
45 changes: 25 additions & 20 deletions pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, TableProperties, update_table_metadata
from pyiceberg.table import CommitTableRequest, CommitTableResponse, PropertyUtil, Table, TableProperties, update_table_metadata
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
Expand All @@ -95,6 +95,7 @@
StringType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)
Expand All @@ -103,25 +104,13 @@
import pyarrow as pa


# Replace by visitor
hive_types = {
BooleanType: "boolean",
IntegerType: "int",
LongType: "bigint",
FloatType: "float",
DoubleType: "double",
DateType: "date",
TimeType: "string",
TimestampType: "timestamp",
StringType: "string",
UUIDType: "string",
BinaryType: "binary",
FixedType: "binary",
}

COMMENT = "comment"
OWNER = "owner"

# If set to true, HiveCatalog will operate in Hive2 compatibility mode
HIVE2_COMPATIBLE = "hive.hive2-compatible"
HIVE2_COMPATIBLE_DEFAULT = False


class _HiveClient:
"""Helper class to nicely open and close the transport."""
Expand Down Expand Up @@ -151,10 +140,15 @@ def __exit__(
self._transport.close()


def _construct_hive_storage_descriptor(schema: Schema, location: Optional[str]) -> StorageDescriptor:
def _construct_hive_storage_descriptor(
schema: Schema, location: Optional[str], hive2_compatible: bool = False
) -> StorageDescriptor:
ser_de_info = SerDeInfo(serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
return StorageDescriptor(
[FieldSchema(field.name, visit(field.field_type, SchemaToHiveConverter()), field.doc) for field in schema.fields],
[
FieldSchema(field.name, visit(field.field_type, SchemaToHiveConverter(hive2_compatible)), field.doc)
for field in schema.fields
],
location,
"org.apache.hadoop.mapred.FileInputFormat",
"org.apache.hadoop.mapred.FileOutputFormat",
Expand Down Expand Up @@ -199,6 +193,7 @@ def _annotate_namespace(database: HiveDatabase, properties: Properties) -> HiveD
DateType: "date",
TimeType: "string",
TimestampType: "timestamp",
TimestamptzType: "timestamp with local time zone",
StringType: "string",
UUIDType: "string",
BinaryType: "binary",
Expand All @@ -207,6 +202,11 @@ def _annotate_namespace(database: HiveDatabase, properties: Properties) -> HiveD


class SchemaToHiveConverter(SchemaVisitor[str]):
hive2_compatible: bool

def __init__(self, hive2_compatible: bool):
self.hive2_compatible = hive2_compatible

def schema(self, schema: Schema, struct_result: str) -> str:
return struct_result

Expand All @@ -226,6 +226,9 @@ def map(self, map_type: MapType, key_result: str, value_result: str) -> str:
def primitive(self, primitive: PrimitiveType) -> str:
if isinstance(primitive, DecimalType):
return f"decimal({primitive.precision},{primitive.scale})"
elif self.hive2_compatible and isinstance(primitive, TimestamptzType):
# Hive2 doesn't support timestamp with local time zone
return "timestamp"
else:
return HIVE_PRIMITIVE_TYPES[type(primitive)]

Expand Down Expand Up @@ -314,7 +317,9 @@ def create_table(
owner=properties[OWNER] if properties and OWNER in properties else getpass.getuser(),
createTime=current_time_millis // 1000,
lastAccessTime=current_time_millis // 1000,
sd=_construct_hive_storage_descriptor(schema, location),
sd=_construct_hive_storage_descriptor(
schema, location, PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT)
),
tableType=EXTERNAL_TABLE,
parameters=_construct_parameters(metadata_location),
)
Expand Down
6 changes: 6 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def property_as_int(properties: Dict[str, str], property_name: str, default: Opt
else:
return default

@staticmethod
def property_as_bool(properties: Dict[str, str], property_name: str, default: bool) -> bool:
if value := properties.get(property_name):
return value.lower() == "true"
return default


class Transaction:
_table: Table
Expand Down
88 changes: 79 additions & 9 deletions tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,24 @@
from pyiceberg.transforms import BucketTransform, IdentityTransform
from pyiceberg.typedef import UTF8
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
FixedType,
FloatType,
IntegerType,
ListType,
LongType,
MapType,
NestedField,
StringType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)

HIVE_CATALOG_NAME = "hive"
Expand Down Expand Up @@ -181,15 +194,20 @@ def test_check_number_of_namespaces(table_schema_simple: Schema) -> None:
catalog.create_table("table", schema=table_schema_simple)


@pytest.mark.parametrize("hive2_compatible", [True, False])
@patch("time.time", MagicMock(return_value=12345))
def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
def test_create_table(
table_schema_with_all_types: Schema, hive_database: HiveDatabase, hive_table: HiveTable, hive2_compatible: bool
) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
if hive2_compatible:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL, **{"hive.hive2-compatible": "true"})

catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg"})
catalog.create_table(("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"})

called_hive_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
# This one is generated within the function itself, so we need to extract
Expand All @@ -207,9 +225,27 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
retention=None,
sd=StorageDescriptor(
cols=[
FieldSchema(name="foo", type="string", comment=None),
FieldSchema(name="bar", type="int", comment=None),
FieldSchema(name="baz", type="boolean", comment=None),
FieldSchema(name='boolean', type='boolean', comment=None),
FieldSchema(name='integer', type='int', comment=None),
FieldSchema(name='long', type='bigint', comment=None),
FieldSchema(name='float', type='float', comment=None),
FieldSchema(name='double', type='double', comment=None),
FieldSchema(name='decimal', type='decimal(32,3)', comment=None),
FieldSchema(name='date', type='date', comment=None),
FieldSchema(name='time', type='string', comment=None),
FieldSchema(name='timestamp', type='timestamp', comment=None),
FieldSchema(
name='timestamptz',
type='timestamp' if hive2_compatible else 'timestamp with local time zone',
comment=None,
),
FieldSchema(name='string', type='string', comment=None),
FieldSchema(name='uuid', type='string', comment=None),
FieldSchema(name='fixed', type='binary', comment=None),
FieldSchema(name='binary', type='binary', comment=None),
FieldSchema(name='list', type='array<string>', comment=None),
FieldSchema(name='map', type='map<string,int>', comment=None),
FieldSchema(name='struct', type='struct<inner_string:string,inner_int:int>', comment=None),
],
location=f"{hive_database.locationUri}/table",
inputFormat="org.apache.hadoop.mapred.FileInputFormat",
Expand Down Expand Up @@ -266,12 +302,46 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
location=metadata.location,
table_uuid=metadata.table_uuid,
last_updated_ms=metadata.last_updated_ms,
last_column_id=3,
last_column_id=22,
schemas=[
Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
NestedField(field_id=1, name='boolean', field_type=BooleanType(), required=True),
NestedField(field_id=2, name='integer', field_type=IntegerType(), required=True),
NestedField(field_id=3, name='long', field_type=LongType(), required=True),
NestedField(field_id=4, name='float', field_type=FloatType(), required=True),
NestedField(field_id=5, name='double', field_type=DoubleType(), required=True),
NestedField(field_id=6, name='decimal', field_type=DecimalType(precision=32, scale=3), required=True),
NestedField(field_id=7, name='date', field_type=DateType(), required=True),
NestedField(field_id=8, name='time', field_type=TimeType(), required=True),
NestedField(field_id=9, name='timestamp', field_type=TimestampType(), required=True),
NestedField(field_id=10, name='timestamptz', field_type=TimestamptzType(), required=True),
NestedField(field_id=11, name='string', field_type=StringType(), required=True),
NestedField(field_id=12, name='uuid', field_type=UUIDType(), required=True),
NestedField(field_id=13, name='fixed', field_type=FixedType(length=12), required=True),
NestedField(field_id=14, name='binary', field_type=BinaryType(), required=True),
NestedField(
field_id=15,
name='list',
field_type=ListType(type='list', element_id=18, element_type=StringType(), element_required=True),
required=True,
),
NestedField(
field_id=16,
name='map',
field_type=MapType(
type='map', key_id=19, key_type=StringType(), value_id=20, value_type=IntegerType(), value_required=True
),
required=True,
),
NestedField(
field_id=17,
name='struct',
field_type=StructType(
NestedField(field_id=21, name='inner_string', field_type=StringType(), required=False),
NestedField(field_id=22, name='inner_int', field_type=IntegerType(), required=True),
),
required=True,
),
schema_id=0,
identifier_field_ids=[2],
)
Expand Down
Loading

0 comments on commit 07442cc

Please sign in to comment.