From 1c360fdef28a977318555a37aee5bb2b3176e6e4 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Mon, 23 Dec 2024 10:19:31 +0100 Subject: [PATCH 1/5] Add SPI for returning connectors' sensitive properties The SPI will be used by the engine to redact security-sensitive information in statements that manage catalogs. It has been added at the connector factory level, rather than the connector level, to allow more flexibility in retrieving properties. In some cases, we want to perform redacting before a connector is initiated. For example, when we create a new catalog by issuing the CREATE CATALOG statement. --- .../java/io/trino/spi/connector/ConnectorFactory.java | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorFactory.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorFactory.java index 83f1eb68f01a..31c80364dae8 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorFactory.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorFactory.java @@ -16,6 +16,7 @@ import com.google.errorprone.annotations.CheckReturnValue; import java.util.Map; +import java.util.Set; public interface ConnectorFactory { @@ -23,4 +24,14 @@ public interface ConnectorFactory @CheckReturnValue Connector create(String catalogName, Map config, ConnectorContext context); + + /** + * Extracts property names from the provided set that may include security-sensitive + * values. The engine will use these names to mask the corresponding property values, + * preventing the leakage of sensitive information. + */ + default Set getRedactablePropertyNames(Set propertyNames) + { + return propertyNames; + } } From 8018b456a501fe61260feae4dd114c03da9120b2 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Mon, 23 Dec 2024 10:21:04 +0100 Subject: [PATCH 2/5] Expose redactable properties for PostgreSQL connector Exposed properties fall into one of the following categories: they are either explicitly marked as security-sensitive or are unknown. The connector assumes that unknown properties might be misspelled security-sensitive properties. The purpose of the included test is to identify security-sensitive properties that may be used by the connector. It uses the output generated by the maven-dependency-plugin, configured in the connector's pom.xml file. This output contains the connector's runtime classpath, which is then scanned to identify all property names annotated with @Config. Scanning the classpath ensures that all configuration classes are included, even those used conditionally. --- .../base/config/ConfigPropertyMetadata.java | 39 ++++++ .../plugin/jdbc/JdbcConnectorFactory.java | 50 +++++++- .../java/io/trino/plugin/jdbc/JdbcPlugin.java | 13 +- plugin/trino-postgresql/pom.xml | 30 +++++ .../plugin/postgresql/PostgreSqlPlugin.java | 3 +- .../postgresql/TestPostgreSqlPlugin.java | 112 ++++++++++++++++++ 6 files changed, 244 insertions(+), 3 deletions(-) create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/config/ConfigPropertyMetadata.java diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/config/ConfigPropertyMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/config/ConfigPropertyMetadata.java new file mode 100644 index 000000000000..6a2d5bdeafc2 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/config/ConfigPropertyMetadata.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.plugin.base.config; + +import io.airlift.configuration.ConfigurationMetadata; + +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.configuration.ConfigurationMetadata.getConfigurationMetadata; +import static java.util.Objects.requireNonNull; + +public record ConfigPropertyMetadata(String name, boolean sensitive) +{ + public ConfigPropertyMetadata + { + requireNonNull(name, "name is null"); + } + + public static Set getConfigProperties(Class configClass) + { + ConfigurationMetadata configurationMetadata = getConfigurationMetadata(configClass); + return configurationMetadata.getAttributes().values().stream() + .map(attribute -> new ConfigPropertyMetadata(attribute.getInjectionPoint().getProperty(), attribute.isSecuritySensitive())) + .collect(toImmutableSet()); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcConnectorFactory.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcConnectorFactory.java index f7adb6c17a2a..7c891406544e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcConnectorFactory.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcConnectorFactory.java @@ -13,10 +13,21 @@ */ package io.trino.plugin.jdbc; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import com.google.inject.Injector; import com.google.inject.Module; import io.airlift.bootstrap.Bootstrap; +import io.airlift.bootstrap.BootstrapConfig; import io.opentelemetry.api.OpenTelemetry; +import io.trino.plugin.base.config.ConfigPropertyMetadata; +import io.trino.plugin.base.mapping.MappingConfig; +import io.trino.plugin.jdbc.credential.CredentialConfig; +import io.trino.plugin.jdbc.credential.CredentialProviderTypeConfig; +import io.trino.plugin.jdbc.credential.ExtraCredentialConfig; +import io.trino.plugin.jdbc.credential.file.ConfigFileBasedCredentialProviderConfig; +import io.trino.plugin.jdbc.credential.keystore.KeyStoreBasedCredentialProviderConfig; +import io.trino.plugin.jdbc.logging.FormatBasedRemoteQueryModifierConfig; import io.trino.spi.NodeManager; import io.trino.spi.VersionEmbedder; import io.trino.spi.catalog.CatalogName; @@ -26,11 +37,15 @@ import io.trino.spi.type.TypeManager; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; +import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.base.Versions.checkStrictSpiVersionMatch; +import static io.trino.plugin.base.config.ConfigPropertyMetadata.getConfigProperties; import static java.util.Objects.requireNonNull; public class JdbcConnectorFactory @@ -38,12 +53,39 @@ public class JdbcConnectorFactory { private final String name; private final Supplier module; + private final Set nonRedactablePropertyNames; - public JdbcConnectorFactory(String name, Supplier module) + public JdbcConnectorFactory(String name, Supplier module, Set additionalProperties) { checkArgument(!isNullOrEmpty(name), "name is null or empty"); this.name = name; this.module = requireNonNull(module, "module is null"); + Set> configClasses = ImmutableSet.>builder() + .add(BaseJdbcConfig.class) + .add(CredentialConfig.class) + .add(JdbcStatisticsConfig.class) + .add(JdbcWriteConfig.class) + .add(QueryConfig.class) + .add(RemoteQueryCancellationConfig.class) + .add(TypeHandlingJdbcConfig.class) + .add(JdbcMetadataConfig.class) + .add(JdbcJoinPushdownConfig.class) + .add(DecimalConfig.class) + .add(JdbcDynamicFilteringConfig.class) + .add(KeyStoreBasedCredentialProviderConfig.class) + .add(FormatBasedRemoteQueryModifierConfig.class) + .add(ConfigFileBasedCredentialProviderConfig.class) + .add(CredentialProviderTypeConfig.class) + .add(ExtraCredentialConfig.class) + .add(MappingConfig.class) + .add(BootstrapConfig.class) + .build(); + this.nonRedactablePropertyNames = Stream.concat( + configClasses.stream().flatMap(clazz -> getConfigProperties(clazz).stream()), + requireNonNull(additionalProperties, "additionalProperties is null").stream()) + .filter(property -> !property.sensitive()) + .map(ConfigPropertyMetadata::name) + .collect(toImmutableSet()); } @Override @@ -74,4 +116,10 @@ public Connector create(String catalogName, Map requiredConfig, return injector.getInstance(JdbcConnector.class); } + + @Override + public Set getRedactablePropertyNames(Set propertyNames) + { + return Sets.difference(propertyNames, nonRedactablePropertyNames); + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPlugin.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPlugin.java index 5fe4aa80b012..1cad7213cac7 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPlugin.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPlugin.java @@ -14,11 +14,14 @@ package io.trino.plugin.jdbc; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.inject.Module; +import io.trino.plugin.base.config.ConfigPropertyMetadata; import io.trino.plugin.jdbc.credential.CredentialProviderModule; import io.trino.spi.Plugin; import io.trino.spi.connector.ConnectorFactory; +import java.util.Set; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; @@ -31,12 +34,19 @@ public class JdbcPlugin { private final String name; private final Supplier module; + private final Set additionalProperties; public JdbcPlugin(String name, Supplier module) + { + this(name, module, ImmutableSet.of()); + } + + public JdbcPlugin(String name, Supplier module, Set additionalProperties) { checkArgument(!isNullOrEmpty(name), "name is null or empty"); this.name = name; this.module = requireNonNull(module, "module is null"); + this.additionalProperties = ImmutableSet.copyOf(requireNonNull(additionalProperties, "additionalProperties is null")); } @Override @@ -47,6 +57,7 @@ public Iterable getConnectorFactories() () -> combine( new CredentialProviderModule(), new ExtraCredentialsBasedIdentityCacheMappingModule(), - module.get()))); + module.get()), + additionalProperties)); } } diff --git a/plugin/trino-postgresql/pom.xml b/plugin/trino-postgresql/pom.xml index d7105256f301..036da48aa73a 100644 --- a/plugin/trino-postgresql/pom.xml +++ b/plugin/trino-postgresql/pom.xml @@ -158,6 +158,13 @@ test + + io.github.classgraph + classgraph + 4.8.174 + test + + io.trino trino-base-jdbc @@ -276,4 +283,27 @@ test + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + print-runtime-dependencies + + build-classpath + + generate-sources + + false + runtime + ${project.build.testOutputDirectory}/runtime-dependencies.txt + + + + + + diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlPlugin.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlPlugin.java index 4c860eaf3593..a497f7441e08 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlPlugin.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlPlugin.java @@ -16,12 +16,13 @@ import io.trino.plugin.jdbc.JdbcPlugin; import static io.airlift.configuration.ConfigurationAwareModule.combine; +import static io.trino.plugin.base.config.ConfigPropertyMetadata.getConfigProperties; public class PostgreSqlPlugin extends JdbcPlugin { public PostgreSqlPlugin() { - super("postgresql", () -> combine(new PostgreSqlClientModule(), new PostgreSqlConnectionFactoryModule())); + super("postgresql", () -> combine(new PostgreSqlClientModule(), new PostgreSqlConnectionFactoryModule()), getConfigProperties(PostgreSqlConfig.class)); } } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlPlugin.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlPlugin.java index 6e288e7501ca..04931d3a82f7 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlPlugin.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlPlugin.java @@ -14,11 +14,30 @@ package io.trino.plugin.postgresql; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigSecuritySensitive; +import io.github.classgraph.AnnotationInfo; +import io.github.classgraph.AnnotationParameterValueList; +import io.github.classgraph.ClassGraph; +import io.github.classgraph.ScanResult; +import io.trino.plugin.base.config.ConfigPropertyMetadata; import io.trino.spi.Plugin; import io.trino.spi.connector.ConnectorFactory; import org.junit.jupiter.api.Test; +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static org.assertj.core.api.Assertions.assertThat; public class TestPostgreSqlPlugin { @@ -34,4 +53,97 @@ public void testCreateConnector() "bootstrap.quiet", "true"), new TestingPostgreSqlConnectorContext()).shutdown(); } + + @Test + void testUnknownPropertiesAreRedactable() + { + Plugin plugin = new PostgreSqlPlugin(); + ConnectorFactory connectorFactory = getOnlyElement(plugin.getConnectorFactories()); + Set unknownProperties = ImmutableSet.of("unknown"); + + Set redactableProperties = connectorFactory.getRedactablePropertyNames(unknownProperties); + + assertThat(redactableProperties).isEqualTo(unknownProperties); + } + + @Test + void testSecuritySensitivePropertiesAreRedactable() + throws Exception + { + // The purpose of this test is to help identify security-sensitive properties that + // may be used by the connector. These properties are detected by scanning the + // plugin's runtime classpath and collecting all property names annotated with + // @ConfigSecuritySensitive. The scan includes all configuration classes, whether + // they are always used, conditionally used, or never used. This approach has both + // advantages and disadvantages. + // + // One advantage is that we don't need to rely on the plugin's configuration to + // retrieve properties that are used conditionally. However, this method may also + // capture properties that are not used at all but are pulled into the classpath + // by dependencies. With that in mind, if this test fails, it likely indicates that + // either a property needs to be added to the connector's security-sensitive + // property names list, or it should be added to the excluded properties list below. + Set excludedClasses = ImmutableSet.of( + "io.trino.plugin.base.ldap.LdapClientConfig", + "io.airlift.http.client.HttpClientConfig", + "io.airlift.node.NodeConfig", + "io.airlift.log.LoggingConfiguration", + "io.trino.plugin.base.security.FileBasedAccessControlConfig", + "io.airlift.configuration.secrets.SecretsPluginConfig", + "io.trino.plugin.base.jmx.ObjectNameGeneratorConfig"); + Plugin plugin = new PostgreSqlPlugin(); + ConnectorFactory connectorFactory = getOnlyElement(plugin.getConnectorFactories()); + + Set propertiesFoundInClasspath = findPropertiesInRuntimeClasspath(excludedClasses); + Set allPropertyNames = propertiesFoundInClasspath.stream() + .map(ConfigPropertyMetadata::name) + .collect(toImmutableSet()); + Set expectedProperties = propertiesFoundInClasspath.stream() + .filter(ConfigPropertyMetadata::sensitive) + .map(ConfigPropertyMetadata::name) + .collect(toImmutableSet()); + + Set redactableProperties = connectorFactory.getRedactablePropertyNames(allPropertyNames); + + assertThat(redactableProperties).isEqualTo(expectedProperties); + } + + private static Set findPropertiesInRuntimeClasspath(Set excludedClassNames) + throws URISyntaxException, IOException + { + try (ScanResult scanResult = new ClassGraph() + .overrideClasspath(buildRuntimeClasspath()) + .enableAllInfo() + .scan()) { + return scanResult.getClassesWithMethodAnnotation(Config.class).stream() + .filter(classInfo -> !excludedClassNames.contains(classInfo.getName())) + .flatMap(classInfo -> classInfo.getMethodInfo().stream()) + .filter(methodInfo -> methodInfo.hasAnnotation(Config.class)) + .map(methodInfo -> { + boolean sensitive = methodInfo.hasAnnotation(ConfigSecuritySensitive.class); + AnnotationInfo annotationInfo = methodInfo.getAnnotationInfo(Config.class); + checkState(annotationInfo != null, "Missing @Config annotation for %s", methodInfo); + AnnotationParameterValueList parameterValues = annotationInfo.getParameterValues(); + checkState(parameterValues.size() == 1, "Expected exactly one parameter for %s", annotationInfo); + String propertyName = (String) parameterValues.getFirst().getValue(); + return new ConfigPropertyMetadata(propertyName, sensitive); + }) + .collect(toImmutableSet()); + } + } + + private static String buildRuntimeClasspath() + throws URISyntaxException, IOException + { + // This file is generated by the maven-dependency-plugin, which is configured in the connector's pom.xml file. + String runtimeDependenciesFile = "runtime-dependencies.txt"; + URL runtimeDependenciesUrl = TestPostgreSqlPlugin.class.getClassLoader().getResource(runtimeDependenciesFile); + checkState(runtimeDependenciesUrl != null, "Missing %s file", runtimeDependenciesUrl); + String runtimeDependenciesClasspath = Files.readString(Path.of(runtimeDependenciesUrl.toURI())); + + File classDirectory = new File(new File(runtimeDependenciesUrl.toURI()).getParentFile().getParentFile(), "classes/"); + checkState(classDirectory.exists(), "Missing %s directory", classDirectory.getAbsolutePath()); + + return "%s:%s".formatted(runtimeDependenciesClasspath, classDirectory.getAbsolutePath()); + } } From 8b7148de98a87fa6a165de9995eb55e4ae70eefb Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Mon, 23 Dec 2024 11:08:18 +0100 Subject: [PATCH 3/5] Introduce redaction for sensitive statements This commit introduces redacting of security-sensitive information in statements containing connector properties, specifically: * CREATE CATALOG * EXPLAIN CREATE CATALOG * EXPLAIN ANALYZE CREATE CATALOG The current approach is as follows: * For syntactically valid statements, only properties containing sensitive information are masked. * If a valid query references a nonexistent connector, all properties are masked. * If a query fails before or during parsing, the entire query is masked The redacted form is created in DispatchManager and is propagated to all places that create QueryInfo and BasicQueryInfo. Before this change, QueryInfo/BasicQueryInfo stored the raw query text received from the end user. From now on, the text will be altered for the cases listed above. --- .../main/java/io/trino/FeaturesConfig.java | 14 ++ .../io/trino/connector/CatalogFactory.java | 4 + .../connector/DefaultCatalogFactory.java | 13 ++ .../trino/connector/LazyCatalogFactory.java | 7 + .../io/trino/dispatcher/DispatchManager.java | 20 +- .../io/trino/server/CoordinatorModule.java | 4 + .../trino/sql/SensitiveStatementRedactor.java | 131 ++++++++++++ ...estSqlTaskManagerRaceWithCatalogPrune.java | 6 + .../sql/analyzer/TestFeaturesConfig.java | 7 +- .../java/io/trino/sql/tree/CreateCatalog.java | 12 ++ .../execution/TestEventListenerBasic.java | 16 +- .../TestRedactSensitiveStatements.java | 201 ++++++++++++++++++ 12 files changed, 428 insertions(+), 7 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/sql/SensitiveStatementRedactor.java create mode 100644 testing/trino-tests/src/test/java/io/trino/execution/TestRedactSensitiveStatements.java diff --git a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java index 707b3300e8ca..449f0971307a 100644 --- a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java +++ b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java @@ -122,6 +122,8 @@ public class FeaturesConfig private boolean faultTolerantExecutionExchangeEncryptionEnabled = true; + private boolean statementRedactingEnabled = true; + public enum DataIntegrityVerification { NONE, @@ -514,6 +516,18 @@ public FeaturesConfig setFaultTolerantExecutionExchangeEncryptionEnabled(boolean return this; } + public boolean isStatementRedactingEnabled() + { + return statementRedactingEnabled; + } + + @Config("statement-redacting-enabled") + public FeaturesConfig setStatementRedactingEnabled(boolean statementRedactingEnabled) + { + this.statementRedactingEnabled = statementRedactingEnabled; + return this; + } + public void applyFaultTolerantExecutionDefaults() { exchangeCompressionCodec = LZ4; diff --git a/core/trino-main/src/main/java/io/trino/connector/CatalogFactory.java b/core/trino-main/src/main/java/io/trino/connector/CatalogFactory.java index e91151f5d8c5..692df69a44f8 100644 --- a/core/trino-main/src/main/java/io/trino/connector/CatalogFactory.java +++ b/core/trino-main/src/main/java/io/trino/connector/CatalogFactory.java @@ -20,6 +20,8 @@ import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.connector.ConnectorName; +import java.util.Set; + @ThreadSafe public interface CatalogFactory { @@ -28,4 +30,6 @@ public interface CatalogFactory CatalogConnector createCatalog(CatalogProperties catalogProperties); CatalogConnector createCatalog(CatalogHandle catalogHandle, ConnectorName connectorName, Connector connector); + + Set getRedactablePropertyNames(ConnectorName connectorName, Set propertyNames); } diff --git a/core/trino-main/src/main/java/io/trino/connector/DefaultCatalogFactory.java b/core/trino-main/src/main/java/io/trino/connector/DefaultCatalogFactory.java index b9759c83fced..50b6bebe638f 100644 --- a/core/trino-main/src/main/java/io/trino/connector/DefaultCatalogFactory.java +++ b/core/trino-main/src/main/java/io/trino/connector/DefaultCatalogFactory.java @@ -45,6 +45,7 @@ import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -144,6 +145,18 @@ public CatalogConnector createCatalog(CatalogHandle catalogHandle, ConnectorName return createCatalog(catalogHandle, connectorName, connector, Optional.empty()); } + @Override + public Set getRedactablePropertyNames(ConnectorName connectorName, Set propertyNames) + { + ConnectorFactory connectorFactory = connectorFactories.get(connectorName); + if (connectorFactory == null) { + // If someone tries to use a non-existent connector, we assume they + // misspelled the name and, for safety, we redact all the properties. + return propertyNames; + } + return connectorFactory.getRedactablePropertyNames(propertyNames); + } + private CatalogConnector createCatalog(CatalogHandle catalogHandle, ConnectorName connectorName, Connector connector, Optional catalogProperties) { Tracer tracer = createTracer(catalogHandle); diff --git a/core/trino-main/src/main/java/io/trino/connector/LazyCatalogFactory.java b/core/trino-main/src/main/java/io/trino/connector/LazyCatalogFactory.java index 09513d9669f9..9356e6a98348 100644 --- a/core/trino-main/src/main/java/io/trino/connector/LazyCatalogFactory.java +++ b/core/trino-main/src/main/java/io/trino/connector/LazyCatalogFactory.java @@ -19,6 +19,7 @@ import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.connector.ConnectorName; +import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import static com.google.common.base.Preconditions.checkState; @@ -51,6 +52,12 @@ public CatalogConnector createCatalog(CatalogHandle catalogHandle, ConnectorName return getDelegate().createCatalog(catalogHandle, connectorName, connector); } + @Override + public Set getRedactablePropertyNames(ConnectorName connectorName, Set propertyNames) + { + return getDelegate().getRedactablePropertyNames(connectorName, propertyNames); + } + private CatalogFactory getDelegate() { CatalogFactory catalogFactory = delegate.get(); diff --git a/core/trino-main/src/main/java/io/trino/dispatcher/DispatchManager.java b/core/trino-main/src/main/java/io/trino/dispatcher/DispatchManager.java index f49114735147..cf20c560b2e7 100644 --- a/core/trino-main/src/main/java/io/trino/dispatcher/DispatchManager.java +++ b/core/trino-main/src/main/java/io/trino/dispatcher/DispatchManager.java @@ -44,6 +44,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.resourcegroups.SelectionContext; import io.trino.spi.resourcegroups.SelectionCriteria; +import io.trino.sql.SensitiveStatementRedactor; import jakarta.annotation.PostConstruct; import jakarta.annotation.PreDestroy; import org.weakref.jmx.Flatten; @@ -84,6 +85,7 @@ public class DispatchManager private final SessionPropertyDefaults sessionPropertyDefaults; private final SessionPropertyManager sessionPropertyManager; private final Tracer tracer; + private final SensitiveStatementRedactor sensitiveStatementRedactor; private final int maxQueryLength; @@ -107,6 +109,7 @@ public DispatchManager( SessionPropertyDefaults sessionPropertyDefaults, SessionPropertyManager sessionPropertyManager, Tracer tracer, + SensitiveStatementRedactor sensitiveStatementRedactor, QueryManagerConfig queryManagerConfig, DispatchExecutor dispatchExecutor, QueryMonitor queryMonitor) @@ -121,6 +124,7 @@ public DispatchManager( this.sessionPropertyDefaults = requireNonNull(sessionPropertyDefaults, "sessionPropertyDefaults is null"); this.sessionPropertyManager = sessionPropertyManager; this.tracer = requireNonNull(tracer, "tracer is null"); + this.sensitiveStatementRedactor = requireNonNull(sensitiveStatementRedactor, "sensitiveStatementRedactor is null"); this.maxQueryLength = queryManagerConfig.getMaxQueryLength(); @@ -207,6 +211,7 @@ private void createQueryInternal(QueryId queryId, Span querySpan, Slug slug, { Session session = null; PreparedQuery preparedQuery = null; + String redactedQuery = null; try { if (query.length() > maxQueryLength) { int queryLength = query.length(); @@ -223,6 +228,9 @@ private void createQueryInternal(QueryId queryId, Span querySpan, Slug slug, // prepare query preparedQuery = queryPreparer.prepareQuery(session, query); + // redact security-sensitive information that query may contain + redactedQuery = sensitiveStatementRedactor.redact(query, preparedQuery.getStatement()); + // select resource group Optional queryType = getQueryType(preparedQuery.getStatement()).map(Enum::name); SelectionContext selectionContext = resourceGroupManager.selectGroup(new SelectionCriteria( @@ -240,7 +248,7 @@ private void createQueryInternal(QueryId queryId, Span querySpan, Slug slug, DispatchQuery dispatchQuery = dispatchQueryFactory.createDispatchQuery( session, sessionContext.getTransactionId(), - query, + redactedQuery, preparedQuery, slug, selectionContext.getResourceGroupId()); @@ -266,8 +274,16 @@ private void createQueryInternal(QueryId queryId, Span querySpan, Slug slug, .setSource(sessionContext.getSource().orElse(null)) .build(); } + if (redactedQuery == null) { + redactedQuery = sensitiveStatementRedactor.redact(query); + } Optional preparedSql = Optional.ofNullable(preparedQuery).flatMap(PreparedQuery::getPrepareSql); - DispatchQuery failedDispatchQuery = failedDispatchQueryFactory.createFailedDispatchQuery(session, query, preparedSql, Optional.empty(), throwable); + DispatchQuery failedDispatchQuery = failedDispatchQueryFactory.createFailedDispatchQuery( + session, + redactedQuery, + preparedSql, + Optional.empty(), + throwable); queryCreated(failedDispatchQuery); // maintain proper order of calls such that EventListener has access to QueryInfo // - add query to tracker diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 2c1afc074bd4..6759385904ec 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -113,6 +113,7 @@ import io.trino.server.ui.WorkerResource; import io.trino.spi.VersionEmbedder; import io.trino.sql.PlannerContext; +import io.trino.sql.SensitiveStatementRedactor; import io.trino.sql.analyzer.AnalyzerFactory; import io.trino.sql.analyzer.QueryExplainerFactory; import io.trino.sql.planner.OptimizerStatsMBeanExporter; @@ -304,6 +305,9 @@ List getCompositeOutputDataSizeEstimatorDelegateFac rewriteBinder.addBinding().to(ShowStatsRewrite.class).in(Scopes.SINGLETON); rewriteBinder.addBinding().to(ExplainRewrite.class).in(Scopes.SINGLETON); + // security-sensitive statement redactor + binder.bind(SensitiveStatementRedactor.class).in(Scopes.SINGLETON); + // planner binder.bind(PlanFragmenter.class).in(Scopes.SINGLETON); binder.bind(PlanOptimizersFactory.class).to(PlanOptimizers.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/main/java/io/trino/sql/SensitiveStatementRedactor.java b/core/trino-main/src/main/java/io/trino/sql/SensitiveStatementRedactor.java new file mode 100644 index 000000000000..3498374caa90 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/SensitiveStatementRedactor.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql; + +import com.google.inject.Inject; +import io.trino.FeaturesConfig; +import io.trino.connector.CatalogFactory; +import io.trino.spi.connector.ConnectorName; +import io.trino.sql.tree.AstVisitor; +import io.trino.sql.tree.CreateCatalog; +import io.trino.sql.tree.Explain; +import io.trino.sql.tree.ExplainAnalyze; +import io.trino.sql.tree.Identifier; +import io.trino.sql.tree.Node; +import io.trino.sql.tree.Property; +import io.trino.sql.tree.Statement; +import io.trino.sql.tree.StringLiteral; + +import java.util.List; +import java.util.Set; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +public class SensitiveStatementRedactor +{ + public static final String REDACTED_VALUE = "***"; + + private final boolean enabled; + private final CatalogFactory catalogFactory; + + @Inject + public SensitiveStatementRedactor(FeaturesConfig config, CatalogFactory catalogFactory) + { + this.enabled = config.isStatementRedactingEnabled(); + this.catalogFactory = catalogFactory; + } + + public String redact(String rawQuery, Statement statement) + { + if (enabled) { + RedactingVisitor visitor = new RedactingVisitor(); + Node redactedStatement = visitor.process(statement); + if (visitor.isRedacted()) { + return SqlFormatter.formatSql(redactedStatement); + } + } + return rawQuery; + } + + public String redact(String rawQuery) + { + if (enabled) { + return REDACTED_VALUE; + } + return rawQuery; + } + + private class RedactingVisitor + extends AstVisitor + { + private boolean redacted; + + public boolean isRedacted() + { + return redacted; + } + + @Override + protected Node visitNode(Node node, Void context) + { + return node; + } + + @Override + protected Node visitExplain(Explain explain, Void context) + { + Statement statement = (Statement) process(explain.getStatement()); + return new Explain(explain.getLocation().orElseThrow(), statement, explain.getOptions()); + } + + @Override + protected Node visitExplainAnalyze(ExplainAnalyze explainAnalyze, Void context) + { + Statement statement = (Statement) process(explainAnalyze.getStatement()); + return new ExplainAnalyze(explainAnalyze.getLocation().orElseThrow(), statement, explainAnalyze.isVerbose()); + } + + @Override + protected Node visitCreateCatalog(CreateCatalog createCatalog, Void context) + { + ConnectorName connectorName = new ConnectorName(createCatalog.getConnectorName().getValue()); + List redactedProperties = redact(connectorName, createCatalog.getProperties()); + return createCatalog.withProperties(redactedProperties); + } + + private List redact(ConnectorName connectorName, List properties) + { + redacted = true; + Set propertyNames = properties.stream() + .map(Property::getName) + .map(Identifier::getValue) + .collect(toImmutableSet()); + Set redactableProperties = catalogFactory.getRedactablePropertyNames(connectorName, propertyNames); + return redactProperties(properties, redactableProperties); + } + + private List redactProperties(List properties, Set redactableProperties) + { + return properties.stream() + .map(property -> { + if (redactableProperties.contains(property.getName().getValue())) { + return new Property(property.getName(), new StringLiteral(REDACTED_VALUE)); + } + return property; + }) + .collect(toImmutableList()); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerRaceWithCatalogPrune.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerRaceWithCatalogPrune.java index cc5681b4d435..aae92dcdedbe 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerRaceWithCatalogPrune.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerRaceWithCatalogPrune.java @@ -140,6 +140,12 @@ public CatalogConnector createCatalog(CatalogHandle catalogHandle, ConnectorName { throw new UnsupportedOperationException("Only implement what is needed by worker catalog manager"); } + + @Override + public Set getRedactablePropertyNames(ConnectorName connectorName, Set propertyNames) + { + throw new UnsupportedOperationException("Only implement what is needed by worker catalog manager"); + } }; private static final TaskExecutor NOOP_TASK_EXECUTOR = new TaskExecutor() { @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java index ca8b7134d1cb..ea362768c2bf 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java @@ -66,7 +66,8 @@ public void testDefaults() .setHideInaccessibleColumns(false) .setForceSpillingJoin(false) .setColumnarFilterEvaluationEnabled(true) - .setFaultTolerantExecutionExchangeEncryptionEnabled(true)); + .setFaultTolerantExecutionExchangeEncryptionEnabled(true) + .setStatementRedactingEnabled(true)); } @Test @@ -101,6 +102,7 @@ public void testExplicitPropertyMappings() .put("force-spilling-join-operator", "true") .put("experimental.columnar-filter-evaluation.enabled", "false") .put("fault-tolerant-execution-exchange-encryption-enabled", "false") + .put("statement-redacting-enabled", "false") .buildOrThrow(); FeaturesConfig expected = new FeaturesConfig() @@ -131,7 +133,8 @@ public void testExplicitPropertyMappings() .setHideInaccessibleColumns(true) .setForceSpillingJoin(true) .setColumnarFilterEvaluationEnabled(false) - .setFaultTolerantExecutionExchangeEncryptionEnabled(false); + .setFaultTolerantExecutionExchangeEncryptionEnabled(false) + .setStatementRedactingEnabled(false); assertFullMapping(properties, expected); } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/CreateCatalog.java b/core/trino-parser/src/main/java/io/trino/sql/tree/CreateCatalog.java index d443a07a831e..6b032a374537 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/CreateCatalog.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/CreateCatalog.java @@ -50,6 +50,18 @@ public CreateCatalog( this.comment = requireNonNull(comment, "comment is null"); } + public CreateCatalog withProperties(List properties) + { + return new CreateCatalog( + getLocation().orElseThrow(), + catalogName, + notExists, + connectorName, + properties, + principal, + comment); + } + public Identifier getCatalogName() { return catalogName; diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java index 3088d39af799..d539831de660 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java @@ -316,7 +316,11 @@ public void testAnalysisFailure() public void testParseError() throws Exception { - assertFailedQuery("You shall not parse!", "line 1:1: mismatched input 'You'. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', 'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', 'WITH', "); + assertFailedQuery( + getSession(), + "You shall not parse!", + "line 1:1: mismatched input 'You'. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', 'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', 'WITH', ", + "***"); } @Test @@ -395,16 +399,22 @@ private Optional findQueryId(String queryPattern) private void assertFailedQuery(@Language("SQL") String sql, String expectedFailure) throws Exception { - assertFailedQuery(getSession(), sql, expectedFailure); + assertFailedQuery(getSession(), sql, expectedFailure, sql); } private void assertFailedQuery(Session session, @Language("SQL") String sql, String expectedFailure) throws Exception + { + assertFailedQuery(session, sql, expectedFailure, sql); + } + + private void assertFailedQuery(Session session, @Language("SQL") String sql, String expectedFailure, String redactedQuery) + throws Exception { QueryEvents queryEvents = queries.runQueryAndWaitForEvents(sql, session, Optional.of(expectedFailure)).getQueryEvents(); QueryCompletedEvent queryCompletedEvent = queryEvents.getQueryCompletedEvent(); - assertThat(queryCompletedEvent.getMetadata().getQuery()).isEqualTo(sql); + assertThat(queryCompletedEvent.getMetadata().getQuery()).isEqualTo(redactedQuery); QueryFailureInfo failureInfo = queryCompletedEvent.getFailureInfo() .orElseThrow(() -> new AssertionError("Expected query event to be failed")); diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestRedactSensitiveStatements.java b/testing/trino-tests/src/test/java/io/trino/execution/TestRedactSensitiveStatements.java new file mode 100644 index 000000000000..3384cdbf4048 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestRedactSensitiveStatements.java @@ -0,0 +1,201 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.execution.EventsCollector.QueryEvents; +import io.trino.execution.TestEventListenerPlugin.TestingEventListenerPlugin; +import io.trino.plugin.jdbc.JdbcPlugin; +import io.trino.plugin.jdbc.TestingH2JdbcModule; +import io.trino.plugin.tpch.TpchPlugin; +import io.trino.spi.eventlistener.QueryCompletedEvent; +import io.trino.spi.eventlistener.QueryCreatedEvent; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.QueryRunner; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +import static io.trino.plugin.jdbc.TestingH2JdbcModule.createH2ConnectionUrl; +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestRedactSensitiveStatements + extends AbstractTestQueryFramework +{ + private final EventsCollector generatedEvents = new EventsCollector(); + private EventsAwaitingQueries queries; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .build(); + + QueryRunner queryRunner = DistributedQueryRunner.builder(session) + .setWorkerCount(0) + .build(); + queryRunner.installPlugin(new TestingEventListenerPlugin(generatedEvents)); + queryRunner.installPlugin(new JdbcPlugin("jdbc", TestingH2JdbcModule::new)); + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch", ImmutableMap.of()); + + queries = new EventsAwaitingQueries(generatedEvents, queryRunner); + + return queryRunner; + } + + @Test + void testSyntacticallyInvalidStatementThatMayBeSensitive() + throws Exception + { + String catalog = "catalog_" + randomNameSuffix(); + QueryEvents queryEvents = runQueryAndWaitForEvents(""" + CREATE CCATALOG %s USING jdbc + WITH ( + "connection-user" = 'bob', + "connection-password" = '1234' + )""".formatted(catalog), ".*mismatched input 'CCATALOG'.*"); + + assertRedactedQuery("***", queryEvents); + } + + @Test + void testSyntacticallyInvalidStatementThatShouldNotBeSensitive() + throws Exception + { + QueryEvents queryEvents = runQueryAndWaitForEvents("SELECTT * FROM nation WHERE name = '\"password\" = ''1234'''", ".*mismatched input 'SELECTT'.*"); + + assertRedactedQuery("***", queryEvents); + } + + @Test + void testStatementThatShouldNotBeSensitive() + throws Exception + { + QueryEvents queryEvents = runQueryAndWaitForEvents("SELECT * FROM nation WHERE name = '\"password\" = ''1234'''"); + + assertRedactedQuery("SELECT * FROM nation WHERE name = '\"password\" = ''1234'''", queryEvents); + } + + @Test + void testUnsupportedStatement() + throws Exception + { + String catalog = "catalog_" + randomNameSuffix(); + QueryEvents queryEvents = runQueryAndWaitForEvents(""" + EXPLAIN ANALYZE CREATE CATALOG %s USING jdbc + WITH ( + "user" = 'bob', + "password" = '1234' + )""".formatted(catalog), "EXPLAIN ANALYZE doesn't support statement type: CreateCatalog"); + + assertRedactedQuery("***", queryEvents); + } + + @Test + void testCreateCatalog() + throws Exception + { + String catalog = "catalog_" + randomNameSuffix(); + String connectionUrl = createH2ConnectionUrl(); + QueryEvents queryEvents = runQueryAndWaitForEvents(""" + CREATE CATALOG %s USING jdbc + WITH ( + "connection-url" = '%s', + "connection-user" = 'bob', + "connection-password" = '1234' + )""".formatted(catalog, connectionUrl)); + + assertRedactedQuery(""" + CREATE CATALOG %s USING jdbc + WITH ( + "connection-url" = '%s', + "connection-user" = 'bob', + "connection-password" = '***' + )""".formatted(catalog, connectionUrl), queryEvents); + } + + @Test + void testCreateCatalogForNonExistentConnector() + throws Exception + { + String catalog = "catalog_" + randomNameSuffix(); + QueryEvents queryEvents = runQueryAndWaitForEvents(""" + CREATE CATALOG %s USING nonexistent_connector + WITH ( + "user" = 'bob', + "password" = '1234' + )""".formatted(catalog), "No factory for connector 'nonexistent_connector'.*"); + + assertRedactedQuery(""" + CREATE CATALOG %s USING nonexistent_connector + WITH ( + "user" = '***', + "password" = '***' + )""".formatted(catalog), queryEvents); + } + + @Test + void testExplainCreateCatalog() + throws Exception + { + String catalog = "catalog_" + randomNameSuffix(); + String connectionUrl = createH2ConnectionUrl(); + QueryEvents queryEvents = runQueryAndWaitForEvents(""" + EXPLAIN CREATE CATALOG %s USING jdbc + WITH ( + "connection-url" = '%s', + "connection-user" = 'bob', + "connection-password" = '1234' + )""".formatted(catalog, connectionUrl)); + + assertRedactedQuery(""" + EXPLAIN\s + CREATE CATALOG %s USING jdbc + WITH ( + "connection-url" = '%s', + "connection-user" = 'bob', + "connection-password" = '***' + )""".formatted(catalog, connectionUrl), queryEvents); + } + + private static void assertRedactedQuery(String expectedQuery, QueryEvents queryEvents) + { + QueryCreatedEvent queryCreatedEvent = queryEvents.getQueryCreatedEvent(); + assertThat(queryCreatedEvent.getMetadata().getQuery()).isEqualTo(expectedQuery); + QueryCompletedEvent queryCompletedEvent = queryEvents.getQueryCompletedEvent(); + assertThat(queryCompletedEvent.getMetadata().getQuery()).isEqualTo(expectedQuery); + } + + private QueryEvents runQueryAndWaitForEvents(@Language("SQL") String sql) + throws Exception + { + return queries.runQueryAndWaitForEvents(sql, getSession()).getQueryEvents(); + } + + private QueryEvents runQueryAndWaitForEvents(@Language("SQL") String sql, String expectedExceptionRegEx) + throws Exception + { + return queries.runQueryAndWaitForEvents(sql, getSession(), Optional.of(expectedExceptionRegEx)).getQueryEvents(); + } +} From 042afdf44fccf33d2c23452e7644898053418562 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Mon, 23 Dec 2024 11:11:03 +0100 Subject: [PATCH 4/5] Ensure queries in system.runtime.queries are redacted --- .../trino/connector/MockConnectorFactory.java | 18 ++++++++++++++ .../runtime/TestSystemRuntimeConnector.java | 24 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java index 7a0c8e090360..0a29cfc707bb 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.AggregationApplicationResult; import io.trino.spi.connector.CatalogSchemaTableName; @@ -140,6 +141,7 @@ public class MockConnectorFactory private final Supplier>> columnProperties; private final Optional partitioningProvider; private final Function tableFunctionSplitsSources; + private final Set redactablePropertyNames; // access control private final ListRoleGrants roleGrants; @@ -196,6 +198,7 @@ private MockConnectorFactory( Supplier>> tableProperties, Supplier>> columnProperties, Optional partitioningProvider, + Set redactablePropertyNames, ListRoleGrants roleGrants, Optional accessControl, boolean allowMissingColumnsOnInsert, @@ -244,6 +247,7 @@ private MockConnectorFactory( this.tableProperties = requireNonNull(tableProperties, "tableProperties is null"); this.columnProperties = requireNonNull(columnProperties, "columnProperties is null"); this.partitioningProvider = requireNonNull(partitioningProvider, "partitioningProvider is null"); + this.redactablePropertyNames = requireNonNull(redactablePropertyNames, "redactablePropertyNames is null"); this.roleGrants = requireNonNull(roleGrants, "roleGrants is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.data = requireNonNull(data, "data is null"); @@ -325,6 +329,12 @@ public Connector create(String catalogName, Map config, Connecto allowSplittingReadIntoMultipleSubQueries); } + @Override + public Set getRedactablePropertyNames(Set propertyNames) + { + return Sets.intersection(redactablePropertyNames, propertyNames); + } + public static MockConnectorFactory create() { return builder().build(); @@ -465,6 +475,7 @@ public static final class Builder private Supplier>> columnProperties = ImmutableList::of; private Optional partitioningProvider = Optional.empty(); private Function tableFunctionSplitsSources = handle -> null; + private Set redactablePropertyNames = ImmutableSet.of(); // access control private boolean provideAccessControl; @@ -844,6 +855,12 @@ public Builder withAllowSplittingReadIntoMultipleSubQueries(boolean allowSplitti return this; } + public Builder withRedactablePropertyNames(Set redactablePropertyNames) + { + this.redactablePropertyNames = redactablePropertyNames; + return this; + } + public MockConnectorFactory build() { Optional accessControl = Optional.empty(); @@ -895,6 +912,7 @@ public MockConnectorFactory build() tableProperties, columnProperties, partitioningProvider, + redactablePropertyNames, roleGrants, accessControl, allowMissingColumnsOnInsert, diff --git a/testing/trino-tests/src/test/java/io/trino/connector/system/runtime/TestSystemRuntimeConnector.java b/testing/trino-tests/src/test/java/io/trino/connector/system/runtime/TestSystemRuntimeConnector.java index e5a386dbacd3..26046e059a47 100644 --- a/testing/trino-tests/src/test/java/io/trino/connector/system/runtime/TestSystemRuntimeConnector.java +++ b/testing/trino-tests/src/test/java/io/trino/connector/system/runtime/TestSystemRuntimeConnector.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.SettableFuture; import io.airlift.units.Duration; @@ -48,6 +49,7 @@ import static io.airlift.concurrent.Threads.threadsNamed; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.testing.TestingNames.randomNameSuffix; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.assertions.Assert.assertEventually; import static java.lang.String.format; @@ -89,6 +91,7 @@ public Iterable getConnectorFactories() .withGetViews((session, schemaTablePrefix) -> ImmutableMap.of()) .withListTables((session, s) -> ImmutableList.of("test_table")) .withGetColumns(tableName -> getColumns.apply(tableName)) + .withRedactablePropertyNames(ImmutableSet.of("password")) .build(); return ImmutableList.of(connectorFactory); } @@ -299,6 +302,27 @@ public void testTasksTable() getQueryRunner().execute("SELECT * FROM system.runtime.tasks"); } + @Test + public void testRedactedRuntimeQueries() + { + String catalog = "catalog_" + randomNameSuffix(); + getQueryRunner().execute(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '1234' + )""".formatted(catalog)); + + assertQuery( + format("SELECT query FROM system.runtime.queries WHERE query LIKE '%%%s%%' AND query NOT LIKE '%%system.runtime.queries%%'", catalog), + """ + VALUES 'CREATE CATALOG %s USING mock + WITH ( + "user" = ''bob'', + "password" = ''***'' + )'""".formatted(catalog)); + } + private static void run(int repetitions, double successRate, Runnable test) { AssertionError lastError = null; From ed595a1cab681ce44f03cfbe1015e45cc87032b3 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Mon, 23 Dec 2024 11:11:28 +0100 Subject: [PATCH 5/5] Ensure queries returned via REST API are redacted @JsonConstructor for TrimmedBasicQueryInfo was introduced to facilitate the deserialization of server responses in tests. --- .../server/ui/TrimmedBasicQueryInfo.java | 40 +++++ .../io/trino/server/TestQueryResource.java | 48 +++++ .../server/TestQueryStateInfoResource.java | 58 +++++- .../TestResourceGroupStateInfoResource.java | 133 ++++++++++++++ .../trino/server/ui/TestUiQueryResource.java | 166 ++++++++++++++++++ 5 files changed, 441 insertions(+), 4 deletions(-) create mode 100644 core/trino-main/src/test/java/io/trino/server/TestResourceGroupStateInfoResource.java create mode 100644 core/trino-main/src/test/java/io/trino/server/ui/TestUiQueryResource.java diff --git a/core/trino-main/src/main/java/io/trino/server/ui/TrimmedBasicQueryInfo.java b/core/trino-main/src/main/java/io/trino/server/ui/TrimmedBasicQueryInfo.java index ab2cff5e9300..2982740949e5 100644 --- a/core/trino-main/src/main/java/io/trino/server/ui/TrimmedBasicQueryInfo.java +++ b/core/trino-main/src/main/java/io/trino/server/ui/TrimmedBasicQueryInfo.java @@ -13,6 +13,7 @@ */ package io.trino.server.ui; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.errorprone.annotations.Immutable; import io.trino.execution.QueryState; @@ -54,6 +55,45 @@ public class TrimmedBasicQueryInfo private final Optional queryType; private final RetryPolicy retryPolicy; + @JsonCreator + public TrimmedBasicQueryInfo( + @JsonProperty("queryId") QueryId queryId, + @JsonProperty("sessionUser") String sessionUser, + @JsonProperty("sessionPrincipal") Optional sessionPrincipal, + @JsonProperty("sessionSource") Optional sessionSource, + @JsonProperty("resourceGroupId") Optional resourceGroupId, + @JsonProperty("queryDataEncoding") Optional queryDataEncoding, + @JsonProperty("state") QueryState state, + @JsonProperty("scheduled") boolean scheduled, + @JsonProperty("self") URI self, + @JsonProperty("queryTextPreview") String queryTextPreview, + @JsonProperty("updateType") Optional updateType, + @JsonProperty("preparedQuery") Optional preparedQuery, + @JsonProperty("queryStats") BasicQueryStats queryStats, + @JsonProperty("errorType") Optional errorType, + @JsonProperty("errorCode") Optional errorCode, + @JsonProperty("queryType") Optional queryType, + @JsonProperty("retryPolicy") RetryPolicy retryPolicy) + { + this.queryId = requireNonNull(queryId, "queryId is null"); + this.sessionUser = requireNonNull(sessionUser, "sessionUser is null"); + this.sessionPrincipal = requireNonNull(sessionPrincipal, "sessionPrincipal is null"); + this.sessionSource = requireNonNull(sessionSource, "sessionSource is null"); + this.resourceGroupId = requireNonNull(resourceGroupId, "resourceGroupId is null"); + this.queryDataEncoding = requireNonNull(queryDataEncoding, "queryDataEncoding is null"); + this.state = requireNonNull(state, "state is null"); + this.scheduled = scheduled; + this.self = requireNonNull(self, "self is null"); + this.queryTextPreview = requireNonNull(queryTextPreview, "queryTextPreview is null"); + this.updateType = requireNonNull(updateType, "updateType is null"); + this.preparedQuery = requireNonNull(preparedQuery, "preparedQuery is null"); + this.queryStats = requireNonNull(queryStats, "queryStats is null"); + this.errorType = requireNonNull(errorType, "errorType is null"); + this.errorCode = requireNonNull(errorCode, "errorCode is null"); + this.queryType = requireNonNull(queryType, "queryType is null"); + this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null"); + } + public TrimmedBasicQueryInfo(BasicQueryInfo queryInfo) { this.queryId = requireNonNull(queryInfo.getQueryId(), "queryId is null"); diff --git a/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java b/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java index d837b0624b07..f3452f8254d7 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java @@ -13,6 +13,7 @@ */ package io.trino.server; +import com.google.common.collect.ImmutableSet; import com.google.inject.Key; import io.airlift.http.client.HttpClient; import io.airlift.http.client.HttpUriBuilder; @@ -29,6 +30,8 @@ import io.trino.client.QueryDataClientJacksonModule; import io.trino.client.QueryResults; import io.trino.client.ResultRowsDecoder; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorPlugin; import io.trino.execution.QueryInfo; import io.trino.plugin.tpch.TpchPlugin; import io.trino.server.testing.TestingTrinoServer; @@ -65,6 +68,7 @@ import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.KILL_QUERY; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.VIEW_QUERY; import static io.trino.testing.TestingAccessControlManager.privilege; +import static io.trino.testing.TestingNames.randomNameSuffix; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -94,6 +98,9 @@ public void setup() { client = new JettyHttpClient(); server = TestingTrinoServer.create(); + server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withRedactablePropertyNames(ImmutableSet.of("password")) + .build())); server.installPlugin(new TpchPlugin()); server.createCatalog("tpch", "tpch"); } @@ -226,6 +233,47 @@ public void testGetQueryInfoExecutionFailure() assertThat(info.getFailureInfo().getErrorCode()).isEqualTo(DIVISION_BY_ZERO.toErrorCode()); } + @Test + public void testGetQueryInfosWithRedactedSecrets() + { + String catalog = "catalog_" + randomNameSuffix(); + runToCompletion(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '1234' + )""".formatted(catalog)); + + List infos = getQueryInfos("/v1/query"); + assertThat(infos.size()).isEqualTo(1); + assertThat(infos.getFirst().getQuery()).isEqualTo(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '***' + )""".formatted(catalog)); + } + + @Test + public void testGetQueryInfoWithRedactedSecrets() + { + String catalog = "catalog_" + randomNameSuffix(); + String queryId = runToCompletion(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '1234' + )""".formatted(catalog)); + + QueryInfo queryInfo = getQueryInfo(queryId); + assertThat(queryInfo.getQuery()).isEqualTo(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '***' + )""".formatted(catalog)); + } + @Test public void testCancel() { diff --git a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfoResource.java b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfoResource.java index 09bc72e3b61f..584d56f579ec 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfoResource.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfoResource.java @@ -13,6 +13,7 @@ */ package io.trino.server; +import com.google.common.collect.ImmutableSet; import com.google.common.io.Closer; import io.airlift.http.client.HttpClient; import io.airlift.http.client.Request; @@ -23,6 +24,8 @@ import io.airlift.json.ObjectMapperProvider; import io.airlift.units.Duration; import io.trino.client.QueryResults; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorPlugin; import io.trino.plugin.tpch.TpchPlugin; import io.trino.server.protocol.spooling.QueryDataJacksonModule; import io.trino.server.testing.TestingTrinoServer; @@ -47,6 +50,7 @@ import static io.airlift.json.JsonCodec.listJsonCodec; import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; import static io.trino.execution.QueryState.FAILED; +import static io.trino.execution.QueryState.FINISHING; import static io.trino.execution.QueryState.RUNNING; import static io.trino.server.TestQueryResource.BASIC_QUERY_INFO_CODEC; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.VIEW_QUERY; @@ -71,11 +75,15 @@ public class TestQueryStateInfoResource private TestingTrinoServer server; private HttpClient client; private QueryResults queryResults; + private QueryResults createCatalogResults; @BeforeAll public void setUp() { server = TestingTrinoServer.create(); + server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withRedactablePropertyNames(ImmutableSet.of("password")) + .build())); server.installPlugin(new TpchPlugin()); server.createCatalog("tpch", "tpch"); client = new JettyHttpClient(); @@ -96,6 +104,19 @@ public void setUp() QueryResults queryResults2 = client.execute(request2, createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); client.execute(prepareGet().setUri(queryResults2.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); + Request createCatalogRequest = preparePost() + .setUri(uriBuilderFrom(server.getBaseUrl()).replacePath("/v1/statement").build()) + .setBodyGenerator(createStaticBodyGenerator(""" + CREATE CATALOG test_catalog USING mock + WITH ( + "user" = 'bob', + "password" = '1234' + )""", UTF_8)) + .setHeader(TRINO_HEADERS.requestUser(), "catalogCreator") + .build(); + createCatalogResults = client.execute(createCatalogRequest, createJsonResponseHandler(jsonCodec(QueryResults.class))); + client.execute(prepareGet().setUri(createCatalogResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); + // queries are started in the background, so they may not all be immediately visible long start = System.nanoTime(); while (Duration.nanosSince(start).compareTo(new Duration(5, MINUTES)) < 0) { @@ -105,8 +126,8 @@ public void setUp() .setHeader(TRINO_HEADERS.requestUser(), "unknown") .build(), createJsonResponseHandler(BASIC_QUERY_INFO_CODEC)); - if (queryInfos.size() == 2) { - if (queryInfos.stream().allMatch(info -> info.getState() == RUNNING)) { + if (queryInfos.size() == 3) { + if (queryInfos.stream().allMatch(info -> info.getState() == RUNNING || info.getState() == FINISHING)) { break; } @@ -143,7 +164,12 @@ public void testGetAllQueryStateInfos() .build(), createJsonResponseHandler(listJsonCodec(QueryStateInfo.class))); - assertThat(infos).hasSize(2); + assertThat(infos.size()).isEqualTo(3); + QueryStateInfo createCatalogInfo = infos.stream() + .filter(info -> info.getQueryId().getId().equals(createCatalogResults.getId())) + .findFirst() + .orElse(null); + assertCreateCatalogQueryIsRedacted(createCatalogInfo); } @Test @@ -185,6 +211,19 @@ public void testGetQueryStateInfo() assertThat(info).isNotNull(); } + @Test + public void testGetQueryStateInfoWithRedactedSecrets() + { + QueryStateInfo info = client.execute( + prepareGet() + .setUri(server.resolve("/v1/queryState/" + createCatalogResults.getId())) + .setHeader(TRINO_HEADERS.requestUser(), "unknown") + .build(), + createJsonResponseHandler(jsonCodec(QueryStateInfo.class))); + + assertCreateCatalogQueryIsRedacted(info); + } + @Test public void testGetAllQueryStateInfosDenied() { @@ -194,7 +233,7 @@ public void testGetAllQueryStateInfosDenied() .setHeader(TRINO_HEADERS.requestUser(), "any-other-user") .build(), createJsonResponseHandler(listJsonCodec(QueryStateInfo.class))); - assertThat(infos).hasSize(2); + assertThat(infos).hasSize(3); testGetAllQueryStateInfosDenied("user1", 1); testGetAllQueryStateInfosDenied("any-other-user", 0); @@ -249,4 +288,15 @@ public void testGetQueryStateInfoNo() .isInstanceOf(UnexpectedResponseException.class) .hasMessageMatching("Expected response code .*, but was 404"); } + + private static void assertCreateCatalogQueryIsRedacted(QueryStateInfo info) + { + assertThat(info).isNotNull(); + assertThat(info.getQuery()).isEqualTo(""" + CREATE CATALOG test_catalog USING mock + WITH ( + "user" = 'bob', + "password" = '***' + )"""); + } } diff --git a/core/trino-main/src/test/java/io/trino/server/TestResourceGroupStateInfoResource.java b/core/trino-main/src/test/java/io/trino/server/TestResourceGroupStateInfoResource.java new file mode 100644 index 000000000000..ae9cf42487a2 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/server/TestResourceGroupStateInfoResource.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.Request; +import io.airlift.http.client.jetty.JettyHttpClient; +import io.trino.client.QueryResults; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorPlugin; +import io.trino.server.testing.TestingTrinoServer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.net.URI; +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static io.airlift.http.client.Request.Builder.prepareGet; +import static io.airlift.http.client.Request.Builder.preparePost; +import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; +import static io.airlift.json.JsonCodec.jsonCodec; +import static io.airlift.testing.Closeables.closeAll; +import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +final class TestResourceGroupStateInfoResource +{ + private TestingTrinoServer server; + private HttpClient client; + + @BeforeAll + public void setup() + { + client = new JettyHttpClient(); + server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .put("web-ui.authentication.type", "fixed") + .put("web-ui.user", "test-user") + .buildOrThrow()) + .build(); + server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withRedactablePropertyNames(ImmutableSet.of("password")) + .build())); + } + + @AfterAll + public void teardown() + throws Exception + { + closeAll(server, client); + server = null; + client = null; + } + + @Test + void testGetResourceGroupInfoWithRedactedSecrets() + { + String catalog = "catalog_" + randomNameSuffix(); + startQuery(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '1234' + )""".formatted(catalog)); + + ResourceGroupInfo resourceGroupInfo = getResourceGroupInfo("global"); + Optional> queryStateInfos = resourceGroupInfo.runningQueries(); + assertThat(queryStateInfos.isPresent()).isTrue(); + List queryStates = queryStateInfos.get(); + assertThat(queryStates.size()).isEqualTo(1); + assertThat(queryStates.getFirst().getQuery()).isEqualTo(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '***' + )""".formatted(catalog)); + } + + private void startQuery(String sql) + { + Request request = preparePost() + .setUri(server.resolve("/v1/statement")) + .setBodyGenerator(createStaticBodyGenerator(sql, UTF_8)) + .setHeader(TRINO_HEADERS.requestUser(), "unknown") + .build(); + QueryResults queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + checkState(queryResults.getNextUri() != null && queryResults.getNextUri().toString().contains("/v1/statement/queued/"), "nextUri should point to /v1/statement/queued/"); + request = prepareGet() + .setHeader(TRINO_HEADERS.requestUser(), "unknown") + .setUri(queryResults.getNextUri()) + .build(); + client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + } + + private ResourceGroupInfo getResourceGroupInfo(String resourceGroupId) + { + URI uri = uriBuilderFrom(server.getBaseUrl()) + .replacePath("/v1/resourceGroupState") + .appendPath(resourceGroupId) + .build(); + Request request = prepareGet() + .setUri(uri) + .setHeader(TRINO_HEADERS.requestUser(), "unknown") + .build(); + return client.execute(request, createJsonResponseHandler(jsonCodec(ResourceGroupInfo.class))); + } +} diff --git a/core/trino-main/src/test/java/io/trino/server/ui/TestUiQueryResource.java b/core/trino-main/src/test/java/io/trino/server/ui/TestUiQueryResource.java new file mode 100644 index 000000000000..55cf83418e70 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/server/ui/TestUiQueryResource.java @@ -0,0 +1,166 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.ui; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.inject.Key; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.Request; +import io.airlift.http.client.jetty.JettyHttpClient; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.trino.client.QueryResults; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorPlugin; +import io.trino.execution.QueryInfo; +import io.trino.server.testing.TestingTrinoServer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.net.URI; +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static io.airlift.http.client.Request.Builder.prepareGet; +import static io.airlift.http.client.Request.Builder.preparePost; +import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; +import static io.airlift.json.JsonCodec.jsonCodec; +import static io.airlift.json.JsonCodec.listJsonCodec; +import static io.airlift.testing.Closeables.closeAll; +import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +final class TestUiQueryResource +{ + private TestingTrinoServer server; + private HttpClient client; + + @BeforeAll + public void setup() + { + client = new JettyHttpClient(); + server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .put("web-ui.authentication.type", "fixed") + .put("web-ui.user", "test-user") + .buildOrThrow()) + .build(); + server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withRedactablePropertyNames(ImmutableSet.of("password")) + .build())); + } + + @AfterAll + public void teardown() + throws Exception + { + closeAll(server, client); + server = null; + client = null; + } + + @Test + void testGetQueryInfosWithRedactedSecrets() + { + String catalog = "catalog_" + randomNameSuffix(); + String queryId = runToCompletion(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '1234' + )""".formatted(catalog)); + + List infos = getQueryInfos().stream() + .filter(info -> info.getQueryId().getId().equals(queryId)) + .collect(toImmutableList()); + assertThat(infos.size()).isEqualTo(1); + assertThat(infos.getFirst().getQueryTextPreview()).isEqualTo(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '***' + )""".formatted(catalog)); + } + + @Test + void testGetQueryInfoWithRedactedSecrets() + { + String catalog = "catalog_" + randomNameSuffix(); + String queryId = runToCompletion(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '1234' + )""".formatted(catalog)); + + QueryInfo queryInfo = getQueryInfo(queryId); + assertThat(queryInfo.getQuery()).isEqualTo(""" + CREATE CATALOG %s USING mock + WITH ( + "user" = 'bob', + "password" = '***' + )""".formatted(catalog)); + } + + private String runToCompletion(String sql) + { + Request request = preparePost() + .setHeader(TRINO_HEADERS.requestUser(), "unknown") + .setUri(server.getBaseUrl().resolve("/v1/statement")) + .setBodyGenerator(createStaticBodyGenerator(sql, UTF_8)) + .build(); + QueryResults queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + while (queryResults.getNextUri() != null) { + request = prepareGet() + .setHeader(TRINO_HEADERS.requestUser(), "unknown") + .setUri(queryResults.getNextUri()) + .build(); + queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + } + return queryResults.getId(); + } + + private List getQueryInfos() + { + Request request = prepareGet() + .setUri(server.resolve("/ui/api/query")) + .build(); + return client.execute(request, createJsonResponseHandler(listJsonCodec(TrimmedBasicQueryInfo.class))); + } + + private QueryInfo getQueryInfo(String queryId) + { + URI uri = uriBuilderFrom(server.getBaseUrl()) + .replacePath("/ui/api/query") + .appendPath(queryId) + .build(); + Request request = prepareGet() + .setUri(uri) + .build(); + JsonCodec codec = server.getInstance(Key.get(JsonCodecFactory.class)).jsonCodec(QueryInfo.class); + return client.execute(request, createJsonResponseHandler(codec)); + } +}