diff --git a/py/client/pydeephaven/query.py b/py/client/pydeephaven/query.py index ccabac46918..f0d2bb2b518 100644 --- a/py/client/pydeephaven/query.py +++ b/py/client/pydeephaven/query.py @@ -170,7 +170,8 @@ def tail(self, num_rows: int) -> Query: """ return super().tail(num_rows) - def natural_join(self, table: Any, on: Union[str, List[str]], joins: Union[str, List[str]] = None) -> Query: + def natural_join(self, table: Any, on: Union[str, List[str]], joins: Union[str, List[str]] = None, + type: NaturalJoinType = NaturalJoinType.ERROR_ON_DUPLICATE) -> Query: """Adds a natural-join operation to the query. Args: @@ -179,11 +180,13 @@ def natural_join(self, table: Any, on: Union[str, List[str]], joins: Union[str, i.e. "col_a = col_b" for different column names joins (Union[str, List[str]], optional): the column(s) to be added from the right table to the result table, can be renaming expressions, i.e. "new_col = col"; default is None + type (NaturalJoinType, optional): the action to be taken when duplicate right hand rows are + encountered; default is ERROR_ON_DUPLICATE Returns: self """ - return super().natural_join(table, on, joins) + return super().natural_join(table, on, joins, type) def exact_join(self, table: Any, on: Union[str, List[str]], joins: Union[str, List[str]] = None) -> Query: """Adds an exact-join operation to the query. diff --git a/py/client/pydeephaven/table.py b/py/client/pydeephaven/table.py index 58437223afb..38630614351 100644 --- a/py/client/pydeephaven/table.py +++ b/py/client/pydeephaven/table.py @@ -296,7 +296,7 @@ def natural_join(self, table: Table, on: Union[str, List[str]], joins: Union[str joins (Union[str, List[str]], optional): the column(s) to be added from the right table to the result table, can be renaming expressions, i.e. "new_col = col"; default is None, which means all the columns from the right table, excluding those specified in 'on' - joinType (NaturalJoinType, optional): the action to be taken when duplicate right hand rows are + type (NaturalJoinType, optional): the action to be taken when duplicate right hand rows are encountered; default is ERROR_ON_DUPLICATE Returns: diff --git a/py/client/tests/test_table.py b/py/client/tests/test_table.py index 6c275153e67..fe2aab969cb 100644 --- a/py/client/tests/test_table.py +++ b/py/client/tests/test_table.py @@ -149,17 +149,9 @@ def test_natural_join_output(self): # assert the values meet expectations self.assertTrue(df_1.equals(df_2)) - self.assertEqual(df_1.loc[0, "rhs_index"], 0) - self.assertEqual(df_1.loc[1, "rhs_index"], 2) - self.assertEqual(df_1.loc[2, "rhs_index"], 4) - self.assertEqual(df_1.loc[3, "rhs_index"], 6) - self.assertEqual(df_1.loc[4, "rhs_index"], 8) + self.assertEqual(list(df_1.loc[0: 4, "rhs_index"]), [0, 2, 4, 6, 8]) # the following rows have no match and should be null / NA - self.assertTrue(pd.isna(df_1.loc[5, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[6, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[7, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[8, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[9, "rhs_index"])) + self.assertTrue(all(pd.isna(df_1.loc[5:9, "rhs_index"]))) right_table_last_by = right_table_raw.last_by(by="key") @@ -173,17 +165,9 @@ def test_natural_join_output(self): # assert the values meet expectations self.assertTrue(df_1.equals(df_2)) - self.assertEqual(df_1.loc[0, "rhs_index"], 1) - self.assertEqual(df_1.loc[1, "rhs_index"], 3) - self.assertEqual(df_1.loc[2, "rhs_index"], 5) - self.assertEqual(df_1.loc[3, "rhs_index"], 7) - self.assertEqual(df_1.loc[4, "rhs_index"], 9) + self.assertEqual(list(df_1.loc[0: 4, "rhs_index"]), [1, 3, 5, 7, 9]) # the following rows have no match and should be null / NA - self.assertTrue(pd.isna(df_1.loc[5, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[6, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[7, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[8, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[9, "rhs_index"])) + self.assertTrue(all(pd.isna(df_1.loc[5:9, "rhs_index"]))) def test_exact_join(self): pa_table = csv.read_csv(self.csv_file) diff --git a/py/server/tests/test_table.py b/py/server/tests/test_table.py index f0d573f6736..54c22afd870 100644 --- a/py/server/tests/test_table.py +++ b/py/server/tests/test_table.py @@ -329,17 +329,9 @@ def test_natural_join_output(self): # assert the values meet expectations self.assertTrue(df_1.equals(df_2)) - self.assertEqual(df_1.loc[0, "rhs_index"], 0) - self.assertEqual(df_1.loc[1, "rhs_index"], 2) - self.assertEqual(df_1.loc[2, "rhs_index"], 4) - self.assertEqual(df_1.loc[3, "rhs_index"], 6) - self.assertEqual(df_1.loc[4, "rhs_index"], 8) + self.assertEqual(list(df_1.loc[0: 4, "rhs_index"]), [0, 2, 4, 6, 8]) # the following rows have no match and should be null / NA - self.assertTrue(pd.isna(df_1.loc[5, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[6, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[7, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[8, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[9, "rhs_index"])) + self.assertTrue(all(pd.isna(df_1.loc[5:9, "rhs_index"]))) right_table_last_by = right_table_raw.last_by(by="key") @@ -353,17 +345,9 @@ def test_natural_join_output(self): # assert the values meet expectations self.assertTrue(df_1.equals(df_2)) - self.assertEqual(df_1.loc[0, "rhs_index"], 1) - self.assertEqual(df_1.loc[1, "rhs_index"], 3) - self.assertEqual(df_1.loc[2, "rhs_index"], 5) - self.assertEqual(df_1.loc[3, "rhs_index"], 7) - self.assertEqual(df_1.loc[4, "rhs_index"], 9) + self.assertEqual(list(df_1.loc[0: 4, "rhs_index"]), [1, 3, 5, 7, 9]) # the following rows have no match and should be null / NA - self.assertTrue(pd.isna(df_1.loc[5, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[6, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[7, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[8, "rhs_index"])) - self.assertTrue(pd.isna(df_1.loc[9, "rhs_index"])) + self.assertTrue(all(pd.isna(df_1.loc[5:9, "rhs_index"]))) def test_exact_join(self): left_table = self.test_table.drop_columns(["d", "e"]).group_by('a') diff --git a/qst/src/main/java/io/deephaven/qst/table/NaturalJoinTable.java b/qst/src/main/java/io/deephaven/qst/table/NaturalJoinTable.java index 8e7ce0c6025..669aac93935 100644 --- a/qst/src/main/java/io/deephaven/qst/table/NaturalJoinTable.java +++ b/qst/src/main/java/io/deephaven/qst/table/NaturalJoinTable.java @@ -5,6 +5,7 @@ import io.deephaven.annotations.NodeStyle; import io.deephaven.api.NaturalJoinType; +import org.immutables.value.Value; import org.immutables.value.Value.Immutable; import java.util.Collection; @@ -16,7 +17,10 @@ @NodeStyle public abstract class NaturalJoinTable extends JoinBase { - public abstract NaturalJoinType joinType(); + @Value.Default + public NaturalJoinType joinType() { + return NaturalJoinType.ERROR_ON_DUPLICATE; + } public static Builder builder() { return ImmutableNaturalJoinTable.builder();