diff --git a/athena-vertica/athena-vertica.yaml b/athena-vertica/athena-vertica.yaml index 551c612311..7914cb652e 100644 --- a/athena-vertica/athena-vertica.yaml +++ b/athena-vertica/athena-vertica.yaml @@ -55,7 +55,7 @@ Parameters: VerticaConnectionString: Description: 'The Vertica connection details to use by default if not catalog specific connection is defined and optionally using SecretsManager (e.g. ${secret_name}).' Type: String - Default: "jdbc:vertica://:/?user=${vertica-username}&password=${vertica-password}" + Default: "vertica://jdbc:vertica://:/?user=${vertica-username}&password=${vertica-password}" PermissionsBoundaryARN: Description: "(Optional) An IAM policy ARN to use as the PermissionsBoundary for the created Lambda function's execution role" Default: '' @@ -79,7 +79,7 @@ Resources: spill_bucket: !Ref SpillBucket spill_prefix: !Ref SpillPrefix export_bucket: !Ref VerticaExportBucket - default_vertica: !Ref VerticaConnectionString + default: !Ref VerticaConnectionString FunctionName: !Sub "${AthenaCatalogName}" Handler: "com.amazonaws.athena.connectors.vertica.VerticaCompositeHandler" diff --git a/athena-vertica/pom.xml b/athena-vertica/pom.xml index a29cfad470..b5b086e449 100644 --- a/athena-vertica/pom.xml +++ b/athena-vertica/pom.xml @@ -70,7 +70,6 @@ com.amazonaws athena-jdbc 2022.47.1 - compile diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConnectionFactory.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConnectionFactory.java deleted file mode 100644 index 5f9532d201..0000000000 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConnectionFactory.java +++ /dev/null @@ -1,98 +0,0 @@ -/*- - * #%L - * athena-vertica - * %% - * Copyright (C) 2019 - 2020 Amazon Web Services - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package com.amazonaws.athena.connectors.vertica; - -import org.apache.arrow.util.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.sql.Connection; -import java.sql.DriverManager; -import java.sql.SQLException; -import java.util.HashMap; -import java.util.Map; - -/** - * Creates and Caches JDBC Connection Instances, using the connection string as the cache key. - * - * @Note Connection String format is expected to be like: - * jdbc:vertica://>:5433/VerticaDB?user&password> - */ - -public class VerticaConnectionFactory -{ - private static final Logger logger = LoggerFactory.getLogger(VerticaConnectionFactory.class); - private final Map clientCache = new HashMap<>(); - private final int CONNECTION_TIMEOUT = 60; - - /** - * Used to get an existing, pooled, connection or to create a new connection - * for the given connection string. - * - * @param connStr Vertica connection details - * @return A JDBC connection if the connection succeeded, else the function will throw. - */ - - public synchronized Connection getOrCreateConn(String connStr) { - Connection result = null; - try { - Class.forName("com.vertica.jdbc.Driver").newInstance(); - result = clientCache.get(connStr); - - if (result == null || !connectionTest(result)) { - logger.info("Unable to reuse an existing connection, creating a new one."); - result = DriverManager.getConnection(connStr); - clientCache.put(connStr, result); - } - } - catch (Exception e) - { - throw new RuntimeException("Error in initializing Vertica JDBC Driver: " + e.getMessage(), e); - } - - return result; - } - - /** - * Runs a 'quick' test on the connection and then returns it if it passes. - */ - private boolean connectionTest(Connection conn) throws SQLException { - try { - return conn.isValid(CONNECTION_TIMEOUT); - } - catch (RuntimeException | SQLException ex) { - logger.warn("getOrCreateConn: Exception while testing existing connection.", ex); - } - return false; - } - - /** - * Injects a connection into the client cache. - * - * @param conStr The connection string (aka the cache key) - * @param conn The connection to inject into the client cache, most often a Mock used in testing. - */ - @VisibleForTesting - protected synchronized void addConnection(String conStr, Connection conn) - { - clientCache.put(conStr, conn); - } - -} diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java new file mode 100644 index 0000000000..dbcf85e8a5 --- /dev/null +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java @@ -0,0 +1,29 @@ +/*- + * #%L + * athena-vertica + * %% + * Copyright (C) 2019 - 2024 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.vertica; + +public final class VerticaConstants +{ + public static final String VERTICA_NAME = "vertica"; + public static final String VERTICA_DRIVER_CLASS = "com.vertica.jdbc.Driver"; + public static final int VERTICA_DEFAULT_PORT = 5433; + + private VerticaConstants() {} +} diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java index a0156f2819..ee40632659 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java @@ -29,7 +29,6 @@ import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; -import com.amazonaws.athena.connector.lambda.handlers.MetadataHandler; import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest; import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest; @@ -39,20 +38,21 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.metadata.ListSchemasRequest; import com.amazonaws.athena.connector.lambda.metadata.ListSchemasResponse; -import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; -import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse; -import com.amazonaws.athena.connector.lambda.metadata.MetadataRequest; import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; -import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; +import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; +import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo; +import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory; +import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; +import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; +import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; +import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough; import com.amazonaws.athena.connectors.vertica.query.QueryFactory; import com.amazonaws.athena.connectors.vertica.query.VerticaExportQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.amazonaws.services.s3.model.ListObjectsRequest; import com.amazonaws.services.s3.model.ObjectListing; import com.amazonaws.services.s3.model.S3ObjectSummary; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -77,11 +77,14 @@ import java.util.Set; import java.util.UUID; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_DEFAULT_PORT; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_DRIVER_CLASS; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_NAME; import static com.amazonaws.athena.connectors.vertica.VerticaSchemaUtils.convertToArrowType; public class VerticaMetadataHandler - extends MetadataHandler + extends JdbcMetadataHandler { private static final Logger logger = LoggerFactory.getLogger(VerticaMetadataHandler.class); @@ -90,62 +93,39 @@ public class VerticaMetadataHandler * used to aid in debugging. Athena will use this name in conjunction with your catalog id * to correlate relevant query errors. */ - private static final String SOURCE_TYPE = "vertica"; - protected static final String VERTICA_CONN_STR = "conn_str"; - private static final String DEFAULT_VERTICA = "default_vertica"; - private static final String TABLE_NAME = "TABLE_NAME"; - private static final String TABLE_SCHEMA = "TABLE_SCHEM"; - private static final String[] TABLE_TYPES = {"TABLE"}; + static final Map JDBC_PROPERTIES = ImmutableMap.of("databaseTerm", "SCHEMA"); private static final String EXPORT_BUCKET_KEY = "export_bucket"; private static final String EMPTY_STRING = StringUtils.EMPTY; - private final VerticaConnectionFactory connectionFactory; + private static final String TABLE_SCHEMA = "TABLE_SCHEM"; + private static final String[] TABLE_TYPES = {"TABLE"}; private final QueryFactory queryFactory = new QueryFactory(); private final VerticaSchemaUtils verticaSchemaUtils; private AmazonS3 amazonS3; - private final VerticaQueryPassthrough queryPassthrough = new VerticaQueryPassthrough(); + private final JdbcQueryPassthrough queryPassthrough = new JdbcQueryPassthrough(); + public VerticaMetadataHandler(Map configOptions) { - super(SOURCE_TYPE, configOptions); - amazonS3 = AmazonS3ClientBuilder.defaultClient(); - connectionFactory = new VerticaConnectionFactory(); - verticaSchemaUtils = new VerticaSchemaUtils(); + this(JDBCUtil.getSingleDatabaseConfigFromEnv(VERTICA_NAME, configOptions), configOptions); } - @VisibleForTesting - protected VerticaMetadataHandler( - EncryptionKeyFactory keyFactory, - VerticaConnectionFactory connectionFactory, - AWSSecretsManager awsSecretsManager, - AmazonAthena athena, - String spillBucket, - String spillPrefix, - VerticaSchemaUtils verticaSchemaUtils, - AmazonS3 amazonS3, - Map configOptions) + public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, Map configOptions) { - super(keyFactory, awsSecretsManager, athena, SOURCE_TYPE, spillBucket, spillPrefix, configOptions); - this.connectionFactory = connectionFactory; - this.verticaSchemaUtils = verticaSchemaUtils; - this.amazonS3 = amazonS3; + this(databaseConnectionConfig, new GenericJdbcConnectionFactory(databaseConnectionConfig, JDBC_PROPERTIES, new DatabaseConnectionInfo(VERTICA_DRIVER_CLASS, VERTICA_DEFAULT_PORT)), configOptions); } - private Connection getConnection(MetadataRequest request) { - String endpoint = resolveSecrets(getConnStr(request)); - return connectionFactory.getOrCreateConn(endpoint); - + public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, Map configOptions) + { + super(databaseConnectionConfig, jdbcConnectionFactory, configOptions); + amazonS3 = AmazonS3ClientBuilder.defaultClient(); + verticaSchemaUtils = new VerticaSchemaUtils(); } - - private String getConnStr(MetadataRequest request) + @VisibleForTesting + public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, Map configOptions, AmazonS3 amazonS3, VerticaSchemaUtils verticaSchemaUtils) { - String conStr = configOptions.get(request.getCatalogName()); - if (conStr == null) { - logger.info("getConnStr: No environment variable found for catalog {} , using default {}", - request.getCatalogName(), DEFAULT_VERTICA); - conStr = configOptions.get(DEFAULT_VERTICA); - } - logger.info("exit getConnStr in VerticaMetadataHandler with conStr as {}",conStr); - return conStr; + super(databaseConnectionConfig, jdbcConnectionFactory, configOptions); + this.amazonS3 = amazonS3; + this.verticaSchemaUtils = verticaSchemaUtils; } /** @@ -158,11 +138,11 @@ private String getConnStr(MetadataRequest request) */ @Override public ListSchemasResponse doListSchemaNames(BlockAllocator allocator, ListSchemasRequest request) - throws SQLException + throws Exception { - logger.info("doListSchemaNames: " + request.getCatalogName()); + logger.info("doListSchemaNames: {}", request.getCatalogName()); List schemas = new ArrayList<>(); - try (Connection client = getConnection(request)) { + try (Connection client = getJdbcConnectionFactory().getConnection(getCredentialProvider())) { DatabaseMetaData dbMetadata = client.getMetaData(); ResultSet rs = dbMetadata.getTables(null, null, null, TABLE_TYPES); @@ -177,30 +157,6 @@ public ListSchemasResponse doListSchemaNames(BlockAllocator allocator, ListSchem return new ListSchemasResponse(request.getCatalogName(), schemas); } - /** - * Used to get the list of tables that this source contains. - * - * @param allocator Tool for creating and managing Apache Arrow Blocks. - * @param request Provides details on who made the request and which Athena catalog and database they are querying. - * @return A ListTablesResponse which primarily contains a List enumerating the tables in this - * catalog, database tuple. It also contains the catalog name corresponding the Athena catalog that was queried. - */ - @Override - public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesRequest request) - throws SQLException - { - logger.info("doListTables: " + request); - List tables = new ArrayList<>(); - try (Connection client = getConnection(request)) { - DatabaseMetaData dbMetadata = client.getMetaData(); - ResultSet table = dbMetadata.getTables(null, request.getSchemaName(),null, TABLE_TYPES); - while (table.next()){ - tables.add(new TableName(table.getString(TABLE_SCHEMA), table.getString(TABLE_NAME))); - } - } - return new ListTablesResponse(request.getCatalogName(), tables, null); - - } protected ArrowType getArrayArrowTypeFromTypeName(String typeName, int precision, int scale) { // Default ARRAY type is VARCHAR. @@ -225,9 +181,9 @@ public GetTableResponse doGetQueryPassthroughSchema(final BlockAllocator blockAl } queryPassthrough.verify(getTableRequest.getQueryPassthroughArguments()); - String customerPassedQuery = getTableRequest.getQueryPassthroughArguments().get(VerticaQueryPassthrough.QUERY); + String customerPassedQuery = getTableRequest.getQueryPassthroughArguments().get(JdbcQueryPassthrough.QUERY); - try (Connection connection = getConnection(getTableRequest)) { + try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) { PreparedStatement preparedStatement = connection.prepareStatement(customerPassedQuery); ResultSetMetaData metadata = preparedStatement.getMetaData(); if (metadata == null) { @@ -247,6 +203,12 @@ public GetTableResponse doGetQueryPassthroughSchema(final BlockAllocator blockAl } } + @Override + public Schema getPartitionSchema(String catalogName) + { + return null; + } + /** * Used to get definition (field names, types, descriptions, etc...) of a Table. * @@ -259,11 +221,10 @@ public GetTableResponse doGetQueryPassthroughSchema(final BlockAllocator blockAl * 4. A catalog name corresponding the Athena catalog that was queried. */ @Override - public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest request) + public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest request) throws Exception { - logger.info("doGetTable: " + request.getTableName()); Set partitionCols = new HashSet<>(); - Connection connection = getConnection(request); + Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider()); //build the schema as per columns in Vertica Schema schema = verticaSchemaUtils.buildTableSchema(connection, request.getTableName()); @@ -300,8 +261,7 @@ public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTabl * @param queryStatusChecker A QueryStatusChecker that you can use to stop doing work for a query that has already terminated */ @Override - public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) throws SQLException { - logger.info("in getPartitions: "+ request); + public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) throws Exception { Schema schemaName = request.getSchema(); TableName tableName = request.getTableName(); @@ -314,7 +274,7 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request String queryID = request.getQueryId().replace("-","").concat(randomStr); //Build the SQL query - Connection connection = getConnection(request); + Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider()); // if QPT get input query from Athena console //else old logic @@ -367,11 +327,14 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request */ @Override public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest request) - throws SQLException { //ToDo: implement use of a continuation token to use in case of larger queries - - Connection connection = getConnection(request); + Connection connection; + try { + connection = getJdbcConnectionFactory().getConnection(getCredentialProvider()); + } catch (Exception e) { + throw new RuntimeException("connection failed ", e); + } Set splits = new HashSet<>(); String exportBucket = getS3ExportBucket(); String queryId = request.getQueryId().replace("-",""); @@ -387,6 +350,7 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest FieldReader fieldReaderPS = request.getPartitions().getFieldReader("preparedStmt"); if (constraints.isQueryPassThrough()) { String preparedSQL = buildQueryPassthroughSql(constraints); + VerticaExportQueryBuilder queryBuilder = queryFactory.createQptVerticaExportQueryBuilder(); sqlStatement = queryBuilder.withS3ExportBucket(s3ExportBucket) .withQueryID(queryID) @@ -418,14 +382,12 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest { split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()) .add("query_id", queryID) - .add(VERTICA_CONN_STR, getConnStr(request)) .add("exportBucket", exportBucket) .add("s3ObjectKey", objectSummary.getKey()) .build(); splits.add(split); } - logger.info("doGetSplits: exit - " + splits.size()); return new GetSplitsResponse(catalogName, splits); } else @@ -434,12 +396,10 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest logger.info("No records were exported by Vertica"); split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()) .add("query_id", queryID) - .add(VERTICA_CONN_STR, getConnStr(request)) .add("exportBucket", exportBucket) .add("s3ObjectKey", EMPTY_STRING) .build(); splits.add(split); - logger.info("doGetSplits: exit - " + splits.size()); return new GetSplitsResponse(catalogName,split); } @@ -450,16 +410,19 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest * and executes the queries */ private void executeQueriesOnVertica(Connection connection, String sqlStatement, String awsRegionSql) - throws SQLException { - PreparedStatement setAwsRegion = connection.prepareStatement(awsRegionSql); - PreparedStatement exportSQL = connection.prepareStatement(sqlStatement); + try { + PreparedStatement setAwsRegion = connection.prepareStatement(awsRegionSql); + PreparedStatement exportSQL = connection.prepareStatement(sqlStatement); - //execute the query to set region - setAwsRegion.execute(); + //execute the query to set region + setAwsRegion.execute(); - //execute the query to export the data to S3 - exportSQL.execute(); + //execute the query to export the data to S3 + exportSQL.execute(); + } catch (SQLException e) { + throw new RuntimeException("Exception in executing query in vertica ", e); + } } /* @@ -505,10 +468,10 @@ public String getS3ExportBucket() return configOptions.get(EXPORT_BUCKET_KEY); } - public String buildQueryPassthroughSql(Constraints constraints) throws SQLException + public String buildQueryPassthroughSql(Constraints constraints) { queryPassthrough.verify(constraints.getQueryPassthroughArguments()); - return constraints.getQueryPassthroughArguments().get(VerticaQueryPassthrough.QUERY); + return constraints.getQueryPassthroughArguments().get(JdbcQueryPassthrough.QUERY); } } diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaQueryPassthrough.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaQueryPassthrough.java deleted file mode 100644 index e8ff0bc37a..0000000000 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaQueryPassthrough.java +++ /dev/null @@ -1,68 +0,0 @@ -/*- - * #%L - * athena-vertica - * %% - * Copyright (C) 2019 - 2024 Amazon Web Services - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package com.amazonaws.athena.connectors.vertica; - -import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Arrays; -import java.util.List; - -public class VerticaQueryPassthrough implements QueryPassthroughSignature -{ - // Constant value representing the name of the query. - public static final String NAME = "query"; - - // Constant value representing the domain of the query. - public static final String SCHEMA_NAME = "system"; - - // List of arguments for the query, statically initialized as it always contains the same value. - public static final String QUERY = "QUERY"; - - public static final List ARGUMENTS = Arrays.asList(QUERY); - - private static final Logger LOGGER = LoggerFactory.getLogger(VerticaQueryPassthrough.class); - - @Override - public String getFunctionSchema() - { - return SCHEMA_NAME; - } - - @Override - public String getFunctionName() - { - return NAME; - } - - @Override - public List getFunctionArguments() - { - return ARGUMENTS; - } - - @Override - public Logger getLogger() - { - return LOGGER; - } - -} diff --git a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaConnectionFactoryTest.java b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaConnectionFactoryTest.java deleted file mode 100644 index 0d7842301e..0000000000 --- a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaConnectionFactoryTest.java +++ /dev/null @@ -1,53 +0,0 @@ -/*- - * #%L - * athena-vertica - * %% - * Copyright (C) 2019 - 2020 Amazon Web Services - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package com.amazonaws.athena.connectors.vertica; - -import org.junit.Before; -import org.junit.Test; - -import java.sql.Connection; -import java.sql.SQLException; -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.*; - -public class VerticaConnectionFactoryTest { - private VerticaConnectionFactory connectionFactory; - private int CONNECTION_TIMEOUT = 60; - - @Before - public void setUp() - { - connectionFactory = new VerticaConnectionFactory(); - } - - @Test - public void clientCacheTest() throws SQLException - { - Connection mockConnection = mock(Connection.class); - when(mockConnection.isValid(CONNECTION_TIMEOUT)).thenReturn(true); - - connectionFactory.addConnection("conStr", mockConnection); - Connection connection = connectionFactory.getOrCreateConn("conStr"); - - assertEquals(mockConnection, connection); - - } - -} diff --git a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java index 70ceb679b5..2a5adad8a4 100644 --- a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java +++ b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java @@ -7,9 +7,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,16 +20,31 @@ package com.amazonaws.athena.connectors.vertica; import com.amazonaws.athena.connector.lambda.QueryStatusChecker; -import com.amazonaws.athena.connector.lambda.data.*; +import com.amazonaws.athena.connector.lambda.data.Block; +import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl; +import com.amazonaws.athena.connector.lambda.data.BlockUtils; +import com.amazonaws.athena.connector.lambda.data.BlockWriter; +import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; import com.amazonaws.athena.connector.lambda.domain.predicate.Range; import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; -import com.amazonaws.athena.connector.lambda.metadata.*; +import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest; +import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse; +import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest; +import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutResponse; +import com.amazonaws.athena.connector.lambda.metadata.ListSchemasRequest; +import com.amazonaws.athena.connector.lambda.metadata.ListSchemasResponse; +import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; +import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse; +import com.amazonaws.athena.connector.lambda.metadata.MetadataRequestType; +import com.amazonaws.athena.connector.lambda.metadata.MetadataResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; -import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; +import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; +import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; +import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.vertica.query.QueryFactory; import com.amazonaws.athena.connectors.vertica.query.VerticaExportQueryBuilder; import com.amazonaws.services.athena.AmazonAthena; @@ -55,15 +70,25 @@ import org.slf4j.LoggerFactory; import org.stringtemplate.v4.ST; -import java.sql.*; -import java.util.*; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.Types; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE; -import static org.junit.Assert.*; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_NAME; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.mockito.ArgumentMatchers.any; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; @RunWith(MockitoJUnitRunner.class) @@ -71,11 +96,10 @@ public class VerticaMetadataHandlerTest extends TestBase { private static final Logger logger = LoggerFactory.getLogger(VerticaMetadataHandlerTest.class); - private static final String[] TABLE_TYPES = {"TABLE"}; + private static final String[] TABLE_TYPES = new String[]{"TABLE"}; private QueryFactory queryFactory; - + private JdbcConnectionFactory jdbcConnectionFactory; private VerticaMetadataHandler verticaMetadataHandler; - private VerticaConnectionFactory verticaConnectionFactory; private VerticaExportQueryBuilder verticaExportQueryBuilder; private VerticaSchemaUtils verticaSchemaUtils; private Connection connection; @@ -98,13 +122,14 @@ public class VerticaMetadataHandlerTest extends TestBase private ListObjectsRequest listObjectsRequest; @Mock private ObjectListing objectListing; + private DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", VERTICA_NAME, + "vertica://jdbc:vertica:thin:username/password@//127.0.0.1:1521/vrt"); @Before - public void setUp() throws SQLException + public void setUp() throws Exception { - this.verticaConnectionFactory = Mockito.mock(VerticaConnectionFactory.class); this.verticaSchemaUtils = Mockito.mock(VerticaSchemaUtils.class); this.queryFactory = Mockito.mock(QueryFactory.class); this.verticaExportQueryBuilder = Mockito.mock(VerticaExportQueryBuilder.class); @@ -122,20 +147,17 @@ public void setUp() throws SQLException this.amazonS3 = Mockito.mock(AmazonS3.class); Mockito.lenient().when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); - Mockito.when(this.verticaConnectionFactory.getOrCreateConn(nullable(String.class))).thenReturn(connection); Mockito.when(connection.getMetaData()).thenReturn(databaseMetaData); Mockito.when(amazonS3.getRegion()).thenReturn(Region.US_West_2); - this.verticaMetadataHandler = new VerticaMetadataHandler(new LocalKeyFactory(), - verticaConnectionFactory, - secretsManager, - athena, - "spill-bucket", - "spill-prefix", - verticaSchemaUtils, - amazonS3, - com.google.common.collect.ImmutableMap.of()); - this.allocator = new BlockAllocatorImpl(); + this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); + this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); + this.secretsManager = Mockito.mock(AWSSecretsManager.class); + this.athena = Mockito.mock(AmazonAthena.class); + this.verticaMetadataHandler = new VerticaMetadataHandler(databaseConnectionConfig, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of(), amazonS3, verticaSchemaUtils); + this.federatedIdentity = Mockito.mock(FederatedIdentity.class); + this.allocator = new BlockAllocatorImpl(); this.databaseMetaData = this.connection.getMetaData(); verticaMetadataHandlerMocked = Mockito.spy(this.verticaMetadataHandler); } @@ -148,48 +170,49 @@ public void tearDown() allocator.close(); } - @Test - public void doListSchemaNames() throws SQLException + public void doListTables() throws Exception { + String[] schema = {"TABLE_SCHEM", "TABLE_NAME",}; + Object[][] values = {{"testSchema", "testTable1"}}; + List expectedTables = new ArrayList<>(); + expectedTables.add(new TableName("testSchema", "testTable1")); - String[] schema = {"TABLE_SCHEM"}; - Object[][] values = {{"testDB1"}}; - String[] expected = {"testDB1"}; AtomicInteger rowNumber = new AtomicInteger(-1); ResultSet resultSet = mockResultSet(schema, values, rowNumber); - Mockito.when(databaseMetaData.getTables(null,null, null, TABLE_TYPES)).thenReturn(resultSet); + Mockito.when(databaseMetaData.getTables(null, tableName.getSchemaName(), null, new String[]{"TABLE", "VIEW", "EXTERNAL TABLE", "MATERIALIZED VIEW"})).thenReturn(resultSet); Mockito.when(resultSet.next()).thenReturn(true).thenReturn(false); - ListSchemasResponse listSchemasResponse = this.verticaMetadataHandler.doListSchemaNames(this.allocator, - new ListSchemasRequest(this.federatedIdentity, - "testQueryId", "testCatalog")); - Assert.assertArrayEquals(expected, listSchemasResponse.getSchemas().toArray()); + ListTablesResponse listTablesResponse = this.verticaMetadataHandler.doListTables(this.allocator, + new ListTablesRequest(this.federatedIdentity, + "testQueryId", + "testCatalog", + tableName.getSchemaName(), + null, UNLIMITED_PAGE_SIZE_VALUE)); + + Assert.assertArrayEquals(expectedTables.toArray(), listTablesResponse.getTables().toArray()); + } @Test - public void doListTables() throws SQLException { - String[] schema = {"TABLE_SCHEM", "TABLE_NAME", }; - Object[][] values = {{"testSchema", "testTable1"}}; - List expectedTables = new ArrayList<>(); - expectedTables.add(new TableName("testSchema", "testTable1")); + public void doListSchemaNames() throws Exception + { + String[] schema = {"TABLE_SCHEM"}; + Object[][] values = {{"testDB1"}}; + String[] expected = {"testDB1"}; AtomicInteger rowNumber = new AtomicInteger(-1); ResultSet resultSet = mockResultSet(schema, values, rowNumber); - Mockito.when(databaseMetaData.getTables(null, tableName.getSchemaName(), null, TABLE_TYPES)).thenReturn(resultSet); + Mockito.when(databaseMetaData.getTables(null, null, null, TABLE_TYPES)).thenReturn(resultSet); Mockito.when(resultSet.next()).thenReturn(true).thenReturn(false); - ListTablesResponse listTablesResponse = this.verticaMetadataHandler.doListTables(this.allocator, - new ListTablesRequest(this.federatedIdentity, - "testQueryId", - "testCatalog", - tableName.getSchemaName(), - null, UNLIMITED_PAGE_SIZE_VALUE)); - - Assert.assertArrayEquals(expectedTables.toArray(), listTablesResponse.getTables().toArray()); + ListSchemasResponse listSchemasResponse = this.verticaMetadataHandler.doListSchemaNames(this.allocator, + new ListSchemasRequest(this.federatedIdentity, + "testQueryId", "testCatalog")); + Assert.assertArrayEquals(expected, listSchemasResponse.getSchemas().toArray()); } @Test @@ -216,7 +239,7 @@ public void enhancePartitionSchema() @Test public void getPartitions() throws Exception { - Schema tableSchema = SchemaBuilder.newBuilder() + Schema tableSchema = SchemaBuilder.newBuilder() .addIntField("day") .addIntField("month") .addIntField("year") @@ -240,19 +263,14 @@ public void getPartitions() throws Exception constraintsMap.put("year", SortedRangeSet.copyOf(org.apache.arrow.vector.types.Types.MinorType.INT.getType(), ImmutableList.of(Range.greaterThan(allocator, org.apache.arrow.vector.types.Types.MinorType.INT.getType(), 2000)), false)); - - GetTableLayoutRequest req = null; - GetTableLayoutResponse res = null; - - String testSql = "Select * from schema1.table1"; String[] test = new String[]{"Select * from schema1.table1", "Select * from schema1.table1"}; String[] schema = {"TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME"}; Object[][] values = {{"testSchema", "testTable1", "day", "int"}, {"testSchema", "testTable1", "month", "int"}, {"testSchema", "testTable1", "year", "int"}, {"testSchema", "testTable1", "preparedStmt", "varchar"}, - {"testSchema", "testTable1", "queryId", "varchar"}, {"testSchema", "testTable1", "awsRegionSql", "varchar"}}; - int[] types = {Types.INTEGER, Types.INTEGER, Types.INTEGER,Types.VARCHAR, Types.VARCHAR, Types.VARCHAR}; + {"testSchema", "testTable1", "queryId", "varchar"}, {"testSchema", "testTable1", "awsRegionSql", "varchar"}}; + int[] types = {Types.INTEGER, Types.INTEGER, Types.INTEGER, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR}; List expectedTables = new ArrayList<>(); expectedTables.add(new TableName("testSchema", "testTable1")); @@ -265,15 +283,13 @@ public void getPartitions() throws Exception Mockito.lenient().when(queryFactory.createVerticaExportQueryBuilder()).thenReturn(new VerticaExportQueryBuilder(new ST("templateVerticaExportQuery"))); Mockito.when(verticaMetadataHandlerMocked.getS3ExportBucket()).thenReturn("testS3Bucket"); - try { - req = new GetTableLayoutRequest(this.federatedIdentity, "queryId", "default", - new TableName("schema1", "table1"), - new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT), - tableSchema, - partitionCols); - - res = verticaMetadataHandlerMocked.doGetTableLayout(allocator, req); + try (GetTableLayoutRequest req = new GetTableLayoutRequest(this.federatedIdentity, "queryId", "default", + new TableName("schema1", "table1"), + new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT), + tableSchema, + partitionCols); + GetTableLayoutResponse res = verticaMetadataHandlerMocked.doGetTableLayout(allocator, req)) { Block partitions = res.getPartitions(); String actualQueryID = partitions.getFieldReader("queryId").readText().toString(); @@ -294,26 +310,12 @@ public void getPartitions() throws Exception logger.info("doGetTableLayout: partitions[{}]", partitions.getRowCount()); } - finally { - try { - req.close(); - res.close(); - } - catch (Exception ex) { - logger.error("doGetTableLayout: ", ex); - } - } - - logger.info("doGetTableLayout - exit"); - } @Test - public void doGetSplits() throws SQLException { - - logger.info("doGetSplits: enter"); - + public void doGetSplits() throws Exception + { Schema schema = SchemaBuilder.newBuilder() .addIntField("day") .addIntField("month") @@ -383,9 +385,6 @@ public void doGetSplits() throws SQLException { } assertTrue(!response.getSplits().isEmpty()); - - logger.info("doGetSplits: exit"); } - } diff --git a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaSchemaUtilsTest.java b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaSchemaUtilsTest.java index 96062f1852..37f5f7b903 100644 --- a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaSchemaUtilsTest.java +++ b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaSchemaUtilsTest.java @@ -47,10 +47,8 @@ public void setUp() throws SQLException { this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); this.databaseMetaData = Mockito.mock(DatabaseMetaData.class); - VerticaConnectionFactory verticaConnectionFactory = Mockito.mock(VerticaConnectionFactory.class); this.tableName = Mockito.mock(TableName.class); - Mockito.when(verticaConnectionFactory.getOrCreateConn(nullable(String.class))).thenReturn(connection); Mockito.when(connection.getMetaData()).thenReturn(databaseMetaData); }