Skip to content

Commit

Permalink
create_table with a PyArrow Schema (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy authored Jan 30, 2024
1 parent a3e3683 commit 02e6430
Show file tree
Hide file tree
Showing 17 changed files with 417 additions and 139 deletions.
19 changes: 19 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,25 @@ catalog.create_table(
)
```

To create a table using a pyarrow schema:

```python
import pyarrow as pa
schema = pa.schema(
[
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
]
)
catalog.create_table(
identifier="docs_example.bids",
schema=schema,
)
```

## Load a table

### Catalog table
Expand Down
22 changes: 21 additions & 1 deletion pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Expand Down Expand Up @@ -56,6 +57,9 @@
)
from pyiceberg.utils.config import Config, merge_config

if TYPE_CHECKING:
import pyarrow as pa

logger = logging.getLogger(__name__)

_ENV_CONFIG = Config()
Expand Down Expand Up @@ -288,7 +292,7 @@ def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand Down Expand Up @@ -512,6 +516,22 @@ def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> Non
if overlap:
raise ValueError(f"Updates and deletes have an overlap: {overlap}")

@staticmethod
def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema:
if isinstance(schema, Schema):
return schema
try:
import pyarrow as pa

from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow

if isinstance(schema, pa.Schema):
schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore
return schema
except ModuleNotFoundError:
pass
raise ValueError(f"{type(schema)=}, but it must be pyiceberg.schema.Schema or pyarrow.Schema")

def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str:
if not location:
return self._get_default_warehouse_location(database_name, table_name)
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/catalog/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import uuid
from time import time
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand Down Expand Up @@ -57,6 +58,9 @@
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT

if TYPE_CHECKING:
import pyarrow as pa

DYNAMODB_CLIENT = "dynamodb"

DYNAMODB_COL_IDENTIFIER = "identifier"
Expand Down Expand Up @@ -127,7 +131,7 @@ def _dynamodb_table_exists(self) -> bool:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand All @@ -152,6 +156,8 @@ def create_table(
ValueError: If the identifier is invalid, or no path is given to store metadata.
"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

database_name, table_name = self.identifier_to_database_and_table(identifier)

location = self._resolve_table_location(location, database_name, table_name)
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/catalog/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand Down Expand Up @@ -88,6 +89,9 @@
UUIDType,
)

if TYPE_CHECKING:
import pyarrow as pa

# If Glue should skip archiving an old table version when creating a new version in a commit. By
# default, Glue archives all old table versions after an UpdateTable call, but Glue has a default
# max number of archived table versions (can be increased). So for streaming use case with lots
Expand Down Expand Up @@ -329,7 +333,7 @@ def _get_glue_table(self, database_name: str, table_name: str) -> TableTypeDef:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand All @@ -354,6 +358,8 @@ def create_table(
ValueError: If the identifier is invalid, or no path is given to store metadata.
"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

database_name, table_name = self.identifier_to_database_and_table(identifier)

location = self._resolve_table_location(location, database_name, table_name)
Expand Down
9 changes: 8 additions & 1 deletion pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import time
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand Down Expand Up @@ -91,6 +92,10 @@
UUIDType,
)

if TYPE_CHECKING:
import pyarrow as pa


# Replace by visitor
hive_types = {
BooleanType: "boolean",
Expand Down Expand Up @@ -250,7 +255,7 @@ def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand All @@ -273,6 +278,8 @@ def create_table(
AlreadyExistsError: If a table with the name already exists.
ValueError: If the identifier is invalid.
"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

properties = {**DEFAULT_PROPERTIES, **properties}
database_name, table_name = self.identifier_to_database_and_table(identifier)
current_time_millis = int(time.time() * 1000)
Expand Down
6 changes: 5 additions & 1 deletion pyiceberg/catalog/noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import (
TYPE_CHECKING,
List,
Optional,
Set,
Expand All @@ -33,12 +34,15 @@
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties

if TYPE_CHECKING:
import pyarrow as pa


class NoopCatalog(Catalog):
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand Down Expand Up @@ -68,6 +69,9 @@
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel

if TYPE_CHECKING:
import pyarrow as pa

ICEBERG_REST_SPEC_VERSION = "0.14.1"


Expand Down Expand Up @@ -437,12 +441,14 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
properties: Properties = EMPTY_DICT,
) -> Table:
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

namespace_and_table = self._split_identifier_for_path(identifier)
request = CreateTableRequest(
name=namespace_and_table["table"],
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/catalog/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from typing import (
TYPE_CHECKING,
List,
Optional,
Set,
Expand Down Expand Up @@ -65,6 +66,9 @@
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT

if TYPE_CHECKING:
import pyarrow as pa


class SqlCatalogBaseTable(MappedAsDataclass, DeclarativeBase):
pass
Expand Down Expand Up @@ -140,7 +144,7 @@ def _convert_orm_to_iceberg(self, orm_table: IcebergTables) -> Table:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand All @@ -165,6 +169,8 @@ def create_table(
ValueError: If the identifier is invalid, or no path is given to store metadata.
"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

database_name, table_name = self.identifier_to_database_and_table(identifier)
if not self._namespace_exists(database_name):
raise NoSuchNamespaceError(f"Namespace does not exist: {database_name}")
Expand Down
25 changes: 20 additions & 5 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from __future__ import annotations

import concurrent.futures
import itertools
import logging
import os
import re
Expand All @@ -34,7 +35,6 @@
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache, singledispatch
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -637,7 +637,7 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows:
if len(positional_deletes) == 1:
all_chunks = positional_deletes[0]
else:
all_chunks = pa.chunked_array(chain(*[arr.chunks for arr in positional_deletes]))
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)


Expand Down Expand Up @@ -912,6 +912,21 @@ def after_map_value(self, element: pa.Field) -> None:
self._field_names.pop()


class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
"""
Converts PyArrowSchema to Iceberg Schema with all -1 ids.
The schema generated through this visitor should always be
used in conjunction with `new_table_metadata` function to
assign new field ids in order. This is currently used only
when creating an Iceberg Schema from a PyArrow schema when
creating a new Iceberg table.
"""

def _field_id(self, field: pa.Field) -> int:
return -1


def _task_to_table(
fs: FileSystem,
task: FileScanTask,
Expand Down Expand Up @@ -999,7 +1014,7 @@ def _task_to_table(

def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
unique_deletes = set(chain.from_iterable([task.delete_files for task in tasks]))
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
if len(unique_deletes) > 0:
executor = ExecutorFactory.get_or_create()
deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map(
Expand Down Expand Up @@ -1421,7 +1436,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsColl
def struct(
self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]]
) -> List[StatisticsCollector]:
return list(chain(*[result() for result in field_results]))
return list(itertools.chain(*[result() for result in field_results]))

def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
self._field_id = field.field_id
Expand Down Expand Up @@ -1513,7 +1528,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath
return struct_result()

def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]:
return list(chain(*[result() for result in field_results]))
return list(itertools.chain(*[result() for result in field_results]))

def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
self._field_id = field.field_id
Expand Down
Loading

0 comments on commit 02e6430

Please sign in to comment.