Skip to content

Commit

Permalink
Enable glue support with docdb (#1804)
Browse files Browse the repository at this point in the history
  • Loading branch information
ejeffrli authored Mar 7, 2024
1 parent c0e1d44 commit ba27e15
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.glue.AWSGlue;
import com.amazonaws.services.glue.model.Database;
import com.amazonaws.services.glue.model.Table;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.base.Strings;
Expand All @@ -54,7 +55,9 @@
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -90,6 +93,8 @@ public class DocDBMetadataHandler
private static final TableFilter TABLE_FILTER = (Table table) -> table.getParameters().containsKey(DOCDB_METADATA_FLAG);
//The number of documents to scan when attempting to infer schema from an DocDB collection.
private static final int SCHEMA_INFERRENCE_NUM_DOCS = 10;
// used to filter out Glue databases which lack the docdb-metadata-flag in the URI.
private static final DatabaseFilter DB_FILTER = (Database database) -> (database.getLocationUri() != null && database.getLocationUri().contains(DOCDB_METADATA_FLAG));

private final AWSGlue glue;
private final DocDBConnectionFactory connectionFactory;
Expand Down Expand Up @@ -144,8 +149,18 @@ private String getConnStr(MetadataRequest request)
* @see GlueMetadataHandler
*/
@Override
public ListSchemasResponse doListSchemaNames(BlockAllocator blockAllocator, ListSchemasRequest request)
public ListSchemasResponse doListSchemaNames(BlockAllocator blockAllocator, ListSchemasRequest request) throws Exception
{
Set<String> combinedSchemas = new LinkedHashSet<>();
if (glue != null) {
try {
combinedSchemas.addAll(super.doListSchemaNames(blockAllocator, request, DB_FILTER).getSchemas());
}
catch (RuntimeException e) {
logger.warn("doListSchemaNames: Unable to retrieve schemas from AWSGlue.", e);
}
}

List<String> schemas = new ArrayList<>();
MongoClient client = getOrCreateConn(request);
try (MongoCursor<String> itr = client.listDatabaseNames().iterator()) {
Expand All @@ -156,9 +171,9 @@ public ListSchemasResponse doListSchemaNames(BlockAllocator blockAllocator, List
schemas.add(schema);
}
}

return new ListSchemasResponse(request.getCatalogName(), schemas);
combinedSchemas.addAll(schemas);
}
return new ListSchemasResponse(request.getCatalogName(), combinedSchemas);
}

/**
Expand All @@ -168,11 +183,23 @@ public ListSchemasResponse doListSchemaNames(BlockAllocator blockAllocator, List
* @see GlueMetadataHandler
*/
@Override
public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTablesRequest request)
public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTablesRequest request) throws Exception
{
logger.info("Getting tables in {}", request.getSchemaName());
Set<TableName> combinedTables = new LinkedHashSet<>();
String token = request.getNextToken();
if (token == null && glue != null) {
try {
combinedTables.addAll(super.doListTables(blockAllocator, new ListTablesRequest(request.getIdentity(), request.getQueryId(),
request.getCatalogName(), request.getSchemaName(), null, UNLIMITED_PAGE_SIZE_VALUE), TABLE_FILTER).getTables());
}
catch (RuntimeException e) {
logger.warn("doListTables: Unable to retrieve tables from AWSGlue in database/schema {}", request.getSchemaName(), e);
}
}

MongoClient client = getOrCreateConn(request);
Stream<String> tableNames = doListTablesWithCommand(client, request);

int startToken = request.getNextToken() != null ? Integer.parseInt(request.getNextToken()) : 0;
int pageSize = request.getPageSize();
String nextToken = null;
Expand All @@ -184,8 +211,9 @@ public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTables
}

List<TableName> paginatedTables = tableNames.map(tableName -> new TableName(request.getSchemaName(), tableName)).collect(Collectors.toList());
combinedTables.addAll(paginatedTables);
logger.info("doListTables returned {} tables. Next token is {}", paginatedTables.size(), nextToken);
return new ListTablesResponse(request.getCatalogName(), paginatedTables, nextToken);
return new ListTablesResponse(request.getCatalogName(), new ArrayList<>(combinedTables), nextToken);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void tearDown()
}

@Test
public void doListSchemaNames()
public void doListSchemaNames() throws Exception
{
List<String> schemaNames = new ArrayList<>();
schemaNames.add("schema1");
Expand All @@ -146,7 +146,7 @@ public void doListSchemaNames()
}

@Test
public void doListTables()
public void doListTables() throws Exception
{
List<String> tableNames = new ArrayList<>();
tableNames.add("table1");
Expand Down

0 comments on commit ba27e15

Please sign in to comment.