Skip to content

Commit

Permalink
Extended QPT to athena-neptune for property graph (#1886)
Browse files Browse the repository at this point in the history
  • Loading branch information
VenkatasivareddyTR authored Apr 12, 2024
1 parent 69a64b3 commit 6fd2923
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
*/
package com.amazonaws.athena.connectors.neptune;

import com.amazonaws.athena.connectors.neptune.propertygraph.NeptuneGremlinConnection;
import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.driver.Cluster;
import org.apache.tinkerpop.gremlin.driver.SigV4WebSocketChannelizer;
Expand Down Expand Up @@ -52,6 +54,29 @@ protected NeptuneConnection(String neptuneEndpoint, String neptunePort, boolean
this.enabledIAM = enabledIAM;
this.region = region;
}

public static NeptuneConnection createConnection(java.util.Map<String, String> configOptions)
{
Enums.GraphType graphType = Enums.GraphType.PROPERTYGRAPH;
if (configOptions.get(Constants.CFG_GRAPH_TYPE) != null) {
graphType = Enums.GraphType.valueOf(configOptions.get(Constants.CFG_GRAPH_TYPE).toUpperCase());
}

switch (graphType){
case PROPERTYGRAPH:
return new NeptuneGremlinConnection(configOptions.get(Constants.CFG_ENDPOINT),
configOptions.get(Constants.CFG_PORT), Boolean.parseBoolean(configOptions.get(Constants.CFG_IAM)),
configOptions.get(Constants.CFG_REGION));

case RDF:
return new NeptuneSparqlConnection(configOptions.get(Constants.CFG_ENDPOINT),
configOptions.get(Constants.CFG_PORT), Boolean.parseBoolean(configOptions.get(Constants.CFG_IAM)),
configOptions.get(Constants.CFG_REGION));

default:
throw new IllegalArgumentException("Unsupported graphType: " + graphType);
}
}

public String getNeptuneEndpoint()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import com.amazonaws.athena.connector.lambda.QueryStatusChecker;
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
import com.amazonaws.athena.connector.lambda.data.BlockWriter;
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
Expand All @@ -36,22 +39,30 @@
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.glue.GlueFieldLexer;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.neptune.propertygraph.PropertyGraphHandler;
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.glue.AWSGlue;
import com.amazonaws.services.glue.model.GetTablesRequest;
import com.amazonaws.services.glue.model.GetTablesResult;
import com.amazonaws.services.glue.model.Table;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversal;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -80,6 +91,9 @@ public class NeptuneMetadataHandler extends GlueMetadataHandler
private final AWSGlue glue;
private final String glueDBName;

private NeptuneConnection neptuneConnection = null;
private final NeptuneQueryPassthrough queryPassthrough = new NeptuneQueryPassthrough();

public NeptuneMetadataHandler(java.util.Map<String, String> configOptions)
{
super(SOURCE_TYPE, configOptions);
Expand All @@ -88,6 +102,7 @@ public NeptuneMetadataHandler(java.util.Map<String, String> configOptions)
// check is enforcing this connector's contract.
requireNonNull(this.glue);
this.glueDBName = configOptions.get("glue_database_name");
this.neptuneConnection = NeptuneConnection.createConnection(configOptions);
}

@VisibleForTesting
Expand All @@ -105,6 +120,15 @@ protected NeptuneMetadataHandler(
this.glue = glue;
this.glueDBName = configOptions.get("glue_database_name");
}

@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

/**
* Since the entire Neptune cluster is considered as a single graph database,
Expand Down Expand Up @@ -243,4 +267,60 @@ protected Field convertField(String name, String glueType)
{
return GlueFieldLexer.lex(name, glueType);
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
{
Map<String, String> qptArguments = request.getQueryPassthroughArguments();
queryPassthrough.verify(qptArguments);
String schemaName = qptArguments.get(NeptuneQueryPassthrough.DATABASE);
String tableName = qptArguments.get(NeptuneQueryPassthrough.COLLECTION);
TableName tableNameObj = new TableName(schemaName, tableName);
request = new GetTableRequest(request.getIdentity(), request.getQueryId(),
request.getCatalogName(), tableNameObj, request.getQueryPassthroughArguments());

GetTableResponse getTableResponse = doGetTable(allocator, request);
List<Field> fields = getTableResponse.getSchema().getFields();
Client client = neptuneConnection.getNeptuneClientConnection();
GraphTraversalSource graphTraversalSource = neptuneConnection.getTraversalSource(client);
String gremlinQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY);
gremlinQuery = gremlinQuery.concat(".limit(1)");
logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with limit: " + gremlinQuery);
Object object = new PropertyGraphHandler(neptuneConnection).getResponseFromGremlinQuery(graphTraversalSource, gremlinQuery);
GraphTraversal graphTraversalForSchema = (GraphTraversal) object;
Map graphTraversalObj = null;
Schema schema;
while (graphTraversalForSchema.hasNext()) {
graphTraversalObj = (Map) graphTraversalForSchema.next();
}

//In case of gremlin query is fetching all columns then we can use same schema from glue
//otherwise we will build schema from gremlin query result column names.
if (graphTraversalObj != null && graphTraversalObj.size() == fields.size()) {
schema = getTableResponse.getSchema();
}
else {
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
//Building schema from gremlin query results and list of fields from glue response.
//It's require only when we are selecting limited columns.
graphTraversalObj.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), fields, schemaBuilder));
Map<String, String> metaData = getTableResponse.getSchema().getCustomMetadata();
for (Map.Entry<String, String> map : metaData.entrySet()) {
schemaBuilder.addMetadata(map.getKey(), map.getValue());
}
schema = schemaBuilder.build();
}

return new GetTableResponse(request.getCatalogName(), tableNameObj, schema);
}

private void buildSchema(String columnName, List<Field> fields, SchemaBuilder schemaBuilder)
{
for (Field field : fields) {
if (field.getName().equals(columnName)) {
schemaBuilder.addField(field);
break;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
import com.amazonaws.athena.connector.lambda.handlers.RecordHandler;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connectors.neptune.Enums.GraphType;
import com.amazonaws.athena.connectors.neptune.propertygraph.NeptuneGremlinConnection;
import com.amazonaws.athena.connectors.neptune.propertygraph.PropertyGraphHandler;
import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection;
import com.amazonaws.athena.connectors.neptune.rdf.RDFHandler;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.athena.AmazonAthenaClientBuilder;
Expand Down Expand Up @@ -64,34 +62,13 @@ public class NeptuneRecordHandler extends RecordHandler
private NeptuneConnection neptuneConnection = null;
private java.util.Map<String, String> configOptions = null;

private static NeptuneConnection createConnection(java.util.Map<String, String> configOptions)
{
GraphType graphType = GraphType.PROPERTYGRAPH;
if (configOptions.get(Constants.CFG_GRAPH_TYPE) != null) {
graphType = GraphType.valueOf(configOptions.get(Constants.CFG_GRAPH_TYPE).toUpperCase());
}

switch(graphType){
case PROPERTYGRAPH:
return new NeptuneGremlinConnection(configOptions.get(Constants.CFG_ENDPOINT),
configOptions.get(Constants.CFG_PORT), Boolean.parseBoolean(configOptions.get(Constants.CFG_IAM)),
configOptions.get(Constants.CFG_REGION));

case RDF:
return new NeptuneSparqlConnection(configOptions.get(Constants.CFG_ENDPOINT),
configOptions.get(Constants.CFG_PORT), Boolean.parseBoolean(configOptions.get(Constants.CFG_IAM)),
configOptions.get(Constants.CFG_REGION));
}
return null;
}

public NeptuneRecordHandler(java.util.Map<String, String> configOptions)
{
this(
AmazonS3ClientBuilder.defaultClient(),
AWSSecretsManagerClientBuilder.defaultClient(),
AmazonAthenaClientBuilder.defaultClient(),
createConnection(configOptions),
NeptuneConnection.createConnection(configOptions),
configOptions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,22 @@
import com.amazonaws.athena.connectors.neptune.propertygraph.rowwriters.CustomSchemaRowWriter;
import com.amazonaws.athena.connectors.neptune.propertygraph.rowwriters.EdgeRowWriter;
import com.amazonaws.athena.connectors.neptune.propertygraph.rowwriters.VertexRowWriter;
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.driver.Result;
import org.apache.tinkerpop.gremlin.groovy.jsr223.GremlinGroovyScriptEngine;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversal;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.step.util.WithOptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.script.Bindings;
import javax.script.ScriptEngine;
import javax.script.ScriptException;

import java.util.Iterator;
import java.util.Map;

Expand Down Expand Up @@ -67,6 +73,7 @@ public class PropertyGraphHandler
*/

private final NeptuneConnection neptuneConnection;
private final NeptuneQueryPassthrough queryPassthrough = new NeptuneQueryPassthrough();

@VisibleForTesting
public PropertyGraphHandler(NeptuneConnection neptuneConnection)
Expand Down Expand Up @@ -105,7 +112,20 @@ public void executeQuery(
Client client = neptuneConnection.getNeptuneClientConnection();
GraphTraversalSource graphTraversalSource = neptuneConnection.getTraversalSource(client);
GraphTraversal graphTraversal = null;
String labelName = recordsRequest.getTableName().getTableName();
String labelName;
String gremlinQuery = null;
if (recordsRequest.getConstraints().isQueryPassThrough()) {
Map<String, String> qptArguments = recordsRequest.getConstraints().getQueryPassthroughArguments();
queryPassthrough.verify(qptArguments);
labelName = qptArguments.get(NeptuneQueryPassthrough.COLLECTION);
gremlinQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY);

Object object = getResponseFromGremlinQuery(graphTraversalSource, gremlinQuery);
graphTraversal = (GraphTraversal) object;
}
else {
labelName = recordsRequest.getTableName().getTableName();
}
GeneratedRowWriter.RowWriterBuilder builder = GeneratedRowWriter.newBuilder(recordsRequest.getConstraints());
String type = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_COMPONENT_TYPE);
String glabel = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_GLABEL);
Expand All @@ -121,8 +141,10 @@ public void executeQuery(
if (tableSchemaMetaType != null) {
switch (tableSchemaMetaType) {
case VERTEX:
graphTraversal = graphTraversalSource.V().hasLabel(labelName);
graphTraversal = graphTraversal.valueMap().with(WithOptions.tokens);
if (!recordsRequest.getConstraints().isQueryPassThrough()) {
graphTraversal = graphTraversalSource.V().hasLabel(labelName);
graphTraversal = graphTraversal.valueMap().with(WithOptions.tokens);
}

for (final Field nextField : recordsRequest.getSchema().getFields()) {
VertexRowWriter.writeRowTemplate(builder, nextField, configOptions);
Expand All @@ -133,8 +155,10 @@ public void executeQuery(
break;

case EDGE:
graphTraversal = graphTraversalSource.E().hasLabel(labelName);
graphTraversal = graphTraversal.elementMap();
if (!recordsRequest.getConstraints().isQueryPassThrough()) {
graphTraversal = graphTraversalSource.E().hasLabel(labelName);
graphTraversal = graphTraversal.elementMap();
}

for (final Field nextField : recordsRequest.getSchema().getFields()) {
EdgeRowWriter.writeRowTemplate(builder, nextField, configOptions);
Expand All @@ -144,8 +168,15 @@ public void executeQuery(

break;

case VIEW:
String query = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_QUERY);
case VIEW:
String query;
if (recordsRequest.getConstraints().isQueryPassThrough()) {
query = gremlinQuery;
}
else {
query = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_QUERY);
}

Iterator<Result> resultIterator = client.submit(query).iterator();

for (final Field nextField : recordsRequest.getSchema().getFields()) {
Expand All @@ -166,6 +197,14 @@ public void executeQuery(
}
}

public Object getResponseFromGremlinQuery(GraphTraversalSource graphTraversalSource, String gremlinQuery) throws ScriptException
{
ScriptEngine engine = new GremlinGroovyScriptEngine();
Bindings bindings = engine.createBindings();
bindings.put("g", graphTraversalSource);
return engine.eval(gremlinQuery, bindings);
}

private void parseNodeOrEdge(final QueryStatusChecker queryStatusChecker, final BlockSpiller spiller, long numRows,
GraphTraversal graphTraversal, GeneratedRowWriter.RowWriterBuilder builder)
{
Expand Down
Loading

0 comments on commit 6fd2923

Please sign in to comment.