Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
LogLevel,
log_with_request_id,
)
from awslabs.aws_dataprocessing_mcp_server.utils.mutable_sql_detector import detect_mutating_keywords
from mcp.server.fastmcp import Context
from mcp.types import TextContent
from pydantic import Field
Expand Down Expand Up @@ -211,29 +212,57 @@ async def manage_aws_athena_queries(
f'Athena Query Handler - Tool: manage_aws_athena_queries - Operation: {operation}',
)

if not self.allow_write and operation in [
'start-query-execution',
]:
error_message = (
f'Operation {operation} for select query is only allowed without write access'
)
log_with_request_id(ctx, LogLevel.ERROR, error_message)

if (
operation == 'start-query-execution'
and query_string
and (
'select' not in query_string.lower()
or 'create table as select' in query_string.lower()
if operation == 'start-query-execution':
if query_string is None:
raise ValueError(
'query_string is required for start-query-execution operation'
)
):

matched = detect_mutating_keywords(query_string)
if not self.allow_write and matched:
error_message = f"Mutating operations: {matched} are not allowed when write access is disabled"
log_with_request_id(ctx, LogLevel.ERROR, error_message)

return StartQueryExecutionResponse(
isError=True,
content=[TextContent(type='text', text=error_message)],
query_execution_id='',
operation='start-query-execution',
)

# Prepare parameters
params = {'QueryString': query_string}

if client_request_token is not None:
params['ClientRequestToken'] = client_request_token

if query_execution_context is not None:
params['QueryExecutionContext'] = query_execution_context

if result_configuration is not None:
params['ResultConfiguration'] = result_configuration

if work_group is not None:
params['WorkGroup'] = work_group

if execution_parameters is not None:
params['ExecutionParameters'] = execution_parameters

if result_reuse_configuration is not None:
params['ResultReuseConfiguration'] = result_reuse_configuration

# Start query execution
response = self.athena_client.start_query_execution(**params)

return StartQueryExecutionResponse(
isError=False,
content=[
TextContent(type='text', text='Successfully started query execution')
],
query_execution_id=response.get('QueryExecutionId', ''),
operation='start-query-execution',
)

if operation == 'batch-get-query-execution':
if query_execution_ids is None:
raise ValueError(
Expand Down Expand Up @@ -366,45 +395,6 @@ async def manage_aws_athena_queries(
operation='list-query-executions',
)

elif operation == 'start-query-execution':
if query_string is None:
raise ValueError(
'query_string is required for start-query-execution operation'
)

# Prepare parameters
params = {'QueryString': query_string}

if client_request_token is not None:
params['ClientRequestToken'] = client_request_token

if query_execution_context is not None:
params['QueryExecutionContext'] = query_execution_context

if result_configuration is not None:
params['ResultConfiguration'] = result_configuration

if work_group is not None:
params['WorkGroup'] = work_group

if execution_parameters is not None:
params['ExecutionParameters'] = execution_parameters

if result_reuse_configuration is not None:
params['ResultReuseConfiguration'] = result_reuse_configuration

# Start query execution
response = self.athena_client.start_query_execution(**params)

return StartQueryExecutionResponse(
isError=False,
content=[
TextContent(type='text', text='Successfully started query execution')
],
query_execution_id=response.get('QueryExecutionId', ''),
operation='start-query-execution',
)

elif operation == 'stop-query-execution':
if query_execution_id is None:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re


# -- Mutating keyword set for quick string matching --
MUTATING_KEYWORDS = {
'INSERT',
'UPDATE',
'DELETE',
'REPLACE',
'TRUNCATE',
'CREATE',
'DROP',
'ALTER',
'RENAME',
'GRANT',
'REVOKE',
'LOAD DATA',
'LOAD XML',
'INSTALL PLUGIN',
'UNINSTALL PLUGIN',
}

MUTATING_PATTERN = re.compile(
r'(?i)\b(' + '|'.join(re.escape(k) for k in MUTATING_KEYWORDS) + r')\b'
)

# -- Regex for DDL statements --
DDL_REGEX = re.compile(
r"""
^\s*(
CREATE\s+(TABLE|VIEW|INDEX|TRIGGER|PROCEDURE|FUNCTION|EVENT)|
DROP\s+(TABLE|VIEW|INDEX|TRIGGER|PROCEDURE|FUNCTION|EVENT)|
ALTER\s+(TABLE|VIEW|TRIGGER|PROCEDURE|FUNCTION|EVENT)|
RENAME\s+(TABLE)|
TRUNCATE
)\b
""",
re.IGNORECASE | re.VERBOSE,
)

def detect_mutating_keywords(sql: str) -> list[str]:
"""Return a list of mutating keywords found in the SQL (excluding comments)."""
matched = []

if DDL_REGEX.search(sql):
matched.append('DDL')

# Match individual keywords from MUTATING_KEYWORDS
keyword_matches = MUTATING_PATTERN.findall(sql)
if keyword_matches:
# Deduplicate and normalize casing
matched.extend(sorted({k.upper() for k in keyword_matches}))

return matched
Loading