Skip to content

Commit 61fc100

Browse files
committed
Add regression tests for default KAS key merging and segment integrity hashing
Signed-off-by: Paul Flynn <pflynn@virtru.com>
1 parent abdeb69 commit 61fc100

File tree

2 files changed

+356
-0
lines changed

2 files changed

+356
-0
lines changed

sdk/experimental/tdf/keysplit/xor_splitter_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,73 @@ func TestXORSplitter_ComplexScenarios(t *testing.T) {
526526
assert.True(t, found, "Should find split with multiple KAS URLs")
527527
})
528528
}
529+
530+
// TestXORSplitter_DefaultKASMergedForURIOnlyGrant is a regression test
531+
// ensuring that when an attribute grant references a KAS URL without
532+
// embedding the public key (URI-only legacy grant), the default KAS's
533+
// full public key info is merged into the result. Without the merge fix
534+
// in GenerateSplits, collectAllPublicKeys returns an incomplete map and
535+
// key wrapping fails.
536+
func TestXORSplitter_DefaultKASMergedForURIOnlyGrant(t *testing.T) {
537+
defaultKAS := &policy.SimpleKasKey{
538+
KasUri: kasUs,
539+
PublicKey: &policy.SimpleKasPublicKey{
540+
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
541+
Kid: "default-kid",
542+
Pem: mockRSAPublicKey1,
543+
},
544+
}
545+
splitter := NewXORSplitter(WithDefaultKAS(defaultKAS))
546+
547+
dek := make([]byte, 32)
548+
_, err := rand.Read(dek)
549+
require.NoError(t, err)
550+
551+
// Create an attribute whose grant references kasUs by URI only (no KasKeys).
552+
attr := createMockValue("https://test.com/attr/level/value/secret", "", "", policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)
553+
attr.Grants = []*policy.KeyAccessServer{
554+
{Uri: kasUs}, // URI-only, no embedded public key
555+
}
556+
557+
result, err := splitter.GenerateSplits(t.Context(), []*policy.Value{attr}, dek)
558+
require.NoError(t, err)
559+
require.NotNil(t, result)
560+
561+
// The default KAS public key must be merged into the result.
562+
require.Contains(t, result.KASPublicKeys, kasUs, "default KAS key should be merged for URI-only grant")
563+
pubKey := result.KASPublicKeys[kasUs]
564+
assert.Equal(t, "default-kid", pubKey.KID)
565+
assert.Equal(t, mockRSAPublicKey1, pubKey.PEM)
566+
assert.Equal(t, "rsa:2048", pubKey.Algorithm)
567+
}
568+
569+
// TestXORSplitter_DefaultKASDoesNotOverwriteExistingKey verifies that when
570+
// an attribute grant already embeds a full public key for the same KAS URL
571+
// as the default, the grant's key is preserved and not overwritten.
572+
func TestXORSplitter_DefaultKASDoesNotOverwriteExistingKey(t *testing.T) {
573+
defaultKAS := &policy.SimpleKasKey{
574+
KasUri: kasUs,
575+
PublicKey: &policy.SimpleKasPublicKey{
576+
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
577+
Kid: "default-kid",
578+
Pem: mockRSAPublicKey1,
579+
},
580+
}
581+
splitter := NewXORSplitter(WithDefaultKAS(defaultKAS))
582+
583+
dek := make([]byte, 32)
584+
_, err := rand.Read(dek)
585+
require.NoError(t, err)
586+
587+
// Create an attribute with a fully-embedded grant for the same KAS URL
588+
// but with a different KID.
589+
attr := createMockValue("https://test.com/attr/level/value/secret", kasUs, "grant-kid", policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)
590+
591+
result, err := splitter.GenerateSplits(t.Context(), []*policy.Value{attr}, dek)
592+
require.NoError(t, err)
593+
require.NotNil(t, result)
594+
595+
require.Contains(t, result.KASPublicKeys, kasUs)
596+
pubKey := result.KASPublicKeys[kasUs]
597+
assert.Equal(t, "grant-kid", pubKey.KID, "grant's key should not be overwritten by default KAS")
598+
}

sdk/experimental/tdf/writer_test.go

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ package tdf
44

55
import (
66
"bytes"
7+
"crypto/hmac"
78
"crypto/rand"
9+
"crypto/sha256"
10+
"encoding/base64"
811
"encoding/json"
12+
"io"
913
"os"
1014
"path/filepath"
1115
"runtime"
@@ -14,6 +18,7 @@ import (
1418

1519
"github.com/opentdf/platform/lib/ocrypto"
1620
"github.com/opentdf/platform/protocol/go/policy"
21+
"github.com/opentdf/platform/sdk/internal/zipstream"
1722
"github.com/stretchr/testify/assert"
1823
"github.com/stretchr/testify/require"
1924
"github.com/xeipuuv/gojsonschema"
@@ -58,6 +63,8 @@ func TestWriterEndToEnd(t *testing.T) {
5863
{"GetManifestIncludesInitialPolicy", testGetManifestIncludesInitialPolicy},
5964
{"SparseIndicesInOrder", testSparseIndicesInOrder},
6065
{"SparseIndicesOutOfOrder", testSparseIndicesOutOfOrder},
66+
{"SegmentHashCoversNonceAndCipher", testSegmentHashCoversNonceAndCipher},
67+
{"FinalizeWithURIOnlyGrant", testFinalizeWithURIOnlyGrant},
6168
}
6269

6370
for _, tc := range testCases {
@@ -188,6 +195,95 @@ func testSparseIndicesOutOfOrder(t *testing.T) {
188195
assert.Equal(t, int64(expectedPlain), fin.TotalSize)
189196
}
190197

198+
// testSegmentHashCoversNonceAndCipher is a regression test ensuring that the
199+
// HS256 segment hash covers nonce+ciphertext, not ciphertext alone.
200+
//
201+
// The standard SDK's Encrypt() returns nonce prepended to ciphertext and
202+
// hashes that combined blob; the experimental SDK's EncryptInPlace() returns
203+
// them separately, so the writer must concatenate before hashing.
204+
//
205+
// Only HS256 is tested because GMAC extracts the last 16 bytes of data as
206+
// the tag — stripping the nonce prefix doesn't change the tail, so GMAC is
207+
// structurally unable to detect a nonce-exclusion regression.
208+
func testSegmentHashCoversNonceAndCipher(t *testing.T) {
209+
ctx := t.Context()
210+
211+
writer, err := NewWriter(ctx, WithSegmentIntegrityAlgorithm(HS256))
212+
require.NoError(t, err)
213+
214+
testData := []byte("segment hash regression test payload")
215+
result, err := writer.WriteSegment(ctx, 0, testData)
216+
require.NoError(t, err)
217+
218+
// Read all bytes from the TDFData reader to get the full segment output.
219+
allBytes, err := io.ReadAll(result.TDFData)
220+
require.NoError(t, err)
221+
222+
// The last EncryptedSize bytes are the encrypted segment (nonce + cipher).
223+
// Everything before that is the ZIP local file header.
224+
encryptedData := allBytes[int64(len(allBytes))-result.EncryptedSize:]
225+
226+
// Positive assertion: independently compute HMAC-SHA256 over nonce+cipher
227+
// using crypto/hmac directly (not the production calculateSignature path)
228+
// and verify it matches the stored hash.
229+
mac := hmac.New(sha256.New, writer.dek)
230+
mac.Write(encryptedData)
231+
expectedHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
232+
assert.Equal(t, expectedHash, result.Hash, "hash should equal independent HMAC-SHA256 over nonce+ciphertext")
233+
234+
// Negative / regression assertion: independently compute HMAC-SHA256 over
235+
// cipher-only (stripping the 12-byte GCM nonce). If someone reverts the
236+
// fix so only cipher is hashed, the stored hash would match this value.
237+
cipherOnly := encryptedData[ocrypto.GcmStandardNonceSize:]
238+
wrongMac := hmac.New(sha256.New, writer.dek)
239+
wrongMac.Write(cipherOnly)
240+
wrongHash := base64.StdEncoding.EncodeToString(wrongMac.Sum(nil))
241+
assert.NotEqual(t, wrongHash, result.Hash, "hash must NOT match cipher-only (nonce must be included)")
242+
}
243+
244+
// testFinalizeWithURIOnlyGrant is an end-to-end regression test ensuring
245+
// that Finalize succeeds when attribute grants reference a KAS URL without
246+
// embedding the public key (URI-only legacy grants). The default KAS must
247+
// supply the missing key information. Without the merge fix in
248+
// GenerateSplits, key wrapping fails with "no valid key access objects".
249+
func testFinalizeWithURIOnlyGrant(t *testing.T) {
250+
ctx := t.Context()
251+
252+
defaultKAS := &policy.SimpleKasKey{
253+
KasUri: testKAS1,
254+
PublicKey: &policy.SimpleKasPublicKey{
255+
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
256+
Kid: "default-kid",
257+
Pem: mockRSAPublicKey1,
258+
},
259+
}
260+
261+
writer, err := NewWriter(ctx, WithDefaultKASForWriter(defaultKAS))
262+
require.NoError(t, err)
263+
264+
_, err = writer.WriteSegment(ctx, 0, []byte("uri-only grant test"))
265+
require.NoError(t, err)
266+
267+
// Create attribute with a URI-only grant (no KasKeys / no embedded public key).
268+
uriOnlyAttr := createTestAttributeWithRule(
269+
"https://example.com/attr/Level/value/Secret",
270+
"", "", // no KAS URL → no grants added by helper
271+
policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF,
272+
)
273+
uriOnlyAttr.Grants = []*policy.KeyAccessServer{
274+
{Uri: testKAS1}, // URI-only, no KasKeys
275+
}
276+
277+
fin, err := writer.Finalize(ctx, WithAttributeValues([]*policy.Value{uriOnlyAttr}))
278+
require.NoError(t, err, "Finalize must succeed when default KAS fills in missing key for URI-only grant")
279+
require.NotNil(t, fin.Manifest)
280+
281+
// Verify the key access object references the right KAS
282+
require.GreaterOrEqual(t, len(fin.Manifest.KeyAccessObjs), 1)
283+
assert.Equal(t, testKAS1, fin.Manifest.KeyAccessObjs[0].KasURL)
284+
assert.NotEmpty(t, fin.Manifest.KeyAccessObjs[0].WrappedKey)
285+
}
286+
191287
// testInitialAttributesOnWriter verifies that attributes/KAS supplied at
192288
// NewWriter are used by Finalize when not overridden, and that Finalize
193289
// overrides take precedence.
@@ -996,6 +1092,196 @@ func BenchmarkTDFCreation(b *testing.B) {
9961092
})
9971093
}
9981094

1095+
// TestCrossDecryptWithSharedDEK verifies that the experimental writer's
1096+
// encryption format is compatible with the production SDK by injecting a
1097+
// shared DEK into the experimental writer and cross-validating with the
1098+
// same crypto primitives the production reader uses:
1099+
//
1100+
// - ocrypto.AesGcm.Encrypt() (production encrypt: returns nonce||ciphertext)
1101+
// - ocrypto.AesGcm.Decrypt() (production decrypt: expects nonce||ciphertext)
1102+
// - HMAC-SHA256(dek, nonce||ciphertext) (production segment hash verification)
1103+
//
1104+
// The test also assembles a complete TDF ZIP from the experimental writer
1105+
// and parses it with zipstream.TDFReader (the same reader the production
1106+
// SDK uses internally) to verify structural compatibility.
1107+
func TestCrossDecryptWithSharedDEK(t *testing.T) {
1108+
ctx := t.Context()
1109+
1110+
sharedDEK, err := ocrypto.RandomBytes(kKeySize)
1111+
require.NoError(t, err)
1112+
1113+
t.Run("SingleSegment", func(t *testing.T) {
1114+
original := []byte("Cross-SDK format compatibility: single segment")
1115+
1116+
sharedCipher, err := ocrypto.NewAESGcm(sharedDEK)
1117+
require.NoError(t, err)
1118+
1119+
// --- Experimental writer with injected DEK ---
1120+
writer, err := NewWriter(ctx, WithSegmentIntegrityAlgorithm(HS256))
1121+
require.NoError(t, err)
1122+
writer.dek = sharedDEK
1123+
writer.block, err = ocrypto.NewAESGcm(sharedDEK)
1124+
require.NoError(t, err)
1125+
1126+
expInput := append([]byte(nil), original...)
1127+
expResult, err := writer.WriteSegment(ctx, 0, expInput)
1128+
require.NoError(t, err)
1129+
1130+
allBytes, err := io.ReadAll(expResult.TDFData)
1131+
require.NoError(t, err)
1132+
expEncrypted := allBytes[int64(len(allBytes))-expResult.EncryptedSize:]
1133+
1134+
// --- Production-style encrypt with the same DEK ---
1135+
prodEncrypted, err := sharedCipher.Encrypt(original)
1136+
require.NoError(t, err)
1137+
1138+
// --- Cross-decrypt: production Decrypt() on experimental output ---
1139+
decryptedFromExp, err := sharedCipher.Decrypt(expEncrypted)
1140+
require.NoError(t, err, "production Decrypt must handle experimental output")
1141+
assert.Equal(t, decryptedFromExp, original)
1142+
1143+
// --- Cross-decrypt: Decrypt() on production output ---
1144+
decryptedFromProd, err := sharedCipher.Decrypt(prodEncrypted)
1145+
require.NoError(t, err)
1146+
assert.Equal(t, original, decryptedFromProd)
1147+
1148+
// --- Hash cross-verification ---
1149+
// The production reader computes HMAC-SHA256(payloadKey, encryptedSegment)
1150+
// and compares it against the manifest segment hash. Verify the
1151+
// experimental writer's stored hash matches this computation.
1152+
mac := hmac.New(sha256.New, sharedDEK)
1153+
mac.Write(expEncrypted)
1154+
independentHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
1155+
assert.Equal(t, expResult.Hash, independentHash,
1156+
"experimental hash must equal production-style HMAC-SHA256")
1157+
1158+
// Verify production-encrypted data also hashes correctly
1159+
prodMac := hmac.New(sha256.New, sharedDEK)
1160+
prodMac.Write(prodEncrypted)
1161+
prodHash := base64.StdEncoding.EncodeToString(prodMac.Sum(nil))
1162+
assert.NotEmpty(t, prodHash)
1163+
// Both hashes are valid HMACs but differ because nonces are random
1164+
assert.NotEqual(t, independentHash, prodHash)
1165+
})
1166+
1167+
t.Run("MultiSegment", func(t *testing.T) {
1168+
sharedCipher, err := ocrypto.NewAESGcm(sharedDEK)
1169+
require.NoError(t, err)
1170+
1171+
writer, err := NewWriter(ctx, WithSegmentIntegrityAlgorithm(HS256))
1172+
require.NoError(t, err)
1173+
writer.dek = sharedDEK
1174+
writer.block, err = ocrypto.NewAESGcm(sharedDEK)
1175+
require.NoError(t, err)
1176+
1177+
segments := [][]byte{
1178+
[]byte("segment zero"),
1179+
[]byte("segment one with longer content for variety"),
1180+
[]byte("s2"),
1181+
}
1182+
1183+
for i, original := range segments {
1184+
input := append([]byte(nil), original...)
1185+
result, err := writer.WriteSegment(ctx, i, input)
1186+
require.NoError(t, err)
1187+
1188+
raw, err := io.ReadAll(result.TDFData)
1189+
require.NoError(t, err)
1190+
encrypted := raw[int64(len(raw))-result.EncryptedSize:]
1191+
1192+
// Cross-decrypt each segment with production-style Decrypt
1193+
decrypted, err := sharedCipher.Decrypt(encrypted)
1194+
require.NoError(t, err, "segment %d cross-decrypt", i)
1195+
assert.Equal(t, original, decrypted, "segment %d plaintext", i)
1196+
1197+
// Verify hash matches independent HMAC
1198+
mac := hmac.New(sha256.New, sharedDEK)
1199+
mac.Write(encrypted)
1200+
assert.Equal(t,
1201+
base64.StdEncoding.EncodeToString(mac.Sum(nil)),
1202+
result.Hash, "segment %d hash", i)
1203+
}
1204+
})
1205+
1206+
t.Run("FullTDFAssembly", func(t *testing.T) {
1207+
// Assemble a complete TDF ZIP from the experimental writer and
1208+
// parse it with the same zipstream.TDFReader the production SDK uses.
1209+
writer, err := NewWriter(ctx, WithSegmentIntegrityAlgorithm(HS256))
1210+
require.NoError(t, err)
1211+
writer.dek = sharedDEK
1212+
writer.block, err = ocrypto.NewAESGcm(sharedDEK)
1213+
require.NoError(t, err)
1214+
1215+
plainSegments := [][]byte{
1216+
[]byte("first segment payload"),
1217+
[]byte("second segment payload - a bit longer"),
1218+
}
1219+
sharedCipher, err := ocrypto.NewAESGcm(sharedDEK)
1220+
require.NoError(t, err)
1221+
1222+
// Collect segment TDFData (ZIP local headers + encrypted data)
1223+
var tdfBuf bytes.Buffer
1224+
for i, original := range plainSegments {
1225+
input := append([]byte(nil), original...)
1226+
result, err := writer.WriteSegment(ctx, i, input)
1227+
require.NoError(t, err)
1228+
_, err = io.Copy(&tdfBuf, result.TDFData)
1229+
require.NoError(t, err)
1230+
}
1231+
1232+
// Finalize (adds central directory + manifest entry)
1233+
attrs := []*policy.Value{
1234+
createTestAttribute("https://example.com/attr/Cross/value/Test", testKAS1, "kid1"),
1235+
}
1236+
fin, err := writer.Finalize(ctx, WithAttributeValues(attrs))
1237+
require.NoError(t, err)
1238+
tdfBuf.Write(fin.Data)
1239+
1240+
// Parse with zipstream.TDFReader — the production SDK's ZIP parser
1241+
tdfReader, err := zipstream.NewTDFReader(bytes.NewReader(tdfBuf.Bytes()))
1242+
require.NoError(t, err, "production TDFReader must parse experimental TDF ZIP")
1243+
1244+
// Verify manifest is valid JSON with expected fields
1245+
manifestJSON, err := tdfReader.Manifest()
1246+
require.NoError(t, err)
1247+
assert.Contains(t, manifestJSON, `"algorithm":"AES-256-GCM"`)
1248+
assert.Contains(t, manifestJSON, `"isStreamable":true`)
1249+
1250+
var manifest Manifest
1251+
require.NoError(t, json.Unmarshal([]byte(manifestJSON), &manifest))
1252+
require.Len(t, manifest.Segments, len(plainSegments))
1253+
assert.Equal(t, "HS256", manifest.SegmentHashAlgorithm)
1254+
assert.NotEmpty(t, manifest.Signature, "root signature must be present")
1255+
1256+
// Verify payload is readable and each segment decrypts correctly
1257+
payloadSize, err := tdfReader.PayloadSize()
1258+
require.NoError(t, err)
1259+
1260+
var offset int64
1261+
for i, seg := range manifest.Segments {
1262+
require.LessOrEqual(t, offset+seg.EncryptedSize, payloadSize,
1263+
"segment %d exceeds payload bounds", i)
1264+
1265+
readBuf, err := tdfReader.ReadPayload(offset, seg.EncryptedSize)
1266+
require.NoError(t, err, "segment %d ReadPayload", i)
1267+
1268+
// This is exactly what the production reader does:
1269+
// 1. Verify segment hash
1270+
mac := hmac.New(sha256.New, sharedDEK)
1271+
mac.Write(readBuf)
1272+
computedHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
1273+
assert.Equal(t, seg.Hash, computedHash, "segment %d hash verification", i)
1274+
1275+
// 2. Decrypt
1276+
decrypted, err := sharedCipher.Decrypt(readBuf)
1277+
require.NoError(t, err, "segment %d decrypt", i)
1278+
assert.Equal(t, plainSegments[i], decrypted, "segment %d plaintext", i)
1279+
1280+
offset += seg.EncryptedSize
1281+
}
1282+
})
1283+
}
1284+
9991285
// testGetManifestBeforeAndAfterFinalize verifies GetManifest returns a stub
10001286
// before finalization and the final manifest after finalization.
10011287
func testGetManifestBeforeAndAfterFinalize(t *testing.T) {

0 commit comments

Comments
 (0)