-
Notifications
You must be signed in to change notification settings - Fork 681
feat(api): implement upsert() using MERGE INTO
#11624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0a68cb1
110d373
22ff057
878b2fa
7595e17
f775574
727e1fd
2877a31
d823752
2ead552
5e6acea
712a8d7
f4da65d
e4dee90
2ae8d30
a1c5406
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -423,7 +423,7 @@ def insert( | |
| Parameters | ||
| ---------- | ||
| name | ||
| The name of the table to which data needs will be inserted | ||
| The name of the table to which data will be inserted | ||
| obj | ||
| The source data or expression to insert | ||
| database | ||
|
|
@@ -453,22 +453,30 @@ def insert( | |
| with self._safe_raw_sql(query): | ||
| pass | ||
|
|
||
| def _build_insert_from_table( | ||
| def _get_columns_to_insert( | ||
| self, *, target: str, source, db: str | None = None, catalog: str | None = None | ||
| ): | ||
| compiler = self.compiler | ||
| quoted = compiler.quoted | ||
| # Compare the columns between the target table and the object to be inserted | ||
| # If source is a subset of target, use source columns for insert list | ||
| # Otherwise, assume auto-generated column names and use positional ordering. | ||
| target_cols = self.get_schema(target, catalog=catalog, database=db).keys() | ||
|
|
||
| columns = ( | ||
| return ( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oof, I see what you mean how you are inheriting this logic from insert. I think it is essential that they use the same logic. But, I think this is a footgun waiting to happen. I think we should make a breaking change to .insert() and require that source be a subset of target. What do you think of this change @cpcloud ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there is anywhere else in ibis where we rely on positional ordering of columns, is there? I think we should keep that in mind, if we make this change here, then I think we should wipe out all other instances of relying on positional ordering. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Asking copilot Looking at the codebase, Ibis primarily relies on column names rather than positions, but there are some specific cases where positional ordering is used: Column Name-Based Operations (Majority)Most operations use column names:
Position-Based Operations (Specific Cases)However, there are cases where position matters:
So the answer is: primarily name-based, but position is significant for schema equality, positional joins, and some insertion scenarios. |
||
| source_cols | ||
| if (source_cols := source.schema().keys()) <= target_cols | ||
| else target_cols | ||
| ) | ||
|
|
||
| def _build_insert_from_table( | ||
| self, *, target: str, source, db: str | None = None, catalog: str | None = None | ||
| ): | ||
| compiler = self.compiler | ||
| quoted = compiler.quoted | ||
|
|
||
| columns = self._get_columns_to_insert( | ||
| target=target, source=source, db=db, catalog=catalog | ||
| ) | ||
|
|
||
| query = sge.insert( | ||
| expression=self.compile(source), | ||
| into=sg.table(target, db=db, catalog=catalog, quoted=quoted), | ||
|
|
@@ -526,6 +534,116 @@ def _build_insert_template( | |
| ), | ||
| ).sql(self.dialect) | ||
|
|
||
| def upsert( | ||
| self, | ||
| name: str, | ||
| /, | ||
| obj: pd.DataFrame | ir.Table | list | dict, | ||
| on: str, | ||
| *, | ||
| database: str | None = None, | ||
| ) -> None: | ||
| """Upsert data into a table. | ||
|
|
||
| ::: {.callout-note} | ||
| ## Ibis does not use the word `schema` to refer to database hierarchy. | ||
|
|
||
| A collection of `table` is referred to as a `database`. | ||
| A collection of `database` is referred to as a `catalog`. | ||
|
|
||
| These terms are mapped onto the corresponding features in each | ||
| backend (where available), regardless of whether the backend itself | ||
| uses the same terminology. | ||
| ::: | ||
|
|
||
| Parameters | ||
| ---------- | ||
| name | ||
| The name of the table to which data will be upserted | ||
| obj | ||
| The source data or expression to upsert | ||
| on | ||
| Column name to join on | ||
| database | ||
| Name of the attached database that the table is located in. | ||
|
|
||
| For backends that support multi-level table hierarchies, you can | ||
| pass in a dotted string path like `"catalog.database"` or a tuple of | ||
| strings like `("catalog", "database")`. | ||
| """ | ||
| table_loc = self._to_sqlglot_table(database) | ||
| catalog, db = self._to_catalog_db_tuple(table_loc) | ||
|
|
||
| if not isinstance(obj, ir.Table): | ||
| obj = ibis.memtable(obj) | ||
|
|
||
| self._run_pre_execute_hooks(obj) | ||
|
|
||
| query = self._build_upsert_from_table( | ||
| target=name, source=obj, on=on, db=db, catalog=catalog | ||
| ) | ||
|
|
||
| with self._safe_raw_sql(query): | ||
| pass | ||
|
|
||
| def _build_upsert_from_table( | ||
| self, | ||
| *, | ||
| target: str, | ||
| source, | ||
| on: str, | ||
| db: str | None = None, | ||
| catalog: str | None = None, | ||
| ): | ||
| compiler = self.compiler | ||
| quoted = compiler.quoted | ||
|
|
||
| columns = self._get_columns_to_insert( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assuming an existing table with columns {i: int64, s: string, f: float64}, can you add a tests for upserting objects (using condition i=i) with schemas
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then, depending on how the "should we fallback to positional ordering" decision goes, we should add tests for that too. But I think all the above tests should still be valid regardless of what we decide there. |
||
| target=target, source=source, db=db, catalog=catalog | ||
| ) | ||
|
|
||
| source_alias = util.gen_name("source") | ||
| target_alias = util.gen_name("target") | ||
| query = sge.merge( | ||
| sge.When( | ||
| matched=True, | ||
| then=sge.Update( | ||
| expressions=[ | ||
| sg.column(col, quoted=quoted).eq( | ||
| sg.column(col, table=source_alias, quoted=quoted) | ||
| ) | ||
| for col in columns | ||
| if col != on | ||
| ] | ||
| ), | ||
| ), | ||
| sge.When( | ||
| matched=False, | ||
| then=sge.Insert( | ||
| this=sge.Tuple( | ||
| expressions=[sg.column(col, quoted=quoted) for col in columns] | ||
| ), | ||
| expression=sge.Tuple( | ||
| expressions=[ | ||
| sg.column(col, table=source_alias, quoted=quoted) | ||
| for col in columns | ||
| ] | ||
| ), | ||
| ), | ||
| ), | ||
| into=sg.table(target, db=db, catalog=catalog, quoted=quoted).as_( | ||
| sg.to_identifier(target_alias, quoted=quoted), table=True | ||
deepyaman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ), | ||
| using=f"({self.compile(source)}) AS {sg.to_identifier(source_alias, quoted=quoted)}", | ||
| on=sge.Paren( | ||
| this=sg.column(on, table=target_alias, quoted=quoted).eq( | ||
| sg.column(on, table=source_alias, quoted=quoted) | ||
| ) | ||
| ), | ||
| dialect=compiler.dialect, | ||
| ) | ||
| return query | ||
|
|
||
| def truncate_table(self, name: str, /, *, database: str | None = None) -> None: | ||
| """Delete all rows from a table. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| import ibis.expr.datatypes as dt | ||
| import ibis.expr.operations as ops | ||
| from ibis.backends.conftest import ALL_BACKENDS | ||
| from ibis.backends.tests.conftest import NO_MERGE_SUPPORT | ||
| from ibis.backends.tests.errors import ( | ||
| DatabricksServerOperationError, | ||
| ExaQueryError, | ||
|
|
@@ -519,6 +520,34 @@ def employee_data_2_temp_table( | |
| con.drop_table(temp_table_name, force=True) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def test_employee_data_3(): | ||
| import pandas as pd | ||
|
|
||
| df3 = pd.DataFrame( | ||
| { | ||
| "first_name": ["B", "Y", "Z"], | ||
| "last_name": ["A", "B", "C"], | ||
| "department_name": ["XX", "YY", "ZZ"], | ||
| "salary": [400.0, 500.0, 600.0], | ||
| } | ||
| ) | ||
|
|
||
| return df3 | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def employee_data_3_temp_table( | ||
| backend, con, test_employee_schema, test_employee_data_3 | ||
| ): | ||
| temp_table_name = gen_name("temp_employee_data_3") | ||
| _create_temp_table_with_schema( | ||
| backend, con, temp_table_name, test_employee_schema, data=test_employee_data_3 | ||
| ) | ||
| yield temp_table_name | ||
| con.drop_table(temp_table_name, force=True) | ||
|
|
||
|
|
||
| @pytest.mark.notimpl(["polars"], reason="`insert` method not implemented") | ||
| def test_insert_no_overwrite_from_dataframe( | ||
| backend, con, test_employee_data_2, employee_empty_temp_table | ||
|
|
@@ -626,6 +655,53 @@ def _emp(a, b, c, d): | |
| assert len(con.table(employee_data_1_temp_table).execute()) == 3 | ||
|
|
||
|
|
||
| @NO_MERGE_SUPPORT | ||
| def test_upsert_from_dataframe( | ||
| backend, con, employee_data_1_temp_table, test_employee_data_3 | ||
| ): | ||
| temporary = con.table(employee_data_1_temp_table) | ||
| df1 = temporary.execute().set_index("first_name") | ||
|
|
||
| con.upsert(employee_data_1_temp_table, obj=test_employee_data_3, on="first_name") | ||
| result = temporary.execute() | ||
| df2 = test_employee_data_3.set_index("first_name") | ||
| expected = pd.concat([df1[~df1.index.isin(df2.index)], df2]).reset_index() | ||
| assert len(result) == len(expected) | ||
| backend.assert_frame_equal( | ||
| result.sort_values("first_name").reset_index(drop=True), | ||
| expected.sort_values("first_name").reset_index(drop=True), | ||
| ) | ||
|
|
||
|
|
||
| @NO_MERGE_SUPPORT | ||
| @pytest.mark.parametrize("with_order_by", [True, False]) | ||
| def test_upsert_from_expr( | ||
| backend, con, employee_data_1_temp_table, employee_data_3_temp_table, with_order_by | ||
| ): | ||
| temporary = con.table(employee_data_1_temp_table) | ||
| from_table = con.table(employee_data_3_temp_table) | ||
| if with_order_by: | ||
| if backend.name() == "mssql": | ||
| pytest.xfail( | ||
| "MSSQL doesn't allow ORDER BY in subqueries, unless " | ||
| "TOP, OFFSET or FOR XML is also specified" | ||
| ) | ||
|
|
||
| from_table = from_table.filter(ibis._.salary > 0).order_by("first_name") | ||
|
|
||
| df1 = temporary.execute().set_index("first_name") | ||
|
|
||
| con.upsert(employee_data_1_temp_table, obj=from_table, on="first_name") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add an xfail test for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? Seems like an unconventional use of xfail. I think it would make more sense to test if we were explicitly testing whether the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this was the confusion, I guess Ireally meant But, I think I'm ok with not testing for this behavior if you don't want to (leaving it as undefined behavior). As long as we test all the explicitly-supported behaviors I'm happy. |
||
| result = temporary.execute() | ||
| df2 = from_table.execute().set_index("first_name") | ||
| expected = pd.concat([df1[~df1.index.isin(df2.index)], df2]).reset_index() | ||
| assert len(result) == len(expected) | ||
| backend.assert_frame_equal( | ||
| result.sort_values("first_name").reset_index(drop=True), | ||
| expected.sort_values("first_name").reset_index(drop=True), | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.notimpl( | ||
| ["polars"], raises=AttributeError, reason="`insert` method not implemented" | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm honestly not sure how to better do this; I wasn't able to figure out how to add a semicolon to an expression in SQLGlot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't ideal but I think is fine with me. The "cleaner" way would be to override the
_build_upsert_from_tablemethod, but the amount of boilerplate for that feels not worth it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or actually, is it not possible to just always stick on a ; at the end of
_build_upsert_from_table()for every backend? or does that break on some backends?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't actually know how to stick a semicolon on the end of a SQLGlot expression. 😅 If
_build_upsert_from_table()was handling the conversion to SQL, that would have been simple enough.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I understand.
huh, is this a bug in the upstream mssql engine? Can you confirm that only merge statements require trailing semicolons, but other sql queries/statements do not?