From 5bff45dc7e2d265ac62770b5e8208e59e0bf6ab0 Mon Sep 17 00:00:00 2001 From: Aimery Methena <159072740+aimethed@users.noreply.github.com> Date: Fri, 26 Apr 2024 10:40:01 -0400 Subject: [PATCH] Query Passthrough for Redis Connector (#1920) --- athena-redis/athena-redis.yaml | 20 ++++ .../redis/RedisMetadataHandler.java | 46 ++++++++ .../connectors/redis/RedisRecordHandler.java | 109 +++++++++++++++++- .../redis/lettuce/RedisCommandsWrapper.java | 11 ++ .../redis/qpt/RedisQueryPassthrough.java | 69 +++++++++++ 5 files changed, 251 insertions(+), 4 deletions(-) create mode 100644 athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/qpt/RedisQueryPassthrough.java diff --git a/athena-redis/athena-redis.yaml b/athena-redis/athena-redis.yaml index 7d66e4e426..c8e8c6458f 100644 --- a/athena-redis/athena-redis.yaml +++ b/athena-redis/athena-redis.yaml @@ -49,6 +49,22 @@ Parameters: Description: "(Optional) An IAM policy ARN to use as the PermissionsBoundary for the created Lambda function's execution role" Default: '' Type: String + QPTConnectionEndpoint: + Description: "(Optional) The hostname:port:password of the Redis server that contains data for this table optionally using SecretsManager (e.g. ${secret_name}). Used for Query Pass Through queries only." + Default: '' + Type: String + QPTConnectionSSL: + Description: "(Optional) When True, creates a Redis connection that uses SSL/TLS. Used for Query Pass Through queries only." + Default: 'false' + Type: String + QPTConnectionCluster: + Description: "(Optional) When True, enables support for clustered Redis instances. Used for Query Pass Through queries only." + Default: 'false' + Type: String + QPTConnectionDBNumber: + Description: "(Optional) Set this number (for example 1, 2, or 3) to read from a non-default Redis database. Used for Query Pass Through queries only." + Default: 0 + Type: Number Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] Resources: @@ -60,6 +76,10 @@ Resources: disable_spill_encryption: !Ref DisableSpillEncryption spill_bucket: !Ref SpillBucket spill_prefix: !Ref SpillPrefix + qpt_endpoint: !Ref QPTConnectionEndpoint + qpt_ssl: !Ref QPTConnectionSSL + qpt_cluster: !Ref QPTConnectionCluster + qpt_db_number: !Ref QPTConnectionDBNumber FunctionName: !Ref AthenaCatalogName Handler: "com.amazonaws.athena.connectors.redis.RedisCompositeHandler" CodeUri: "./target/athena-redis-2022.47.1.jar" diff --git a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandler.java b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandler.java index 7623dc6218..b0d6be7d06 100644 --- a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandler.java +++ b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandler.java @@ -29,6 +29,8 @@ import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation; import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler; +import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest; +import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse; import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest; @@ -39,15 +41,18 @@ import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse; import com.amazonaws.athena.connector.lambda.metadata.glue.DefaultGlueType; +import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.redis.lettuce.RedisCommandsWrapper; import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionFactory; import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper; +import com.amazonaws.athena.connectors.redis.qpt.RedisQueryPassthrough; import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.glue.AWSGlue; import com.amazonaws.services.glue.model.Database; import com.amazonaws.services.glue.model.Table; import com.amazonaws.services.secretsmanager.AWSSecretsManager; +import com.google.common.collect.ImmutableMap; import io.lettuce.core.KeyScanCursor; import io.lettuce.core.Range; import io.lettuce.core.ScanArgs; @@ -62,6 +67,7 @@ import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -93,6 +99,14 @@ public class RedisMetadataHandler protected static final String SPLIT_START_INDEX = "start-index"; protected static final String SPLIT_END_INDEX = "end-index"; + //The name of the column added for Query Passthrough requests + protected static final String QPT_COLUMN_NAME = "_value_"; + //The names of the lambda environment variables used for QPT connections + protected static final String QPT_ENDPOINT_ENV_VAR = "qpt_endpoint"; + protected static final String QPT_SSL_ENV_VAR = "qpt_ssl"; + protected static final String QPT_CLUSTER_ENV_VAR = "qpt_cluster"; + protected static final String QPT_DB_NUMBER_ENV_VAR = "qpt_db_number"; + //Defines the table property name used to set the Redis Key Type for the table. (e.g. prefix, zset) protected static final String KEY_TYPE = "redis-key-type"; //Defines the table property name used to set the Redis value type for the table. (e.g. liternal, zset, hash) @@ -125,6 +139,8 @@ public class RedisMetadataHandler private final AWSGlue awsGlue; private final RedisConnectionFactory redisConnectionFactory; + private final RedisQueryPassthrough queryPassthrough = new RedisQueryPassthrough(); + public RedisMetadataHandler(java.util.Map configOptions) { super(SOURCE_TYPE, configOptions); @@ -148,6 +164,15 @@ protected RedisMetadataHandler( this.awsGlue = awsGlue; this.redisConnectionFactory = redisConnectionFactory; } + + @Override + public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request) + { + ImmutableMap.Builder> capabilities = ImmutableMap.builder(); + queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions); + + return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build()); + } /** * Used to obtain a Redis client connection for the provided endpoint. @@ -208,6 +233,27 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques return new GetTableResponse(response.getCatalogName(), response.getTableName(), schemaBuilder.build()); } + @Override + public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception + { + if (!request.isQueryPassthrough()) { + throw new IllegalArgumentException("No Query passed through [{}]" + request); + } + Map queryPassthroughArgs = request.getQueryPassthroughArguments(); + queryPassthrough.verify(queryPassthroughArgs); + + SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); + schemaBuilder.addField(QPT_COLUMN_NAME, Types.MinorType.VARCHAR.getType()); + schemaBuilder.addMetadata(VALUE_TYPE_TABLE_PROP, "literal"); + schemaBuilder.addMetadata(KEY_PREFIX_TABLE_PROP, "*"); + schemaBuilder.addMetadata(REDIS_SSL_FLAG, configOptions.get(QPT_SSL_ENV_VAR)); + schemaBuilder.addMetadata(REDIS_ENDPOINT_PROP, configOptions.get(QPT_ENDPOINT_ENV_VAR)); + schemaBuilder.addMetadata(REDIS_CLUSTER_FLAG, configOptions.get(QPT_CLUSTER_ENV_VAR)); + schemaBuilder.addMetadata(REDIS_DB_NUMBER, configOptions.get(QPT_DB_NUMBER_ENV_VAR)); + + return new GetTableResponse(request.getCatalogName(), request.getTableName(), schemaBuilder.build()); + } + @Override public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTableLayoutRequest request) { diff --git a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java index b64ee82a2c..dfb490a3a6 100644 --- a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java +++ b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java @@ -28,6 +28,7 @@ import com.amazonaws.athena.connectors.redis.lettuce.RedisCommandsWrapper; import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionFactory; import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper; +import com.amazonaws.athena.connectors.redis.qpt.RedisQueryPassthrough; import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.athena.AmazonAthenaClientBuilder; import com.amazonaws.services.s3.AmazonS3; @@ -39,6 +40,7 @@ import io.lettuce.core.ScanCursor; import io.lettuce.core.ScoredValue; import io.lettuce.core.ScoredValueScanCursor; +import io.lettuce.core.ScriptOutputType; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.slf4j.Logger; @@ -54,6 +56,7 @@ import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.KEY_COLUMN_NAME; import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.KEY_PREFIX_TABLE_PROP; import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.KEY_TYPE; +import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.QPT_COLUMN_NAME; import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.REDIS_CLUSTER_FLAG; import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.REDIS_DB_NUMBER; import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.REDIS_ENDPOINT_PROP; @@ -87,6 +90,8 @@ public class RedisRecordHandler private final RedisConnectionFactory redisConnectionFactory; private final AmazonS3 amazonS3; + private final RedisQueryPassthrough queryPassthrough = new RedisQueryPassthrough(); + public RedisRecordHandler(java.util.Map configOptions) { this( @@ -131,20 +136,81 @@ private RedisConnectionWrapper getOrCreateClient(String rawEndpo */ @Override protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) + { + if (recordsRequest.getConstraints().isQueryPassThrough()) { + handleQueryPassthrough(spiller, recordsRequest, queryStatusChecker); + } + else { + handleStandardQuery(spiller, recordsRequest, queryStatusChecker); + } + } + + /** + * Given a recordsRequest, creates the Redis connection + * @param recordsRequest the recordsRequest to create a connection from + * @return the resulting connection object + */ + private RedisCommandsWrapper getSyncCommands(ReadRecordsRequest recordsRequest) { Split split = recordsRequest.getSplit(); - ScanCursor keyCursor = null; boolean sslEnabled = Boolean.parseBoolean(split.getProperty(REDIS_SSL_FLAG)); boolean isCluster = Boolean.parseBoolean(split.getProperty(REDIS_CLUSTER_FLAG)); String dbNumber = split.getProperty(REDIS_DB_NUMBER); - ValueType valueType = ValueType.fromId(split.getProperty(VALUE_TYPE_TABLE_PROP)); - List fieldList = recordsRequest.getSchema().getFields().stream() - .filter((Field next) -> !KEY_COLUMN_NAME.equals(next.getName())).collect(Collectors.toList()); RedisConnectionWrapper connection = getOrCreateClient(split.getProperty(REDIS_ENDPOINT_PROP), sslEnabled, isCluster, dbNumber); RedisCommandsWrapper syncCommands = connection.sync(); + return syncCommands; + } + + /** + * readWithConstraint case for when the query involves Query Passthrough + * @see RecordHandler + */ + private void handleQueryPassthrough(BlockSpiller spiller, + ReadRecordsRequest recordsRequest, + QueryStatusChecker queryStatusChecker) + { + Map queryPassthroughArgs = recordsRequest.getConstraints().getQueryPassthroughArguments(); + queryPassthrough.verify(queryPassthroughArgs); + + RedisCommandsWrapper syncCommands = getSyncCommands(recordsRequest); + String script = queryPassthroughArgs.get(RedisQueryPassthrough.SCRIPT); + byte[] scriptBytes = script.getBytes(); + + String keys = queryPassthroughArgs.get(RedisQueryPassthrough.KEYS); + String[] keysArray = new String[0]; + if (!keys.isEmpty()) { + // to convert string formatted as "[value1, value2, ...]" to array of strings + keysArray = keys.substring(1, keys.length() - 1).split(",\\s*"); + } + + String argv = queryPassthroughArgs.get(RedisQueryPassthrough.ARGV); + String[] argvArray = new String[0]; + if (!argv.isEmpty()) { + // to convert string formatted as "[value1, value2, ...]" to array of strings + argvArray = argv.substring(1, argv.length() - 1).split(",\\s*"); + } + + List result = syncCommands.evalReadOnly(scriptBytes, ScriptOutputType.MULTI, keysArray, argvArray); + loadSingleColumn(result, spiller, queryStatusChecker); + } + + /** + * readWithConstraint case for when the query does not involve Query Passthrough + * @see RecordHandler + */ + private void handleStandardQuery(BlockSpiller spiller, + ReadRecordsRequest recordsRequest, + QueryStatusChecker queryStatusChecker) + { + Split split = recordsRequest.getSplit(); + ScanCursor keyCursor = null; + ValueType valueType = ValueType.fromId(split.getProperty(VALUE_TYPE_TABLE_PROP)); + List fieldList = recordsRequest.getSchema().getFields().stream() + .filter((Field next) -> !KEY_COLUMN_NAME.equals(next.getName())).collect(Collectors.toList()); + RedisCommandsWrapper syncCommands = getSyncCommands(recordsRequest); do { Set keys = new HashSet<>(); //Load all the keys associated with this split @@ -205,6 +271,41 @@ private ScanCursor loadKeys(RedisCommandsWrapper syncCommands, S } } + private void loadSingleColumn(List values, BlockSpiller spiller, QueryStatusChecker queryStatusChecker) + { + values.stream().forEach((Object value) -> { + StringBuilder builder = new StringBuilder(); + flattenRow(value, builder); + if (!queryStatusChecker.isQueryRunning()) { + return; + } + spiller.writeRows((Block block, int row) -> { + boolean literalMatched = block.offerValue(QPT_COLUMN_NAME, row, builder.toString()); + return literalMatched ? 1 : 0; + }); + }); + } + + /** + * Redis eval calls return an object of type T, which is either a String or List. + * This flattens it into a String using a StringBuilder. + */ + private void flattenRow(Object value, StringBuilder builder) + { + if (value == null) { + return; + } + if (value instanceof String) { + if (builder.length() != 0) { + builder.append(", "); + } + builder.append(value); + } + else { + ((List) value).forEach((Object subValue) -> flattenRow(subValue, builder)); + } + } + private void loadLiteralRow(RedisCommandsWrapper syncCommands, String keyString, BlockSpiller spiller, List fieldList) { spiller.writeRows((Block block, int row) -> { diff --git a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/lettuce/RedisCommandsWrapper.java b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/lettuce/RedisCommandsWrapper.java index 664c9d9017..a0487a9dcf 100644 --- a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/lettuce/RedisCommandsWrapper.java +++ b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/lettuce/RedisCommandsWrapper.java @@ -24,6 +24,7 @@ import io.lettuce.core.ScanArgs; import io.lettuce.core.ScanCursor; import io.lettuce.core.ScoredValueScanCursor; +import io.lettuce.core.ScriptOutputType; import io.lettuce.core.api.sync.RedisCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import org.apache.arrow.util.VisibleForTesting; @@ -113,6 +114,16 @@ public ScoredValueScanCursor zscan(K var1, ScanCursor var2) } } + public T evalReadOnly(byte[] script, ScriptOutputType type, K[] keys, V... values) + { + if (isCluster) { + return clusterCommands.evalReadOnly(script, type, keys, values); + } + else { + return standaloneCommands.evalReadOnly(script, type, keys, values); + } + } + @VisibleForTesting public String hmset(K var1, Map var2) { diff --git a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/qpt/RedisQueryPassthrough.java b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/qpt/RedisQueryPassthrough.java new file mode 100644 index 0000000000..1ba495a970 --- /dev/null +++ b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/qpt/RedisQueryPassthrough.java @@ -0,0 +1,69 @@ +/*- + * #%L + * athena-redis + * %% + * Copyright (C) 2019 - 2024 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.redis.qpt; + +import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; + +public final class RedisQueryPassthrough implements QueryPassthroughSignature +{ + // Constant value representing the name of the query. + public static final String NAME = "script"; + + // Constant value representing the domain of the query. + public static final String SCHEMA = "system"; + + // List of arguments for the query, statically initialized as it always contains the same value. + public static final String SCRIPT = "SCRIPT"; + public static final String KEYS = "KEYS"; + public static final String ARGV = "ARGV"; + + public static final List ARGUMENTS = Arrays.asList(SCRIPT, KEYS, ARGV); + + private static final Logger LOGGER = LoggerFactory.getLogger(RedisQueryPassthrough.class); + + @Override + public String getFunctionSchema() + { + return SCHEMA; + } + + @Override + public String getFunctionName() + { + return NAME; + } + + @Override + public List getFunctionArguments() + { + return ARGUMENTS; + } + + @Override + public Logger getLogger() + { + return LOGGER; + } +}