Skip to content

Commit

Permalink
Extended QPT to athena-timestream (#1857)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jithendar12 authored Mar 25, 2024
1 parent 9370de9 commit 05a4da1
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
Expand All @@ -35,14 +37,17 @@
import com.amazonaws.athena.connector.lambda.metadata.ListSchemasResponse;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connector.util.PaginatedRequestIterator;
import com.amazonaws.athena.connectors.timestream.qpt.TimestreamQueryPassthrough;
import com.amazonaws.athena.connectors.timestream.query.QueryFactory;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.glue.AWSGlue;
import com.amazonaws.services.glue.model.Table;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.amazonaws.services.timestreamquery.AmazonTimestreamQuery;
import com.amazonaws.services.timestreamquery.model.ColumnInfo;
import com.amazonaws.services.timestreamquery.model.Datum;
import com.amazonaws.services.timestreamquery.model.QueryRequest;
import com.amazonaws.services.timestreamquery.model.QueryResult;
Expand All @@ -51,13 +56,16 @@
import com.amazonaws.services.timestreamwrite.model.ListDatabasesRequest;
import com.amazonaws.services.timestreamwrite.model.ListDatabasesResult;
import com.amazonaws.services.timestreamwrite.model.ListTablesResult;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -85,12 +93,15 @@ public class TimestreamMetadataHandler
private final AmazonTimestreamQuery tsQuery;
private final AmazonTimestreamWrite tsMeta;

private final TimestreamQueryPassthrough queryPassthrough;

public TimestreamMetadataHandler(java.util.Map<String, String> configOptions)
{
super(SOURCE_TYPE, configOptions);
glue = getAwsGlue();
tsQuery = TimestreamClientBuilder.buildQueryClient(SOURCE_TYPE);
tsMeta = TimestreamClientBuilder.buildWriteClient(SOURCE_TYPE);
queryPassthrough = new TimestreamQueryPassthrough();
}

@VisibleForTesting
Expand All @@ -109,6 +120,16 @@ protected TimestreamMetadataHandler(
this.glue = glue;
this.tsQuery = tsQuery;
this.tsMeta = tsMeta;
queryPassthrough = new TimestreamQueryPassthrough();
}

@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
this.queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, this.configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

@Override
Expand Down Expand Up @@ -289,6 +310,28 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
}
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
{
if (!request.isQueryPassthrough()) {
throw new IllegalArgumentException("No Query passed through [{}]" + request);
}

queryPassthrough.verify(request.getQueryPassthroughArguments());
String customerPassedQuery = request.getQueryPassthroughArguments().get(TimestreamQueryPassthrough.QUERY);
QueryRequest queryRequest = new QueryRequest().withQueryString(customerPassedQuery).withMaxRows(1);
// Timestream Query does not provide a way to conduct a dry run or retrieve metadata results without execution. Therefore, we need to "seek" at least once before obtaining metadata.
QueryResult queryResult = tsQuery.query(queryRequest);
List<ColumnInfo> columnInfo = queryResult.getColumnInfo();
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
for (ColumnInfo column : columnInfo) {
Field nextField = TimestreamSchemaUtils.makeField(column.getName(), column.getType().getScalarType().toLowerCase());
schemaBuilder.addField(nextField);
}

return new GetTableResponse(request.getCatalogName(), request.getTableName(), schemaBuilder.build(), Collections.emptySet());
}

/**
* Our table doesn't support complex layouts or partitioning so we simply make this method a NoOp.
*
Expand All @@ -308,7 +351,16 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
{
//Since we do not support connector level parallelism for this source at the moment, we generate a single
//basic split.
Split split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()).build();
Split split;
if (request.getConstraints().isQueryPassThrough()) {
logger.info("QPT Split Requested");
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()).applyProperties(qptArguments).build();
}
else {
split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()).build();
}

return new GetSplitsResponse(request.getCatalogName(), split);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
import com.amazonaws.athena.connector.lambda.handlers.RecordHandler;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connectors.timestream.qpt.TimestreamQueryPassthrough;
import com.amazonaws.athena.connectors.timestream.query.QueryFactory;
import com.amazonaws.athena.connectors.timestream.query.SelectQueryBuilder;
import com.amazonaws.services.athena.AmazonAthena;
Expand Down Expand Up @@ -88,6 +89,7 @@ public class TimestreamRecordHandler

private final QueryFactory queryFactory = new QueryFactory();
private final AmazonTimestreamQuery tsQuery;
private final TimestreamQueryPassthrough queryPassthrough = new TimestreamQueryPassthrough();

public TimestreamRecordHandler(java.util.Map<String, String> configOptions)
{
Expand Down Expand Up @@ -115,14 +117,19 @@ protected TimestreamRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsMa
protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker)
{
TableName tableName = recordsRequest.getTableName();

SelectQueryBuilder queryBuilder = queryFactory.createSelectQueryBuilder(GlueMetadataHandler.VIEW_METADATA_FIELD);

String query = queryBuilder.withDatabaseName(tableName.getSchemaName())
.withTableName(tableName.getTableName())
.withProjection(recordsRequest.getSchema())
.withConjucts(recordsRequest.getConstraints())
.build();
String query;
if (recordsRequest.getConstraints().isQueryPassThrough()) {
queryPassthrough.verify(recordsRequest.getConstraints().getQueryPassthroughArguments());
query = recordsRequest.getConstraints().getQueryPassthroughArguments().get(TimestreamQueryPassthrough.QUERY);
}
else {
SelectQueryBuilder queryBuilder = queryFactory.createSelectQueryBuilder(GlueMetadataHandler.VIEW_METADATA_FIELD);
query = queryBuilder.withDatabaseName(tableName.getSchemaName())
.withTableName(tableName.getTableName())
.withProjection(recordsRequest.getSchema())
.withConjucts(recordsRequest.getConstraints())
.build();
}

logger.info("readWithConstraint: query[{}]", query);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*-
* #%L
* athena-timestream
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.timestream.qpt;

import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature;
import com.google.common.collect.ImmutableSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

public class TimestreamQueryPassthrough implements QueryPassthroughSignature
{
// Constant value representing the name of the query.
public static final String NAME = "query";

// Constant value representing the domain of the query.
public static final String SCHEMA_NAME = "system";

// List of arguments for the query, statically initialized as it always contains the same value.
public static final String QUERY = "QUERY";

public static final List<String> ARGUMENTS = Arrays.asList(QUERY);

private static final Logger LOGGER = LoggerFactory.getLogger(TimestreamQueryPassthrough.class);

@Override
public String getFunctionSchema()
{
return SCHEMA_NAME;
}

@Override
public String getFunctionName()
{
return NAME;
}

@Override
public List<String> getFunctionArguments()
{
return ARGUMENTS;
}

@Override
public Logger getLogger()
{
return LOGGER;
}

@Override
public void customConnectorVerifications(Map<String, String> engineQptArguments)
{
String customerPassedQuery = engineQptArguments.get(QUERY);
String upperCaseStatement = customerPassedQuery.trim().toUpperCase(Locale.ENGLISH);

// Immediately check if the statement starts with "SELECT"
if (!upperCaseStatement.startsWith("SELECT")) {
throw new UnsupportedOperationException("Statement does not start with SELECT.");
}

// List of disallowed keywords
Set<String> disallowedKeywords = ImmutableSet.of("INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER");

// Check if the statement contains any disallowed keywords
for (String keyword : disallowedKeywords) {
if (upperCaseStatement.contains(keyword)) {
throw new UnsupportedOperationException("Unaccepted operation; only SELECT statements are allowed. Found: " + keyword);
}
}
}
}

0 comments on commit 05a4da1

Please sign in to comment.