diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java index 385b18bcd4..9ff862c2df 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java @@ -29,7 +29,7 @@ public final class SnowflakeConstants * This constant limits the number of partitions. The default set to 50. A large number may cause a timeout issue. * We arrived at this number after performance testing with datasets of different size */ - public static final int MAX_PARTITION_COUNT = 1; + public static final int MAX_PARTITION_COUNT = 50; /** * This constant limits the number of records to be returned in a single split. */ diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java index a3c22a04f2..234718760a 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java @@ -58,6 +58,7 @@ import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -73,11 +74,13 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -103,6 +106,9 @@ public class SnowflakeMetadataHandler extends JdbcMetadataHandler "WHERE table_type = 'BASE TABLE'\n" + "AND table_schema= ?\n" + "AND TABLE_NAME = ? "; + static final String SHOW_PRIMARY_KEYS_QUERY = "SHOW PRIMARY KEYS IN "; + static final String PRIMARY_KEY_COLUMN_NAME = "column_name"; + static final String COUNTS_COLUMN_NAME = "COUNTS"; private static final String CASE_UPPER = "upper"; private static final String CASE_LOWER = "lower"; /** @@ -178,6 +184,48 @@ public Schema getPartitionSchema(final String catalogName) .addField(BLOCK_PARTITION_COLUMN_NAME, Types.MinorType.VARCHAR.getType()); return schemaBuilder.build(); } + + private Optional getPrimaryKey(TableName tableName) throws Exception + { + List primaryKeys = new ArrayList(); + try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) { + try (PreparedStatement preparedStatement = connection.prepareStatement(SHOW_PRIMARY_KEYS_QUERY + tableName.getTableName()); + ResultSet rs = preparedStatement.executeQuery()) { + while (rs.next()) { + // Concatenate multiple primary keys if they exist + primaryKeys.add(rs.getString(PRIMARY_KEY_COLUMN_NAME)); + } + } + } + String primaryKey = String.join(", ", primaryKeys); + if (!Strings.isNullOrEmpty(primaryKey) && hasUniquePrimaryKey(tableName, primaryKey)) { + return Optional.of(primaryKey); + } + return Optional.empty(); + } + + /** + * Snowflake does not enforce primary key constraints, so we double-check user has unique primary key + * before partitioning. + */ + private boolean hasUniquePrimaryKey(TableName tableName, String primaryKey) throws Exception + { + try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) { + try (PreparedStatement preparedStatement = connection.prepareStatement("SELECT " + primaryKey + ", count(*) as COUNTS FROM " + tableName.getTableName() + " GROUP BY " + primaryKey + " ORDER BY COUNTS DESC"); + ResultSet rs = preparedStatement.executeQuery()) { + if (rs.next()) { + if (rs.getInt(COUNTS_COLUMN_NAME) == 1) { + // Since it is in descending order and 1 is this first count seen, + // this table has a unique primary key + return true; + } + } + } + } + LOGGER.warn("Primary key ,{}, is not unique. Falling back to single partition...", primaryKey); + return false; + } + /** * Snowflake manual partition logic based upon number of records * @param blockWriter @@ -217,18 +265,19 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest getTabl totalRecordCount = rs.getLong(1); } if (totalRecordCount > 0) { - long pageCount = (long) (Math.ceil(totalRecordCount / MAX_PARTITION_COUNT)); - long partitionRecordCount = (totalRecordCount <= SINGLE_SPLIT_LIMIT_COUNT) ? (long) totalRecordCount : pageCount; - LOGGER.info(" Total Page Count" + partitionRecordCount); - double limit = (int) Math.ceil(totalRecordCount / partitionRecordCount); + Optional primaryKey = getPrimaryKey(getTableLayoutRequest.getTableName()); + long recordsInPartition = (long) (Math.ceil(totalRecordCount / MAX_PARTITION_COUNT)); + long partitionRecordCount = (totalRecordCount <= SINGLE_SPLIT_LIMIT_COUNT || !primaryKey.isPresent()) ? (long) totalRecordCount : recordsInPartition; + LOGGER.info(" Total Page Count: " + partitionRecordCount); + double numberOfPartitions = (int) Math.ceil(totalRecordCount / partitionRecordCount); long offset = 0; /** * Custom pagination based partition logic will be applied with limit and offset clauses. * It will have maximum 50 partitions and number of records in each partition is decided by dividing total number of records by 50 * the partition values we are setting the limit and offset values like p-limit-3000-offset-0 */ - for (int i = 1; i <= limit; i++) { - final String partitionVal = BLOCK_PARTITION_COLUMN_NAME + "-limit-" + partitionRecordCount + "-offset-" + offset; + for (int i = 1; i <= numberOfPartitions; i++) { + final String partitionVal = BLOCK_PARTITION_COLUMN_NAME + "-primary-" + primaryKey.orElse("") + "-limit-" + partitionRecordCount + "-offset-" + offset; LOGGER.info("partitionVal {} ", partitionVal); blockWriter.writeRows((Block block, int rowNum) -> { diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilder.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilder.java index aea0bc9a6b..2b46b579ef 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilder.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilder.java @@ -72,17 +72,24 @@ protected List getPartitionWhereClauses(Split split) @Override protected String appendLimitOffset(Split split) { + String primaryKey = ""; String xLimit = ""; String xOffset = ""; - String partitionVal = split.getProperty(split.getProperties().keySet().iterator().next()); //p-limit-3000-offset-0 + String partitionVal = split.getProperty(split.getProperties().keySet().iterator().next()); //p-primary--limit-3000-offset-0 if (!partitionVal.contains("-")) { return EMPTY_STRING; } else { String[] arr = partitionVal.split("-"); - xLimit = arr[2]; - xOffset = arr[4]; + primaryKey = arr[2]; + xLimit = arr[4]; + xOffset = arr[6]; } - return " limit " + xLimit + " offset " + xOffset; + + // if no primary key, single split only + if (primaryKey.equals("")) { + return ""; + } + return "ORDER BY " + primaryKey + " limit " + xLimit + " offset " + xOffset; } } diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java index b03df0f097..b18fde1cda 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java @@ -100,12 +100,28 @@ public void doGetTableLayout() Mockito.when(this.connection.prepareStatement(SnowflakeMetadataHandler.COUNT_RECORDS_QUERY)).thenReturn(preparedStatement); String[] columns = {"partition"}; int[] types = {Types.VARCHAR}; - Object[][] values = {{"partition : partition-limit-500000-offset-0"},{"partition : partition-limit-500000-offset-500000"}}; + Object[][] values = {{"partition : partition-primary-pkey-limit-500000-offset-0"},{"partition : partition-primary-pkey-limit-500000-offset-500000"}}; ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); Mockito.when(this.connection.getMetaData().getSearchStringEscape()).thenReturn(null); double totalActualRecordCount = 10001; Mockito.when(resultSet.getLong(1)).thenReturn((long) totalActualRecordCount); + + PreparedStatement primaryKeyPreparedStatement = Mockito.mock(PreparedStatement.class); + String[] primaryKeyColumns = new String[] {SnowflakeMetadataHandler.PRIMARY_KEY_COLUMN_NAME}; + String[][] primaryKeyValues = new String[][]{new String[] {"pkey"}}; + ResultSet primaryKeyResultSet = mockResultSet(primaryKeyColumns, primaryKeyValues, new AtomicInteger(-1)); + Mockito.when(this.connection.prepareStatement(SnowflakeMetadataHandler.SHOW_PRIMARY_KEYS_QUERY + "testTable")).thenReturn(primaryKeyPreparedStatement); + Mockito.when(primaryKeyPreparedStatement.executeQuery()).thenReturn(primaryKeyResultSet); + + PreparedStatement countsPreparedStatement = Mockito.mock(PreparedStatement.class); + String GET_PKEY_COUNTS_QUERY = "SELECT pkey, count(*) as COUNTS FROM testTable GROUP BY pkey ORDER BY COUNTS DESC"; + String[] countsColumns = new String[] {"pkey", SnowflakeMetadataHandler.COUNTS_COLUMN_NAME}; + Object[][] countsValues = {{"a", 1}}; + ResultSet countsResultSet = mockResultSet(countsColumns, countsValues, new AtomicInteger(-1)); + Mockito.when(this.connection.prepareStatement(GET_PKEY_COUNTS_QUERY)).thenReturn(countsPreparedStatement); + Mockito.when(countsPreparedStatement.executeQuery()).thenReturn(countsResultSet); + GetTableLayoutResponse getTableLayoutResponse = this.snowflakeMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest); List expectedValues = new ArrayList<>(); for (int i = 0; i < getTableLayoutResponse.getPartitions().getRowCount(); i++) { @@ -120,7 +136,7 @@ public void doGetTableLayout() if (i > 1) { offset = offset + partitionActualRecordCount; } - actualValues.add("[partition : partition-limit-" +partitionActualRecordCount + "-offset-" + offset + "]"); + actualValues.add("[partition : partition-primary-pkey-limit-" +partitionActualRecordCount + "-offset-" + offset + "]"); } Assert.assertEquals((int)limit, getTableLayoutResponse.getPartitions().getRowCount()); Assert.assertEquals(expectedValues, actualValues); @@ -148,12 +164,28 @@ public void doGetTableLayoutSinglePartition() String[] columns = {"partition"}; int[] types = {Types.VARCHAR}; - Object[][] values = {{"partition : partition-limit-500000-offset-0"}}; + Object[][] values = {{"partition : partition-primary-pkey-limit-500000-offset-0"}}; ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); Mockito.when(this.connection.getMetaData().getSearchStringEscape()).thenReturn(null); Mockito.when(resultSet.getLong(1)).thenReturn(1001L); + + PreparedStatement primaryKeyPreparedStatement = Mockito.mock(PreparedStatement.class); + String[] primaryKeyColumns = new String[] {SnowflakeMetadataHandler.PRIMARY_KEY_COLUMN_NAME}; + String[][] primaryKeyValues = new String[][]{new String[] {""}}; + ResultSet primaryKeyResultSet = mockResultSet(primaryKeyColumns, primaryKeyValues, new AtomicInteger(-1)); + Mockito.when(this.connection.prepareStatement(SnowflakeMetadataHandler.SHOW_PRIMARY_KEYS_QUERY + "testTable")).thenReturn(primaryKeyPreparedStatement); + Mockito.when(primaryKeyPreparedStatement.executeQuery()).thenReturn(primaryKeyResultSet); + + PreparedStatement countsPreparedStatement = Mockito.mock(PreparedStatement.class); + String GET_PKEY_COUNTS_QUERY = "SELECT pkey, count(*) as COUNTS FROM testTable GROUP BY pkey ORDER BY COUNTS DESC"; + String[] countsColumns = new String[] {"pkey", SnowflakeMetadataHandler.COUNTS_COLUMN_NAME}; + Object[][] countsValues = {{"a", 1}}; + ResultSet countsResultSet = mockResultSet(countsColumns, countsValues, new AtomicInteger(-1)); + Mockito.when(this.connection.prepareStatement(GET_PKEY_COUNTS_QUERY)).thenReturn(countsPreparedStatement); + Mockito.when(countsPreparedStatement.executeQuery()).thenReturn(countsResultSet); + GetTableLayoutResponse getTableLayoutResponse = this.snowflakeMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest); Assert.assertEquals(values.length, getTableLayoutResponse.getPartitions().getRowCount()); @@ -162,7 +194,7 @@ public void doGetTableLayoutSinglePartition() for (int i = 0; i < getTableLayoutResponse.getPartitions().getRowCount(); i++) { expectedValues.add(BlockUtils.rowToString(getTableLayoutResponse.getPartitions(), i)); } - Assert.assertEquals(expectedValues, Arrays.asList("[partition : partition-limit-1001-offset-0]")); + Assert.assertEquals(expectedValues, Arrays.asList("[partition : partition-primary--limit-1001-offset-0]")); SchemaBuilder expectedSchemaBuilder = SchemaBuilder.newBuilder(); expectedSchemaBuilder.addField(FieldBuilder.newBuilder("partition", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); @@ -194,11 +226,27 @@ public void doGetTableLayoutMaxPartition() long offset = 0; String[] columns = {"partition"}; int[] types = {Types.VARCHAR}; - Object[][] values = {{"partition : partition-limit-500000-offset-0"}}; + Object[][] values = {{"partition : partition-primary-pkey-limit-500000-offset-0"}}; ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); Mockito.when(this.connection.getMetaData().getSearchStringEscape()).thenReturn(null); Mockito.when(resultSet.getLong(1)).thenReturn((long)totalActualRecordCount); + + PreparedStatement primaryKeyPreparedStatement = Mockito.mock(PreparedStatement.class); + String[] primaryKeyColumns = new String[] {SnowflakeMetadataHandler.PRIMARY_KEY_COLUMN_NAME}; + String[][] primaryKeyValues = new String[][]{new String[] {"pkey"}}; + ResultSet primaryKeyResultSet = mockResultSet(primaryKeyColumns, primaryKeyValues, new AtomicInteger(-1)); + Mockito.when(this.connection.prepareStatement(SnowflakeMetadataHandler.SHOW_PRIMARY_KEYS_QUERY + "testTable")).thenReturn(primaryKeyPreparedStatement); + Mockito.when(primaryKeyPreparedStatement.executeQuery()).thenReturn(primaryKeyResultSet); + + PreparedStatement countsPreparedStatement = Mockito.mock(PreparedStatement.class); + String GET_PKEY_COUNTS_QUERY = "SELECT pkey, count(*) as COUNTS FROM testTable GROUP BY pkey ORDER BY COUNTS DESC"; + String[] countsColumns = new String[] {"pkey", SnowflakeMetadataHandler.COUNTS_COLUMN_NAME}; + Object[][] countsValues = {{"a", 1}}; + ResultSet countsResultSet = mockResultSet(countsColumns, countsValues, new AtomicInteger(-1)); + Mockito.when(this.connection.prepareStatement(GET_PKEY_COUNTS_QUERY)).thenReturn(countsPreparedStatement); + Mockito.when(countsPreparedStatement.executeQuery()).thenReturn(countsResultSet); + GetTableLayoutResponse getTableLayoutResponse = this.snowflakeMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest); List actualValues = new ArrayList<>(); List expectedValues = new ArrayList<>(); @@ -209,7 +257,7 @@ public void doGetTableLayoutMaxPartition() if (i > 1) { offset = offset + partitionActualRecordCount; } - actualValues.add("[partition : partition-limit-" +partitionActualRecordCount + "-offset-" + offset + "]"); + actualValues.add("[partition : partition-primary-pkey-limit-" +partitionActualRecordCount + "-offset-" + offset + "]"); } Assert.assertEquals(expectedValues,actualValues); SchemaBuilder expectedSchemaBuilder = SchemaBuilder.newBuilder(); diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilderTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilderTest.java index c373fed085..847b55cea5 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilderTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilderTest.java @@ -41,7 +41,7 @@ public void testQueryBuilderNew() Split split = Mockito.mock(Split.class); SnowflakeQueryStringBuilder builder = new SnowflakeQueryStringBuilder(SNOWFLAKE_QUOTE_CHARACTER, new SnowflakeFederationExpressionParser(SNOWFLAKE_QUOTE_CHARACTER)); Mockito.when(split.getProperties()).thenReturn(Collections.singletonMap("partition", "p0")); - Mockito.when(split.getProperty(Mockito.eq("partition"))).thenReturn("p1-p2-p3-p4-p5"); + Mockito.when(split.getProperty(Mockito.eq("partition"))).thenReturn("p1-p2-p3-p4-p5-p6-p7"); builder.getFromClauseWithSplit("default", "", "table", split); builder.appendLimitOffset(split); }