Skip to content

Commit

Permalink
Extended QPT to Neptune RDF (#1933)
Browse files Browse the repository at this point in the history
  • Loading branch information
VenkatasivareddyTR authored Apr 26, 2024
1 parent 44c1398 commit 2296ee8
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.neptune.propertygraph.PropertyGraphHandler;
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough;
import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.glue.AWSGlue;
import com.amazonaws.services.glue.model.GetTablesRequest;
Expand Down Expand Up @@ -281,37 +282,69 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge

GetTableResponse getTableResponse = doGetTable(allocator, request);
List<Field> fields = getTableResponse.getSchema().getFields();
Client client = neptuneConnection.getNeptuneClientConnection();
GraphTraversalSource graphTraversalSource = neptuneConnection.getTraversalSource(client);
String gremlinQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY);
gremlinQuery = gremlinQuery.concat(".limit(1)");
logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with limit: " + gremlinQuery);
Object object = new PropertyGraphHandler(neptuneConnection).getResponseFromGremlinQuery(graphTraversalSource, gremlinQuery);
GraphTraversal graphTraversalForSchema = (GraphTraversal) object;
Map graphTraversalObj = null;
Schema schema;
while (graphTraversalForSchema.hasNext()) {
graphTraversalObj = (Map) graphTraversalForSchema.next();
Enums.GraphType graphType = Enums.GraphType.PROPERTYGRAPH;

if (configOptions.get(Constants.CFG_GRAPH_TYPE) != null) {
graphType = Enums.GraphType.valueOf(configOptions.get(Constants.CFG_GRAPH_TYPE).toUpperCase());
}

//In case of gremlin query is fetching all columns then we can use same schema from glue
//otherwise we will build schema from gremlin query result column names.
if (graphTraversalObj != null && graphTraversalObj.size() == fields.size()) {
switch (graphType){
case PROPERTYGRAPH:
Client client = neptuneConnection.getNeptuneClientConnection();
GraphTraversalSource graphTraversalSource = neptuneConnection.getTraversalSource(client);
String gremlinQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY);
gremlinQuery = gremlinQuery.concat(".limit(1)");
logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with limit: " + gremlinQuery);
Object object = new PropertyGraphHandler(neptuneConnection).getResponseFromGremlinQuery(graphTraversalSource, gremlinQuery);
GraphTraversal graphTraversalForSchema = (GraphTraversal) object;
Map graphTraversalObj = null;
if (graphTraversalForSchema.hasNext()) {
graphTraversalObj = (Map) graphTraversalForSchema.next();
}
schema = getSchemaFromResults(getTableResponse, graphTraversalObj, fields);
return new GetTableResponse(request.getCatalogName(), tableNameObj, schema);

case RDF:
String sparqlQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY);
sparqlQuery = sparqlQuery.contains("limit") ? sparqlQuery : sparqlQuery.concat("\nlimit 1");
logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema sparql query with limit: " + sparqlQuery);
NeptuneSparqlConnection neptuneSparqlConnection = (NeptuneSparqlConnection) neptuneConnection;
neptuneSparqlConnection.runQuery(sparqlQuery);
String strim = getTableResponse.getSchema().getCustomMetadata().get(Constants.SCHEMA_STRIP_URI);
boolean trimURI = strim == null ? false : Boolean.parseBoolean(strim);
Map<String, Object> resultsMap = null;
if (neptuneSparqlConnection.hasNext()) {
resultsMap = neptuneSparqlConnection.next(trimURI);
}
schema = getSchemaFromResults(getTableResponse, resultsMap, fields);
return new GetTableResponse(request.getCatalogName(), tableNameObj, schema);

default:
throw new IllegalArgumentException("Unsupported graphType: " + graphType);
}
}

private Schema getSchemaFromResults(GetTableResponse getTableResponse, Map resultsMap, List<Field> fields)
{
Schema schema;
//In case of 'gremlin/sparql query' is fetching all columns then we can use same schema from glue
//otherwise we will build schema from gremlin/sparql query result column names.
if (resultsMap != null && resultsMap.size() == fields.size()) {
schema = getTableResponse.getSchema();
}
else {
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
//Building schema from gremlin query results and list of fields from glue response.
//Building schema from gremlin/sparql query results and list of fields from glue response.
//It's require only when we are selecting limited columns.
graphTraversalObj.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), fields, schemaBuilder));
resultsMap.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), fields, schemaBuilder));
Map<String, String> metaData = getTableResponse.getSchema().getCustomMetadata();
for (Map.Entry<String, String> map : metaData.entrySet()) {
schemaBuilder.addMetadata(map.getKey(), map.getValue());
}
schema = schemaBuilder.build();
}

return new GetTableResponse(request.getCatalogName(), tableNameObj, schema);
return schema;
}

private void buildSchema(String columnName, List<Field> fields, SchemaBuilder schemaBuilder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connectors.neptune.Constants;
import com.amazonaws.athena.connectors.neptune.NeptuneConnection;
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough;
import org.apache.arrow.vector.types.pojo.Field;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -58,6 +59,7 @@ public class RDFHandler
*/

private final NeptuneSparqlConnection neptuneConnection;
private final NeptuneQueryPassthrough queryPassthrough = new NeptuneQueryPassthrough();

// @VisibleForTesting
public RDFHandler(NeptuneConnection neptuneConnection) throws Exception
Expand Down Expand Up @@ -97,63 +99,71 @@ public RDFHandler(NeptuneConnection neptuneConnection) throws Exception
public void executeQuery(ReadRecordsRequest recordsRequest, final QueryStatusChecker queryStatusChecker,
final BlockSpiller spiller, java.util.Map<String, String> configOptions) throws Exception
{
// 1. Get the specified prefixes
Map<String, String> prefixMap = new HashMap<>();
StringBuilder prefixBlock = new StringBuilder("");
for (String k : recordsRequest.getSchema().getCustomMetadata().keySet()) {
if (k.startsWith(Constants.PREFIX_KEY)) {
String pfx = k.substring(Constants.PREFIX_LEN);
String val = recordsRequest.getSchema().getCustomMetadata().get(k);
prefixMap.put(pfx, val);
prefixBlock.append("PREFIX " + pfx + ": <" + val + ">\n");
}
}

// 2. Build the SPARQL query
String queryMode = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_QUERY_MODE);
if (queryMode == null) {
throw new RuntimeException("Mandatory: querymode");
StringBuilder sparql;
if (recordsRequest.getConstraints().isQueryPassThrough()) {
Map<String, String> qptArguments = recordsRequest.getConstraints().getQueryPassthroughArguments();
queryPassthrough.verify(qptArguments);
sparql = new StringBuilder(qptArguments.get(NeptuneQueryPassthrough.QUERY));
}
queryMode = queryMode.toLowerCase();
StringBuilder sparql = new StringBuilder(prefixBlock + "\n");
if (queryMode.equals(Constants.QUERY_MODE_SPARQL)) {
String sparqlParam = recordsRequest.getSchema().getCustomMetadata().get(Constants.QUERY_MODE_SPARQL);
if (sparqlParam == null) {
throw new RuntimeException("Mandatory: sparql when querympde=sparql");
}
sparql.append("\n" + sparqlParam);
}
else if (queryMode.equals(Constants.QUERY_MODE_CLASS)) {
String classURI = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_CLASS_URI);
String predsPrefix = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_PREDS_PREFIX);
String subject = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_SUBJECT);
if (classURI == null) {
throw new RuntimeException("Mandatory: classuri when querymode=class");
}
if (predsPrefix == null) {
throw new RuntimeException("Mandatory: predsPrefix when querymode=class");
else {
// 1. Get the specified prefixes
Map<String, String> prefixMap = new HashMap<>();
StringBuilder prefixBlock = new StringBuilder("");
for (String k : recordsRequest.getSchema().getCustomMetadata().keySet()) {
if (k.startsWith(Constants.PREFIX_KEY)) {
String pfx = k.substring(Constants.PREFIX_LEN);
String val = recordsRequest.getSchema().getCustomMetadata().get(k);
prefixMap.put(pfx, val);
prefixBlock.append("PREFIX " + pfx + ": <" + val + ">\n");
}
}
if (subject == null) {
throw new RuntimeException("Mandatory:subject when querymode=class");

// 2. Build the SPARQL query
String queryMode = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_QUERY_MODE);
if (queryMode == null) {
throw new RuntimeException("Mandatory: querymode");
}
sparql.append("\nselect ");
for (Field prop : recordsRequest.getSchema().getFields()) {
sparql.append("?" + prop.getName() + " ");
queryMode = queryMode.toLowerCase();
sparql = new StringBuilder(prefixBlock + "\n");
if (queryMode.equals(Constants.QUERY_MODE_SPARQL)) {
String sparqlParam = recordsRequest.getSchema().getCustomMetadata().get(Constants.QUERY_MODE_SPARQL);
if (sparqlParam == null) {
throw new RuntimeException("Mandatory: sparql when querympde=sparql");
}
sparql.append("\n" + sparqlParam);
}
sparql.append(" WHERE {");
sparql.append("\n?" + subject + " a " + classURI + " . ");
for (Field prop : recordsRequest.getSchema().getFields()) {
if (!prop.getName().equals(subject)) {
sparql.append("\n?" + subject + " " + predsPrefix + ":" + prop.getName() + " ?" + prop.getName()
+ " .");
else if (queryMode.equals(Constants.QUERY_MODE_CLASS)) {
String classURI = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_CLASS_URI);
String predsPrefix = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_PREDS_PREFIX);
String subject = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_SUBJECT);
if (classURI == null) {
throw new RuntimeException("Mandatory: classuri when querymode=class");
}
if (predsPrefix == null) {
throw new RuntimeException("Mandatory: predsPrefix when querymode=class");
}
if (subject == null) {
throw new RuntimeException("Mandatory:subject when querymode=class");
}
sparql.append("\nselect ");
for (Field prop : recordsRequest.getSchema().getFields()) {
sparql.append("?" + prop.getName() + " ");
}
sparql.append(" WHERE {");
sparql.append("\n?" + subject + " a " + classURI + " . ");
for (Field prop : recordsRequest.getSchema().getFields()) {
if (!prop.getName().equals(subject)) {
sparql.append("\n?" + subject + " " + predsPrefix + ":" + prop.getName() + " ?" + prop.getName()
+ " .");
}
}
sparql.append(" }");
}
else {
throw new RuntimeException("Illegal RDF params");
}
sparql.append(" }");
}
else {
throw new RuntimeException("Illegal RDF params");
}

// 3. Create the builder and add row writer exttractors for each field
GeneratedRowWriter.RowWriterBuilder builder = GeneratedRowWriter
.newBuilder(recordsRequest.getConstraints());
Expand Down

0 comments on commit 2296ee8

Please sign in to comment.