Skip to content

Commit

Permalink
Additional tests for FIRST_MATCH + ADD_ONLY NaturalJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
lbooker42 committed Feb 4, 2025
1 parent 22c5825 commit 84987dd
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 44 deletions.
7 changes: 5 additions & 2 deletions py/client/pydeephaven/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def tail(self, num_rows: int) -> Query:
"""
return super().tail(num_rows)

def natural_join(self, table: Any, on: Union[str, List[str]], joins: Union[str, List[str]] = None) -> Query:
def natural_join(self, table: Any, on: Union[str, List[str]], joins: Union[str, List[str]] = None,
type: NaturalJoinType = NaturalJoinType.ERROR_ON_DUPLICATE) -> Query:
"""Adds a natural-join operation to the query.
Args:
Expand All @@ -179,11 +180,13 @@ def natural_join(self, table: Any, on: Union[str, List[str]], joins: Union[str,
i.e. "col_a = col_b" for different column names
joins (Union[str, List[str]], optional): the column(s) to be added from the right table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None
type (NaturalJoinType, optional): the action to be taken when duplicate right hand rows are
encountered; default is ERROR_ON_DUPLICATE
Returns:
self
"""
return super().natural_join(table, on, joins)
return super().natural_join(table, on, joins, type)

def exact_join(self, table: Any, on: Union[str, List[str]], joins: Union[str, List[str]] = None) -> Query:
"""Adds an exact-join operation to the query.
Expand Down
2 changes: 1 addition & 1 deletion py/client/pydeephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def natural_join(self, table: Table, on: Union[str, List[str]], joins: Union[str
joins (Union[str, List[str]], optional): the column(s) to be added from the right table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None, which means all the columns
from the right table, excluding those specified in 'on'
joinType (NaturalJoinType, optional): the action to be taken when duplicate right hand rows are
type (NaturalJoinType, optional): the action to be taken when duplicate right hand rows are
encountered; default is ERROR_ON_DUPLICATE
Returns:
Expand Down
24 changes: 4 additions & 20 deletions py/client/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,9 @@ def test_natural_join_output(self):
# assert the values meet expectations
self.assertTrue(df_1.equals(df_2))

self.assertEqual(df_1.loc[0, "rhs_index"], 0)
self.assertEqual(df_1.loc[1, "rhs_index"], 2)
self.assertEqual(df_1.loc[2, "rhs_index"], 4)
self.assertEqual(df_1.loc[3, "rhs_index"], 6)
self.assertEqual(df_1.loc[4, "rhs_index"], 8)
self.assertEqual(list(df_1.loc[0: 4, "rhs_index"]), [0, 2, 4, 6, 8])
# the following rows have no match and should be null / NA
self.assertTrue(pd.isna(df_1.loc[5, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[6, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[7, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[8, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[9, "rhs_index"]))
self.assertTrue(all(pd.isna(df_1.loc[5:9, "rhs_index"])))

right_table_last_by = right_table_raw.last_by(by="key")

Expand All @@ -173,17 +165,9 @@ def test_natural_join_output(self):
# assert the values meet expectations
self.assertTrue(df_1.equals(df_2))

self.assertEqual(df_1.loc[0, "rhs_index"], 1)
self.assertEqual(df_1.loc[1, "rhs_index"], 3)
self.assertEqual(df_1.loc[2, "rhs_index"], 5)
self.assertEqual(df_1.loc[3, "rhs_index"], 7)
self.assertEqual(df_1.loc[4, "rhs_index"], 9)
self.assertEqual(list(df_1.loc[0: 4, "rhs_index"]), [1, 3, 5, 7, 9])
# the following rows have no match and should be null / NA
self.assertTrue(pd.isna(df_1.loc[5, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[6, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[7, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[8, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[9, "rhs_index"]))
self.assertTrue(all(pd.isna(df_1.loc[5:9, "rhs_index"])))

def test_exact_join(self):
pa_table = csv.read_csv(self.csv_file)
Expand Down
24 changes: 4 additions & 20 deletions py/server/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,17 +329,9 @@ def test_natural_join_output(self):
# assert the values meet expectations
self.assertTrue(df_1.equals(df_2))

self.assertEqual(df_1.loc[0, "rhs_index"], 0)
self.assertEqual(df_1.loc[1, "rhs_index"], 2)
self.assertEqual(df_1.loc[2, "rhs_index"], 4)
self.assertEqual(df_1.loc[3, "rhs_index"], 6)
self.assertEqual(df_1.loc[4, "rhs_index"], 8)
self.assertEqual(list(df_1.loc[0: 4, "rhs_index"]), [0, 2, 4, 6, 8])
# the following rows have no match and should be null / NA
self.assertTrue(pd.isna(df_1.loc[5, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[6, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[7, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[8, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[9, "rhs_index"]))
self.assertTrue(all(pd.isna(df_1.loc[5:9, "rhs_index"])))

right_table_last_by = right_table_raw.last_by(by="key")

Expand All @@ -353,17 +345,9 @@ def test_natural_join_output(self):
# assert the values meet expectations
self.assertTrue(df_1.equals(df_2))

self.assertEqual(df_1.loc[0, "rhs_index"], 1)
self.assertEqual(df_1.loc[1, "rhs_index"], 3)
self.assertEqual(df_1.loc[2, "rhs_index"], 5)
self.assertEqual(df_1.loc[3, "rhs_index"], 7)
self.assertEqual(df_1.loc[4, "rhs_index"], 9)
self.assertEqual(list(df_1.loc[0: 4, "rhs_index"]), [1, 3, 5, 7, 9])
# the following rows have no match and should be null / NA
self.assertTrue(pd.isna(df_1.loc[5, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[6, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[7, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[8, "rhs_index"]))
self.assertTrue(pd.isna(df_1.loc[9, "rhs_index"]))
self.assertTrue(all(pd.isna(df_1.loc[5:9, "rhs_index"])))

def test_exact_join(self):
left_table = self.test_table.drop_columns(["d", "e"]).group_by('a')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import io.deephaven.annotations.NodeStyle;
import io.deephaven.api.NaturalJoinType;
import org.immutables.value.Value;
import org.immutables.value.Value.Immutable;

import java.util.Collection;
Expand All @@ -16,7 +17,10 @@
@NodeStyle
public abstract class NaturalJoinTable extends JoinBase {

public abstract NaturalJoinType joinType();
@Value.Default
public NaturalJoinType joinType() {
return NaturalJoinType.ERROR_ON_DUPLICATE;
}

public static Builder builder() {
return ImmutableNaturalJoinTable.builder();
Expand Down

0 comments on commit 84987dd

Please sign in to comment.