diff --git a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java index f4cf8a8a72..d4a925a691 100644 --- a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java +++ b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java @@ -43,6 +43,7 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.neptune.propertygraph.PropertyGraphHandler; import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough; +import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection; import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.glue.AWSGlue; import com.amazonaws.services.glue.model.GetTablesRequest; @@ -281,37 +282,69 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge GetTableResponse getTableResponse = doGetTable(allocator, request); List fields = getTableResponse.getSchema().getFields(); - Client client = neptuneConnection.getNeptuneClientConnection(); - GraphTraversalSource graphTraversalSource = neptuneConnection.getTraversalSource(client); - String gremlinQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY); - gremlinQuery = gremlinQuery.concat(".limit(1)"); - logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with limit: " + gremlinQuery); - Object object = new PropertyGraphHandler(neptuneConnection).getResponseFromGremlinQuery(graphTraversalSource, gremlinQuery); - GraphTraversal graphTraversalForSchema = (GraphTraversal) object; - Map graphTraversalObj = null; Schema schema; - while (graphTraversalForSchema.hasNext()) { - graphTraversalObj = (Map) graphTraversalForSchema.next(); + Enums.GraphType graphType = Enums.GraphType.PROPERTYGRAPH; + + if (configOptions.get(Constants.CFG_GRAPH_TYPE) != null) { + graphType = Enums.GraphType.valueOf(configOptions.get(Constants.CFG_GRAPH_TYPE).toUpperCase()); } - //In case of gremlin query is fetching all columns then we can use same schema from glue - //otherwise we will build schema from gremlin query result column names. - if (graphTraversalObj != null && graphTraversalObj.size() == fields.size()) { + switch (graphType){ + case PROPERTYGRAPH: + Client client = neptuneConnection.getNeptuneClientConnection(); + GraphTraversalSource graphTraversalSource = neptuneConnection.getTraversalSource(client); + String gremlinQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY); + gremlinQuery = gremlinQuery.concat(".limit(1)"); + logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with limit: " + gremlinQuery); + Object object = new PropertyGraphHandler(neptuneConnection).getResponseFromGremlinQuery(graphTraversalSource, gremlinQuery); + GraphTraversal graphTraversalForSchema = (GraphTraversal) object; + Map graphTraversalObj = null; + if (graphTraversalForSchema.hasNext()) { + graphTraversalObj = (Map) graphTraversalForSchema.next(); + } + schema = getSchemaFromResults(getTableResponse, graphTraversalObj, fields); + return new GetTableResponse(request.getCatalogName(), tableNameObj, schema); + + case RDF: + String sparqlQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY); + sparqlQuery = sparqlQuery.contains("limit") ? sparqlQuery : sparqlQuery.concat("\nlimit 1"); + logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema sparql query with limit: " + sparqlQuery); + NeptuneSparqlConnection neptuneSparqlConnection = (NeptuneSparqlConnection) neptuneConnection; + neptuneSparqlConnection.runQuery(sparqlQuery); + String strim = getTableResponse.getSchema().getCustomMetadata().get(Constants.SCHEMA_STRIP_URI); + boolean trimURI = strim == null ? false : Boolean.parseBoolean(strim); + Map resultsMap = null; + if (neptuneSparqlConnection.hasNext()) { + resultsMap = neptuneSparqlConnection.next(trimURI); + } + schema = getSchemaFromResults(getTableResponse, resultsMap, fields); + return new GetTableResponse(request.getCatalogName(), tableNameObj, schema); + + default: + throw new IllegalArgumentException("Unsupported graphType: " + graphType); + } + } + + private Schema getSchemaFromResults(GetTableResponse getTableResponse, Map resultsMap, List fields) + { + Schema schema; + //In case of 'gremlin/sparql query' is fetching all columns then we can use same schema from glue + //otherwise we will build schema from gremlin/sparql query result column names. + if (resultsMap != null && resultsMap.size() == fields.size()) { schema = getTableResponse.getSchema(); } else { SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); - //Building schema from gremlin query results and list of fields from glue response. + //Building schema from gremlin/sparql query results and list of fields from glue response. //It's require only when we are selecting limited columns. - graphTraversalObj.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), fields, schemaBuilder)); + resultsMap.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), fields, schemaBuilder)); Map metaData = getTableResponse.getSchema().getCustomMetadata(); for (Map.Entry map : metaData.entrySet()) { schemaBuilder.addMetadata(map.getKey(), map.getValue()); } schema = schemaBuilder.build(); } - - return new GetTableResponse(request.getCatalogName(), tableNameObj, schema); + return schema; } private void buildSchema(String columnName, List fields, SchemaBuilder schemaBuilder) diff --git a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/rdf/RDFHandler.java b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/rdf/RDFHandler.java index faa89fc648..0491c9bb51 100644 --- a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/rdf/RDFHandler.java +++ b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/rdf/RDFHandler.java @@ -26,6 +26,7 @@ import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.neptune.Constants; import com.amazonaws.athena.connectors.neptune.NeptuneConnection; +import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough; import org.apache.arrow.vector.types.pojo.Field; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,6 +59,7 @@ public class RDFHandler */ private final NeptuneSparqlConnection neptuneConnection; + private final NeptuneQueryPassthrough queryPassthrough = new NeptuneQueryPassthrough(); // @VisibleForTesting public RDFHandler(NeptuneConnection neptuneConnection) throws Exception @@ -97,63 +99,71 @@ public RDFHandler(NeptuneConnection neptuneConnection) throws Exception public void executeQuery(ReadRecordsRequest recordsRequest, final QueryStatusChecker queryStatusChecker, final BlockSpiller spiller, java.util.Map configOptions) throws Exception { - // 1. Get the specified prefixes - Map prefixMap = new HashMap<>(); - StringBuilder prefixBlock = new StringBuilder(""); - for (String k : recordsRequest.getSchema().getCustomMetadata().keySet()) { - if (k.startsWith(Constants.PREFIX_KEY)) { - String pfx = k.substring(Constants.PREFIX_LEN); - String val = recordsRequest.getSchema().getCustomMetadata().get(k); - prefixMap.put(pfx, val); - prefixBlock.append("PREFIX " + pfx + ": <" + val + ">\n"); - } - } - - // 2. Build the SPARQL query - String queryMode = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_QUERY_MODE); - if (queryMode == null) { - throw new RuntimeException("Mandatory: querymode"); + StringBuilder sparql; + if (recordsRequest.getConstraints().isQueryPassThrough()) { + Map qptArguments = recordsRequest.getConstraints().getQueryPassthroughArguments(); + queryPassthrough.verify(qptArguments); + sparql = new StringBuilder(qptArguments.get(NeptuneQueryPassthrough.QUERY)); } - queryMode = queryMode.toLowerCase(); - StringBuilder sparql = new StringBuilder(prefixBlock + "\n"); - if (queryMode.equals(Constants.QUERY_MODE_SPARQL)) { - String sparqlParam = recordsRequest.getSchema().getCustomMetadata().get(Constants.QUERY_MODE_SPARQL); - if (sparqlParam == null) { - throw new RuntimeException("Mandatory: sparql when querympde=sparql"); - } - sparql.append("\n" + sparqlParam); - } - else if (queryMode.equals(Constants.QUERY_MODE_CLASS)) { - String classURI = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_CLASS_URI); - String predsPrefix = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_PREDS_PREFIX); - String subject = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_SUBJECT); - if (classURI == null) { - throw new RuntimeException("Mandatory: classuri when querymode=class"); - } - if (predsPrefix == null) { - throw new RuntimeException("Mandatory: predsPrefix when querymode=class"); + else { + // 1. Get the specified prefixes + Map prefixMap = new HashMap<>(); + StringBuilder prefixBlock = new StringBuilder(""); + for (String k : recordsRequest.getSchema().getCustomMetadata().keySet()) { + if (k.startsWith(Constants.PREFIX_KEY)) { + String pfx = k.substring(Constants.PREFIX_LEN); + String val = recordsRequest.getSchema().getCustomMetadata().get(k); + prefixMap.put(pfx, val); + prefixBlock.append("PREFIX " + pfx + ": <" + val + ">\n"); + } } - if (subject == null) { - throw new RuntimeException("Mandatory:subject when querymode=class"); + + // 2. Build the SPARQL query + String queryMode = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_QUERY_MODE); + if (queryMode == null) { + throw new RuntimeException("Mandatory: querymode"); } - sparql.append("\nselect "); - for (Field prop : recordsRequest.getSchema().getFields()) { - sparql.append("?" + prop.getName() + " "); + queryMode = queryMode.toLowerCase(); + sparql = new StringBuilder(prefixBlock + "\n"); + if (queryMode.equals(Constants.QUERY_MODE_SPARQL)) { + String sparqlParam = recordsRequest.getSchema().getCustomMetadata().get(Constants.QUERY_MODE_SPARQL); + if (sparqlParam == null) { + throw new RuntimeException("Mandatory: sparql when querympde=sparql"); + } + sparql.append("\n" + sparqlParam); } - sparql.append(" WHERE {"); - sparql.append("\n?" + subject + " a " + classURI + " . "); - for (Field prop : recordsRequest.getSchema().getFields()) { - if (!prop.getName().equals(subject)) { - sparql.append("\n?" + subject + " " + predsPrefix + ":" + prop.getName() + " ?" + prop.getName() - + " ."); + else if (queryMode.equals(Constants.QUERY_MODE_CLASS)) { + String classURI = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_CLASS_URI); + String predsPrefix = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_PREDS_PREFIX); + String subject = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_SUBJECT); + if (classURI == null) { + throw new RuntimeException("Mandatory: classuri when querymode=class"); } + if (predsPrefix == null) { + throw new RuntimeException("Mandatory: predsPrefix when querymode=class"); + } + if (subject == null) { + throw new RuntimeException("Mandatory:subject when querymode=class"); + } + sparql.append("\nselect "); + for (Field prop : recordsRequest.getSchema().getFields()) { + sparql.append("?" + prop.getName() + " "); + } + sparql.append(" WHERE {"); + sparql.append("\n?" + subject + " a " + classURI + " . "); + for (Field prop : recordsRequest.getSchema().getFields()) { + if (!prop.getName().equals(subject)) { + sparql.append("\n?" + subject + " " + predsPrefix + ":" + prop.getName() + " ?" + prop.getName() + + " ."); + } + } + sparql.append(" }"); + } + else { + throw new RuntimeException("Illegal RDF params"); } - sparql.append(" }"); - } - else { - throw new RuntimeException("Illegal RDF params"); } - + // 3. Create the builder and add row writer exttractors for each field GeneratedRowWriter.RowWriterBuilder builder = GeneratedRowWriter .newBuilder(recordsRequest.getConstraints());