diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index e188ef577..4d9f8be5f 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -187,6 +187,7 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": """create new Row object""" + if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values " @@ -229,6 +230,7 @@ def __reduce__( self, ) -> Union[str, Tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" + if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) else: @@ -236,6 +238,7 @@ def __reduce__( def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" + if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join( "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 4617f7de6..cb0ffd6da 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -740,7 +740,6 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": def convert_to_assigned_datatypes_in_column_table(column_table, description): - converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": diff --git a/tests/e2e/common/decimal_tests.py b/tests/e2e/common/decimal_tests.py index 0029f30cb..9a9403fd4 100644 --- a/tests/e2e/common/decimal_tests.py +++ b/tests/e2e/common/decimal_tests.py @@ -38,22 +38,42 @@ class DecimalTestsMixin: ), ] + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) @pytest.mark.parametrize( "decimal, expected_value, expected_type", decimal_and_expected_results ) - def test_decimals(self, decimal, expected_value, expected_type): - with self.cursor({}) as cursor: + def test_decimals(self, decimal, expected_value, expected_type, backend_params): + with self.cursor(backend_params) as cursor: query = "SELECT CAST ({})".format(decimal) cursor.execute(query) table = cursor.fetchmany_arrow(1) assert table.field(0).type == expected_type assert table.to_pydict().popitem()[1][0] == expected_value + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) @pytest.mark.parametrize( "decimals, expected_values, expected_type", multi_decimals_and_expected_results ) - def test_multi_decimals(self, decimals, expected_values, expected_type): - with self.cursor({}) as cursor: + def test_multi_decimals( + self, decimals, expected_values, expected_type, backend_params + ): + with self.cursor(backend_params) as cursor: union_str = " UNION ".join( ["(SELECT CAST ({}))".format(dec) for dec in decimals] ) diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 825f830f3..8c444cf40 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -25,7 +25,16 @@ class PySQLStagingIngestionTestSuiteMixin: In addition to connection credentials (host, path, token) this suite requires an env var named staging_ingestion_user""" - def test_staging_ingestion_life_cycle(self, ingestion_user): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_staging_ingestion_life_cycle(self, ingestion_user, backend_params): """PUT a file into the staging location GET the file from the staging location REMOVE the file from the staging location @@ -42,7 +51,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): fp.write(original_text) with self.connection( - extra_params={"staging_allowed_local_path": temp_path} + extra_params={"staging_allowed_local_path": temp_path, **backend_params} ) as conn: cursor = conn.cursor() @@ -54,7 +63,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): new_fh, new_temp_path = tempfile.mkstemp() with self.connection( - extra_params={"staging_allowed_local_path": new_temp_path} + extra_params={"staging_allowed_local_path": new_temp_path, **backend_params} ) as conn: cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" @@ -69,7 +78,9 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv'" - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/", **backend_params} + ) as conn: cursor = conn.cursor() cursor.execute(remove_query) @@ -85,8 +96,17 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): os.remove(temp_path) os.remove(new_temp_path) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_staging_ingestion_put_fails_without_staging_allowed_local_path( - self, ingestion_user + self, ingestion_user, backend_params ): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path @@ -102,13 +122,22 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path( with pytest.raises( Error, match="You must provide at least one staging_allowed_local_path" ): - with self.connection() as conn: + with self.connection(extra_params=backend_params) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_path( - self, ingestion_user + self, ingestion_user, backend_params ): fh, temp_path = tempfile.mkstemp() @@ -128,14 +157,23 @@ def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_p match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): with self.connection( - extra_params={"staging_allowed_local_path": base_path} + extra_params={"staging_allowed_local_path": base_path, **backend_params} ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set( - self, ingestion_user + self, ingestion_user, backend_params ): """PUT a file into the staging location twice. First command should succeed. Second should fail.""" @@ -148,7 +186,7 @@ def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set( def perform_put(): with self.connection( - extra_params={"staging_allowed_local_path": temp_path} + extra_params={"staging_allowed_local_path": temp_path, **backend_params} ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" @@ -161,7 +199,7 @@ def perform_remove(): ) with self.connection( - extra_params={"staging_allowed_local_path": "/"} + extra_params={"staging_allowed_local_path": "/", **backend_params} ) as conn: cursor = conn.cursor() cursor.execute(remove_query) @@ -183,7 +221,18 @@ def perform_remove(): # Clean up after ourselves perform_remove() - def test_staging_ingestion_fails_to_modify_another_staging_user(self): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_staging_ingestion_fails_to_modify_another_staging_user( + self, backend_params + ): """The server should only allow modification of the staging_ingestion_user's files""" some_other_user = "mary.poppins@databricks.com" @@ -197,7 +246,7 @@ def test_staging_ingestion_fails_to_modify_another_staging_user(self): def perform_put(): with self.connection( - extra_params={"staging_allowed_local_path": temp_path} + extra_params={"staging_allowed_local_path": temp_path, **backend_params} ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv' OVERWRITE" @@ -207,14 +256,14 @@ def perform_remove(): remove_query = f"REMOVE 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv'" with self.connection( - extra_params={"staging_allowed_local_path": "/"} + extra_params={"staging_allowed_local_path": "/", **backend_params} ) as conn: cursor = conn.cursor() cursor.execute(remove_query) def perform_get(): with self.connection( - extra_params={"staging_allowed_local_path": temp_path} + extra_params={"staging_allowed_local_path": temp_path, **backend_params} ) as conn: cursor = conn.cursor() query = f"GET 'stage://tmp/{some_other_user}/tmp/11/15/file1.csv' TO '{temp_path}'" @@ -232,8 +281,17 @@ def perform_get(): with pytest.raises(sql.exc.ServerOperationError, match="PERMISSION_DENIED"): perform_get() + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path( - self, ingestion_user + self, ingestion_user, backend_params ): """ This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths. @@ -250,42 +308,78 @@ def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowe match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={ + "staging_allowed_local_path": staging_allowed_local_path, + **backend_params, + } ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_staging_ingestion_empty_local_path_fails_to_parse_at_server( - self, ingestion_user + self, ingestion_user, backend_params ): staging_allowed_local_path = "/var/www/html" target_file = "" with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={ + "staging_allowed_local_path": staging_allowed_local_path, + **backend_params, + } ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_staging_ingestion_invalid_staging_path_fails_at_server( - self, ingestion_user + self, ingestion_user, backend_params ): staging_allowed_local_path = "/var/www/html" target_file = "index.html" with pytest.raises(Error, match="INVALID_STAGING_PATH_IN_STAGING_ACCESS_QUERY"): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={ + "staging_allowed_local_path": staging_allowed_local_path, + **backend_params, + } ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO 'stageRANDOMSTRINGOFCHARACTERS://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_staging_ingestion_supports_multiple_staging_allowed_local_path_values( - self, ingestion_user + self, ingestion_user, backend_params ): """staging_allowed_local_path may be either a path-like object or a list of path-like objects. @@ -331,7 +425,10 @@ def generate_file_and_path_and_queries(): ) = generate_file_and_path_and_queries() with self.connection( - extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} + extra_params={ + "staging_allowed_local_path": [temp_path1, temp_path2], + **backend_params, + } ) as conn: cursor = conn.cursor() diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 72e2f5020..f8f8217a9 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -24,7 +24,16 @@ class PySQLUCVolumeTestSuiteMixin: In addition to connection credentials (host, path, token) this suite requires env vars named catalog and schema""" - def test_uc_volume_life_cycle(self, catalog, schema): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_uc_volume_life_cycle(self, catalog, schema, backend_params): """PUT a file into the UC Volume GET the file from the UC Volume REMOVE the file from the UC Volume @@ -41,7 +50,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): fp.write(original_text) with self.connection( - extra_params={"staging_allowed_local_path": temp_path} + extra_params={"staging_allowed_local_path": temp_path, **backend_params} ) as conn: cursor = conn.cursor() @@ -53,7 +62,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): new_fh, new_temp_path = tempfile.mkstemp() with self.connection( - extra_params={"staging_allowed_local_path": new_temp_path} + extra_params={"staging_allowed_local_path": new_temp_path, **backend_params} ) as conn: cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" @@ -68,7 +77,9 @@ def test_uc_volume_life_cycle(self, catalog, schema): remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/", **backend_params} + ) as conn: cursor = conn.cursor() cursor.execute(remove_query) @@ -84,8 +95,17 @@ def test_uc_volume_life_cycle(self, catalog, schema): os.remove(temp_path) os.remove(new_temp_path) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_uc_volume_put_fails_without_staging_allowed_local_path( - self, catalog, schema + self, catalog, schema, backend_params ): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path @@ -101,13 +121,22 @@ def test_uc_volume_put_fails_without_staging_allowed_local_path( with pytest.raises( Error, match="You must provide at least one staging_allowed_local_path" ): - with self.connection() as conn: + with self.connection(extra_params=backend_params) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( - self, catalog, schema + self, catalog, schema, backend_params ): fh, temp_path = tempfile.mkstemp() @@ -127,14 +156,23 @@ def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): with self.connection( - extra_params={"staging_allowed_local_path": base_path} + extra_params={"staging_allowed_local_path": base_path, **backend_params} ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set( - self, catalog, schema + self, catalog, schema, backend_params ): """PUT a file into the staging location twice. First command should succeed. Second should fail.""" @@ -147,7 +185,7 @@ def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set( def perform_put(): with self.connection( - extra_params={"staging_allowed_local_path": temp_path} + extra_params={"staging_allowed_local_path": temp_path, **backend_params} ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" @@ -160,7 +198,7 @@ def perform_remove(): ) with self.connection( - extra_params={"staging_allowed_local_path": "/"} + extra_params={"staging_allowed_local_path": "/", **backend_params} ) as conn: cursor = conn.cursor() cursor.execute(remove_query) @@ -182,8 +220,17 @@ def perform_remove(): # Clean up after ourselves perform_remove() + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_uc_volume_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path( - self, catalog, schema + self, catalog, schema, backend_params ): """ This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths. @@ -200,38 +247,78 @@ def test_uc_volume_put_fails_if_absolute_localFile_not_in_staging_allowed_local_ match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={ + "staging_allowed_local_path": staging_allowed_local_path, + **backend_params, + } ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_empty_local_path_fails_to_parse_at_server(self, catalog, schema): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_uc_volume_empty_local_path_fails_to_parse_at_server( + self, catalog, schema, backend_params + ): staging_allowed_local_path = "/var/www/html" target_file = "" with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={ + "staging_allowed_local_path": staging_allowed_local_path, + **backend_params, + } ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_invalid_volume_path_fails_at_server(self, catalog, schema): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_uc_volume_invalid_volume_path_fails_at_server( + self, catalog, schema, backend_params + ): staging_allowed_local_path = "/var/www/html" target_file = "index.html" with pytest.raises(Error, match="NOT_FOUND: Catalog"): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={ + "staging_allowed_local_path": staging_allowed_local_path, + **backend_params, + } ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO '/Volumes/RANDOMSTRINGOFCHARACTERS/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) def test_uc_volume_supports_multiple_staging_allowed_local_path_values( - self, catalog, schema + self, catalog, schema, backend_params ): """staging_allowed_local_path may be either a path-like object or a list of path-like objects. @@ -277,7 +364,10 @@ def generate_file_and_path_and_queries(): ) = generate_file_and_path_and_queries() with self.connection( - extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} + extra_params={ + "staging_allowed_local_path": [temp_path1, temp_path2], + **backend_params, + } ) as conn: cursor = conn.cursor() diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 3fa87b1af..469f9d5fe 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -386,8 +386,15 @@ def test_create_table_will_return_empty_result_set(self, extra_params): finally: cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) - def test_get_tables(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "backend_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_get_tables(self, backend_params): + with self.cursor(extra_params=backend_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) table_names = [table_name + "_1", table_name + "_2"] @@ -562,8 +569,15 @@ def test_get_schemas(self): finally: cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) - def test_get_catalogs(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "backend_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_get_catalogs(self, backend_params): + with self.cursor(extra_params=backend_params) as cursor: cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description @@ -1053,9 +1067,18 @@ class HTTP429Suite(Client429ResponseMixin, PySQLPytestTestCase): class HTTP503Suite(Client503ResponseMixin, PySQLPytestTestCase): # 503Response suite gets custom error here vs PyODBC - def test_retry_disabled(self): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_retry_disabled(self, backend_params): self._test_retry_disabled_with_message( - "TEMPORARILY_UNAVAILABLE", OperationalError + "TEMPORARILY_UNAVAILABLE", OperationalError, backend_params ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 19375cde3..bb9a46ad0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -46,7 +46,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - has_more_rows=True, + is_direct_results=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -105,7 +105,11 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the execute response with controlled state mock_execute_response = Mock(spec=ExecuteResponse) - mock_execute_response.status = initial_state + + mock_execute_response.command_id = Mock(spec=CommandId) + mock_execute_response.status = ( + CommandState.SUCCEEDED if not closed else CommandState.CLOSED + ) mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False mock_execute_response.command_id = Mock(spec=CommandId) diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 1d485ea61..ac9648a0e 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - has_more_rows=False, + is_direct_results=False, description=Mock(), command_id=None, arrow_schema_bytes=arrow_table.schema, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0cdb43f5c..2ce99670f 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1016,10 +1016,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1031,7 +1031,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, ), closeOperation=Mock(), @@ -1062,10 +1062,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1078,7 +1078,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1112,7 +1112,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( chunk_id=0, ) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(is_direct_results, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class):