Skip to content

Fix set explain regex #20319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions postgres/changelog.d/20319.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow EXPLAIN on multi-statement SQL where one or more SET commands appear before another supported statement type
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import psycopg2

from datadog_checks.base.utils.db.sql import compute_sql_signature
from datadog_checks.base.utils.tracking import tracked_method
from datadog_checks.postgres.cursor import CommenterDictCursor

Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(self, check, config, explain_function):
self._explain_function = explain_function

@tracked_method(agent_check_getter=agent_check_getter)
def explain_statement(self, dbname, statement, obfuscated_statement):
def explain_statement(self, dbname, statement, obfuscated_statement, query_signature):
if self._check.version < V12:
# if pg version < 12, skip explaining parameterized queries because
# plan_cache_mode is not supported
Expand All @@ -85,7 +84,6 @@ def explain_statement(self, dbname, statement, obfuscated_statement):
return None, DBExplainError.parameterized_query, '{}'.format(type(e))
self._set_plan_cache_mode(dbname)

query_signature = compute_sql_signature(obfuscated_statement)
try:
self._create_prepared_statement(dbname, statement, obfuscated_statement, query_signature)
except psycopg2.errors.IndeterminateDatatype as e:
Expand Down
19 changes: 13 additions & 6 deletions postgres/datadog_checks/postgres/statement_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from datadog_checks.base.utils.tracking import tracked_method
from datadog_checks.postgres.explain_parameterized_queries import ExplainParameterizedQueries

from .util import DatabaseConfigurationError, DBExplainError, warning_with_tags
from .util import DatabaseConfigurationError, DBExplainError, trim_leading_set_stmts, warning_with_tags
from .version_utils import V9_6, V10

# according to https://unicodebook.readthedocs.io/unicode_encodings.html, the max supported size of a UTF-8 encoded
Expand Down Expand Up @@ -749,15 +749,22 @@ def _run_and_track_explain(self, dbname, statement, obfuscated_statement, query_
@tracked_method(agent_check_getter=agent_check_getter)
def _run_explain_safe(self, dbname, statement, obfuscated_statement, query_signature):
# type: (str, str, str, str) -> Tuple[Optional[Dict], Optional[DBExplainError], Optional[str]]

orig_statement = statement

# remove leading SET statements from our SQL
if obfuscated_statement[:3].lower() == "set":
statement = trim_leading_set_stmts(statement)
obfuscated_statement = trim_leading_set_stmts(obfuscated_statement)

if not self._can_explain_statement(obfuscated_statement):
return None, DBExplainError.no_plans_possible, None

track_activity_query_size = self._get_track_activity_query_size()

if (
self._get_truncation_state(track_activity_query_size, statement, query_signature)
== StatementTruncationState.truncated
):
# truncation check is on the original query, not the trimmed version
stmt_trunc = self._get_truncation_state(track_activity_query_size, orig_statement, query_signature)
if stmt_trunc == StatementTruncationState.truncated:
return (
None,
DBExplainError.query_truncated,
Expand All @@ -779,7 +786,7 @@ def _run_explain_safe(self, dbname, statement, obfuscated_statement, query_signa
if self._explain_parameterized_queries._is_parameterized_query(statement):
if is_affirmative(self._config.statement_samples_config.get('explain_parameterized_queries', True)):
return self._explain_parameterized_queries.explain_statement(
dbname, statement, obfuscated_statement
dbname, statement, obfuscated_statement, query_signature
)
e = psycopg2.errors.UndefinedParameter("Unable to explain parameterized query")
self._log.debug(
Expand Down
35 changes: 35 additions & 0 deletions postgres/datadog_checks/postgres/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# (C) Datadog, Inc. 2019-present
# All rights reserved
# Licensed under Simplified BSD License (see LICENSE)
import re
import string
from enum import Enum
from typing import Any, List, Tuple # noqa: F401
Expand Down Expand Up @@ -130,6 +131,40 @@ def get_list_chunks(lst, n):
yield lst[i : i + n]


SET_TRIM_PATTERN = re.compile(
r"""
^
(?:
\s*
# match one leading comment
(?:
/\*
.*
\*/
\s*
)?

# match leading SET commands
SET\b
(?:
[^';] | # keywords, integer literals, etc.
'[^']*' # single-quoted strings
)+
;
)+
""",
flags=(re.I | re.X),
)


# Expects one or more SQL statements in a string. If the string
# begins with any SET statements, they are removed and the rest
# of the string is returned. Otherwise, the string is returned
# as it was received.
def trim_leading_set_stmts(sql):
return SET_TRIM_PATTERN.sub('', sql, 1).lstrip()


fmt = PartialFormatter()

AWS_RDS_HOSTNAME_SUFFIX = ".rds.amazonaws.com"
Expand Down
23 changes: 18 additions & 5 deletions postgres/tests/test_explain_parameterized_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def test_explain_parameterized_queries(integration_check, dbm_instance, query, e
if check.version < V12:
return

plan_dict, explain_err_code, err = check.statement_samples._run_and_track_explain(DB_NAME, query, query, query)
plan_dict, explain_err_code, err = check.statement_samples._run_and_track_explain(
DB_NAME, query, query, "7231596c8b5536d1"
)
assert plan_dict is not None
assert explain_err_code == expected_explain_err_code
assert err is None
Expand Down Expand Up @@ -111,7 +113,10 @@ def test_explain_parameterized_queries_version_below_12(integration_check, dbm_i
return

plan_dict, explain_err_code, err = check.statement_samples._run_and_track_explain(
DB_NAME, "SELECT * FROM pg_settings WHERE name = $1", "SELECT * FROM pg_settings WHERE name = $1", ""
DB_NAME,
"SELECT * FROM pg_settings WHERE name = $1",
"SELECT * FROM pg_settings WHERE name = $1",
"7231596c8b5536d1",
)
assert plan_dict is None
assert explain_err_code == DBExplainError.parameterized_query
Expand All @@ -133,7 +138,10 @@ def test_explain_parameterized_queries_create_prepared_statement_exception(integ
side_effect=psycopg2.errors.DatabaseError("unexpected exception"),
):
plan_dict, explain_err_code, err = check.statement_samples._run_and_track_explain(
DB_NAME, "SELECT * FROM pg_settings WHERE name = $1", "SELECT * FROM pg_settings WHERE name = $1", ""
DB_NAME,
"SELECT * FROM pg_settings WHERE name = $1",
"SELECT * FROM pg_settings WHERE name = $1",
"7231596c8b5536d1",
)
assert plan_dict is None
assert explain_err_code == DBExplainError.failed_to_explain_with_prepared_statement
Expand All @@ -155,7 +163,9 @@ def test_explain_parameterized_queries_explain_prepared_statement_exception(inte
side_effect=psycopg2.errors.DatabaseError("unexpected exception"),
):
query = "SELECT * FROM pg_settings WHERE name = $1"
plan_dict, explain_err_code, err = check.statement_samples._run_and_track_explain(DB_NAME, query, query, "")
plan_dict, explain_err_code, err = check.statement_samples._run_and_track_explain(
DB_NAME, query, query, "7231596c8b5536d1"
)
assert plan_dict is None
assert explain_err_code == DBExplainError.failed_to_explain_with_prepared_statement
assert err is not None
Expand Down Expand Up @@ -184,7 +194,10 @@ def test_explain_parameterized_queries_explain_prepared_statement_no_plan_return
return_value=None,
):
plan_dict, explain_err_code, err = check.statement_samples._run_and_track_explain(
DB_NAME, "SELECT * FROM pg_settings WHERE name = $1", "SELECT * FROM pg_settings WHERE name = $1", ""
DB_NAME,
"SELECT * FROM pg_settings WHERE name = $1",
"SELECT * FROM pg_settings WHERE name = $1",
"7231596c8b5536d1",
)
assert plan_dict is None
assert explain_err_code == DBExplainError.no_plan_returned_with_prepared_statement
Expand Down
48 changes: 45 additions & 3 deletions postgres/tests/test_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,44 @@ def test_get_db_explain_setup_state(integration_check, dbm_instance, dbname, exp
failed_explain_test_repeat_count = 5


@pytest.mark.parametrize(
"query",
[
"SELECT * FROM pg_class",
"SET LOCAL datestyle TO postgres; SELECT * FROM pg_class",
],
)
def test_successful_explain(
integration_check,
dbm_instance,
aggregator,
query,
):
dbname = "datadog_test"
# Don't need metrics for this one
dbm_instance['query_metrics']['enabled'] = False
dbm_instance['query_samples']['explain_parameterized_queries'] = False
check = integration_check(dbm_instance)
check._connect()

# run check so all internal state is correctly initialized
run_one_check(check)

# clear out contents of aggregator so we measure only the metrics generated during this specific part of the test
aggregator.reset()

db_explain_error, err = check.statement_samples._get_db_explain_setup_state(dbname)
assert db_explain_error is None
assert err is None

plan, *rest = check.statement_samples._run_and_track_explain(dbname, query, query, "7231596c8b5536d1")
assert plan is not None

plan = plan['Plan']
assert plan['Node Type'] == 'Seq Scan'
assert plan['Relation Name'] == 'pg_class'


@pytest.mark.parametrize(
"query,expected_error_tag,explain_function_override,expected_fail_count,skip_on_versions",
[
Expand Down Expand Up @@ -663,7 +701,7 @@ def test_failed_explain_handling(
assert err is None

for _ in range(failed_explain_test_repeat_count):
check.statement_samples._run_and_track_explain(dbname, query, query, query)
check.statement_samples._run_and_track_explain(dbname, query, query, "7231596c8b5536d1")

expected_tags = _get_expected_tags(
check, dbm_instance, with_host=False, with_db=True, agent_hostname='stubbed.hostname'
Expand Down Expand Up @@ -1480,7 +1518,9 @@ def test_statement_run_explain_errors(
check._connect()

run_one_check(check)
_, explain_err_code, err = check.statement_samples._run_and_track_explain("datadog_test", query, query, query)
_, explain_err_code, err = check.statement_samples._run_and_track_explain(
"datadog_test", query, query, "7231596c8b5536d1"
)
run_one_check(check)

assert explain_err_code == expected_explain_err_code
Expand Down Expand Up @@ -1534,7 +1574,9 @@ def test_statement_run_explain_parameterized_queries(
return

run_one_check(check)
_, explain_err_code, err = check.statement_samples._run_and_track_explain("datadog_test", query, query, query)
_, explain_err_code, err = check.statement_samples._run_and_track_explain(
"datadog_test", query, query, "7231596c8b5536d1"
)
run_one_check(check)

assert explain_err_code == expected_explain_err_code
Expand Down
34 changes: 34 additions & 0 deletions postgres/tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,37 @@ def test_database_identifier(pg_instance, template, expected, tags):
pg_instance['tags'] = tags
check = PostgreSql('postgres', {}, [pg_instance])
assert check.database_identifier == expected


@pytest.mark.unit
@pytest.mark.parametrize(
"query,expected_trimmed_query",
[
("SELECT * FROM pg_settings WHERE name = $1", "SELECT * FROM pg_settings WHERE name = $1"),
("SELECT * FROM pg_settings; DELETE FROM pg_settings;", "SELECT * FROM pg_settings; DELETE FROM pg_settings;"),
("SET search_path TO 'my_schema', public; SELECT * FROM pg_settings", "SELECT * FROM pg_settings"),
("SET TIME ZONE 'Europe/Rome'; SELECT * FROM pg_settings", "SELECT * FROM pg_settings"),
(
"SET LOCAL request_id = 1234; SET LOCAL hostname TO 'Bob''s Laptop'; SELECT * FROM pg_settings",
"SELECT * FROM pg_settings",
),
("SET LONG;" * 1024 + "SELECT *;", "SELECT *;"),
("SET " + "'quotable'" * 1024 + "; SELECT *;", "SELECT *;"),
("SET 'l" + "o" * 1024 + "ng'; SELECT *;", "SELECT *;"),
(" /** pl/pgsql **/ SET 'comment'; SELECT *;", "SELECT *;"),
("this isn't SQL", "this isn't SQL"),
(
"SET SESSION min_wal_size = 14400; "
+ "SET LOCAL wal_buffers TO 2048; "
+ "/* testing id 1234 */ set send_abort_for_kill TO 'stderr'; "
+ "set id = case when (false) and ((((cast(null as box) ~= cast(null as box)) "
+ "or (cast(null as point) <@ cast(null as line))) or (public.my table",
"set id = case when (false) and ((((cast(null as box) ~= cast(null as box)) "
+ "or (cast(null as point) <@ cast(null as line))) or (public.my table",
),
("", ""),
],
)
def test_trim_set_stmts(query, expected_trimmed_query):
trimmed_query = util.trim_leading_set_stmts(query)
assert trimmed_query == expected_trimmed_query
Loading