@@ -4,8 +4,12 @@ package tdf
44
55import (
66 "bytes"
7+ "crypto/hmac"
78 "crypto/rand"
9+ "crypto/sha256"
10+ "encoding/base64"
811 "encoding/json"
12+ "io"
913 "os"
1014 "path/filepath"
1115 "runtime"
@@ -14,6 +18,7 @@ import (
1418
1519 "github.com/opentdf/platform/lib/ocrypto"
1620 "github.com/opentdf/platform/protocol/go/policy"
21+ "github.com/opentdf/platform/sdk/internal/zipstream"
1722 "github.com/stretchr/testify/assert"
1823 "github.com/stretchr/testify/require"
1924 "github.com/xeipuuv/gojsonschema"
@@ -58,6 +63,8 @@ func TestWriterEndToEnd(t *testing.T) {
5863 {"GetManifestIncludesInitialPolicy" , testGetManifestIncludesInitialPolicy },
5964 {"SparseIndicesInOrder" , testSparseIndicesInOrder },
6065 {"SparseIndicesOutOfOrder" , testSparseIndicesOutOfOrder },
66+ {"SegmentHashCoversNonceAndCipher" , testSegmentHashCoversNonceAndCipher },
67+ {"FinalizeWithURIOnlyGrant" , testFinalizeWithURIOnlyGrant },
6168 }
6269
6370 for _ , tc := range testCases {
@@ -188,6 +195,95 @@ func testSparseIndicesOutOfOrder(t *testing.T) {
188195 assert .Equal (t , int64 (expectedPlain ), fin .TotalSize )
189196}
190197
198+ // testSegmentHashCoversNonceAndCipher is a regression test ensuring that the
199+ // HS256 segment hash covers nonce+ciphertext, not ciphertext alone.
200+ //
201+ // The standard SDK's Encrypt() returns nonce prepended to ciphertext and
202+ // hashes that combined blob; the experimental SDK's EncryptInPlace() returns
203+ // them separately, so the writer must concatenate before hashing.
204+ //
205+ // Only HS256 is tested because GMAC extracts the last 16 bytes of data as
206+ // the tag — stripping the nonce prefix doesn't change the tail, so GMAC is
207+ // structurally unable to detect a nonce-exclusion regression.
208+ func testSegmentHashCoversNonceAndCipher (t * testing.T ) {
209+ ctx := t .Context ()
210+
211+ writer , err := NewWriter (ctx , WithSegmentIntegrityAlgorithm (HS256 ))
212+ require .NoError (t , err )
213+
214+ testData := []byte ("segment hash regression test payload" )
215+ result , err := writer .WriteSegment (ctx , 0 , testData )
216+ require .NoError (t , err )
217+
218+ // Read all bytes from the TDFData reader to get the full segment output.
219+ allBytes , err := io .ReadAll (result .TDFData )
220+ require .NoError (t , err )
221+
222+ // The last EncryptedSize bytes are the encrypted segment (nonce + cipher).
223+ // Everything before that is the ZIP local file header.
224+ encryptedData := allBytes [int64 (len (allBytes ))- result .EncryptedSize :]
225+
226+ // Positive assertion: independently compute HMAC-SHA256 over nonce+cipher
227+ // using crypto/hmac directly (not the production calculateSignature path)
228+ // and verify it matches the stored hash.
229+ mac := hmac .New (sha256 .New , writer .dek )
230+ mac .Write (encryptedData )
231+ expectedHash := base64 .StdEncoding .EncodeToString (mac .Sum (nil ))
232+ assert .Equal (t , expectedHash , result .Hash , "hash should equal independent HMAC-SHA256 over nonce+ciphertext" )
233+
234+ // Negative / regression assertion: independently compute HMAC-SHA256 over
235+ // cipher-only (stripping the 12-byte GCM nonce). If someone reverts the
236+ // fix so only cipher is hashed, the stored hash would match this value.
237+ cipherOnly := encryptedData [ocrypto .GcmStandardNonceSize :]
238+ wrongMac := hmac .New (sha256 .New , writer .dek )
239+ wrongMac .Write (cipherOnly )
240+ wrongHash := base64 .StdEncoding .EncodeToString (wrongMac .Sum (nil ))
241+ assert .NotEqual (t , wrongHash , result .Hash , "hash must NOT match cipher-only (nonce must be included)" )
242+ }
243+
244+ // testFinalizeWithURIOnlyGrant is an end-to-end regression test ensuring
245+ // that Finalize succeeds when attribute grants reference a KAS URL without
246+ // embedding the public key (URI-only legacy grants). The default KAS must
247+ // supply the missing key information. Without the merge fix in
248+ // GenerateSplits, key wrapping fails with "no valid key access objects".
249+ func testFinalizeWithURIOnlyGrant (t * testing.T ) {
250+ ctx := t .Context ()
251+
252+ defaultKAS := & policy.SimpleKasKey {
253+ KasUri : testKAS1 ,
254+ PublicKey : & policy.SimpleKasPublicKey {
255+ Algorithm : policy .Algorithm_ALGORITHM_RSA_2048 ,
256+ Kid : "default-kid" ,
257+ Pem : mockRSAPublicKey1 ,
258+ },
259+ }
260+
261+ writer , err := NewWriter (ctx , WithDefaultKASForWriter (defaultKAS ))
262+ require .NoError (t , err )
263+
264+ _ , err = writer .WriteSegment (ctx , 0 , []byte ("uri-only grant test" ))
265+ require .NoError (t , err )
266+
267+ // Create attribute with a URI-only grant (no KasKeys / no embedded public key).
268+ uriOnlyAttr := createTestAttributeWithRule (
269+ "https://example.com/attr/Level/value/Secret" ,
270+ "" , "" , // no KAS URL → no grants added by helper
271+ policy .AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF ,
272+ )
273+ uriOnlyAttr .Grants = []* policy.KeyAccessServer {
274+ {Uri : testKAS1 }, // URI-only, no KasKeys
275+ }
276+
277+ fin , err := writer .Finalize (ctx , WithAttributeValues ([]* policy.Value {uriOnlyAttr }))
278+ require .NoError (t , err , "Finalize must succeed when default KAS fills in missing key for URI-only grant" )
279+ require .NotNil (t , fin .Manifest )
280+
281+ // Verify the key access object references the right KAS
282+ require .GreaterOrEqual (t , len (fin .Manifest .KeyAccessObjs ), 1 )
283+ assert .Equal (t , testKAS1 , fin .Manifest .KeyAccessObjs [0 ].KasURL )
284+ assert .NotEmpty (t , fin .Manifest .KeyAccessObjs [0 ].WrappedKey )
285+ }
286+
191287// testInitialAttributesOnWriter verifies that attributes/KAS supplied at
192288// NewWriter are used by Finalize when not overridden, and that Finalize
193289// overrides take precedence.
@@ -996,6 +1092,196 @@ func BenchmarkTDFCreation(b *testing.B) {
9961092 })
9971093}
9981094
1095+ // TestCrossDecryptWithSharedDEK verifies that the experimental writer's
1096+ // encryption format is compatible with the production SDK by injecting a
1097+ // shared DEK into the experimental writer and cross-validating with the
1098+ // same crypto primitives the production reader uses:
1099+ //
1100+ // - ocrypto.AesGcm.Encrypt() (production encrypt: returns nonce||ciphertext)
1101+ // - ocrypto.AesGcm.Decrypt() (production decrypt: expects nonce||ciphertext)
1102+ // - HMAC-SHA256(dek, nonce||ciphertext) (production segment hash verification)
1103+ //
1104+ // The test also assembles a complete TDF ZIP from the experimental writer
1105+ // and parses it with zipstream.TDFReader (the same reader the production
1106+ // SDK uses internally) to verify structural compatibility.
1107+ func TestCrossDecryptWithSharedDEK (t * testing.T ) {
1108+ ctx := t .Context ()
1109+
1110+ sharedDEK , err := ocrypto .RandomBytes (kKeySize )
1111+ require .NoError (t , err )
1112+
1113+ t .Run ("SingleSegment" , func (t * testing.T ) {
1114+ original := []byte ("Cross-SDK format compatibility: single segment" )
1115+
1116+ sharedCipher , err := ocrypto .NewAESGcm (sharedDEK )
1117+ require .NoError (t , err )
1118+
1119+ // --- Experimental writer with injected DEK ---
1120+ writer , err := NewWriter (ctx , WithSegmentIntegrityAlgorithm (HS256 ))
1121+ require .NoError (t , err )
1122+ writer .dek = sharedDEK
1123+ writer .block , err = ocrypto .NewAESGcm (sharedDEK )
1124+ require .NoError (t , err )
1125+
1126+ expInput := append ([]byte (nil ), original ... )
1127+ expResult , err := writer .WriteSegment (ctx , 0 , expInput )
1128+ require .NoError (t , err )
1129+
1130+ allBytes , err := io .ReadAll (expResult .TDFData )
1131+ require .NoError (t , err )
1132+ expEncrypted := allBytes [int64 (len (allBytes ))- expResult .EncryptedSize :]
1133+
1134+ // --- Production-style encrypt with the same DEK ---
1135+ prodEncrypted , err := sharedCipher .Encrypt (original )
1136+ require .NoError (t , err )
1137+
1138+ // --- Cross-decrypt: production Decrypt() on experimental output ---
1139+ decryptedFromExp , err := sharedCipher .Decrypt (expEncrypted )
1140+ require .NoError (t , err , "production Decrypt must handle experimental output" )
1141+ assert .Equal (t , decryptedFromExp , original )
1142+
1143+ // --- Cross-decrypt: Decrypt() on production output ---
1144+ decryptedFromProd , err := sharedCipher .Decrypt (prodEncrypted )
1145+ require .NoError (t , err )
1146+ assert .Equal (t , original , decryptedFromProd )
1147+
1148+ // --- Hash cross-verification ---
1149+ // The production reader computes HMAC-SHA256(payloadKey, encryptedSegment)
1150+ // and compares it against the manifest segment hash. Verify the
1151+ // experimental writer's stored hash matches this computation.
1152+ mac := hmac .New (sha256 .New , sharedDEK )
1153+ mac .Write (expEncrypted )
1154+ independentHash := base64 .StdEncoding .EncodeToString (mac .Sum (nil ))
1155+ assert .Equal (t , expResult .Hash , independentHash ,
1156+ "experimental hash must equal production-style HMAC-SHA256" )
1157+
1158+ // Verify production-encrypted data also hashes correctly
1159+ prodMac := hmac .New (sha256 .New , sharedDEK )
1160+ prodMac .Write (prodEncrypted )
1161+ prodHash := base64 .StdEncoding .EncodeToString (prodMac .Sum (nil ))
1162+ assert .NotEmpty (t , prodHash )
1163+ // Both hashes are valid HMACs but differ because nonces are random
1164+ assert .NotEqual (t , independentHash , prodHash )
1165+ })
1166+
1167+ t .Run ("MultiSegment" , func (t * testing.T ) {
1168+ sharedCipher , err := ocrypto .NewAESGcm (sharedDEK )
1169+ require .NoError (t , err )
1170+
1171+ writer , err := NewWriter (ctx , WithSegmentIntegrityAlgorithm (HS256 ))
1172+ require .NoError (t , err )
1173+ writer .dek = sharedDEK
1174+ writer .block , err = ocrypto .NewAESGcm (sharedDEK )
1175+ require .NoError (t , err )
1176+
1177+ segments := [][]byte {
1178+ []byte ("segment zero" ),
1179+ []byte ("segment one with longer content for variety" ),
1180+ []byte ("s2" ),
1181+ }
1182+
1183+ for i , original := range segments {
1184+ input := append ([]byte (nil ), original ... )
1185+ result , err := writer .WriteSegment (ctx , i , input )
1186+ require .NoError (t , err )
1187+
1188+ raw , err := io .ReadAll (result .TDFData )
1189+ require .NoError (t , err )
1190+ encrypted := raw [int64 (len (raw ))- result .EncryptedSize :]
1191+
1192+ // Cross-decrypt each segment with production-style Decrypt
1193+ decrypted , err := sharedCipher .Decrypt (encrypted )
1194+ require .NoError (t , err , "segment %d cross-decrypt" , i )
1195+ assert .Equal (t , original , decrypted , "segment %d plaintext" , i )
1196+
1197+ // Verify hash matches independent HMAC
1198+ mac := hmac .New (sha256 .New , sharedDEK )
1199+ mac .Write (encrypted )
1200+ assert .Equal (t ,
1201+ base64 .StdEncoding .EncodeToString (mac .Sum (nil )),
1202+ result .Hash , "segment %d hash" , i )
1203+ }
1204+ })
1205+
1206+ t .Run ("FullTDFAssembly" , func (t * testing.T ) {
1207+ // Assemble a complete TDF ZIP from the experimental writer and
1208+ // parse it with the same zipstream.TDFReader the production SDK uses.
1209+ writer , err := NewWriter (ctx , WithSegmentIntegrityAlgorithm (HS256 ))
1210+ require .NoError (t , err )
1211+ writer .dek = sharedDEK
1212+ writer .block , err = ocrypto .NewAESGcm (sharedDEK )
1213+ require .NoError (t , err )
1214+
1215+ plainSegments := [][]byte {
1216+ []byte ("first segment payload" ),
1217+ []byte ("second segment payload - a bit longer" ),
1218+ }
1219+ sharedCipher , err := ocrypto .NewAESGcm (sharedDEK )
1220+ require .NoError (t , err )
1221+
1222+ // Collect segment TDFData (ZIP local headers + encrypted data)
1223+ var tdfBuf bytes.Buffer
1224+ for i , original := range plainSegments {
1225+ input := append ([]byte (nil ), original ... )
1226+ result , err := writer .WriteSegment (ctx , i , input )
1227+ require .NoError (t , err )
1228+ _ , err = io .Copy (& tdfBuf , result .TDFData )
1229+ require .NoError (t , err )
1230+ }
1231+
1232+ // Finalize (adds central directory + manifest entry)
1233+ attrs := []* policy.Value {
1234+ createTestAttribute ("https://example.com/attr/Cross/value/Test" , testKAS1 , "kid1" ),
1235+ }
1236+ fin , err := writer .Finalize (ctx , WithAttributeValues (attrs ))
1237+ require .NoError (t , err )
1238+ tdfBuf .Write (fin .Data )
1239+
1240+ // Parse with zipstream.TDFReader — the production SDK's ZIP parser
1241+ tdfReader , err := zipstream .NewTDFReader (bytes .NewReader (tdfBuf .Bytes ()))
1242+ require .NoError (t , err , "production TDFReader must parse experimental TDF ZIP" )
1243+
1244+ // Verify manifest is valid JSON with expected fields
1245+ manifestJSON , err := tdfReader .Manifest ()
1246+ require .NoError (t , err )
1247+ assert .Contains (t , manifestJSON , `"algorithm":"AES-256-GCM"` )
1248+ assert .Contains (t , manifestJSON , `"isStreamable":true` )
1249+
1250+ var manifest Manifest
1251+ require .NoError (t , json .Unmarshal ([]byte (manifestJSON ), & manifest ))
1252+ require .Len (t , manifest .Segments , len (plainSegments ))
1253+ assert .Equal (t , "HS256" , manifest .SegmentHashAlgorithm )
1254+ assert .NotEmpty (t , manifest .Signature , "root signature must be present" )
1255+
1256+ // Verify payload is readable and each segment decrypts correctly
1257+ payloadSize , err := tdfReader .PayloadSize ()
1258+ require .NoError (t , err )
1259+
1260+ var offset int64
1261+ for i , seg := range manifest .Segments {
1262+ require .LessOrEqual (t , offset + seg .EncryptedSize , payloadSize ,
1263+ "segment %d exceeds payload bounds" , i )
1264+
1265+ readBuf , err := tdfReader .ReadPayload (offset , seg .EncryptedSize )
1266+ require .NoError (t , err , "segment %d ReadPayload" , i )
1267+
1268+ // This is exactly what the production reader does:
1269+ // 1. Verify segment hash
1270+ mac := hmac .New (sha256 .New , sharedDEK )
1271+ mac .Write (readBuf )
1272+ computedHash := base64 .StdEncoding .EncodeToString (mac .Sum (nil ))
1273+ assert .Equal (t , seg .Hash , computedHash , "segment %d hash verification" , i )
1274+
1275+ // 2. Decrypt
1276+ decrypted , err := sharedCipher .Decrypt (readBuf )
1277+ require .NoError (t , err , "segment %d decrypt" , i )
1278+ assert .Equal (t , plainSegments [i ], decrypted , "segment %d plaintext" , i )
1279+
1280+ offset += seg .EncryptedSize
1281+ }
1282+ })
1283+ }
1284+
9991285// testGetManifestBeforeAndAfterFinalize verifies GetManifest returns a stub
10001286// before finalization and the final manifest after finalization.
10011287func testGetManifestBeforeAndAfterFinalize (t * testing.T ) {
0 commit comments