Skip to content

Commit

Permalink
Extended QPT to athena-hbase (#1876)
Browse files Browse the repository at this point in the history
  • Loading branch information
VenkatasivareddyTR authored Apr 12, 2024
1 parent e3fd3d1 commit 69a64b3
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
import com.amazonaws.athena.connector.lambda.data.BlockWriter;
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
Expand All @@ -36,13 +39,16 @@
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.MetadataRequest;
import com.amazonaws.athena.connector.lambda.metadata.glue.GlueFieldLexer;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection;
import com.amazonaws.athena.connectors.hbase.connection.HbaseConnectionFactory;
import com.amazonaws.athena.connectors.hbase.qpt.HbaseQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.glue.AWSGlue;
import com.amazonaws.services.glue.model.Table;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
Expand Down Expand Up @@ -103,6 +109,8 @@ public class HbaseMetadataHandler
private final AWSGlue awsGlue;
private final HbaseConnectionFactory connectionFactory;

private final HbaseQueryPassthrough queryPassthrough = new HbaseQueryPassthrough();

public HbaseMetadataHandler(java.util.Map<String, String> configOptions)
{
super(SOURCE_TYPE, configOptions);
Expand All @@ -126,6 +134,15 @@ protected HbaseMetadataHandler(
this.connectionFactory = connectionFactory;
}

@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

private HBaseConnection getOrCreateConn(MetadataRequest request)
{
String endpoint = resolveSecrets(getConnStr(request));
Expand Down Expand Up @@ -210,9 +227,18 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
request.getTableName().getTableName(),
ex);
}
String schemaName = request.getTableName().getSchemaName();
String tableName = request.getTableName().getTableName();
com.amazonaws.athena.connector.lambda.domain.TableName tableNameObj = new com.amazonaws.athena.connector.lambda.domain.TableName(schemaName, tableName);

return getTableResponse(request, origSchema, tableNameObj);
}

private GetTableResponse getTableResponse(GetTableRequest request, Schema origSchema,
com.amazonaws.athena.connector.lambda.domain.TableName tableName)
{
if (origSchema == null) {
origSchema = HbaseSchemaUtils.inferSchema(getOrCreateConn(request), request.getTableName(), NUM_ROWS_TO_SCAN);
origSchema = HbaseSchemaUtils.inferSchema(getOrCreateConn(request), tableName, NUM_ROWS_TO_SCAN);
}

SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
Expand Down Expand Up @@ -253,6 +279,11 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request
public GetSplitsResponse doGetSplits(BlockAllocator blockAllocator, GetSplitsRequest request)
throws IOException
{
if (request.getConstraints().isQueryPassThrough()) {
logger.info("doGetSplits: QPT enabled");
return setupQueryPassthroughSplit(request);
}

Set<Split> splits = new HashSet<>();

//We can read each region in parallel
Expand All @@ -278,4 +309,34 @@ protected Field convertField(String name, String glueType)
{
return GlueFieldLexer.lex(name, glueType);
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
{
queryPassthrough.verify(request.getQueryPassthroughArguments());
String schemaName = request.getQueryPassthroughArguments().get(HbaseQueryPassthrough.DATABASE);
String tableName = request.getQueryPassthroughArguments().get(HbaseQueryPassthrough.COLLECTION);

com.amazonaws.athena.connector.lambda.domain.TableName tableNameObj = new com.amazonaws.athena.connector.lambda.domain.TableName(schemaName, tableName);

return getTableResponse(request, null, tableNameObj);
}

/**
* Helper function that provides a single partition for Query Pass-Through
*
*/
protected GetSplitsResponse setupQueryPassthroughSplit(GetSplitsRequest request)
{
//Every split must have a unique location if we wish to spill to avoid failures
SpillLocation spillLocation = makeSpillLocation(request);

//Since this is QPT query we return a fixed split.
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
return new GetSplitsResponse(request.getCatalogName(),
Split.newBuilder(spillLocation, makeEncryptionKey())
.add(HBASE_CONN_STR, getConnStr(request))
.applyProperties(qptArguments)
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
import com.amazonaws.athena.connector.lambda.data.Block;
import com.amazonaws.athena.connector.lambda.data.BlockSpiller;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet;
import com.amazonaws.athena.connector.lambda.handlers.RecordHandler;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection;
import com.amazonaws.athena.connectors.hbase.connection.HbaseConnectionFactory;
import com.amazonaws.athena.connectors.hbase.qpt.HbaseQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.athena.AmazonAthenaClientBuilder;
import com.amazonaws.services.s3.AmazonS3;
Expand All @@ -40,19 +42,22 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.ResultScanner;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.filter.BinaryComparator;
import org.apache.hadoop.hbase.filter.CompareFilter;
import org.apache.hadoop.hbase.filter.Filter;
import org.apache.hadoop.hbase.filter.ParseFilter;
import org.apache.hadoop.hbase.filter.RowFilter;
import org.apache.hadoop.hbase.filter.SingleColumnValueFilter;
import org.apache.hadoop.hbase.util.Bytes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.charset.CharacterCodingException;
import java.util.Map;

import static com.amazonaws.athena.connectors.hbase.HbaseMetadataHandler.END_KEY_FIELD;
Expand Down Expand Up @@ -81,6 +86,8 @@ public class HbaseRecordHandler
private final AmazonS3 amazonS3;
private final HbaseConnectionFactory connectionFactory;

private final HbaseQueryPassthrough queryPassthrough = new HbaseQueryPassthrough();

public HbaseRecordHandler(java.util.Map<String, String> configOptions)
{
this(
Expand Down Expand Up @@ -119,18 +126,43 @@ protected void readWithConstraint(BlockSpiller blockSpiller, ReadRecordsRequest
String conStr = split.getProperty(HBASE_CONN_STR);
boolean isNative = projection.getCustomMetadata().get(HBASE_NATIVE_STORAGE_FLAG) != null;

//setup the scan so that we only read the key range associated with the region represented by our Split.
Scan scan = new Scan(split.getProperty(START_KEY_FIELD).getBytes(), split.getProperty(END_KEY_FIELD).getBytes());

//attempts to push down a partial predicate using HBase Filters
scan.setFilter(pushdownPredicate(isNative, request.getConstraints()));
String schemaName;
String tableName;
Scan scan;
if (request.getConstraints().isQueryPassThrough()) {
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
queryPassthrough.verify(qptArguments);
schemaName = qptArguments.get(HbaseQueryPassthrough.DATABASE);
tableName = qptArguments.get(HbaseQueryPassthrough.COLLECTION);
String filter = qptArguments.get(HbaseQueryPassthrough.FILTER);
scan = new Scan();
//Empty filter is allowed in HBase scan operation, so we don't have else condition.
//If filter is empty then we are scanning all the records from table.
if (!StringUtils.isEmpty(filter)) {
try {
scan.setFilter(new ParseFilter().parseFilterString(filter));
}
catch (CharacterCodingException e) {
throw new RuntimeException("Can't parse filter argument as hbase scan filer: ", e);
}
}
}
else {
//setup the scan so that we only read the key range associated with the region represented by our Split.
scan = new Scan(split.getProperty(START_KEY_FIELD).getBytes(), split.getProperty(END_KEY_FIELD).getBytes());
//attempts to push down a partial predicate using HBase Filters
scan.setFilter(pushdownPredicate(isNative, request.getConstraints()));
schemaName = request.getTableName().getSchemaName();
tableName = request.getTableName().getTableName();
}
TableName tableNameObj = new TableName(schemaName, tableName);

//setup the projection so we only pull columns/families that we need
for (Field next : request.getSchema().getFields()) {
addToProjection(scan, next);
}

getOrCreateConn(conStr).scanTable(HbaseSchemaUtils.getQualifiedTable(request.getTableName()),
getOrCreateConn(conStr).scanTable(HbaseSchemaUtils.getQualifiedTable(tableNameObj),
scan,
(ResultScanner scanner) -> scanFilterProject(scanner, request, blockSpiller, queryStatusChecker));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*-
* #%L
* athena-hbase
* %%
* Copyright (C) 2019 - 2024 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.hbase.qpt;

import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

public final class HbaseQueryPassthrough implements QueryPassthroughSignature
{
private static final Logger LOGGER = LoggerFactory.getLogger(HbaseQueryPassthrough.class);

// Constant value representing the name of the query.
public static final String NAME = "query";

// Constant value representing the domain of the query.
public static final String SCHEMA_NAME = "system";

// List of arguments for the query, statically initialized as it always contains the same value.
public static final String DATABASE = "DATABASE";
public static final String COLLECTION = "COLLECTION";
public static final String FILTER = "FILTER";

@Override
public String getFunctionSchema()
{
return SCHEMA_NAME;
}

@Override
public String getFunctionName()
{
return NAME;
}

@Override
public List<String> getFunctionArguments()
{
return Arrays.asList(DATABASE, COLLECTION, FILTER);
}

@Override
public Logger getLogger()
{
return LOGGER;
}
/**
* verify that the arguments returned by the engine are the same that the connector defined.
* In the parent class verify function, all argument values are mandatory. However, in the overridden function,
* we've removed the condition that checks whether the HBase FILTER argument value is empty. This modification
* allows for the possibility of an empty FILTER argument value without triggering an exception.
* @param engineQptArguments
* @throws IllegalArgumentException
*/
@Override
public void verify(Map<String, String> engineQptArguments)
throws IllegalArgumentException
{
//High-level Query Passthrough Function Argument verification
//First verifies that these arguments belong to this specific function
if (!verifyFunctionSignature(engineQptArguments)) {
throw new IllegalArgumentException("Function Signature doesn't match implementation's");
}
//Ensuring the arguments received from the engine are the one defined by the connector
for (String argument : this.getFunctionArguments()) {
if (!engineQptArguments.containsKey(argument)) {
throw new IllegalArgumentException("Missing Query Passthrough Argument: " + argument);
}
}
}
}

0 comments on commit 69a64b3

Please sign in to comment.