Skip to content

Commit

Permalink
Transition to AES-GCM for UDFs (#708)
Browse files Browse the repository at this point in the history
Co-authored-by: Abdul Al-Faraj <[email protected]>
Co-authored-by: AbdulRehman Faraj <[email protected]>
  • Loading branch information
3 people authored Sep 11, 2023
1 parent b5b2a1d commit 911b8eb
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,31 @@

import com.amazonaws.athena.connector.lambda.handlers.UserDefinedFunctionHandler;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.commons.codec.binary.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Map;
import java.util.Base64;

public class ExampleUserDefinedFuncHandler
extends UserDefinedFunctionHandler
{
private static final Logger logger = LoggerFactory.getLogger(ExampleUserDefinedFuncHandler.class);

private static final String SOURCE_TYPE = "custom";
public static final int GCM_IV_LENGTH = 12;
public static final int GCM_TAG_LENGTH = 16;

public ExampleUserDefinedFuncHandler()
{
Expand Down Expand Up @@ -114,31 +119,61 @@ public String decrypt(String payload)
}

/**
* This is an extremely POOR usage of AES-GCM and is only mean to illustrate how one could
* This usage of AES-GCM and is only meant to illustrate how one could
* use a UDF for masking a field using encryption. In production scenarios we would recommend
* using AWS KMS for Key Management and a strong cipher like AES-GCM.
*
* @param text The text to decrypt.
* @param ciphertext The text to decrypt.
* @param secretKey The password/key to use to decrypt the text.
* @return The decrypted text.
*/
@VisibleForTesting
protected String symmetricDecrypt(String text, String secretKey)
throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException, BadPaddingException,
IllegalBlockSizeException
public String symmetricDecrypt(String ciphertext, String secretKey)
{
Cipher cipher;
String encryptedString;
byte[] encryptText;
byte[] raw;
SecretKeySpec skeySpec;
raw = Base64.decodeBase64(secretKey);
skeySpec = new SecretKeySpec(raw, "AES");
encryptText = Base64.decodeBase64(text);
cipher = Cipher.getInstance("AES");
cipher.init(Cipher.DECRYPT_MODE, skeySpec);
encryptedString = new String(cipher.doFinal(encryptText));
return encryptedString;
if (ciphertext == null) {
return null;
}
byte[] plaintextKey = Base64.getDecoder().decode(secretKey);

try {
byte[] encryptedContent = Base64.getDecoder().decode(ciphertext.getBytes());
// extract IV from first GCM_IV_LENGTH bytes of ciphertext
Cipher cipher = getCipher(Cipher.DECRYPT_MODE, plaintextKey, getGCMSpecDecryption(encryptedContent));
byte[] plainTextBytes = cipher.doFinal(encryptedContent, GCM_IV_LENGTH, encryptedContent.length - GCM_IV_LENGTH);
return new String(plainTextBytes);
}
catch (IllegalBlockSizeException | BadPaddingException e) {
throw new RuntimeException(e);
}
}

private static GCMParameterSpec getGCMSpecDecryption(byte[] encryptedText)
{
return new GCMParameterSpec(GCM_TAG_LENGTH * Byte.SIZE, encryptedText, 0, GCM_IV_LENGTH);
}

static GCMParameterSpec getGCMSpecEncryption()
{
byte[] iv = new byte[GCM_IV_LENGTH];
SecureRandom random = new SecureRandom();
random.nextBytes(iv);

return new GCMParameterSpec(GCM_TAG_LENGTH * Byte.SIZE, iv);
}

static Cipher getCipher(int cipherMode, byte[] plainTextDataKey, GCMParameterSpec gcmParameterSpec)
{
try {
Cipher cipher = Cipher.getInstance("AES_256/GCM/NoPadding");
SecretKeySpec skeySpec = new SecretKeySpec(plainTextDataKey, "AES");

cipher.init(cipherMode, skeySpec, gcmParameterSpec);
return cipher;
}
catch (NoSuchPaddingException | NoSuchAlgorithmException | InvalidKeyException |
InvalidAlgorithmParameterException e) {
throw new RuntimeException(e);
}
}

/**
Expand All @@ -151,7 +186,8 @@ protected String symmetricDecrypt(String text, String secretKey)
@VisibleForTesting
protected String getEncryptionKey()
{
//must be exactly 24 chars or the KeySpec will fail. In general this is a poor, but simple, way to store the key.
return "AMzDLG4D039Km2IxIzQwfg==";
// The algorithm used requires 32 Byte Key!
// Can be generated for testing using `openssl rand -base64 32`
return "i5YnyBO4gJKWuIQ+gjuJjcJ/5kUph9pmYFUbW7zf3PE=";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
*/
package com.amazonaws.athena.connectors.example;

import com.amazonaws.athena.connector.lambda.handlers.UserDefinedFunctionHandler;
import org.apache.commons.codec.binary.Base64;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -29,25 +27,21 @@
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;

import java.io.BufferedWriter;
import java.io.FileWriter;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.InvalidParameterSpecException;
import java.security.spec.KeySpec;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import static com.amazonaws.athena.connectors.example.ExampleUserDefinedFuncHandler.GCM_IV_LENGTH;
import static org.junit.Assert.*;

public class ExampleUserDefinedFuncHandlerTest
Expand Down Expand Up @@ -86,7 +80,8 @@ public void decrypt()
return;
}

assertTrue(handler.decrypt("0UTIXoWnKqtQe8y+BSHNmdEXmWfQalRQH60pobsgwws=").equals("SecretText-1755604178"));
String encryptedValue = symmetricEncrypt("SecretText-1755604178", handler.getEncryptionKey());
assertTrue(handler.decrypt(encryptedValue).equals("SecretText-1755604178"));
}

@Test
Expand All @@ -100,33 +95,32 @@ public void testEncryption()
String encrypted = symmetricEncrypt(value, key);
String actual = handler.symmetricDecrypt(encrypted, key);
assertEquals(value, actual);

//TODO: find and test the sample_data file automatically
//NOTE!!!!!! _______IF_THIS_REQUIRES_A_CHANGE_THEN_YOU_NEED_TO_UPDATE_THE_SAMPLE_DATA.CSV___________
assertTrue(handler.symmetricDecrypt("0UTIXoWnKqtQe8y+BSHNmdEXmWfQalRQH60pobsgwws=", key).equals("SecretText-1755604178"));
String encryptedValue = symmetricEncrypt("SecretText-1755604178", key);
assertTrue(handler.symmetricDecrypt(encryptedValue, key).equals("SecretText-1755604178"));
}

/**
* Used to test the decrypt function in the handler.
* This example is taken from the UDF handle example
*/
private static String symmetricEncrypt(String text, String secretKey)
private static String symmetricEncrypt(String plaintext, String secretKey)
{
byte[] raw;
String encryptedString;
SecretKeySpec skeySpec;
byte[] encryptText = text.getBytes();
Cipher cipher;
try {
raw = Base64.decodeBase64(secretKey);
skeySpec = new SecretKeySpec(raw, "AES");
cipher = Cipher.getInstance("AES");
cipher.init(Cipher.ENCRYPT_MODE, skeySpec);
encryptedString = Base64.encodeBase64String(cipher.doFinal(encryptText));
byte[] plaintextKey = Base64.getDecoder().decode(secretKey);
Cipher cipher = ExampleUserDefinedFuncHandler.getCipher(Cipher.ENCRYPT_MODE, plaintextKey, ExampleUserDefinedFuncHandler.getGCMSpecEncryption());
byte[] encryptedContent = cipher.doFinal(plaintext.getBytes());
// prepend ciphertext with IV
ByteBuffer byteBuffer = ByteBuffer.allocate(GCM_IV_LENGTH + encryptedContent.length);
byteBuffer.put(cipher.getIV());
byteBuffer.put(encryptedContent);

byte[] encodedContent = Base64.getEncoder().encode(byteBuffer.array());
return new String(encodedContent);
}
catch (Exception e) {
e.printStackTrace();
return "Error";
catch (IllegalBlockSizeException | BadPaddingException e) {
throw new RuntimeException(e);
}
return encryptedString;
}
}
17 changes: 15 additions & 2 deletions athena-udfs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ This would return result 'StringToBeCompressed'.

3. "encrypt": encrypt the data with a data key stored in AWS Secrets Manager*

Before testing this query, you would need to create a secret in AWS Secrets Manager. Make sure to use "DefaultEncryptionKey". If you choose to use your KMS key, you would need to update ./athena-udfs.yaml to allow access to your KMS key. Remove all the json brackets and store a base64 encoded string as data key. Sample data is like `AQIDBAUGBwgJAAECAwQFBg==`.
Before testing this query, you would need to create a secret in AWS Secrets Manager. Make sure to use "DefaultEncryptionKey". If you choose to use your KMS key, you would need to update ./athena-udfs.yaml to allow access to your KMS key. Remove all the json brackets and store a base64 encoded string as data key. Sample data is like `i5YnyBO4gJKWuIQ+gjuJjcJ/5kUph9pmYFUbW7zf3PE=`.

Example query:

Expand All @@ -30,7 +30,7 @@ Example query:

Example query:

`USING EXTERNAL FUNCTION decrypt(col VARCHAR, secretName VARCHAR) RETURNS VARCHAR LAMBDA '<lambda name>' SELECT decrypt('tEgyixKs1d0RsnL51ypMgg==', 'my_secret_name');`
`USING EXTERNAL FUNCTION decrypt(col VARCHAR, secretName VARCHAR) RETURNS VARCHAR LAMBDA '<lambda name>' SELECT decrypt('G/VP2sbMb7d4zE2HVl2XkiB5xUHpszlEjccEBsTVji209IaCjg==', 'my_secret_name');`

*To use the Athena Federated Query feature with AWS Secrets Manager, the VPC connected to your Lambda function should have [internet access](https://aws.amazon.com/premiumsupport/knowledge-center/internet-access-lambda-function/) or a [VPC endpoint](https://docs.aws.amazon.com/secretsmanager/latest/userguide/vpc-endpoint-overview.html#vpc-endpoint-create) to connect to Secrets Manager.

Expand All @@ -52,6 +52,19 @@ To use this connector in your queries, navigate to AWS Serverless Application Re
3. From the athena-udfs dir, run `sam deploy --template-file athena-udfs.yaml -g` and follow the guided prompt to synthesize your CloudFormation template and create your IAM policies and Lambda function.
4. Try using your UDF(s) in a query.

## Migrating To V2
This UDF includes a sample encryption/decryption method to showcase the benefits of integrating UDFs into your queries. If you were using the prior version of this UDF with AES-based encryption and wish to transition to the new version, please follow these steps:

1. Deploy the new connector with new name (say v2) as shown in previous example.
2. Use the previously deployed connector to decrypt, and the new one to encrypt:

```
USING EXTERNAL FUNCTION decrypt(col VARCHAR, secretName VARCHAR) RETURNS VARCHAR LAMBDA 'athena_udf_v1',
EXTERNAL FUNCTION encrypt(col VARCHAR, secretName VARCHAR) RETURNS VARCHAR LAMBDA 'athena_udf_v2'
SELECT encrypt(t.plaintext, 'SOME_SECRET')
FROM (SELECT decrypt('PREVIOUSLY_ENCRYPTED_MESSAGE', 'SOME_SECRET') as plaintext) as t
```

## License

This project is licensed under the Apache-2.0 License.
4 changes: 2 additions & 2 deletions athena-udfs/athena-udfs.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Transform: 'AWS::Serverless-2016-10-31'
Metadata:
'AWS::ServerlessRepo::Application':
Name: AthenaUserDefinedFunctions
Name: AthenaUserDefinedFunctionsV2
Description: 'This connector enables Amazon Athena to leverage common UDFs made available via Lambda.'
Author: 'default author'
SpdxLicenseId: Apache-2.0
Expand Down Expand Up @@ -52,4 +52,4 @@ Resources:
- secretsmanager:GetSecretValue
Effect: Allow
Resource: !Sub 'arn:${AWS::Partition}:secretsmanager:*:*:secret:${SecretNameOrPrefix}'
Version: '2012-10-17'
Version: '2012-10-17'
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
Expand All @@ -44,6 +48,8 @@ public class AthenaUDFHandler
extends UserDefinedFunctionHandler
{
private static final String SOURCE_TYPE = "athena_common_udfs";
public static final int GCM_IV_LENGTH = 12;
public static final int GCM_TAG_LENGTH = 16; // max allowable

private final CachableSecretsManager cachableSecretsManager;

Expand Down Expand Up @@ -166,9 +172,10 @@ public String decrypt(String ciphertext, String secretName)
byte[] plaintextKey = Base64.getDecoder().decode(secretString);

try {
Cipher cipher = getCipher(Cipher.DECRYPT_MODE, plaintextKey);
byte[] encryptedContent = Base64.getDecoder().decode(ciphertext.getBytes());
byte[] plainTextBytes = cipher.doFinal(encryptedContent);
// extract IV from first GCM_IV_LENGTH bytes of ciphertext
Cipher cipher = getCipher(Cipher.DECRYPT_MODE, plaintextKey, getGCMSpecDecryption(encryptedContent));
byte[] plainTextBytes = cipher.doFinal(encryptedContent, GCM_IV_LENGTH, encryptedContent.length - GCM_IV_LENGTH);
return new String(plainTextBytes);
}
catch (IllegalBlockSizeException | BadPaddingException e) {
Expand Down Expand Up @@ -196,25 +203,44 @@ public String encrypt(String plaintext, String secretName)
byte[] plaintextKey = Base64.getDecoder().decode(secretString);

try {
Cipher cipher = getCipher(Cipher.ENCRYPT_MODE, plaintextKey);
Cipher cipher = getCipher(Cipher.ENCRYPT_MODE, plaintextKey, getGCMSpecEncryption());
byte[] encryptedContent = cipher.doFinal(plaintext.getBytes());
byte[] encodedContent = Base64.getEncoder().encode(encryptedContent);
// prepend ciphertext with IV
ByteBuffer byteBuffer = ByteBuffer.allocate(GCM_IV_LENGTH + encryptedContent.length);
byteBuffer.put(cipher.getIV());
byteBuffer.put(encryptedContent);

byte[] encodedContent = Base64.getEncoder().encode(byteBuffer.array());
return new String(encodedContent);
}
catch (IllegalBlockSizeException | BadPaddingException e) {
throw new RuntimeException(e);
}
}

private Cipher getCipher(int cipherMode, byte[] plainTextDataKey)
private static GCMParameterSpec getGCMSpecDecryption(byte[] encryptedText)
{
return new GCMParameterSpec(GCM_TAG_LENGTH * Byte.SIZE, encryptedText, 0, GCM_IV_LENGTH);
}

static GCMParameterSpec getGCMSpecEncryption()
{
byte[] iv = new byte[GCM_IV_LENGTH];
SecureRandom random = new SecureRandom();
random.nextBytes(iv);

return new GCMParameterSpec(GCM_TAG_LENGTH * Byte.SIZE, iv);
}

static Cipher getCipher(int cipherMode, byte[] plainTextDataKey, GCMParameterSpec gcmParameterSpec)
{
try {
Cipher cipher = Cipher.getInstance("AES");
Cipher cipher = Cipher.getInstance("AES_256/GCM/NoPadding");
SecretKeySpec skeySpec = new SecretKeySpec(plainTextDataKey, "AES");
cipher.init(cipherMode, skeySpec);
cipher.init(cipherMode, skeySpec, gcmParameterSpec);
return cipher;
}
catch (NoSuchPaddingException | NoSuchAlgorithmException | InvalidKeyException e) {
catch (NoSuchPaddingException | NoSuchAlgorithmException | InvalidKeyException | InvalidAlgorithmParameterException e) {
throw new RuntimeException(e);
}
}
Expand Down
Loading

0 comments on commit 911b8eb

Please sign in to comment.