From ef57d94ea1ee68e32db8513829b4abee9b97cf5c Mon Sep 17 00:00:00 2001 From: Gaetan Ingrassia Date: Mon, 13 Oct 2025 21:36:34 +0000 Subject: [PATCH] fix: mutable query detection --- .../handlers/athena/athena_query_handler.py | 98 +++++++++---------- .../utils/mutable_sql_detector.py | 68 +++++++++++++ 2 files changed, 112 insertions(+), 54 deletions(-) create mode 100644 src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/utils/mutable_sql_detector.py diff --git a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/handlers/athena/athena_query_handler.py b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/handlers/athena/athena_query_handler.py index 24c20565f6..5b8e28ccd2 100644 --- a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/handlers/athena/athena_query_handler.py +++ b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/handlers/athena/athena_query_handler.py @@ -34,6 +34,7 @@ LogLevel, log_with_request_id, ) +from awslabs.aws_dataprocessing_mcp_server.utils.mutable_sql_detector import detect_mutating_keywords from mcp.server.fastmcp import Context from mcp.types import TextContent from pydantic import Field @@ -211,22 +212,17 @@ async def manage_aws_athena_queries( f'Athena Query Handler - Tool: manage_aws_athena_queries - Operation: {operation}', ) - if not self.allow_write and operation in [ - 'start-query-execution', - ]: - error_message = ( - f'Operation {operation} for select query is only allowed without write access' - ) - log_with_request_id(ctx, LogLevel.ERROR, error_message) - - if ( - operation == 'start-query-execution' - and query_string - and ( - 'select' not in query_string.lower() - or 'create table as select' in query_string.lower() + if operation == 'start-query-execution': + if query_string is None: + raise ValueError( + 'query_string is required for start-query-execution operation' ) - ): + + matched = detect_mutating_keywords(query_string) + if not self.allow_write and matched: + error_message = f"Mutating operations: {matched} are not allowed when write access is disabled" + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return StartQueryExecutionResponse( isError=True, content=[TextContent(type='text', text=error_message)], @@ -234,6 +230,39 @@ async def manage_aws_athena_queries( operation='start-query-execution', ) + # Prepare parameters + params = {'QueryString': query_string} + + if client_request_token is not None: + params['ClientRequestToken'] = client_request_token + + if query_execution_context is not None: + params['QueryExecutionContext'] = query_execution_context + + if result_configuration is not None: + params['ResultConfiguration'] = result_configuration + + if work_group is not None: + params['WorkGroup'] = work_group + + if execution_parameters is not None: + params['ExecutionParameters'] = execution_parameters + + if result_reuse_configuration is not None: + params['ResultReuseConfiguration'] = result_reuse_configuration + + # Start query execution + response = self.athena_client.start_query_execution(**params) + + return StartQueryExecutionResponse( + isError=False, + content=[ + TextContent(type='text', text='Successfully started query execution') + ], + query_execution_id=response.get('QueryExecutionId', ''), + operation='start-query-execution', + ) + if operation == 'batch-get-query-execution': if query_execution_ids is None: raise ValueError( @@ -366,45 +395,6 @@ async def manage_aws_athena_queries( operation='list-query-executions', ) - elif operation == 'start-query-execution': - if query_string is None: - raise ValueError( - 'query_string is required for start-query-execution operation' - ) - - # Prepare parameters - params = {'QueryString': query_string} - - if client_request_token is not None: - params['ClientRequestToken'] = client_request_token - - if query_execution_context is not None: - params['QueryExecutionContext'] = query_execution_context - - if result_configuration is not None: - params['ResultConfiguration'] = result_configuration - - if work_group is not None: - params['WorkGroup'] = work_group - - if execution_parameters is not None: - params['ExecutionParameters'] = execution_parameters - - if result_reuse_configuration is not None: - params['ResultReuseConfiguration'] = result_reuse_configuration - - # Start query execution - response = self.athena_client.start_query_execution(**params) - - return StartQueryExecutionResponse( - isError=False, - content=[ - TextContent(type='text', text='Successfully started query execution') - ], - query_execution_id=response.get('QueryExecutionId', ''), - operation='start-query-execution', - ) - elif operation == 'stop-query-execution': if query_execution_id is None: raise ValueError( diff --git a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/utils/mutable_sql_detector.py b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/utils/mutable_sql_detector.py new file mode 100644 index 0000000000..6b91ffe55d --- /dev/null +++ b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/utils/mutable_sql_detector.py @@ -0,0 +1,68 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +# -- Mutating keyword set for quick string matching -- +MUTATING_KEYWORDS = { + 'INSERT', + 'UPDATE', + 'DELETE', + 'REPLACE', + 'TRUNCATE', + 'CREATE', + 'DROP', + 'ALTER', + 'RENAME', + 'GRANT', + 'REVOKE', + 'LOAD DATA', + 'LOAD XML', + 'INSTALL PLUGIN', + 'UNINSTALL PLUGIN', +} + +MUTATING_PATTERN = re.compile( + r'(?i)\b(' + '|'.join(re.escape(k) for k in MUTATING_KEYWORDS) + r')\b' +) + +# -- Regex for DDL statements -- +DDL_REGEX = re.compile( + r""" + ^\s*( + CREATE\s+(TABLE|VIEW|INDEX|TRIGGER|PROCEDURE|FUNCTION|EVENT)| + DROP\s+(TABLE|VIEW|INDEX|TRIGGER|PROCEDURE|FUNCTION|EVENT)| + ALTER\s+(TABLE|VIEW|TRIGGER|PROCEDURE|FUNCTION|EVENT)| + RENAME\s+(TABLE)| + TRUNCATE + )\b + """, + re.IGNORECASE | re.VERBOSE, +) + +def detect_mutating_keywords(sql: str) -> list[str]: + """Return a list of mutating keywords found in the SQL (excluding comments).""" + matched = [] + + if DDL_REGEX.search(sql): + matched.append('DDL') + + # Match individual keywords from MUTATING_KEYWORDS + keyword_matches = MUTATING_PATTERN.findall(sql) + if keyword_matches: + # Deduplicate and normalize casing + matched.extend(sorted({k.upper() for k in keyword_matches})) + + return matched