Skip to content

Commit

Permalink
Query Passthrough for Redis Connector (#1920)
Browse files Browse the repository at this point in the history
  • Loading branch information
aimethed authored Apr 26, 2024
1 parent 2296ee8 commit 5bff45d
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 4 deletions.
20 changes: 20 additions & 0 deletions athena-redis/athena-redis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ Parameters:
Description: "(Optional) An IAM policy ARN to use as the PermissionsBoundary for the created Lambda function's execution role"
Default: ''
Type: String
QPTConnectionEndpoint:
Description: "(Optional) The hostname:port:password of the Redis server that contains data for this table optionally using SecretsManager (e.g. ${secret_name}). Used for Query Pass Through queries only."
Default: ''
Type: String
QPTConnectionSSL:
Description: "(Optional) When True, creates a Redis connection that uses SSL/TLS. Used for Query Pass Through queries only."
Default: 'false'
Type: String
QPTConnectionCluster:
Description: "(Optional) When True, enables support for clustered Redis instances. Used for Query Pass Through queries only."
Default: 'false'
Type: String
QPTConnectionDBNumber:
Description: "(Optional) Set this number (for example 1, 2, or 3) to read from a non-default Redis database. Used for Query Pass Through queries only."
Default: 0
Type: Number
Conditions:
HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ]
Resources:
Expand All @@ -60,6 +76,10 @@ Resources:
disable_spill_encryption: !Ref DisableSpillEncryption
spill_bucket: !Ref SpillBucket
spill_prefix: !Ref SpillPrefix
qpt_endpoint: !Ref QPTConnectionEndpoint
qpt_ssl: !Ref QPTConnectionSSL
qpt_cluster: !Ref QPTConnectionCluster
qpt_db_number: !Ref QPTConnectionDBNumber
FunctionName: !Ref AthenaCatalogName
Handler: "com.amazonaws.athena.connectors.redis.RedisCompositeHandler"
CodeUri: "./target/athena-redis-2022.47.1.jar"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
Expand All @@ -39,15 +41,18 @@
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.glue.DefaultGlueType;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.redis.lettuce.RedisCommandsWrapper;
import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionFactory;
import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper;
import com.amazonaws.athena.connectors.redis.qpt.RedisQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.glue.AWSGlue;
import com.amazonaws.services.glue.model.Database;
import com.amazonaws.services.glue.model.Table;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableMap;
import io.lettuce.core.KeyScanCursor;
import io.lettuce.core.Range;
import io.lettuce.core.ScanArgs;
Expand All @@ -62,6 +67,7 @@

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -93,6 +99,14 @@ public class RedisMetadataHandler
protected static final String SPLIT_START_INDEX = "start-index";
protected static final String SPLIT_END_INDEX = "end-index";

//The name of the column added for Query Passthrough requests
protected static final String QPT_COLUMN_NAME = "_value_";
//The names of the lambda environment variables used for QPT connections
protected static final String QPT_ENDPOINT_ENV_VAR = "qpt_endpoint";
protected static final String QPT_SSL_ENV_VAR = "qpt_ssl";
protected static final String QPT_CLUSTER_ENV_VAR = "qpt_cluster";
protected static final String QPT_DB_NUMBER_ENV_VAR = "qpt_db_number";

//Defines the table property name used to set the Redis Key Type for the table. (e.g. prefix, zset)
protected static final String KEY_TYPE = "redis-key-type";
//Defines the table property name used to set the Redis value type for the table. (e.g. liternal, zset, hash)
Expand Down Expand Up @@ -125,6 +139,8 @@ public class RedisMetadataHandler
private final AWSGlue awsGlue;
private final RedisConnectionFactory redisConnectionFactory;

private final RedisQueryPassthrough queryPassthrough = new RedisQueryPassthrough();

public RedisMetadataHandler(java.util.Map<String, String> configOptions)
{
super(SOURCE_TYPE, configOptions);
Expand All @@ -148,6 +164,15 @@ protected RedisMetadataHandler(
this.awsGlue = awsGlue;
this.redisConnectionFactory = redisConnectionFactory;
}

@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

/**
* Used to obtain a Redis client connection for the provided endpoint.
Expand Down Expand Up @@ -208,6 +233,27 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
return new GetTableResponse(response.getCatalogName(), response.getTableName(), schemaBuilder.build());
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
{
if (!request.isQueryPassthrough()) {
throw new IllegalArgumentException("No Query passed through [{}]" + request);
}
Map<String, String> queryPassthroughArgs = request.getQueryPassthroughArguments();
queryPassthrough.verify(queryPassthroughArgs);

SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
schemaBuilder.addField(QPT_COLUMN_NAME, Types.MinorType.VARCHAR.getType());
schemaBuilder.addMetadata(VALUE_TYPE_TABLE_PROP, "literal");
schemaBuilder.addMetadata(KEY_PREFIX_TABLE_PROP, "*");
schemaBuilder.addMetadata(REDIS_SSL_FLAG, configOptions.get(QPT_SSL_ENV_VAR));
schemaBuilder.addMetadata(REDIS_ENDPOINT_PROP, configOptions.get(QPT_ENDPOINT_ENV_VAR));
schemaBuilder.addMetadata(REDIS_CLUSTER_FLAG, configOptions.get(QPT_CLUSTER_ENV_VAR));
schemaBuilder.addMetadata(REDIS_DB_NUMBER, configOptions.get(QPT_DB_NUMBER_ENV_VAR));

return new GetTableResponse(request.getCatalogName(), request.getTableName(), schemaBuilder.build());
}

@Override
public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTableLayoutRequest request)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.amazonaws.athena.connectors.redis.lettuce.RedisCommandsWrapper;
import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionFactory;
import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper;
import com.amazonaws.athena.connectors.redis.qpt.RedisQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.athena.AmazonAthenaClientBuilder;
import com.amazonaws.services.s3.AmazonS3;
Expand All @@ -39,6 +40,7 @@
import io.lettuce.core.ScanCursor;
import io.lettuce.core.ScoredValue;
import io.lettuce.core.ScoredValueScanCursor;
import io.lettuce.core.ScriptOutputType;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.types.pojo.Field;
import org.slf4j.Logger;
Expand All @@ -54,6 +56,7 @@
import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.KEY_COLUMN_NAME;
import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.KEY_PREFIX_TABLE_PROP;
import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.KEY_TYPE;
import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.QPT_COLUMN_NAME;
import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.REDIS_CLUSTER_FLAG;
import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.REDIS_DB_NUMBER;
import static com.amazonaws.athena.connectors.redis.RedisMetadataHandler.REDIS_ENDPOINT_PROP;
Expand Down Expand Up @@ -87,6 +90,8 @@ public class RedisRecordHandler
private final RedisConnectionFactory redisConnectionFactory;
private final AmazonS3 amazonS3;

private final RedisQueryPassthrough queryPassthrough = new RedisQueryPassthrough();

public RedisRecordHandler(java.util.Map<String, String> configOptions)
{
this(
Expand Down Expand Up @@ -131,20 +136,81 @@ private RedisConnectionWrapper<String, String> getOrCreateClient(String rawEndpo
*/
@Override
protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker)
{
if (recordsRequest.getConstraints().isQueryPassThrough()) {
handleQueryPassthrough(spiller, recordsRequest, queryStatusChecker);
}
else {
handleStandardQuery(spiller, recordsRequest, queryStatusChecker);
}
}

/**
* Given a recordsRequest, creates the Redis connection
* @param recordsRequest the recordsRequest to create a connection from
* @return the resulting connection object
*/
private RedisCommandsWrapper<String, String> getSyncCommands(ReadRecordsRequest recordsRequest)
{
Split split = recordsRequest.getSplit();
ScanCursor keyCursor = null;
boolean sslEnabled = Boolean.parseBoolean(split.getProperty(REDIS_SSL_FLAG));
boolean isCluster = Boolean.parseBoolean(split.getProperty(REDIS_CLUSTER_FLAG));
String dbNumber = split.getProperty(REDIS_DB_NUMBER);
ValueType valueType = ValueType.fromId(split.getProperty(VALUE_TYPE_TABLE_PROP));
List<Field> fieldList = recordsRequest.getSchema().getFields().stream()
.filter((Field next) -> !KEY_COLUMN_NAME.equals(next.getName())).collect(Collectors.toList());

RedisConnectionWrapper<String, String> connection = getOrCreateClient(split.getProperty(REDIS_ENDPOINT_PROP),
sslEnabled, isCluster, dbNumber);
RedisCommandsWrapper<String, String> syncCommands = connection.sync();
return syncCommands;
}

/**
* readWithConstraint case for when the query involves Query Passthrough
* @see RecordHandler
*/
private void handleQueryPassthrough(BlockSpiller spiller,
ReadRecordsRequest recordsRequest,
QueryStatusChecker queryStatusChecker)
{
Map<String, String> queryPassthroughArgs = recordsRequest.getConstraints().getQueryPassthroughArguments();
queryPassthrough.verify(queryPassthroughArgs);

RedisCommandsWrapper<String, String> syncCommands = getSyncCommands(recordsRequest);

String script = queryPassthroughArgs.get(RedisQueryPassthrough.SCRIPT);
byte[] scriptBytes = script.getBytes();

String keys = queryPassthroughArgs.get(RedisQueryPassthrough.KEYS);
String[] keysArray = new String[0];
if (!keys.isEmpty()) {
// to convert string formatted as "[value1, value2, ...]" to array of strings
keysArray = keys.substring(1, keys.length() - 1).split(",\\s*");
}

String argv = queryPassthroughArgs.get(RedisQueryPassthrough.ARGV);
String[] argvArray = new String[0];
if (!argv.isEmpty()) {
// to convert string formatted as "[value1, value2, ...]" to array of strings
argvArray = argv.substring(1, argv.length() - 1).split(",\\s*");
}

List<Object> result = syncCommands.evalReadOnly(scriptBytes, ScriptOutputType.MULTI, keysArray, argvArray);
loadSingleColumn(result, spiller, queryStatusChecker);
}

/**
* readWithConstraint case for when the query does not involve Query Passthrough
* @see RecordHandler
*/
private void handleStandardQuery(BlockSpiller spiller,
ReadRecordsRequest recordsRequest,
QueryStatusChecker queryStatusChecker)
{
Split split = recordsRequest.getSplit();
ScanCursor keyCursor = null;
ValueType valueType = ValueType.fromId(split.getProperty(VALUE_TYPE_TABLE_PROP));
List<Field> fieldList = recordsRequest.getSchema().getFields().stream()
.filter((Field next) -> !KEY_COLUMN_NAME.equals(next.getName())).collect(Collectors.toList());
RedisCommandsWrapper<String, String> syncCommands = getSyncCommands(recordsRequest);
do {
Set<String> keys = new HashSet<>();
//Load all the keys associated with this split
Expand Down Expand Up @@ -205,6 +271,41 @@ private ScanCursor loadKeys(RedisCommandsWrapper<String, String> syncCommands, S
}
}

private void loadSingleColumn(List<Object> values, BlockSpiller spiller, QueryStatusChecker queryStatusChecker)
{
values.stream().forEach((Object value) -> {
StringBuilder builder = new StringBuilder();
flattenRow(value, builder);
if (!queryStatusChecker.isQueryRunning()) {
return;
}
spiller.writeRows((Block block, int row) -> {
boolean literalMatched = block.offerValue(QPT_COLUMN_NAME, row, builder.toString());
return literalMatched ? 1 : 0;
});
});
}

/**
* Redis eval calls return an object of type T, which is either a String or List<T>.
* This flattens it into a String using a StringBuilder.
*/
private void flattenRow(Object value, StringBuilder builder)
{
if (value == null) {
return;
}
if (value instanceof String) {
if (builder.length() != 0) {
builder.append(", ");
}
builder.append(value);
}
else {
((List<Object>) value).forEach((Object subValue) -> flattenRow(subValue, builder));
}
}

private void loadLiteralRow(RedisCommandsWrapper<String, String> syncCommands, String keyString, BlockSpiller spiller, List<Field> fieldList)
{
spiller.writeRows((Block block, int row) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.lettuce.core.ScanArgs;
import io.lettuce.core.ScanCursor;
import io.lettuce.core.ScoredValueScanCursor;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.api.sync.RedisCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import org.apache.arrow.util.VisibleForTesting;
Expand Down Expand Up @@ -113,6 +114,16 @@ public ScoredValueScanCursor<V> zscan(K var1, ScanCursor var2)
}
}

public <T> T evalReadOnly(byte[] script, ScriptOutputType type, K[] keys, V... values)
{
if (isCluster) {
return clusterCommands.evalReadOnly(script, type, keys, values);
}
else {
return standaloneCommands.evalReadOnly(script, type, keys, values);
}
}

@VisibleForTesting
public String hmset(K var1, Map<K, V> var2)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*-
* #%L
* athena-redis
* %%
* Copyright (C) 2019 - 2024 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.redis.qpt;

import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;

public final class RedisQueryPassthrough implements QueryPassthroughSignature
{
// Constant value representing the name of the query.
public static final String NAME = "script";

// Constant value representing the domain of the query.
public static final String SCHEMA = "system";

// List of arguments for the query, statically initialized as it always contains the same value.
public static final String SCRIPT = "SCRIPT";
public static final String KEYS = "KEYS";
public static final String ARGV = "ARGV";

public static final List<String> ARGUMENTS = Arrays.asList(SCRIPT, KEYS, ARGV);

private static final Logger LOGGER = LoggerFactory.getLogger(RedisQueryPassthrough.class);

@Override
public String getFunctionSchema()
{
return SCHEMA;
}

@Override
public String getFunctionName()
{
return NAME;
}

@Override
public List<String> getFunctionArguments()
{
return ARGUMENTS;
}

@Override
public Logger getLogger()
{
return LOGGER;
}
}

0 comments on commit 5bff45d

Please sign in to comment.