Skip to content

Commit 1dc83a8

Browse files
committed
Update kas_client.py & tdf.py, expand tests
1 parent 01b8169 commit 1dc83a8

File tree

5 files changed

+380
-69
lines changed

5 files changed

+380
-69
lines changed

src/otdf_python/kas_client.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,37 @@ def _create_signed_request_jwt(self, policy_json, client_public_key, key_access)
222222

223223
# The server expects a JWT with a requestBody field containing the UnsignedRewrapRequest
224224
# Create the request body that matches UnsignedRewrapRequest protobuf structure
225-
# For legacy v1 SRT format, the policy must be base64-encoded
225+
# Use the v2 format with explicit policy ID and requests array for cross-tool compatibility
226+
227+
# Use "policy" as policy ID for compatibility with otdfctl
228+
import json
229+
230+
policy_uuid = "policy" # otdfctl uses "policy" as the policy ID
231+
232+
# For v2 format, the policy body must be base64-encoded
226233
policy_base64 = base64.b64encode(policy_json.encode("utf-8")).decode("utf-8")
227234

228235
unsigned_rewrap_request = {
229236
"clientPublicKey": client_public_key, # Maps to client_public_key
230-
"policy": policy_base64, # Maps to policy (legacy field) - base64-encoded
231-
"keyAccess": key_access_dict, # Maps to key_access (legacy field)
237+
"requests": [
238+
{ # Maps to requests array (v2 format)
239+
"keyAccessObjects": [
240+
{
241+
"keyAccessObjectId": "kao-0", # Standard KAO ID
242+
"keyAccessObject": key_access_dict,
243+
}
244+
],
245+
"policy": {
246+
"id": policy_uuid, # Use the UUID from policy as the policy ID
247+
"body": policy_base64, # Base64-encoded policy JSON
248+
},
249+
}
250+
],
251+
"keyAccess": key_access_dict, # Keep legacy field for backward compatibility
252+
"policy": policy_base64, # Keep legacy field for backward compatibility
232253
}
233254

234255
# Convert to JSON string
235-
import json
236-
237256
request_body_json = json.dumps(unsigned_rewrap_request)
238257

239258
# JWT payload with requestBody field containing the JSON string

src/otdf_python/tdf.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
import os
44
import hashlib
5+
import hmac
56
import base64
67
import zipfile
78
from otdf_python.manifest import (
@@ -334,7 +335,7 @@ def create_tdf(
334335
segment_size = (
335336
getattr(config, "default_segment_size", None) or self.SEGMENT_SIZE
336337
)
337-
hasher = hashlib.sha256()
338+
segment_hashes_raw = []
338339
total = 0
339340
# Write encrypted payload in segments
340341
with writer.payload() as f:
@@ -346,9 +347,14 @@ def create_tdf(
346347
break
347348
encrypted = aesgcm.encrypt(chunk)
348349
f.write(encrypted.as_bytes())
349-
seg_hash = base64.b64encode(
350-
hashlib.sha256(encrypted.as_bytes()).digest()
351-
).decode()
350+
# Calculate segment hash using GMAC (last 16 bytes of encrypted segment)
351+
# This matches the platform SDK when segmentHashAlg is "GMAC"
352+
encrypted_bytes = encrypted.as_bytes()
353+
gmac_length = 16 # kGMACPayloadLength from platform SDK
354+
if len(encrypted_bytes) < gmac_length:
355+
raise ValueError("Encrypted segment too short for GMAC")
356+
seg_hash_raw = encrypted_bytes[-gmac_length:] # Take last 16 bytes
357+
seg_hash = base64.b64encode(seg_hash_raw).decode()
352358
segments.append(
353359
ManifestSegment(
354360
hash=seg_hash,
@@ -360,14 +366,19 @@ def create_tdf(
360366
), # Changed from encrypted_segment_size to encryptedSegmentSize
361367
)
362368
)
363-
hasher.update(encrypted.as_bytes())
369+
# Collect raw segment hash bytes for root signature calculation
370+
segment_hashes_raw.append(seg_hash_raw)
364371
total += len(chunk)
365372
# Use config fields for policy
366373
policy_json = self._build_policy_json(config)
367374
# Encode policy as base64 to match Java SDK
368375
policy_b64 = base64.b64encode(policy_json.encode()).decode()
369376

370-
root_sig = base64.b64encode(hasher.digest()).decode()
377+
# Calculate root signature: HMAC-SHA256 over concatenated segment hash raw bytes
378+
# This matches the platform SDK approach
379+
aggregate_hash = b"".join(segment_hashes_raw)
380+
root_sig_raw = hmac.new(key, aggregate_hash, hashlib.sha256).digest()
381+
root_sig = base64.b64encode(root_sig_raw).decode()
371382
integrity_info = ManifestIntegrityInformation(
372383
rootSignature=ManifestRootSignature(
373384
alg="HS256", sig=root_sig
@@ -454,7 +465,6 @@ def read_payload(
454465
from otdf_python.aesgcm import AesGcm
455466
from otdf_python.asym_crypto import AsymDecryption
456467
import base64
457-
import hashlib
458468

459469
with zipfile.ZipFile(io.BytesIO(tdf_bytes), "r") as z:
460470
manifest_json = z.read("0.manifest.json").decode()
@@ -485,8 +495,15 @@ def read_payload(
485495
for seg in segments:
486496
enc_len = seg.encryptedSegmentSize # Changed field name
487497
enc_bytes = encrypted_payload[offset : offset + enc_len]
488-
# Integrity check (SHA256 HMAC)
489-
seg_hash = base64.b64encode(hashlib.sha256(enc_bytes).digest()).decode()
498+
# Integrity check using GMAC (last 16 bytes of encrypted segment)
499+
# This matches how segments are hashed when segmentHashAlg is "GMAC"
500+
gmac_length = 16 # kGMACPayloadLength from platform SDK
501+
if len(enc_bytes) < gmac_length:
502+
raise ValueError(
503+
"Encrypted segment too short for GMAC verification"
504+
)
505+
seg_hash_raw = enc_bytes[-gmac_length:] # Take last 16 bytes
506+
seg_hash = base64.b64encode(seg_hash_raw).decode()
490507
if seg.hash != seg_hash:
491508
raise ValueError("Segment signature mismatch")
492509
iv = enc_bytes[: AesGcm.GCM_NONCE_LENGTH]

0 commit comments

Comments
 (0)