diff --git a/doc/source/modules.rst b/doc/source/modules.rst index d45df916..9a702759 100644 --- a/doc/source/modules.rst +++ b/doc/source/modules.rst @@ -16,5 +16,6 @@ The **GreenplumPython** library contains 5 main modules: type group order + op pd_df config \ No newline at end of file diff --git a/doc/source/op.rst b/doc/source/op.rst new file mode 100644 index 00000000..bc65240a --- /dev/null +++ b/doc/source/op.rst @@ -0,0 +1,9 @@ +Operators and Indexing +====================== + +.. module:: greenplumpython + +.. automodule:: op + :members: + :show-inheritance: + :member-order: bysource \ No newline at end of file diff --git a/greenplumpython/__init__.py b/greenplumpython/__init__.py index 5a2e8b03..666e648e 100644 --- a/greenplumpython/__init__.py +++ b/greenplumpython/__init__.py @@ -6,4 +6,5 @@ from greenplumpython.func import create_column_function # type: ignore from greenplumpython.func import create_function # type: ignore from greenplumpython.func import aggregate_function, function +from greenplumpython.op import operator from greenplumpython.type import type_ diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 918576b3..27034255 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -69,12 +69,14 @@ def __init__( parents: List["DataFrame"] = [], db: Optional[Database] = None, columns: Optional[Iterable[Column]] = None, + qualified_table_name: Optional[str] = None, ) -> None: # FIXME: Add doc # noqa self._query = query self._parents = parents self._name = "cte_" + uuid4().hex + self._qualified_table_name = qualified_table_name self._columns = columns self._contents: Optional[Iterable[RealDictRow]] = None if any(parents): @@ -82,6 +84,11 @@ def __init__( else: self._db = db + @property + def is_saved(self) -> bool: + """Check whether the current dataframe is saved in database.""" + return self._qualified_table_name is not None + @singledispatchmethod def _getitem(self, _) -> "DataFrame": raise NotImplementedError() @@ -337,6 +344,7 @@ def apply( func: Callable[["DataFrame"], "FunctionExpr"], expand: bool = False, column_name: Optional[str] = None, + row_id: Optional[str] = None, ) -> "DataFrame": """ Apply a dataframe function to the self :class:`~dataframe.DataFrame`. @@ -424,7 +432,11 @@ def apply( # # To fix this, we need to pass the dataframe to the resulting FunctionExpr # explicitly. - return func(self).bind(dataframe=self).apply(expand=expand, column_name=column_name) + return ( + func(self) + .bind(dataframe=self) + .apply(expand=expand, column_name=column_name, row_id=row_id) + ) def assign(self, **new_columns: Callable[["DataFrame"], Any]) -> "DataFrame": """ @@ -471,6 +483,7 @@ def assign(self, **new_columns: Callable[["DataFrame"], Any]) -> "DataFrame": for k, f in new_columns.items(): v: Any = f(self) if isinstance(v, Expr): + v.bind(db=self._db) assert ( v._dataframe is None or v._dataframe == self ), "Newly included columns must be based on the current dataframe" @@ -539,9 +552,10 @@ def join( other: "DataFrame", how: Literal["", "left", "right", "outer", "inner", "cross"] = "", cond: Optional[Callable[["DataFrame", "DataFrame"], Expr]] = None, - on: Optional[Union[str, Iterable[str]]] = None, + on: Iterable[str] = None, self_columns: Union[Dict[str, Optional[str]], Set[str]] = {"*"}, other_columns: Union[Dict[str, Optional[str]], Set[str]] = {"*"}, + on_columns: Union[Dict[str, Optional[str]], Set[str]] = {"*"}, ) -> "DataFrame": """ Join the current :class:`~dataframe.DataFrame` with another using the given arguments. @@ -567,6 +581,7 @@ def join( the corresponding key to avoid name conflicts. Asterisk :code:`"*"` can be used as a key to indicate all columns. other_columns: Same as `self_columns`, but for the **other** :class:`~dataframe.DataFrame`. + on_columns: A :class:`dict` whose keys are the column names of the resulting joined dataframe using `on` Note: When using :code:`"*"` as key in `self_columns` or `other_columns`, @@ -583,20 +598,21 @@ def join( ... age_rows, column_names=["name", "age"], db=db) >>> result = student.join( ... student, - ... on="age", - ... self_columns={"*"}, - ... other_columns={"name": "name_2"}) + ... on=["age"], + ... self_columns={"name": "name", "age": "age_1"}, + ... other_columns={"name": "name_2", "age": "age_2"}) >>> result ---------------------- - name | age | name_2 - -------+-----+-------- - alice | 18 | alice - bob | 19 | carol - bob | 19 | bob - carol | 19 | carol - carol | 19 | bob + age | name | name_2 + -----+-------+-------- + 18 | alice | alice + 19 | bob | carol + 19 | bob | bob + 19 | carol | carol + 19 | carol | bob ---------------------- (5 rows) + """ # FIXME : Raise Error if target columns don't exist assert how.upper() in [ @@ -629,15 +645,83 @@ def bind(t: DataFrame, columns: Union[Dict[str, Optional[str]], Set[str]]) -> Li if on is not None else None ) + # USING clause in SQL uses argument `on`. sql_using_clause = f"USING ({join_column_names})" if join_column_names is not None else "" - return DataFrame( + + if on is None: + return DataFrame( + f""" + SELECT {",".join(target_list)} + FROM {self._name} {how} JOIN {other_clause} {sql_on_clause} {sql_using_clause} + """, + parents=[self, other], + ) + + def bind_using( + t: DataFrame, + columns: Union[Dict[str, Optional[str]], Set[str]], + on: Iterable[str], + suffix: str, + ) -> List[str]: + target_list: List[str] = [] + for k in columns: + col: Column = t[k] + v = columns[k] if isinstance(columns, dict) else (k + suffix) if k in on else None + target_list.append(col._serialize() + (f' AS "{v}"' if v is not None else "")) + return target_list + + self_target_list = ( + bind_using(self, self_columns, on, "_l") + if isinstance(self_columns, set) + else bind(self, self_columns) + ) + other_target_list = ( + bind_using(other_temp, other_columns, on, "_r") + if isinstance(other_columns, set) + else bind(other_temp, other_columns) + ) + target_list = self_target_list + other_target_list + + join_dataframe = DataFrame( f""" SELECT {",".join(target_list)} FROM {self._name} {how} JOIN {other_clause} {sql_on_clause} {sql_using_clause} """, parents=[self, other], ) + coalesce_target_list = [] + if not (self_columns == {} or other_columns == {}): + for k in on: + s_v = self_columns[k] if isinstance(self_columns, dict) else (k + "_l") + o_v = other_columns[k] if isinstance(other_columns, dict) else (k + "_r") + coalesce_target_list.append(f"COALESCE({s_v},{o_v}) AS {k}") + + join_df = DataFrame( + f""" + SELECT * {("," + ",".join(coalesce_target_list)) if coalesce_target_list != [] else ""} + FROM {join_dataframe._name} + """, + parents=[join_dataframe], + ) + + self_columns_set = ( + self_columns + if isinstance(self_columns, set) + else set([k if k in on else v for k, v in self_columns.items()]) + ) + other_columns_set = ( + other_columns + if isinstance(other_columns, set) + else set([k if k in on else v for k, v in other_columns.items()]) + ) + return DataFrame( + f""" + SELECT {",".join(sorted(self_columns_set | other_columns_set))} + FROM {join_df._name} + """, + parents=[join_df], + ) inner_join = partialmethod(join, how="INNER") """ @@ -707,7 +791,7 @@ def _depth_first_search(self, t: "DataFrame", visited: Set[str], lineage: List[" self._depth_first_search(i, visited, lineage) lineage.append(t) - def _build_full_query(self) -> str: + def _serialize(self) -> str: # noqa """:meta private:""" lineage = self._list_lineage() @@ -849,12 +933,12 @@ def _fetch(self, is_all: bool = True) -> Iterable[Tuple[Any]]: f"SELECT to_json({output_name})::TEXT FROM {self._name} AS {output_name}", parents=[self], ) - result = self._db._execute(to_json_dataframe._build_full_query()) - return result if result is not None else [] + result = self._db._execute(to_json_dataframe._serialize()) + return result if isinstance(result, Iterable) else [] def save_as( self, - table_name: str, + table_name: Optional[str] = None, column_names: List[str] = [], temp: bool = False, storage_params: dict[str, Any] = {}, @@ -915,53 +999,61 @@ def save_as( # build string from parameter dict, such as from {'a': 1, 'b': 2} to # 'WITH (a=1, b=2)' - storage_parameters = ( + storage_params_clause = ( f"WITH ({','.join([f'{key}={storage_params[key]}' for key in storage_params.keys()])})" ) - df_full_name = f'"{table_name}"' if schema is None else f'"{schema}"."{table_name}"' + if table_name is None: + table_name = self._name if not self.is_saved else "cte_" + uuid4().hex + qualified_table_name = f'"{table_name}"' if schema is None else f'"{schema}"."{table_name}"' self._db._execute( f""" - CREATE {'TEMP' if temp else ''} TABLE {df_full_name} + CREATE {'TEMP' if temp else ''} TABLE {qualified_table_name} ({','.join(column_names)}) - {storage_parameters if storage_params else ''} - AS {self._build_full_query()} + {storage_params_clause if storage_params else ''} + AS ( + {self._serialize()} + ) """, has_results=False, ) - return DataFrame.from_table(table_name, self._db) - - # TODO: Uncomment or remove this. - # - # def create_index( - # self, - # columns: Iterable[Union["Column", str]], - # method: str = "btree", - # name: Optional[str] = None, - # ) -> None: - # if not self._in_catalog(): - # raise Exception("Cannot create index on dataframes not in the system catalog.") - # index_name: str = name if name is not None else "idx_" + uuid4().hex - # indexed_cols = ",".join([str(col) for col in columns]) - # assert self._db is not None - # self._db._execute( - # f"CREATE INDEX {index_name} ON {self.name} USING {method} ({indexed_cols})", - # has_results=False, - # ) - - def _explain(self, format: str = "TEXT") -> Iterable[Tuple[str]]: + return DataFrame.from_table(table_name, self._db, schema=schema) + + def create_index( + self, + columns: Union[Set[str], Dict[str, str]], + method: str = "btree", + name: Optional[str] = None, + ) -> "DataFrame": """ - Explain the GreenplumPython :class:`~dataframe.DataFrame`'s execution plan. + Create an index for the current dataframe for fast searching. + + The current dataframe is required to be saved before creating index. Args: - format: str: the format of the explain result. It can be one of "TEXT"/"XML"/"JSON"/"YAML". + columns: key columns of the current dataframe to create index on. + method: index access method. + name: name of the index. Returns: - Iterable[Tuple[str]]: The results of *EXPLAIN* query. + Dataframe with key columns indexed. """ + assert self.is_saved, "Cannot create index for unsaved dataframe." + assert len(columns) > 0, "Column set to be indexed cannot be empty." + + index_name: str = "idx_" + uuid4().hex if name is None else name + keys = ( + [f'"{name}" "{op_class}"' for name, op_class in columns.items()] + if isinstance(columns, dict) + else [f'"{name}"' for name in columns] + ) assert self._db is not None - results = self._db._execute(f"EXPLAIN (FORMAT {format}) {self._build_full_query()}") - assert results is not None - return results + self._db._execute( + f'CREATE INDEX "{index_name}" ON {self._qualified_table_name} USING "{method}" (' + f' {",".join(keys)}' + f")", + has_results=False, + ) + return self def group_by(self, *column_names: str) -> DataFrameGroupingSet: """ @@ -1036,10 +1128,8 @@ def from_table(cls, table_name: str, db: Database, schema: Optional[str] = None) df = gp.DataFrame.from_table("pg_class", db=db) """ - return DataFrame( - f'TABLE "{schema}"."{table_name}"' if schema is not None else f'TABLE "{table_name}"', - db=db, - ) + qualified_name = f'"{schema}"."{table_name}"' if schema is not None else f'"{table_name}"' + return DataFrame(f"TABLE {qualified_name}", db=db, qualified_table_name=qualified_name) @classmethod def from_rows( diff --git a/greenplumpython/db.py b/greenplumpython/db.py index 324d5c3e..37b47a1d 100644 --- a/greenplumpython/db.py +++ b/greenplumpython/db.py @@ -60,6 +60,9 @@ def _execute(self, query: str, has_results: bool = True) -> Union[Iterable[Tuple """ with self._conn.cursor() as cursor: + cursor.execute("SELECT pg_backend_pid()") + print("BACKEND SESSION PID: ", cursor.fetchall()) + print("OBJECT ID: ", id(self)) if config.print_sql: print(query) cursor.execute(query) @@ -67,6 +70,7 @@ def _execute(self, query: str, has_results: bool = True) -> Union[Iterable[Tuple def close(self) -> None: """Close the database connection.""" + print("OBJECT ID CLOSED: ", id(self)) self._conn.close() def create_dataframe( @@ -213,6 +217,7 @@ def assign(self, **new_columns: Callable[[], Any]) -> "DataFrame": for k, f in new_columns.items(): v: Any = f() if isinstance(v, Expr): + v.bind(db=self) assert v._dataframe is None, "New column should not depend on any dataframe." if isinstance(v, FunctionExpr): v = v.bind(db=self) diff --git a/greenplumpython/expr.py b/greenplumpython/expr.py index c851cab8..44f5d8aa 100644 --- a/greenplumpython/expr.py +++ b/greenplumpython/expr.py @@ -20,7 +20,16 @@ def __init__( # noqa: D107 self._dataframe = dataframe self._other_dataframe = other_dataframe - self._db = dataframe._db if dataframe is not None else None + self._db = dataframe._db if dataframe is not None else None # FIXME: set it to None + + def bind( + self, + dataframe: Optional["DataFrame"] = None, + db: Optional[Database] = None, + ) -> "Expr": + self._db = db + self._dataframe = dataframe + return self def __hash__(self) -> int: # noqa: D105 @@ -569,9 +578,9 @@ def _init( if other_dataframe is not None and isinstance(right, Expr): other_dataframe = right._other_dataframe super().__init__(dataframe=dataframe, other_dataframe=other_dataframe) - self.operator = operator - self.left = left - self.right = right + self._operator = operator + self._left = left + self._right = right @overload def __init__( @@ -621,9 +630,13 @@ def __init__( def _serialize(self) -> str: from greenplumpython.expr import _serialize - left_str = _serialize(self.left) - right_str = _serialize(self.right) - return f"({left_str} {self.operator} {right_str})" + if isinstance(self._left, Expr): + self._left._db = self._db + if isinstance(self._right, Expr): + self._right._db = self._db + left_str = _serialize(self._left) + right_str = _serialize(self._right) + return f"({left_str} {self._operator} {right_str})" class UnaryExpr(Expr): @@ -639,12 +652,16 @@ def __init__( (right._dataframe, right._other_dataframe) if isinstance(right, Expr) else (None, None) ) super().__init__(dataframe=dataframe, other_dataframe=other_dataframe) - self.operator = operator - self.right = right + self._operator = operator + self._right = right def _serialize(self) -> str: - right_str = str(self.right) - return f"{self.operator} ({right_str})" + from greenplumpython.expr import _serialize + + if isinstance(self._right._db, Expr): + self._right._db = self._db + right_str = _serialize(self._right) + return f"({self._operator} ({right_str}))" class InExpr(Expr): @@ -671,6 +688,7 @@ def _serialize(self) -> str: # when combining with `~` (bitwise not) operator. container_name: str = "cte_" + uuid4().hex if isinstance(self._container, Expr) and self._other_dataframe is not None: + self._container._db = self._db return ( f"(EXISTS (SELECT FROM {self._other_dataframe._name}" f" WHERE ({self._container._serialize()} = {self._item._serialize()})))" diff --git a/greenplumpython/func.py b/greenplumpython/func.py index 92f531a0..c5b59cef 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -83,6 +83,7 @@ def bind( def _serialize(self) -> str: # noqa D400 """:meta private:""" + assert self._db is not None, "Database is required to create function." self._function._create_in_db(self._db) distinct = "DISTINCT" if self._distinct else "" for arg in self._args: @@ -95,7 +96,9 @@ def _serialize(self) -> str: ) return f"{self._function._qualified_name_str}({distinct} {args_string})" - def apply(self, expand: bool = False, column_name: Optional[str] = None) -> DataFrame: + def apply( + self, expand: bool = False, column_name: Optional[str] = None, row_id: Optional[str] = None + ) -> DataFrame: # noqa D400 """ :meta private: @@ -125,7 +128,7 @@ def apply(self, expand: bool = False, column_name: Optional[str] = None) -> Data orig_func_dataframe = DataFrame( " ".join( [ - f"SELECT {str(self)} {'AS ' + column_name if column_name is not None else ''}", + f"SELECT {(row_id + ',') if row_id is not None else ''} {str(self)} {'AS ' + column_name if column_name is not None else ''}", ("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "", from_clause, group_by_clause, @@ -163,7 +166,7 @@ def apply(self, expand: bool = False, column_name: Optional[str] = None) -> Data ) return DataFrame( - f"SELECT {str(results)} FROM {orig_func_dataframe._name}", + f"SELECT {(row_id + ',') if row_id is not None and expand else ''} {str(results)} FROM {orig_func_dataframe._name}", db=self._db, parents=[orig_func_dataframe], ) @@ -308,6 +311,7 @@ def _create_in_db(self, db: Database) -> None: func_src: str = inspect.getsource(self._wrapped_func) else: func_src: str = dill.source.getsource(self._wrapped_func) + assert isinstance(func_src, str) func_ast: ast.FunctionDef = ast.parse(dedent(func_src)).body[0] # TODO: Lambda expressions are NOT supported since inspect.signature() # does not work as expected. diff --git a/greenplumpython/group.py b/greenplumpython/group.py index b061a81c..0882022e 100644 --- a/greenplumpython/group.py +++ b/greenplumpython/group.py @@ -158,10 +158,11 @@ def assign(self, **new_columns: Callable[["DataFrame"], Any]) -> "DataFrame": targets: List[str] = self._flatten() for k, f in new_columns.items(): v: Any = f(self._dataframe).bind(group_by=self) - if isinstance(v, Expr) and not ( - v._dataframe is None or v._dataframe == self._dataframe - ): - raise Exception("Newly included columns must be based on the current dataframe") + if isinstance(v, Expr): + assert ( + v._dataframe is None or v._dataframe == self._dataframe + ), "Newly included columns must be based on the current dataframe" + v.bind(db=self._dataframe._db) targets.append(f"{_serialize(v)} AS {k}") return DataFrame( f"SELECT {','.join(targets)} FROM {self._dataframe._name} {self._clause()}", diff --git a/greenplumpython/op.py b/greenplumpython/op.py new file mode 100644 index 00000000..7f33337e --- /dev/null +++ b/greenplumpython/op.py @@ -0,0 +1,31 @@ +""" +Indexing is essential for fast data searching in database. Unlike pandas, +""" +from typing import Any, Optional, Union + +from greenplumpython.expr import BinaryExpr, UnaryExpr + + +class Operator: + def __init__(self, name: str, schema: Optional[str] = None) -> None: + self._name = name + self._schema = schema + + @property + def qualified_name(self) -> str: + if self._schema is not None: + return f'OPERATOR("{self._schema}".{self._name})' + else: + return f"OPERATOR({self._name})" + + def __call__(self, *args: Any) -> Union[UnaryExpr, BinaryExpr]: + if len(args) == 1: + return UnaryExpr(self.qualified_name, args[0]) + if len(args) == 2: + return BinaryExpr(self.qualified_name, args[0], args[1]) + else: + raise Exception("Too many operands.") + + +def operator(name: str, schema: Optional[str] = None) -> Operator: + return Operator(name, schema) diff --git a/greenplumpython/pandas/dataframe.py b/greenplumpython/pandas/dataframe.py index 629a984c..3b959352 100644 --- a/greenplumpython/pandas/dataframe.py +++ b/greenplumpython/pandas/dataframe.py @@ -95,7 +95,7 @@ def to_sql( assert index is False, "DataFrame in GreenplumPython.pandas does not have an index column" table_name = f'"{name}"' if schema is None else f'"{schema}"."{name}"' database = db.Database(uri=con) - query = self._dataframe._build_full_query() + query = self._dataframe._serialize() if if_exists == "append": rowcount = database._execute( f""" diff --git a/greenplumpython/type.py b/greenplumpython/type.py index 0a1696ef..dff11241 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -1,10 +1,13 @@ # noqa: D100 -from typing import Any, Dict, List, Optional, Set, Tuple, get_type_hints +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, get_type_hints from uuid import uuid4 from greenplumpython.db import Database from greenplumpython.expr import Expr, _serialize +if TYPE_CHECKING: + from greenplumpython.dataframe import DataFrame + class TypeCast(Expr): """ @@ -28,12 +31,7 @@ class TypeCast(Expr): (2 rows) """ - def __init__( - self, - obj: object, - type_name: str, - schema: Optional[str] = None, - ) -> None: + def __init__(self, obj: object, qualified_type_name: str) -> None: # noqa: D205 D400 """ Args: @@ -43,27 +41,44 @@ def __init__( dataframe = obj._dataframe if isinstance(obj, Expr) else None super().__init__(dataframe) self._obj = obj - self._type_name = type_name - self._schema = schema - self._qualified_name_str = ( - f'"{self._type_name}"' - if self._schema is None - else f'"{self._schema}"."{self._type_name}"' - ) + self._qualified_type_name = qualified_type_name def _serialize(self) -> str: + if isinstance(self._obj, Expr): + self._obj._db = self._db obj_str = _serialize(self._obj) - return f"({obj_str}::{self._qualified_name_str})" - - @property - def _qualified_name(self) -> Tuple[Optional[str], str]: - """ - Return the schema name and name of :class:`~type.TypeCast`. + return f"({obj_str}::{self._qualified_type_name})" - Returns: - Tuple[str, str]: schema name and :class:`~type.TypeCast`'s name. - """ - return self._schema, self._type_name + def bind( + self, + dataframe: Optional["DataFrame"] = None, + db: Optional[Database] = None, + column_name: str = None, + ) -> "Expr": + if isinstance(self._obj, Expr): + self._obj.bind( + dataframe=dataframe, + db=db, + ) + return self + + def apply( + self, expand: bool = False, column_name: Optional[str] = None, row_id: Optional[str] = None + ) -> "DataFrame": + from greenplumpython.dataframe import DataFrame + + if expand and column_name is None: + column_name = "func_" + uuid4().hex + return DataFrame( + f""" + SELECT {(row_id + ',') if row_id is not None else ''} + {self._serialize()} + {'AS ' + column_name if column_name is not None else ''} + {('FROM ' + self._obj._dataframe._name) if isinstance(self._obj, Expr) and self._obj._dataframe is not None else ""} + """, + db=self._db, + parents=[self._obj._dataframe], + ) class Type: @@ -84,16 +99,23 @@ class Type: """ def __init__( - self, name: str, annotation: Optional[type] = None, schema: Optional[str] = None + self, + name: str, + annotation: Optional[type] = None, + schema: Optional[str] = None, + modifier: Optional[int] = None, ) -> None: # noqa: D107 self._name = name self._annotation = annotation self._created_in_dbs: Optional[Set[Database]] = set() if annotation is not None else None self._schema = schema - self._qualified_name_str = ( - f'"{self._name}"' if self._schema is None else f'"{self._schema}"."{self._name}"' - ) + self._modifier = modifier + self._qualified_name_str = f'"{self._name}"' + if self._schema is not None: + self._qualified_name_str = f'"{self._schema}".' + self._qualified_name_str + if self._modifier is not None: + self._qualified_name_str += f"({self._modifier})" # -- Creation of a composite type in Greenplum corresponding to the class_type given def _create_in_db(self, db: Database): @@ -138,7 +160,7 @@ def __call__(self, obj: Any) -> TypeCast: - Any :class:`Expr` consisting of adaptable Python objects and :class:`Column`s of a :class:`DataFrame`. """ - return TypeCast(obj, self._name, self._schema) + return TypeCast(obj, self._qualified_name_str) @property def _qualified_name(self) -> Tuple[Optional[str], str]: @@ -162,7 +184,7 @@ def _qualified_name(self) -> Tuple[Optional[str], str]: } -def type_(name: str, schema: Optional[str] = None) -> Type: +def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = None) -> Type: """ Get access to a type predefined in database. @@ -173,7 +195,7 @@ def type_(name: str, schema: Optional[str] = None) -> Type: Returns: The predefined type as a :class:`~type.Type` object. """ - return Type(name, schema=schema) + return Type(name, schema=schema, modifier=modifier) def to_pg_type( diff --git a/tests/test_dataframe.py b/tests/test_dataframe.py index 7794f65a..0f8ab91c 100644 --- a/tests/test_dataframe.py +++ b/tests/test_dataframe.py @@ -63,7 +63,7 @@ def test_dataframe_getitem_slice_off_limit(db: gp.Database, t: gp.DataFrame): def test_dataframe_getitem_slice_off_limit(db: gp.Database, t: gp.DataFrame): - query = t[:]._build_full_query() + query = t[:]._serialize() assert len(list(t[:])) == 10 assert "LIMIT" in query diff --git a/tests/test_index.py b/tests/test_index.py new file mode 100644 index 00000000..e86e7cab --- /dev/null +++ b/tests/test_index.py @@ -0,0 +1,79 @@ +from typing import List + +import pytest + +import greenplumpython as gp +from tests import db + + +def test_op_on_consts(db: gp.Database): + regex_match = gp.operator("~") + result = db.assign(is_matched=lambda: regex_match("hello", "h.*o")) + assert len(list(result)) == 1 and next(iter(result))["is_matched"] + + +def using_index_scan(results: gp.DataFrame, db: gp.Database) -> bool: + for row in db._execute(f"EXPLAIN {results._serialize()}"): + if "Index Scan" in row["QUERY PLAN"] or "Index Only Scan" in row["QUERY PLAN"]: + return True + return False + + +def test_index(db: gp.Database): + import dataclasses + import json + + @dataclasses.dataclass + class Student: + name: str + courses: List[str] + + john = Student("john", ["math", "english"]) + jsonb = gp.type_("jsonb") + rows = [(jsonb(json.dumps(john.__dict__)),)] + student = ( + db.create_dataframe(rows=rows, column_names=["info"]) + .save_as(temp=True, column_names=["info"]) + .create_index({"info"}, "gin") + ) + + db._execute("SET enable_seqscan TO False", has_results=False) + json_contains = gp.operator("@>") + assert using_index_scan( + student[lambda t: json_contains(t["info"], json.dumps({"name": "john"}))], db + ) + + +def test_index_opclass(db: gp.Database): + df = ( + db.create_dataframe(columns={"text": ["hello", "world"]}) + .save_as(temp=True, column_names=["text"]) + .create_index(columns={"text": "text_pattern_ops"}) + ) + db._execute("SET enable_seqscan TO off", has_results=False) + db._execute( + """ + do $$ + begin + set optimizer to off; + exception when others then + end; + $$; + """, + has_results=False, + ) + assert using_index_scan(df[lambda t: t["text"] > "hello"], db) + + +def test_op_with_schema(db: gp.Database): + my_add = gp.operator("+") + result = db.assign(add=lambda: my_add(1, 2)) + for row in result: + assert row["add"] == 3 + qualified_add = gp.operator("+", "pg_catalog") + result = db.assign(add=lambda: qualified_add(1, 2)) + for row in result: + assert row["add"] == 3 + + +# FIXME : Add test for unary operator diff --git a/tests/test_join.py b/tests/test_join.py index f8ab5a50..13db37d9 100644 --- a/tests/test_join.py +++ b/tests/test_join.py @@ -103,8 +103,8 @@ def test_join_same_column_using(db: gp.Database): rows = [(1,), (2,), (3,)] t1 = db.create_dataframe(rows=rows, column_names=["id"]) t2 = db.create_dataframe(rows=rows, column_names=["id"]) - ret = t1.join(t2, on=["id"], self_columns={"id": "t1_id"}, other_columns={"id": "t2_id"}) - assert sorted(next(iter(ret)).keys()) == sorted(["t1_id", "t2_id"]) + ret = t1.join(t2, on=["id"], self_columns={"id"}, other_columns={"id"}) + assert sorted(next(iter(ret)).keys()) == ["id"] def test_join_same_column_names(db: gp.Database): @@ -121,21 +121,21 @@ def test_join_on_multi_columns(db: gp.Database): rows = [(1, 1), (2, 1), (3, 1)] t1 = db.create_dataframe(rows=rows, column_names=["id", "n"]) t2 = db.create_dataframe(rows=rows, column_names=["id", "n"]) - ret = t1.join(t2, on=["id", "n"], other_columns={}) - print(ret) + ret = t1.join(t2, on=["id", "n"], self_columns={"id", "n"}, other_columns={"id", "n"}) + assert sorted(next(iter(ret)).keys()) == sorted(["id", "n"]) def test_dataframe_inner_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.DataFrame): ret: gp.DataFrame = zoo_1.join( zoo_2, on=["animal"], - self_columns={"animal": "zoo1_animal", "id": "zoo1_id"}, - other_columns={"animal": "zoo2_animal", "id": "zoo2_id"}, + self_columns={"animal": "animal_l", "id": "id_zoo1"}, + other_columns={"animal": "animal_r", "id": "id_zoo2"}, ) assert len(list(ret)) == 2 + assert sorted(next(iter(ret)).keys()) == sorted(["animal", "id_zoo1", "id_zoo2"]) for row in ret: - assert row["zoo1_animal"] == row["zoo2_animal"] - assert row["zoo1_animal"] == "Lion" or row["zoo1_animal"] == "Tiger" + assert row["animal"] == "Lion" or row["animal"] == "Tiger" def test_dataframe_left_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.DataFrame): @@ -147,10 +147,7 @@ def test_dataframe_left_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Dat ) assert len(list(ret)) == 4 for row in ret: - if row["zoo1_animal"] == "Lion" or row["zoo1_animal"] == "Tiger": - assert row["zoo1_animal"] == row["zoo2_animal"] - else: - assert row["zoo2_animal"] is None + if row["animal"] != "Lion" and row["animal"] != "Tiger": assert row["zoo2_id"] is None @@ -163,10 +160,7 @@ def test_dataframe_right_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Da ) assert len(list(ret)) == 4 for row in ret: - if row["zoo2_animal"] == "Lion" or row["zoo2_animal"] == "Tiger": - assert row["zoo1_animal"] == row["zoo2_animal"] - else: - assert row["zoo1_animal"] is None + if row["animal"] != "Lion" and row["animal"] != "Tiger": assert row["zoo1_id"] is None @@ -179,17 +173,42 @@ def test_dataframe_full_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Dat ) assert len(list(ret)) == 6 for row in ret: - if row["zoo2_animal"] == "Lion" or row["zoo2_animal"] == "Tiger": - assert row["zoo1_animal"] == row["zoo2_animal"] - else: - assert (row["zoo1_animal"] is None and row["zoo2_animal"] is not None) or ( - row["zoo1_animal"] is not None and row["zoo2_animal"] is None - ) + if row["animal"] != "Lion" and row["animal"] != "Tiger": assert (row["zoo1_id"] is None and row["zoo2_id"] is not None) or ( row["zoo1_id"] is not None and row["zoo2_id"] is None ) +def test_dataframe_full_join_with_empty(db: gp.Database): + # fmt: off + rows1 = [(1, 100,), (2, 200,), (3, 300,), (4, 400,)] + rows2 = [(3, 300, 3000,), (4, 400, 4000,), (5, 500, 5000,), (6, 600, 6000)] + # fmt: on + l_df = db.create_dataframe(rows=rows1, column_names=["a", "b"]) + r_df = db.create_dataframe(rows=rows2, column_names=["a", "b", "c"]) + ret = l_df.full_join( + r_df, + self_columns={"a", "b"}, + other_columns={"a", "b", "c"}, + on=["a", "b"], + ).order_by("a")[:] + assert len(list(ret)) == 6 + expected = ( + "----------------\n" + " a | b | c \n" + "---+-----+------\n" + " 1 | 100 | \n" + " 2 | 200 | \n" + " 3 | 300 | 3000 \n" + " 4 | 400 | 4000 \n" + " 5 | 500 | 5000 \n" + " 6 | 600 | 6000 \n" + "----------------\n" + "(6 rows)\n" + ) + assert str(ret) == expected + + def test_join_natural(db: gp.Database): # fmt: off rows1 = [("Smart Phone", 1,), ("Laptop", 2,), ("DataFramet", 3,)] @@ -202,8 +221,8 @@ def test_join_natural(db: gp.Database): ret = categories.join( products, on=["category_id"], - self_columns={"category_name", "category_id"}, - other_columns={"product_name"}, + self_columns={"category_id", "category_name"}, + other_columns={"category_id", "product_name"}, ) assert len(list(ret)) == 6 assert sorted(next(iter(ret)).keys()) == sorted( @@ -246,8 +265,6 @@ def test_dataframe_self_join(db: gp.Database, zoo_1: gp.DataFrame): other_columns={"animal": "zoo2_animal", "id": "zoo2_id"}, ) assert len(list(ret)) == 4 - for row in ret: - assert row["zoo1_animal"] == row["zoo2_animal"] def test_dataframe_self_join_cond(db: gp.Database, zoo_1: gp.DataFrame): @@ -271,20 +288,17 @@ def test_dataframe_join_save(db: gp.Database, zoo_1: gp.DataFrame): ) t_join.save_as( "dataframe_join", - column_names=["zoo1_animal", "zoo1_id", "zoo2_animal", "zoo2_id"], + column_names=["animal", "zoo1_id", "zoo2_id"], temp=True, ) t_join_reload = gp.DataFrame.from_table("dataframe_join", db=db) assert sorted(next(iter(t_join_reload)).keys()) == sorted( [ - "zoo1_animal", + "animal", "zoo1_id", - "zoo2_animal", "zoo2_id", ] ) - for row in t_join_reload: - assert row["zoo1_animal"] == row["zoo2_animal"] def test_dataframe_join_ine(db: gp.Database): @@ -307,11 +321,19 @@ def test_dataframe_multiple_self_join(db: gp.Database, zoo_1: gp.DataFrame): ) ret = t_join.join( zoo_1, - cond=lambda s, o: s["zoo1_animal"] == o["animal"], + on=["animal"], + self_columns={"animal", "zoo1_id", "zoo2_id"}, + other_columns={"animal", "id"}, ) assert len(list(ret)) == 4 - for row in ret: - assert row["zoo2_animal"] == row["animal"] + assert sorted(next(iter(ret)).keys()) == sorted( + [ + "animal", + "id", + "zoo1_id", + "zoo2_id", + ] + ) # This test case is to guarantee that the CTEs are generated in the reversed diff --git a/tests/test_schema.py b/tests/test_schema.py index 863a5307..89f1c43d 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -47,7 +47,8 @@ def test_schema_self_join_on(db: gp.Database, t: gp.DataFrame): ret: gp.DataFrame = t.join( t, on=["id"], - other_columns={"id": "id_1"}, + self_columns={"id"}, + other_columns={"id"}, ) assert len(list(ret)) == 10 diff --git a/tests/test_type.py b/tests/test_type.py index 7f336dce..f9dd5827 100644 --- a/tests/test_type.py +++ b/tests/test_type.py @@ -21,6 +21,22 @@ def test_type_cast(db: gp.Database): assert row["complex"] == {"i": 2, "r": 1} +def test_type_cast_func_result(db: gp.Database): + float8 = gp.type_("float8") + rows = [(i, i) for i in range(10)] + df = db.create_dataframe(rows=rows, column_names=["a", "b"]) + + @gp.create_function + def func(a: int, b: int) -> int: + return a + b + + results = df.apply( + lambda t: float8(func(t["a"], t["b"])), + column_name="float8", + ) + assert sorted([row["float8"] for row in results]) == list(range(0, 20, 2)) + + def test_type_create(db: gp.Database): @dataclasses.dataclass class Person: