diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 35632f854..483467ab7 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -337,7 +337,8 @@ def _resolve_query_without_cache_ctas( wait=True, athena_query_wait_polling_delay=athena_query_wait_polling_delay, boto3_session=boto3_session, - execution_params=execution_params, + params=execution_params, + paramstyle="qmark", ) fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"' ctas_query_metadata = cast(_QueryMetadata, ctas_query_info["ctas_query_metadata"]) diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index f13090fa8..d2f10db8f 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -649,6 +649,8 @@ def create_ctas_table( wait: bool = False, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, execution_params: list[str] | None = None, + params: dict[str, Any] | list[str] | None = None, + paramstyle: Literal["qmark", "named"] = "named", boto3_session: boto3.Session | None = None, ) -> dict[str, str | _QueryMetadata]: """Create a new table populated with the results of a SELECT query. @@ -701,6 +703,17 @@ def create_ctas_table( Whether to wait for the query to finish and return a dictionary with the Query metadata. athena_query_wait_polling_delay: float, default: 0.25 seconds Interval in seconds for how often the function will check if the Athena query has completed. + execution_params: List[str], optional [DEPRECATED] + A list of values for the parameters that are used in the SQL query. + This parameter is on a deprecation path. + Use ``params`` and `paramstyle`` instead. + params: Dict[str, Any] | List[str], optional + Dictionary or list of parameters to pass to execute method. + The syntax used to pass parameters depends on the configuration of ``paramstyle``. + paramstyle: str, optional + The syntax style to use for the parameters. + Supported values are ``named`` and ``qmark``. + The default is ``named``. boto3_session: boto3.Session, optional Boto3 Session. The default boto3 session is used if boto3_session is None. @@ -752,6 +765,20 @@ def create_ctas_table( if ctas_database is None: raise exceptions.InvalidArgumentCombination("Either ctas_database or database must be defined.") + # Substitute execution_params with params + if execution_params: + if params: + raise exceptions.InvalidArgumentCombination("`execution_params` and `params` are mutually exclusive.") + + params = execution_params + paramstyle = "qmark" + raise DeprecationWarning( + '`execution_params` is being deprecated. Use `params` and `paramstyle="qmark"` instead.' + ) + + # Substitute query parameters if applicable + sql, execution_params = _apply_formatter(sql, params, paramstyle) + fully_qualified_name = f'"{ctas_database}"."{ctas_table}"' wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup) diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index b0b51e5cb..eda53c0ab 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -222,6 +222,94 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session) +def test_athena_create_ctas_with_named_params(path, glue_table, glue_database, glue_ctas_database): + wr.s3.to_parquet( + df=get_df_list(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + ) + + wr.athena.create_ctas_table( + sql=f"SELECT * FROM {glue_table} WHERE par1 = :par1", + database=glue_database, + ctas_database=glue_ctas_database, + params={"par1": "b"}, + paramstyle="named", + wait=True, + ) + + +def test_athena_create_ctas_with_qmark_params(path, glue_table, glue_database, glue_ctas_database): + wr.s3.to_parquet( + df=get_df_list(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + ) + + wr.athena.create_ctas_table( + sql=f"SELECT * FROM {glue_table} WHERE par1 = ?", + database=glue_database, + ctas_database=glue_ctas_database, + params=["b"], + paramstyle="qmark", + wait=True, + ) + + +def test_athena_create_ctas_with_execution_params_deprecation_warning( + path, glue_table, glue_database, glue_ctas_database +): + wr.s3.to_parquet( + df=get_df_list(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + ) + + with pytest.raises(DeprecationWarning): + wr.athena.create_ctas_table( + sql=f"SELECT * FROM {glue_table} WHERE par1 = ?", + database=glue_database, + ctas_database=glue_ctas_database, + execution_params=["b"], + wait=True, + ) + + +def test_athena_create_ctas_with_params_and_execution_params_error(path, glue_table, glue_database, glue_ctas_database): + wr.s3.to_parquet( + df=get_df_list(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + ) + + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.athena.create_ctas_table( + sql=f"SELECT * FROM {glue_table} WHERE par1 = ?", + database=glue_database, + ctas_database=glue_ctas_database, + execution_params=["b"], + params=["b"], + paramstyle="qmark", + wait=True, + ) + + def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1): wr.s3.to_parquet( df=get_df(),