diff --git a/py/client/pydeephaven/_table_ops.py b/py/client/pydeephaven/_table_ops.py index 34f15fe7859..4aff5ae03d6 100644 --- a/py/client/pydeephaven/_table_ops.py +++ b/py/client/pydeephaven/_table_ops.py @@ -687,3 +687,26 @@ def make_grpc_request(self, result_id, source_id) -> Any: def make_grpc_request_for_batch(self, result_id, source_id) -> Any: return table_pb2.BatchTableRequest.Operation( meta_table=self.make_grpc_request(result_id=result_id, source_id=source_id)) + + +class MultijoinTablesOp(TableOp): + def __init__(self, multi_join_inputs: List["MultiJoinInput"]): + self.multi_join_inputs = multi_join_inputs + + @classmethod + def get_stub_func(cls, table_service_stub: table_pb2_grpc.TableServiceStub) -> Any: + return table_service_stub.MultiJoinTables + + def make_grpc_request(self, result_id, source_id) -> Any: + pb_inputs = [] + for mji in self.multi_join_inputs: + source_id = table_pb2.TableReference(ticket=mji.table.ticket.pb_ticket) + columns_to_match = mji.on + columns_to_add = mji.joins + pb_inputs.append(table_pb2.MultiJoinInput(source_id=source_id, columns_to_match=columns_to_match, + columns_to_add=columns_to_add)) + return table_pb2.MultiJoinTablesRequest(result_id=result_id, multi_join_inputs=pb_inputs) + + def make_grpc_request_for_batch(self, result_id, source_id) -> Any: + return table_pb2.BatchTableRequest.Operation( + multi_join=self.make_grpc_request(result_id=result_id, source_id=source_id)) \ No newline at end of file diff --git a/py/client/pydeephaven/_table_service.py b/py/client/pydeephaven/_table_service.py index e58dd5fa1ff..0fbb7034cf3 100644 --- a/py/client/pydeephaven/_table_service.py +++ b/py/client/pydeephaven/_table_service.py @@ -1,7 +1,7 @@ # # Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending # -from typing import Union, List +from typing import Union, List, Optional from pydeephaven._batch_assembler import BatchOpAssembler from pydeephaven._table_ops import TableOp @@ -38,7 +38,7 @@ def batch(self, ops: List[TableOp]) -> Table: except Exception as e: raise DHError("failed to finish the table batch operation.") from e - def grpc_table_op(self, table: Table, op: TableOp, table_class: type = Table) -> Union[Table, InputTable]: + def grpc_table_op(self, table: Optional[Table], op: TableOp, table_class: type = Table) -> Union[Table, InputTable]: """Makes a single gRPC Table operation call and returns a new Table.""" try: export_ticket = self.session.make_export_ticket() diff --git a/py/client/pydeephaven/session.py b/py/client/pydeephaven/session.py index fb7fbf73ed8..abf9f062476 100644 --- a/py/client/pydeephaven/session.py +++ b/py/client/pydeephaven/session.py @@ -397,7 +397,7 @@ def _connect(self): # started together don't align retries. skew = random() # Backoff schedule for retries after consecutive failures to refresh auth token - self._refresh_backoff = [ skew + 0.1, skew + 1, skew + 10 ] + self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10] if self._refresh_backoff[0] > self._timeout_seconds: raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.') diff --git a/py/client/pydeephaven/table.py b/py/client/pydeephaven/table.py index 9a4cd0b9194..ba5a50bbdd2 100644 --- a/py/client/pydeephaven/table.py +++ b/py/client/pydeephaven/table.py @@ -6,13 +6,13 @@ from __future__ import annotations -from typing import List, Union +from typing import List, Union, Sequence import pyarrow as pa from pydeephaven._utils import to_list -from pydeephaven._table_ops import MetaTableOp, SortDirection +from pydeephaven._table_ops import MetaTableOp, SortDirection, MultijoinTablesOp from pydeephaven.agg import Aggregation from pydeephaven.dherror import DHError from pydeephaven._table_interface import TableInterface @@ -804,3 +804,81 @@ def delete(self, table: Table) -> None: self.session.input_table_service.delete(self, table) except Exception as e: raise DHError("delete data in the InputTable failed.") from e + + +class MultiJoinTable: + """A MultiJoinTable is an object that contains the result of a multi-table natural join. To retrieve the underlying + result Table, use the :attr:`.table` property. """ + + def __init__(self, table: Table): + self._table = table + + @property + def table(self) -> Table: + """Returns the Table containing the multi-table natural join output. """ + return self._table + + +class MultiJoinInput: + """A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table + natural join. + """ + table: Table + on: Union[str, Sequence[str]] + joins: Union[str, Sequence[str]] = None + + def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None): + """Initializes a MultiJoinInput object. + + Args: + table (Table): the right table to include in the join + on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that + matches every input table, i.e. "col_a = col_b" to rename output column names. + joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result + table, can be renaming expressions, i.e. "new_col = col"; default is None + """ + self.table = table + self.on = to_list(on) + self.joins = to_list(joins) + + +def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]], + on: Union[str, Sequence[str]] = None) -> MultiJoinTable: + """ The multi_join method creates a new table by performing a multi-table natural join on the input tables. The + result consists of the set of distinct keys from the input tables natural joined to each input table. Input + tables need not have a matching row for each key, but they may not have multiple matching rows for a given key. + + Args: + input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the + tables and columns to include in the join. + on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression + that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When + MultiJoinInput objects are supplied, this parameter must be omitted. + + Returns: + MultiJoinTable: the result of the multi-table natural join operation. To access the underlying Table, use the + :attr:`~MultiJoinTable.table` property. + + Raises: + DHError + """ + if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)): + tables = to_list(input) + session = tables[0].session + if not all([t.session == session for t in tables]): + raise DHError(message="all tables must be from the same session.") + multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables] + elif isinstance(input, MultiJoinInput) or ( + isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)): + if on is not None: + raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.") + multi_join_inputs = to_list(input) + session = multi_join_inputs[0].table.session + if not all([mji.table.session == session for mji in multi_join_inputs]): + raise DHError(message="all tables must be from the same session.") + else: + raise DHError( + message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.") + + table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs) + return MultiJoinTable(table=session.table_service.grpc_table_op(None, table_op, table_class=Table)) diff --git a/py/client/tests/test_multijoin.py b/py/client/tests/test_multijoin.py new file mode 100644 index 00000000000..a8ca376d169 --- /dev/null +++ b/py/client/tests/test_multijoin.py @@ -0,0 +1,127 @@ +# +# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +# +import unittest + +from pyarrow import csv + +from pydeephaven import DHError, Session +from pydeephaven.table import MultiJoinInput, multi_join +from tests.testbase import BaseTestCase + + +class MultiJoinTestCase(BaseTestCase): + def setUp(self): + super().setUp() + pa_table = csv.read_csv(self.csv_file) + self.static_tableA = self.session.import_table(pa_table).select(["a", "b", "c1=c", "d1=d", "e1=e"]) + self.static_tableB = self.static_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns( + ["c1", "d1", "e1"]) + self.ticking_tableA = self.session.time_table("PT00:00:00.001").update( + ["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"]) + self.ticking_tableB = self.ticking_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns( + ["c1", "d1", "e1"]) + + def tearDown(self) -> None: + self.static_tableA = None + self.static_tableB = None + self.ticking_tableA = None + self.ticking_tableB = None + super().tearDown() + + def test_static_simple(self): + # Test with multiple input tables + mj_table = multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"]) + + # Output table is static + self.assertFalse(mj_table.table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.table.size, self.static_tableA.size) + self.assertEqual(mj_table.table.size, self.static_tableB.size) + + # Test with a single input table + mj_table = multi_join(self.static_tableA, ["a", "b"]) + + # Output table is static + self.assertFalse(mj_table.table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.table.size, self.static_tableA.size) + + def test_ticking_simple(self): + # Test with multiple input tables + mj_table = multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"]) + + # Output table is refreshing + self.assertTrue(mj_table.table.is_refreshing) + + # Test with a single input table + mj_table = multi_join(input=self.ticking_tableA, on=["a", "b"]) + + # Output table is refreshing + self.assertTrue(mj_table.table.is_refreshing) + + def test_static(self): + # Test with multiple input + mj_input = [ + MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + mj_table = multi_join(mj_input) + + # Output table is static + self.assertFalse(mj_table.table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.table.size, self.static_tableA.size) + self.assertEqual(mj_table.table.size, self.static_tableB.size) + + # Test with a single input + mj_table = multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1")) + + # Output table is static + self.assertFalse(mj_table.table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.table.size, self.static_tableA.size) + + def test_ticking(self): + # Test with multiple input + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + mj_table = multi_join(mj_input) + + # Output table is refreshing + self.assertTrue(mj_table.table.is_refreshing) + + # Test with a single input + mj_table = multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1")) + + # Output table is refreshing + self.assertTrue(mj_table.table.is_refreshing) + + def test_errors(self): + # Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted). + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + with self.assertRaises(DHError) as cm: + mj_table = multi_join(mj_input, on=["key1=a", "key2=b"]) + self.assertIn("on parameter is not permitted", str(cm.exception)) + + session = Session() + t = session.time_table("PT00:00:00.001").update( + ["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"]) + + # Assert the exception is raised when to-be-joined tables are not from the same session. + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=t, on=["key1=a", "key2=b"], joins=["d2"]) + ] + with self.assertRaises(DHError) as cm: + mj_table = multi_join(mj_input) + self.assertIn("all tables must be from the same session", str(cm.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index e02307f647e..c634d2d95af 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -3786,7 +3786,7 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str table (Table): the right table to include in the join on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equal expression, i.e. "col_a = col_b" for different column names - joins (Union[str, Sequence[str]], optional): the column(s) to be added from the this table to the result + joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result table, can be renaming expressions, i.e. "new_col = col"; default is None Raises: @@ -3803,13 +3803,14 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str class MultiJoinTable(JObjectWrapper): """A MultiJoinTable is an object that contains the result of a multi-table natural join. To retrieve the underlying - result Table, use the table() method. """ + result Table, use the :attr:`.table` property. """ j_object_type = _JMultiJoinTable @property def j_object(self) -> jpy.JType: return self.j_multijointable + @property def table(self) -> Table: """Returns the Table containing the multi-table natural join output. """ return Table(j_table=self.j_multijointable.table()) @@ -3866,7 +3867,7 @@ def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[Mul Returns: MultiJoinTable: the result of the multi-table natural join operation. To access the underlying Table, use the - table() method. + :attr:`~MultiJoinTable.table` property. """ return MultiJoinTable(input, on) diff --git a/py/server/tests/test_multijoin.py b/py/server/tests/test_multijoin.py index 28cf93bb505..67c5cb7a081 100644 --- a/py/server/tests/test_multijoin.py +++ b/py/server/tests/test_multijoin.py @@ -33,18 +33,18 @@ def test_static_simple(self): mj_table = multi_join(input=[self.static_tableA, self.static_tableB], on=["a","b"]) # Output table is static - self.assertFalse(mj_table.table().is_refreshing) + self.assertFalse(mj_table.table.is_refreshing) # Output table has same # rows as sources - self.assertEqual(mj_table.table().size, self.static_tableA.size) - self.assertEqual(mj_table.table().size, self.static_tableB.size) + self.assertEqual(mj_table.table.size, self.static_tableA.size) + self.assertEqual(mj_table.table.size, self.static_tableB.size) # Test with a single input table mj_table = multi_join(self.static_tableA, ["a","b"]) # Output table is static - self.assertFalse(mj_table.table().is_refreshing) + self.assertFalse(mj_table.table.is_refreshing) # Output table has same # rows as sources - self.assertEqual(mj_table.table().size, self.static_tableA.size) + self.assertEqual(mj_table.table.size, self.static_tableA.size) def test_ticking_simple(self): @@ -52,20 +52,20 @@ def test_ticking_simple(self): mj_table = multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a","b"]) # Output table is refreshing - self.assertTrue(mj_table.table().is_refreshing) + self.assertTrue(mj_table.table.is_refreshing) # Output table has same # rows as sources with update_graph.exclusive_lock(self.test_update_graph): - self.assertEqual(mj_table.table().size, self.ticking_tableA.size) - self.assertEqual(mj_table.table().size, self.ticking_tableB.size) + self.assertEqual(mj_table.table.size, self.ticking_tableA.size) + self.assertEqual(mj_table.table.size, self.ticking_tableB.size) # Test with a single input table mj_table = multi_join(input=self.ticking_tableA, on=["a","b"]) # Output table is refreshing - self.assertTrue(mj_table.table().is_refreshing) + self.assertTrue(mj_table.table.is_refreshing) # Output table has same # rows as sources with update_graph.exclusive_lock(self.test_update_graph): - self.assertEqual(mj_table.table().size, self.ticking_tableA.size) + self.assertEqual(mj_table.table.size, self.ticking_tableA.size) def test_static(self): @@ -77,18 +77,18 @@ def test_static(self): mj_table = multi_join(mj_input) # Output table is static - self.assertFalse(mj_table.table().is_refreshing) + self.assertFalse(mj_table.table.is_refreshing) # Output table has same # rows as sources - self.assertEqual(mj_table.table().size, self.static_tableA.size) - self.assertEqual(mj_table.table().size, self.static_tableB.size) + self.assertEqual(mj_table.table.size, self.static_tableA.size) + self.assertEqual(mj_table.table.size, self.static_tableB.size) # Test with a single input mj_table = multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a","key2=b"], joins="c1")) # Output table is static - self.assertFalse(mj_table.table().is_refreshing) + self.assertFalse(mj_table.table.is_refreshing) # Output table has same # rows as sources - self.assertEqual(mj_table.table().size, self.static_tableA.size) + self.assertEqual(mj_table.table.size, self.static_tableA.size) def test_ticking(self): @@ -100,20 +100,20 @@ def test_ticking(self): mj_table = multi_join(mj_input) # Output table is refreshing - self.assertTrue(mj_table.table().is_refreshing) + self.assertTrue(mj_table.table.is_refreshing) # Output table has same # rows as sources with update_graph.exclusive_lock(self.test_update_graph): - self.assertEqual(mj_table.table().size, self.ticking_tableA.size) - self.assertEqual(mj_table.table().size, self.ticking_tableB.size) + self.assertEqual(mj_table.table.size, self.ticking_tableA.size) + self.assertEqual(mj_table.table.size, self.ticking_tableB.size) # Test with a single input mj_table = multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a","key2=b"], joins="c1")) # Output table is refreshing - self.assertTrue(mj_table.table().is_refreshing) + self.assertTrue(mj_table.table.is_refreshing) # Output table has same # rows as sources with update_graph.exclusive_lock(self.test_update_graph): - self.assertEqual(mj_table.table().size, self.ticking_tableA.size) + self.assertEqual(mj_table.table.size, self.ticking_tableA.size) def test_errors(self): # Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted).