diff --git a/src/otdf_python/cli.py b/src/otdf_python/cli.py index 863d006..529fb83 100644 --- a/src/otdf_python/cli.py +++ b/src/otdf_python/cli.py @@ -201,6 +201,13 @@ def create_nano_tdf_config(sdk: SDK, args) -> NanoTDFConfig: kas_endpoints = parse_kas_endpoints(args.kas_endpoint) kas_info_list = [KASInfo(url=kas_url) for kas_url in kas_endpoints] config.kas_info_list.extend(kas_info_list) + elif args.platform_url: + # If no explicit KAS endpoint provided, derive from platform URL + # This matches the default KAS path convention + kas_url = args.platform_url.rstrip("/") + "/kas" + logger.debug(f"Deriving KAS endpoint from platform URL: {kas_url}") + kas_info = KASInfo(url=kas_url) + config.kas_info_list.append(kas_info) if hasattr(args, "policy_binding") and args.policy_binding: if args.policy_binding.lower() == "ecdsa": @@ -554,7 +561,7 @@ def main(): sys.exit(1) except Exception as e: logger.error(f"Unexpected error: {e}") - logger.debug("", exc_info=True) + logger.error("", exc_info=True) # Always print traceback for unexpected errors sys.exit(1) diff --git a/src/otdf_python/ecc_mode.py b/src/otdf_python/ecc_mode.py index 95b4714..a8dcb5b 100644 --- a/src/otdf_python/ecc_mode.py +++ b/src/otdf_python/ecc_mode.py @@ -1,4 +1,14 @@ +from typing import ClassVar + + class ECCMode: + _CURVE_MAP: ClassVar[dict[str, int]] = { + "secp256r1": 0, + "secp384r1": 1, + "secp521r1": 2, + "secp256k1": 3, + } + def __init__(self, curve_mode: int = 0, use_ecdsa_binding: bool = False): self.curve_mode = curve_mode self.use_ecdsa_binding = use_ecdsa_binding @@ -15,18 +25,43 @@ def set_elliptic_curve(self, curve_mode: int): def get_elliptic_curve_type(self) -> int: return self.curve_mode + def get_curve_name(self) -> str: + """Get the curve name as a string (e.g., 'secp256r1').""" + for name, mode in self._CURVE_MAP.items(): + if mode == self.curve_mode: + return name + # Default to secp256r1 if not found + return "secp256r1" + @staticmethod def get_ec_compressed_pubkey_size(curve_type: int) -> int: - # 0: secp256r1, 1: secp384r1, 2: secp521r1 + # 0: secp256r1, 1: secp384r1, 2: secp521r1, 3: secp256k1 if curve_type == 0: - return 33 + return 33 # secp256r1 elif curve_type == 1: - return 49 + return 49 # secp384r1 elif curve_type == 2: - return 67 + return 67 # secp521r1 + elif curve_type == 3: + return 33 # secp256k1 (same size as secp256r1) else: raise ValueError("Unsupported ECC algorithm.") def get_ecc_mode_as_byte(self) -> int: # Most significant bit: use_ecdsa_binding, lower 3 bits: curve_mode return ((1 if self.use_ecdsa_binding else 0) << 7) | (self.curve_mode & 0x07) + + @staticmethod + def from_string(curve_str: str) -> "ECCMode": + """Create ECCMode from curve string like 'secp256r1' or 'secp384r1', or policy binding type like 'gmac' or 'ecdsa'.""" + # Handle policy binding types + if curve_str.lower() == "gmac": + return ECCMode(0, False) # GMAC binding with default secp256r1 curve + elif curve_str.lower() == "ecdsa": + return ECCMode(0, True) # ECDSA binding with default secp256r1 curve + + # Handle curve names + curve_mode = ECCMode._CURVE_MAP.get(curve_str.lower()) + if curve_mode is None: + raise ValueError(f"Unsupported curve string: '{curve_str}'") + return ECCMode(curve_mode, False) diff --git a/src/otdf_python/ecdh.py b/src/otdf_python/ecdh.py new file mode 100644 index 0000000..a6b36b3 --- /dev/null +++ b/src/otdf_python/ecdh.py @@ -0,0 +1,332 @@ +""" +ECDH (Elliptic Curve Diffie-Hellman) key exchange for NanoTDF. + +This module implements the ECDH key exchange protocol with HKDF key derivation +as specified in the NanoTDF spec. It supports the following curves: +- secp256r1 (NIST P-256) +- secp384r1 (NIST P-384) +- secp521r1 (NIST P-521) +- secp256k1 (Bitcoin curve) + +The protocol follows ECIES methodology similar to S/MIME and GPG: +1. Generate ephemeral keypair +2. Perform ECDH with recipient's public key to get shared secret +3. Use HKDF to derive symmetric encryption key from shared secret +""" + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.serialization import ( + Encoding, + PublicFormat, +) + +# Mapping from curve names to cryptography curve objects +CURVE_MAP = { + "secp256r1": ec.SECP256R1(), + "secp384r1": ec.SECP384R1(), + "secp521r1": ec.SECP521R1(), + "secp256k1": ec.SECP256K1(), +} + +# Compressed public key sizes for each curve +COMPRESSED_KEY_SIZES = { + "secp256r1": 33, # 1 byte prefix + 32 bytes + "secp384r1": 49, # 1 byte prefix + 48 bytes + "secp521r1": 67, # 1 byte prefix + 66 bytes + "secp256k1": 33, # 1 byte prefix + 32 bytes +} + +# HKDF salt for NanoTDF key derivation +# Per spec: "salt value derived from magic number/version" +# This is the SHA-256 hash of the NanoTDF magic number and version +NANOTDF_HKDF_SALT = bytes.fromhex( + "3de3ca1e50cf62d8b6aba603a96fca6761387a7ac86c3d3afe85ae2d1812edfc" +) + + +class ECDHError(Exception): + """Base exception for ECDH operations.""" + + pass + + +class UnsupportedCurveError(ECDHError): + """Raised when an unsupported curve is specified.""" + + pass + + +class InvalidKeyError(ECDHError): + """Raised when a key is invalid or malformed.""" + + pass + + +def get_curve(curve_name: str) -> ec.EllipticCurve: + """ + Get the cryptography curve object for a given curve name. + + Args: + curve_name: Name of the curve (e.g., "secp256r1") + + Returns: + ec.EllipticCurve: The curve object + + Raises: + UnsupportedCurveError: If the curve is not supported + """ + curve_name_lower = curve_name.lower() + if curve_name_lower not in CURVE_MAP: + raise UnsupportedCurveError( + f"Unsupported curve: {curve_name}. " + f"Supported curves: {', '.join(CURVE_MAP.keys())}" + ) + return CURVE_MAP[curve_name_lower] + + +def get_compressed_key_size(curve_name: str) -> int: + """ + Get the size of a compressed public key for a given curve. + + Args: + curve_name: Name of the curve (e.g., "secp256r1") + + Returns: + int: Size in bytes of the compressed public key + + Raises: + UnsupportedCurveError: If the curve is not supported + """ + curve_name_lower = curve_name.lower() + if curve_name_lower not in COMPRESSED_KEY_SIZES: + raise UnsupportedCurveError(f"Unsupported curve: {curve_name}") + return COMPRESSED_KEY_SIZES[curve_name_lower] + + +def generate_ephemeral_keypair( + curve_name: str, +) -> tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]: + """ + Generate an ephemeral keypair for ECDH. + + Args: + curve_name: Name of the curve (e.g., "secp256r1") + + Returns: + tuple: (private_key, public_key) + + Raises: + UnsupportedCurveError: If the curve is not supported + """ + curve = get_curve(curve_name) + private_key = ec.generate_private_key(curve, default_backend()) + public_key = private_key.public_key() + return private_key, public_key + + +def compress_public_key(public_key: ec.EllipticCurvePublicKey) -> bytes: + """ + Compress an EC public key to compressed point format. + + Args: + public_key: The EC public key to compress + + Returns: + bytes: Compressed public key (33-67 bytes depending on curve) + """ + return public_key.public_bytes( + encoding=Encoding.X962, format=PublicFormat.CompressedPoint + ) + + +def decompress_public_key( + compressed_key: bytes, curve_name: str +) -> ec.EllipticCurvePublicKey: + """ + Decompress a public key from compressed point format. + + Args: + compressed_key: The compressed public key bytes + curve_name: Name of the curve (e.g., "secp256r1") + + Returns: + ec.EllipticCurvePublicKey: The decompressed public key + + Raises: + InvalidKeyError: If the key cannot be decompressed + UnsupportedCurveError: If the curve is not supported + """ + try: + curve = get_curve(curve_name) + # Verify the size matches expected compressed size + expected_size = get_compressed_key_size(curve_name) + if len(compressed_key) != expected_size: + raise InvalidKeyError( + f"Invalid compressed key size for {curve_name}: " + f"expected {expected_size} bytes, got {len(compressed_key)} bytes" + ) + + return ec.EllipticCurvePublicKey.from_encoded_point(curve, compressed_key) + except (ValueError, TypeError) as e: + raise InvalidKeyError(f"Failed to decompress public key: {e}") + + +def derive_shared_secret( + private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey +) -> bytes: + """ + Derive a shared secret using ECDH. + + Args: + private_key: The private key (can be ephemeral or recipient's key) + public_key: The public key (recipient's or ephemeral key) + + Returns: + bytes: The raw shared secret (x-coordinate of the ECDH point) + + Raises: + ECDHError: If ECDH fails + """ + try: + shared_secret = private_key.exchange(ec.ECDH(), public_key) + return shared_secret + except Exception as e: + raise ECDHError(f"Failed to derive shared secret: {e}") + + +def derive_key_from_shared_secret( + shared_secret: bytes, + key_length: int = 32, + salt: bytes | None = None, + info: bytes = b"", +) -> bytes: + """ + Derive a symmetric encryption key from the ECDH shared secret using HKDF. + + Args: + shared_secret: The raw ECDH shared secret + key_length: Length of the derived key in bytes (default: 32 for AES-256) + salt: Optional salt for HKDF (default: NANOTDF_HKDF_SALT) + info: Optional context/application-specific info (default: empty) + + Returns: + bytes: Derived symmetric encryption key + + Raises: + ECDHError: If key derivation fails + """ + if salt is None: + salt = NANOTDF_HKDF_SALT + + try: + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=key_length, + salt=salt, + info=info, + backend=default_backend(), + ) + return hkdf.derive(shared_secret) + except Exception as e: + raise ECDHError(f"Failed to derive key from shared secret: {e}") + + +def encrypt_key_with_ecdh( + recipient_public_key_pem: str, curve_name: str = "secp256r1" +) -> tuple[bytes, bytes]: + """ + High-level function: Generate ephemeral keypair and derive encryption key. + + This is used during NanoTDF encryption to derive the key that will be used + to encrypt the payload. The ephemeral public key must be stored in the + NanoTDF header so the recipient can derive the same key. + + Args: + recipient_public_key_pem: Recipient's public key in PEM format (e.g., KAS public key) + curve_name: Name of the curve to use (default: "secp256r1") + + Returns: + tuple: (derived_key, compressed_ephemeral_public_key) + - derived_key: 32-byte AES-256 key for encrypting the payload + - compressed_ephemeral_public_key: Ephemeral public key to store in header + + Raises: + ECDHError: If key derivation fails + InvalidKeyError: If recipient's public key is invalid + UnsupportedCurveError: If the curve is not supported + """ + # Load recipient's public key + try: + recipient_public_key = serialization.load_pem_public_key( + recipient_public_key_pem.encode(), backend=default_backend() + ) + if not isinstance(recipient_public_key, ec.EllipticCurvePublicKey): + raise InvalidKeyError("Recipient's public key is not an EC key") + except Exception as e: + raise InvalidKeyError(f"Failed to load recipient's public key: {e}") + + # Generate ephemeral keypair + ephemeral_private_key, ephemeral_public_key = generate_ephemeral_keypair(curve_name) + + # Derive shared secret + shared_secret = derive_shared_secret(ephemeral_private_key, recipient_public_key) + + # Derive encryption key from shared secret + derived_key = derive_key_from_shared_secret(shared_secret, key_length=32) + + # Compress ephemeral public key for storage in header + compressed_ephemeral_key = compress_public_key(ephemeral_public_key) + + return derived_key, compressed_ephemeral_key + + +def decrypt_key_with_ecdh( + recipient_private_key_pem: str, + compressed_ephemeral_public_key: bytes, + curve_name: str = "secp256r1", +) -> bytes: + """ + High-level function: Derive decryption key from ephemeral public key and recipient's private key. + + This is used during NanoTDF decryption to derive the same key that was used + to encrypt the payload. The ephemeral public key is extracted from the + NanoTDF header. + + Args: + recipient_private_key_pem: Recipient's private key in PEM format (e.g., KAS private key) + compressed_ephemeral_public_key: Ephemeral public key from NanoTDF header + curve_name: Name of the curve (default: "secp256r1") + + Returns: + bytes: 32-byte AES-256 key for decrypting the payload + + Raises: + ECDHError: If key derivation fails + InvalidKeyError: If keys are invalid + UnsupportedCurveError: If the curve is not supported + """ + # Load recipient's private key + try: + recipient_private_key = serialization.load_pem_private_key( + recipient_private_key_pem.encode(), password=None, backend=default_backend() + ) + if not isinstance(recipient_private_key, ec.EllipticCurvePrivateKey): + raise InvalidKeyError("Recipient's private key is not an EC key") + except Exception as e: + raise InvalidKeyError(f"Failed to load recipient's private key: {e}") + + # Decompress ephemeral public key + ephemeral_public_key = decompress_public_key( + compressed_ephemeral_public_key, curve_name + ) + + # Derive shared secret + shared_secret = derive_shared_secret(recipient_private_key, ephemeral_public_key) + + # Derive decryption key from shared secret + derived_key = derive_key_from_shared_secret(shared_secret, key_length=32) + + return derived_key diff --git a/src/otdf_python/header.py b/src/otdf_python/header.py index 2f6654d..c1a698d 100644 --- a/src/otdf_python/header.py +++ b/src/otdf_python/header.py @@ -11,6 +11,7 @@ def __init__(self): self.ecc_mode: ECCMode | None = None self.payload_config: SymmetricAndPayloadConfig | None = None self.policy_info: PolicyInfo | None = None + self.policy_binding: bytes | None = None self.ephemeral_key: bytes | None = None @classmethod @@ -31,6 +32,15 @@ def from_bytes(cls, buffer: bytes): buffer[offset:], ecc_mode ) offset += policy_size + + # Read policy binding (GMAC - 8 bytes fixed size) + # Note: ECDSA binding not yet supported in this implementation + GMAC_SIZE = 8 + policy_binding = buffer[offset : offset + GMAC_SIZE] + if len(policy_binding) != GMAC_SIZE: + raise ValueError("Failed to read policy binding - invalid buffer size.") + offset += GMAC_SIZE + compressed_pubkey_size = ECCMode.get_ec_compressed_pubkey_size( ecc_mode.get_elliptic_curve_type() ) @@ -42,6 +52,7 @@ def from_bytes(cls, buffer: bytes): obj.ecc_mode = ecc_mode obj.payload_config = payload_config obj.policy_info = policy_info + obj.policy_binding = policy_binding obj.ephemeral_key = ephemeral_key return obj @@ -63,6 +74,8 @@ def peek_length(buffer: bytes) -> int: buffer[offset:], ecc_mode ) offset += policy_size + # Policy binding (GMAC - 8 bytes) + offset += 8 # Ephemeral key (size depends on curve) compressed_pubkey_size = ECCMode.get_ec_compressed_pubkey_size( ecc_mode.get_elliptic_curve_type() @@ -94,6 +107,16 @@ def set_policy_info(self, policy_info: PolicyInfo): def get_policy_info(self) -> PolicyInfo | None: return self.policy_info + def set_policy_binding(self, policy_binding: bytes): + if len(policy_binding) != 8: + raise ValueError( + f"Policy binding must be exactly 8 bytes (GMAC), got {len(policy_binding)}" + ) + self.policy_binding = policy_binding + + def get_policy_binding(self) -> bytes | None: + return self.policy_binding + def set_ephemeral_key(self, ephemeral_key: bytes): if self.ecc_mode is not None: expected_size = ECCMode.get_ec_compressed_pubkey_size( @@ -112,6 +135,7 @@ def get_total_size(self) -> int: total += 1 # ECC mode total += 1 # payload config total += self.policy_info.get_total_size() if self.policy_info else 0 + total += 8 # policy binding (GMAC) total += len(self.ephemeral_key) if self.ephemeral_key else 0 return total @@ -132,6 +156,18 @@ def write_into_buffer(self, buffer: bytearray) -> int: # PolicyInfo n = self.policy_info.write_into_buffer(buffer, offset) offset += n + # Policy binding (GMAC - 8 bytes) + if self.policy_binding: + if len(self.policy_binding) != 8: + raise ValueError( + f"Policy binding must be exactly 8 bytes (GMAC), got {len(self.policy_binding)}" + ) + buffer[offset : offset + 8] = self.policy_binding + offset += 8 + else: + # Write zeros if no binding provided + buffer[offset : offset + 8] = b"\x00" * 8 + offset += 8 # Ephemeral key buffer[offset : offset + len(self.ephemeral_key)] = self.ephemeral_key offset += len(self.ephemeral_key) diff --git a/src/otdf_python/kas_client.py b/src/otdf_python/kas_client.py index 43b3c6a..3d46ec2 100644 --- a/src/otdf_python/kas_client.py +++ b/src/otdf_python/kas_client.py @@ -25,6 +25,7 @@ class KeyAccess: url: str wrapped_key: str ephemeral_public_key: str | None = None + header: bytes | None = None # For NanoTDF: entire header including ephemeral key class KASClient: @@ -139,16 +140,16 @@ def _handle_existing_scheme(self, parsed) -> str: except Exception as e: raise SDKException("error creating KAS address", e) - def _create_signed_request_jwt(self, policy_json, client_public_key, key_access): # noqa: C901 + def _get_wrapped_key_base64(self, key_access): """ - Create a signed JWT for the rewrap request. - The JWT is signed with the DPoP private key. - """ - # Handle both ManifestKeyAccess (new camelCase and old snake_case) and simple KeyAccess (for tests) - # TODO: This can probably be simplified to only camelCase + Extract and normalize the wrapped key to base64-encoded string. + + Args: + key_access: KeyAccess object - # Ensure wrappedKey is a base64-encoded string - # Note: wrappedKey from manifest is already base64-encoded + Returns: + Base64-encoded wrapped key string + """ wrapped_key = getattr(key_access, "wrappedKey", None) or getattr( key_access, "wrapped_key", None ) @@ -157,11 +158,24 @@ def _create_signed_request_jwt(self, policy_json, client_public_key, key_access) if isinstance(wrapped_key, bytes): # Only encode if it's raw bytes (shouldn't happen from manifest) - wrapped_key = base64.b64encode(wrapped_key).decode("utf-8") + return base64.b64encode(wrapped_key).decode("utf-8") elif not isinstance(wrapped_key, str): # Convert to string if it's something else - wrapped_key = str(wrapped_key) + return str(wrapped_key) # If it's already a string (from manifest), use it as-is since it's already base64-encoded + return wrapped_key + + def _build_key_access_dict(self, key_access): + """ + Build key access dictionary from KeyAccess object, handling both old and new field names. + + Args: + key_access: KeyAccess object + + Returns: + Dictionary with key access information + """ + wrapped_key = self._get_wrapped_key_base64(key_access) key_access_dict = { "url": key_access.url, @@ -172,89 +186,162 @@ def _create_signed_request_jwt(self, policy_json, client_public_key, key_access) key_type = getattr(key_access, "type", None) or getattr( key_access, "key_type", None ) - if key_type is not None: - key_access_dict["type"] = key_type - else: - key_access_dict["type"] = "wrapped" # Default type for tests + key_access_dict["type"] = key_type if key_type is not None else "wrapped" protocol = getattr(key_access, "protocol", None) - if protocol is not None: - key_access_dict["protocol"] = protocol - else: - key_access_dict["protocol"] = "kas" # Default protocol for tests + key_access_dict["protocol"] = protocol if protocol is not None else "kas" + + # Add optional fields + self._add_optional_fields(key_access_dict, key_access) - # Optional fields - handle both old and new field names, only include if they exist and are not None + return key_access_dict + + def _add_optional_fields(self, key_access_dict, key_access): + """ + Add optional fields to key access dictionary. + + Args: + key_access_dict: Dictionary to add fields to + key_access: KeyAccess object to extract fields from + """ + # Policy binding policy_binding = getattr(key_access, "policyBinding", None) or getattr( key_access, "policy_binding", None ) if policy_binding is not None: - # Policy binding hash should be kept as base64-encoded - # The server expects base64-encoded hash values in the JWT request key_access_dict["policyBinding"] = policy_binding + # Encrypted metadata encrypted_metadata = getattr(key_access, "encryptedMetadata", None) or getattr( key_access, "encrypted_metadata", None ) if encrypted_metadata is not None: key_access_dict["encryptedMetadata"] = encrypted_metadata - kid = getattr(key_access, "kid", None) - if kid is not None: - key_access_dict["kid"] = kid - - sid = getattr(key_access, "sid", None) - if sid is not None: - key_access_dict["sid"] = sid + # Simple optional fields + for field in ["kid", "sid"]: + value = getattr(key_access, field, None) + if value is not None: + key_access_dict[field] = value + # Schema version schema_version = getattr(key_access, "schemaVersion", None) or getattr( key_access, "schema_version", None ) if schema_version is not None: key_access_dict["schemaVersion"] = schema_version + # Ephemeral public key ephemeral_public_key = getattr( key_access, "ephemeralPublicKey", None ) or getattr(key_access, "ephemeral_public_key", None) if ephemeral_public_key is not None: key_access_dict["ephemeralPublicKey"] = ephemeral_public_key - # Get current timestamp in seconds since epoch (UNIX timestamp) - now = int(time.time()) + # NanoTDF header + header = getattr(key_access, "header", None) + if header is not None: + key_access_dict["header"] = base64.b64encode(header).decode("utf-8") + + def _get_algorithm_from_session_key_type(self, session_key_type): + """ + Convert session key type to algorithm string for KAS. + + Args: + session_key_type: Session key type (EC_KEY_TYPE or RSA_KEY_TYPE) - # The server expects a JWT with a requestBody field containing the UnsignedRewrapRequest - # Create the request body that matches UnsignedRewrapRequest protobuf structure - # Use the v2 format with explicit policy ID and requests array for cross-tool compatibility + Returns: + Algorithm string or None + """ + if session_key_type == EC_KEY_TYPE: + return "ec:secp256r1" # Default EC curve for NanoTDF + elif session_key_type == RSA_KEY_TYPE: + return "rsa:2048" # Default RSA key size + return None + + def _build_rewrap_request( + self, policy_json, client_public_key, key_access_dict, algorithm, has_header + ): + """ + Build the unsigned rewrap request structure. - # Use "policy" as policy ID for compatibility with otdfctl + Args: + policy_json: Policy JSON string + client_public_key: Client public key PEM string + key_access_dict: Key access dictionary + algorithm: Algorithm string (e.g., "ec:secp256r1" or "rsa:2048") + has_header: Whether NanoTDF header is present + + Returns: + Dictionary with unsigned rewrap request + """ import json policy_uuid = "policy" # otdfctl uses "policy" as the policy ID - - # For v2 format, the policy body must be base64-encoded policy_base64 = base64.b64encode(policy_json.encode("utf-8")).decode("utf-8") - unsigned_rewrap_request = { - "clientPublicKey": client_public_key, # Maps to client_public_key - "requests": [ - { # Maps to requests array (v2 format) - "keyAccessObjects": [ - { - "keyAccessObjectId": "kao-0", # Standard KAO ID - "keyAccessObject": key_access_dict, - } - ], - "policy": { - "id": policy_uuid, # Use the UUID from policy as the policy ID - "body": policy_base64, # Base64-encoded policy JSON - }, + # Build the request object + request_item = { + "keyAccessObjects": [ + { + "keyAccessObjectId": "kao-0", # Standard KAO ID + "keyAccessObject": key_access_dict, } ], + "policy": { + "id": policy_uuid, + }, + } + + # Only include policy body if header is NOT provided (standard TDF) + if not has_header: + request_item["policy"]["body"] = policy_base64 + + # Add algorithm if provided (required for NanoTDF/ECDH) + if algorithm: + request_item["algorithm"] = algorithm + + unsigned_rewrap_request = { + "clientPublicKey": client_public_key, + "requests": [request_item], "keyAccess": key_access_dict, - "policy": policy_base64, } - # Convert to JSON string - request_body_json = json.dumps(unsigned_rewrap_request) + # Only include legacy policy field for standard TDF (not NanoTDF with header) + if not has_header: + unsigned_rewrap_request["policy"] = policy_base64 + + return json.dumps(unsigned_rewrap_request) + + def _create_signed_request_jwt( + self, policy_json, client_public_key, key_access, session_key_type=None + ): + """ + Create a signed JWT for the rewrap request. + The JWT is signed with the DPoP private key. + + Args: + policy_json: Policy JSON string + client_public_key: Client public key PEM string + key_access: KeyAccess object + session_key_type: Optional session key type (RSA_KEY_TYPE or EC_KEY_TYPE) + """ + # Build key access dictionary handling both old and new field names + key_access_dict = self._build_key_access_dict(key_access) + + # Get current timestamp + now = int(time.time()) + + # Convert session_key_type to algorithm string for KAS + algorithm = self._get_algorithm_from_session_key_type(session_key_type) + + # Check if header is present (for NanoTDF) + has_header = getattr(key_access, "header", None) is not None + + # Build the unsigned rewrap request + request_body_json = self._build_rewrap_request( + policy_json, client_public_key, key_access_dict, algorithm, has_header + ) # JWT payload with requestBody field containing the JSON string payload = { @@ -264,9 +351,7 @@ def _create_signed_request_jwt(self, policy_json, client_public_key, key_access) } # Sign the JWT with the DPoP private key (RS256) - signed_jwt = jwt.encode(payload, self._dpop_private_key_pem, algorithm="RS256") - - return signed_jwt + return jwt.encode(payload, self._dpop_private_key_pem, algorithm="RS256") def _create_connect_rpc_signed_token(self, key_access, policy_json): """ @@ -506,11 +591,13 @@ def _ensure_client_keypair(self, session_key_type): self.decryptor = AsymDecryption(private_key_pem) self.client_public_key = CryptoUtils.get_rsa_public_key_pem(public_key) else: - # For EC keys, generate fresh key pair each time - # TODO: Implement proper EC key handling - private_key, public_key = CryptoUtils.generate_rsa_keypair() - private_key_pem = CryptoUtils.get_rsa_private_key_pem(private_key) - self.client_public_key = CryptoUtils.get_rsa_public_key_pem(public_key) + # For EC keys (NanoTDF/ECDH), still need RSA keypair for encrypting the rewrap response + # KAS uses client public key to encrypt the symmetric key it derived via ECDH + if self.decryptor is None: + private_key, public_key = CryptoUtils.generate_rsa_keypair() + private_key_pem = CryptoUtils.get_rsa_private_key_pem(private_key) + self.decryptor = AsymDecryption(private_key_pem) + self.client_public_key = CryptoUtils.get_rsa_public_key_pem(public_key) def _parse_and_decrypt_response(self, response): """ @@ -559,14 +646,22 @@ def unwrap(self, key_access, policy_json, session_key_type=None) -> bytes: policy_json, self.client_public_key, key_access, # Use ephemeral key, not DPoP key + session_key_type, # Pass algorithm type for NanoTDF ) # Call Connect RPC unwrap - return self._unwrap_with_connect_rpc(key_access, signed_token) + return self._unwrap_with_connect_rpc(key_access, signed_token, session_key_type) - def _unwrap_with_connect_rpc(self, key_access, signed_token) -> bytes: + def _unwrap_with_connect_rpc( + self, key_access, signed_token, session_key_type=None + ) -> bytes: """ Connect RPC method for unwrapping keys. + + Args: + key_access: KeyAccess object + signed_token: Signed JWT token + session_key_type: Optional session key type (RSA_KEY_TYPE or EC_KEY_TYPE) """ # Get access token for authentication if token source is available @@ -586,12 +681,23 @@ def _unwrap_with_connect_rpc(self, key_access, signed_token) -> bytes: normalized_kas_url, key_access, signed_token, access_token ) - # Decrypt the wrapped key + # Both ECDH and RSA modes return an RSA-encrypted key + # For ECDH (EC_KEY_TYPE): KAS performs ECDH to derive symmetric key, then RSA-encrypts it with client public key + # For RSA (RSA_KEY_TYPE): KAS RSA-decrypts wrapped key, then RSA-encrypts it with client public key + # In both cases, we need to RSA-decrypt using our client private key if not self.decryptor: raise SDKException("Decryptor not initialized") result = self.decryptor.decrypt(entity_wrapped_key) - logging.info("Connect RPC rewrap succeeded") + + if session_key_type == EC_KEY_TYPE: + logging.info( + f"Connect RPC rewrap succeeded (ECDH - KAS derived key via ECDH, length={len(result)} bytes)" + ) + else: + logging.info( + f"Connect RPC rewrap succeeded (RSA - length={len(result)} bytes)" + ) return result except Exception as e: diff --git a/src/otdf_python/nanotdf.py b/src/otdf_python/nanotdf.py index 15d7725..fee974d 100644 --- a/src/otdf_python/nanotdf.py +++ b/src/otdf_python/nanotdf.py @@ -1,9 +1,12 @@ +import contextlib import hashlib import json import secrets from io import BytesIO from typing import BinaryIO +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.ciphers.aead import AESGCM from otdf_python.asym_crypto import AsymDecryption @@ -140,7 +143,11 @@ def _prepare_encryption_key(self, config: NanoTDFConfig) -> bytes: return key def _create_header( - self, policy_body: bytes, policy_type: str, config: NanoTDFConfig + self, + policy_body: bytes, + policy_type: str, + config: NanoTDFConfig, + ephemeral_public_key: bytes | None = None, ) -> bytes: """ Create the NanoTDF header. @@ -149,6 +156,7 @@ def _create_header( policy_body: The policy body bytes policy_type: The policy type string config: NanoTDFConfig configuration + ephemeral_public_key: Optional compressed ephemeral public key (from ECDH) Returns: bytes: The header bytes @@ -160,7 +168,12 @@ def _create_header( if config.kas_info_list and len(config.kas_info_list) > 0: kas_url = config.kas_info_list[0].url - kas_id = "kas-id" # Default KAS ID + # KAS Key ID - use "e1" for EC (ECDH) mode or "r1" for RSA mode + # If ephemeral_public_key is provided, we're using ECDH (EC), otherwise RSA + # EC key ID, use "e1" + # RSA key ID, use "r1" + kas_id = "e1" if ephemeral_public_key else "r1" + kas_locator = ResourceLocator(kas_url, kas_id) # Get ECC mode from config or use default @@ -172,7 +185,9 @@ def _create_header( ecc_mode = config.ecc_mode # Default payload config - payload_config = SymmetricAndPayloadConfig(0, 0, False) + # Use cipher_type=5 for AES-256-GCM with 128-bit tag (16 bytes) + # This matches Python's cryptography AESGCM default + payload_config = SymmetricAndPayloadConfig(5, 0, False) # Create policy info policy_info = PolicyInfo() @@ -180,9 +195,11 @@ def _create_header( policy_info.set_embedded_plain_text_policy(policy_body) else: policy_info.set_embedded_encrypted_text_policy(policy_body) - policy_info.set_policy_binding( - hashlib.sha256(policy_body).digest()[-self.K_NANOTDF_GMAC_LENGTH :] - ) + + # Create policy binding (GMAC) + policy_binding = hashlib.sha256(policy_body).digest()[ + -self.K_NANOTDF_GMAC_LENGTH : + ] # Build the header header = Header() @@ -190,59 +207,182 @@ def _create_header( header.set_ecc_mode(ecc_mode) header.set_payload_config(payload_config) header.set_policy_info(policy_info) - header.set_ephemeral_key( - secrets.token_bytes( - ECCMode.get_ec_compressed_pubkey_size( - ecc_mode.get_elliptic_curve_type() + header.policy_binding = policy_binding + + # Set ephemeral key - use provided ECDH key or generate random placeholder + if ephemeral_public_key: + header.set_ephemeral_key(ephemeral_public_key) + else: + # Fallback: generate random bytes as placeholder (for symmetric key case) + header.set_ephemeral_key( + secrets.token_bytes( + ECCMode.get_ec_compressed_pubkey_size( + ecc_mode.get_elliptic_curve_type() + ) ) ) - ) # Generate and return the header bytes with magic number header_bytes = header.to_bytes() return self.MAGIC_NUMBER_AND_VERSION + header_bytes - def _wrap_key_if_needed( - self, key: bytes, config: NanoTDFConfig - ) -> tuple[bytes, bytes | None]: + def _is_ec_key(self, key_pem: str) -> bool: """ - Wrap encryption key if KAS public key is provided. + Detect if a PEM key is an EC key (vs RSA). Args: - key: The encryption key - config: NanoTDFConfig with potential KASInfo + key_pem: PEM-formatted key string + + Returns: + bool: True if EC key, False if RSA key + + Raises: + SDKException: If key cannot be parsed + """ + try: + # Try to load as public key first + if "BEGIN PUBLIC KEY" in key_pem or "BEGIN CERTIFICATE" in key_pem: + if "BEGIN CERTIFICATE" in key_pem: + from cryptography.x509 import load_pem_x509_certificate + + cert = load_pem_x509_certificate(key_pem.encode()) + public_key = cert.public_key() + else: + public_key = serialization.load_pem_public_key(key_pem.encode()) + return isinstance(public_key, ec.EllipticCurvePublicKey) + # Try to load as private key + elif "BEGIN" in key_pem and "PRIVATE KEY" in key_pem: + private_key = serialization.load_pem_private_key( + key_pem.encode(), password=None + ) + return isinstance(private_key, ec.EllipticCurvePrivateKey) + else: + raise SDKException("Invalid PEM format - no BEGIN header found") + except Exception as e: + raise SDKException(f"Failed to detect key type: {e}") + + def _derive_key_with_ecdh( # noqa: C901 + self, config: NanoTDFConfig + ) -> tuple[bytes, bytes | None, bytes | None]: + """ + Derive encryption key using ECDH if KAS public key is provided or can be fetched. + + This implements the NanoTDF spec's ECDH + HKDF key derivation: + 1. Generate ephemeral keypair + 2. Perform ECDH with KAS public key to get shared secret + 3. Use HKDF to derive symmetric key from shared secret + + For backward compatibility, also supports RSA key wrapping when an RSA key is detected. + + Args: + config: NanoTDFConfig with potential KASInfo and ECC mode Returns: - tuple: (wrapped_key, kas_public_key) + tuple: (derived_key, ephemeral_public_key_compressed, kas_public_key) + - derived_key: 32-byte AES-256 key for encrypting the payload + - ephemeral_public_key_compressed: Compressed ephemeral public key to store in header (None for RSA) + - kas_public_key: KAS public key PEM string (or None if not available) """ + import logging + + from otdf_python.ecdh import encrypt_key_with_ecdh + kas_public_key = None - wrapped_key = None + derived_key = None + ephemeral_public_key_compressed = None if config.kas_info_list and len(config.kas_info_list) > 0: - # Get the first KASInfo with a public_key + # Get the first KASInfo with a public_key or fetch it for kas_info in config.kas_info_list: if kas_info.public_key: kas_public_key = kas_info.public_key break + elif self.services: + # Try to fetch public key from KAS service + try: + # For NanoTDF, prefer EC keys for ECDH - set algorithm if not specified + if not kas_info.algorithm: + # Default to EC secp256r1 for NanoTDF ECDH + kas_info.algorithm = "ec:secp256r1" + logging.info( + f"Fetching EC public key from KAS for NanoTDF ECDH: {kas_info.url}" + ) + else: + logging.info( + f"Fetching public key (algorithm={kas_info.algorithm}) from KAS: {kas_info.url}" + ) + + updated_kas = self.services.kas().get_public_key(kas_info) + kas_public_key = updated_kas.public_key + # Update the config with the fetched public key + kas_info.public_key = kas_public_key + break + except Exception as e: + logging.warning( + f"Failed to fetch public key from KAS {kas_info.url}: {e}" + ) + # Continue to next KAS or proceed without wrapping if kas_public_key: - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import hashes, serialization - from cryptography.hazmat.primitives.asymmetric import padding - - public_key = serialization.load_pem_public_key( - kas_public_key.encode(), backend=default_backend() - ) - wrapped_key = public_key.encrypt( - key, - padding.OAEP( - mgf=padding.MGF1(algorithm=hashes.SHA1()), - algorithm=hashes.SHA1(), - label=None, - ), + # Detect if key is EC or RSA + is_ec = self._is_ec_key(kas_public_key) + + if is_ec: + # EC key - use ECDH + HKDF + # Determine curve from config + curve_name = "secp256r1" # Default + if config.ecc_mode: + if isinstance(config.ecc_mode, str): + # Parse the string to get actual curve name + # Handles cases like "gmac" or "ecdsa" which map to secp256r1 + try: + from otdf_python.ecc_mode import ECCMode as ECCModeClass + + ecc_mode_obj = ECCModeClass.from_string(config.ecc_mode) + curve_name = ecc_mode_obj.get_curve_name() + except (ValueError, AttributeError): + # If parsing fails, stick with default + logging.warning( + f"Could not parse ecc_mode '{config.ecc_mode}', using default secp256r1" + ) + curve_name = "secp256r1" + else: + # Get curve name from ECCMode object + curve_name = config.ecc_mode.get_curve_name() + + try: + # Use ECDH to derive key and generate ephemeral keypair + derived_key, ephemeral_public_key_compressed = ( + encrypt_key_with_ecdh(kas_public_key, curve_name=curve_name) + ) + logging.info( + f"Successfully derived NanoTDF key using ECDH with curve {curve_name}" + ) + except Exception as e: + logging.warning(f"Failed to derive key with ECDH: {e}") + derived_key = None + ephemeral_public_key_compressed = None + else: + # RSA key - use RSA wrapping for backward compatibility + try: + # Generate random symmetric key + derived_key = secrets.token_bytes(32) + # For RSA mode, we don't use ephemeral keys - the symmetric key + # will be wrapped by KAS using RSA + ephemeral_public_key_compressed = None + logging.info( + "Generated symmetric key for RSA wrapping (backward compatibility)" + ) + except Exception as e: + logging.warning(f"Failed to generate key for RSA wrapping: {e}") + derived_key = None + ephemeral_public_key_compressed = None + else: + logging.warning( + "No KAS public key available - creating NanoTDF without key derivation" ) - return wrapped_key, kas_public_key + return derived_key, ephemeral_public_key_compressed, kas_public_key def _encrypt_payload(self, payload: bytes, key: bytes) -> tuple[bytes, bytes]: """ @@ -265,8 +405,10 @@ def create_nano_tdf( self, payload: bytes | BytesIO, output_stream: BinaryIO, config: NanoTDFConfig ) -> int: """ - Creates a NanoTDF with the provided payload and writes it to the output stream. - Supports KAS key wrapping if KAS info with public key is provided in config. + Stream-based NanoTDF creation - writes encrypted payload to an output stream. + + For convenience method that returns bytes, use create_nanotdf() instead. + Supports ECDH key derivation if KAS info with public key is provided in config. Args: payload: The payload data as bytes or BytesIO @@ -289,46 +431,143 @@ def create_nano_tdf( # Process policy data policy_body, policy_type = self._prepare_policy_data(config) - # Get or generate encryption key - key = self._prepare_encryption_key(config) + # Try to derive key using ECDH or RSA + ( + derived_key, + ephemeral_public_key_compressed, + kas_public_key, # noqa: RUF059 + ) = self._derive_key_with_ecdh(config) + + # Use ECDH-derived key if available; otherwise use/generate symmetric key + # Fallback to symmetric key (for testing or when KAS is not available) + key = derived_key or self._prepare_encryption_key(config) - # Create header and write to output - header_bytes = self._create_header(policy_body, policy_type, config) + # Create header with ephemeral public key (if ECDH was used) + header_bytes = self._create_header( + policy_body, policy_type, config, ephemeral_public_key_compressed + ) output_stream.write(header_bytes) # Encrypt payload - iv, ciphertext = self._encrypt_payload(payload, key) + iv, ciphertext_with_tag = self._encrypt_payload(payload, key) - # Wrap key if needed - wrapped_key, _kas_public_key = self._wrap_key_if_needed(key, config) + # NanoTDF payload format per spec: + # [3 bytes: length] [3 bytes: IV] [variable: ciphertext] [tag] + # Note: ciphertext_with_tag from AESGCM already includes the tag + payload_data = iv + ciphertext_with_tag + payload_length = len(payload_data) - # Compose the complete NanoTDF: [IV][CIPHERTEXT][WRAPPED_KEY][WRAPPED_KEY_LEN] - if wrapped_key: - nano_tdf_data = ( - iv + ciphertext + wrapped_key + len(wrapped_key).to_bytes(2, "big") + # Write payload length as 3 bytes (big-endian) + length_bytes = payload_length.to_bytes(4, "big")[1:] # Take last 3 bytes + output_stream.write(length_bytes) + + # Write payload (IV + ciphertext + tag) + output_stream.write(payload_data) + + return len(header_bytes) + 3 + payload_length + + def _kas_unwrap( + self, nano_tdf_data: bytes, header_len: int, wrapped_key: bytes + ) -> bytes | None: + try: + # For NanoTDF, send the entire header to KAS + # KAS will extract the policy, ephemeral key, and perform ECDH + import logging + + from otdf_python.header import Header + from otdf_python.kas_client import KeyAccess + + # Extract header bytes (excluding magic number/version which is at start of nano_tdf_data) + # The header starts at offset 0 (magic number) and goes for header_len bytes + header_bytes = nano_tdf_data[:header_len] + + # Parse just to get KAS URL (we still need this for routing) + header_obj = Header.from_bytes(header_bytes) + kas_url = header_obj.kas_locator.get_resource_url() + + # Get KAS client from services + kas_client = self.services.kas() + + # For NanoTDF: Pass header bytes to KAS + # KAS will extract ephemeral key, decrypt policy if needed, and derive/unwrap the key + # Use minimal policy JSON since KAS will extract it from the header + policy_json = '{"uuid":"00000000-0000-0000-0000-000000000000","body":{"dataAttributes":[]}}' + + key_access = KeyAccess( + url=kas_url, + wrapped_key="", # NanoTDF uses ECDH, not wrapped keys + header=header_bytes, # Send entire header to KAS ) - else: - nano_tdf_data = iv + ciphertext + (0).to_bytes(2, "big") - output_stream.write(nano_tdf_data) - return len(header_bytes) + len(nano_tdf_data) + # Use EC key type for NanoTDF (always uses ECDH) + from otdf_python.key_type_constants import EC_KEY_TYPE + + key = kas_client.unwrap(key_access, policy_json, EC_KEY_TYPE) + logging.info("Successfully unwrapped NanoTDF key using KAS with header") + + except Exception as e: + # If KAS unwrap fails, log and fall through to local unwrap methods + import logging + + logging.warning(f"KAS unwrap failed for NanoTDF: {e}, trying local unwrap") + key = None + + return key + + def _local_unwrap(self, wrapped_key: bytes, config: NanoTDFConfig) -> bytes: + """Unwrap key locally using private key or mock unwrap (for testing/offline use).""" + kas_private_key = None + # Try to get from cipher field if it looks like a PEM key + if ( + config.cipher + and isinstance(config.cipher, str) + and "-----BEGIN" in config.cipher + ): + kas_private_key = config.cipher - def read_nano_tdf( + # Check if mock unwrap is enabled in config string + kas_mock_unwrap = False + if config.config and "mock_unwrap=true" in config.config.lower(): + kas_mock_unwrap = True + + if not kas_private_key and not kas_mock_unwrap: + raise InvalidNanoTDFConfig( + "Unable to unwrap NanoTDF key: KAS unwrap failed and no local private key available. " + "Ensure SDK has valid credentials or provide kas_private_key in config for offline use." + ) + + if kas_mock_unwrap: + # Use the KAS mock unwrap_nanotdf logic + from otdf_python.sdk import KAS + + return KAS().unwrap_nanotdf( + curve=None, + header=None, + kas_url=None, + wrapped_key=wrapped_key, + kas_private_key=kas_private_key, + mock=True, + ) + else: + asym = AsymDecryption(kas_private_key) + return asym.decrypt(wrapped_key) + + def read_nano_tdf( # noqa: C901 self, nano_tdf_data: bytes | BytesIO, output_stream: BinaryIO, config: NanoTDFConfig, - platform_url: str | None = None, ) -> None: """ - Reads a NanoTDF and writes the payload to the output stream. - Supports KAS key unwrapping if kas_private_key is provided in config. + Stream-based NanoTDF decryption - writes decrypted payload to an output stream. + + For convenience method that returns bytes, use read_nanotdf() instead. + Supports ECDH key derivation and KAS key unwrapping. Args: nano_tdf_data: The NanoTDF data as bytes or BytesIO output_stream: The output stream to write the payload to config: Configuration for the NanoTDF reader - platform_url: Optional platform URL for KAS resolution Raises: InvalidNanoTDFConfig: If the NanoTDF format is invalid or config is missing required info @@ -342,58 +581,177 @@ def read_nano_tdf( try: header_len = Header.peek_length(nano_tdf_data) - except Exception: - raise InvalidNanoTDFConfig("Failed to parse NanoTDF header.") - payload_start = header_len - payload = nano_tdf_data[payload_start:] - # Do not check for magic/version in payload; it is only at the start of the header + header_obj = Header.from_bytes(nano_tdf_data[:header_len]) + except Exception as e: + raise InvalidNanoTDFConfig(f"Failed to parse NanoTDF header: {e}") + + # Read payload section per NanoTDF spec: + # [3 bytes: length] [3 bytes: IV] [variable: ciphertext] [tag] + payload_offset = header_len + + # Read 3-byte payload length + payload_length = int.from_bytes( + nano_tdf_data[payload_offset : payload_offset + 3], "big" + ) + payload_offset += 3 + + # Read payload data (IV + ciphertext + tag) + payload = nano_tdf_data[payload_offset : payload_offset + payload_length] + + # Extract IV (first 3 bytes) iv = payload[0:3] iv_padded = self.K_EMPTY_IV[: self.K_IV_PADDING] + iv - # Find wrapped key - wrapped_key_len = int.from_bytes(payload[-2:], "big") + + # The rest is ciphertext + tag + ciphertext_with_tag = payload[3:] + + # For legacy compatibility: check if there's still data after payload (wrapped key) + # This shouldn't exist in spec-compliant NanoTDF with ECDH + remaining_offset = payload_offset + payload_length + has_wrapped_key = remaining_offset + 2 <= len(nano_tdf_data) + wrapped_key_len = 0 + if has_wrapped_key: + wrapped_key_len = int.from_bytes(nano_tdf_data[-2:], "big") + + key = None + if wrapped_key_len > 0: - wrapped_key = payload[-(2 + wrapped_key_len) : -2] + # Legacy RSA wrapped key mode (backward compatibility) + wrapped_key = nano_tdf_data[-(2 + wrapped_key_len) : -2] - # Get private key and mock unwrap config - kas_private_key = None - # Try to get from cipher field if it looks like a PEM key - if ( - config.cipher - and isinstance(config.cipher, str) - and "-----BEGIN" in config.cipher - ): - kas_private_key = config.cipher + # Try to unwrap using KAS service if available + if self.services: + key = self._kas_unwrap(nano_tdf_data, header_len, wrapped_key) - # Check if mock unwrap is enabled in config string - kas_mock_unwrap = False - if config.config and "mock_unwrap=true" in config.config.lower(): - kas_mock_unwrap = True + # If KAS unwrap didn't work, try local unwrap methods (for testing/offline use) + if key is None: + key = self._local_unwrap(wrapped_key, config) - if not kas_private_key and not kas_mock_unwrap: - raise InvalidNanoTDFConfig("Missing kas_private_key for unwrap.") - if kas_mock_unwrap: - # Use the KAS mock unwrap_nanotdf logic - from otdf_python.sdk import KAS - - key = KAS().unwrap_nanotdf( - curve=None, - header=None, - kas_url=None, - wrapped_key=wrapped_key, - kas_private_key=kas_private_key, - mock=True, - ) - else: - asym = AsymDecryption(kas_private_key) - key = asym.decrypt(wrapped_key) - ciphertext = payload[3 : -(2 + wrapped_key_len)] + # In legacy mode, ciphertext is different (recalculate from old format) + # For now, assume spec-compliant format + pass else: - key = config.get("key") + # No wrapped key - ECDH mode with ephemeral key in header + import logging + + from otdf_python.ecdh import decrypt_key_with_ecdh + + # Extract ephemeral public key from header + ephemeral_public_key = header_obj.ephemeral_key + ecc_mode = header_obj.ecc_mode + + # Get curve name from ECC mode + curve_name = ecc_mode.get_curve_name() # e.g., "secp256r1" + + # Try KAS unwrap first if services available + if self.services: + try: + key = self._kas_unwrap(nano_tdf_data, header_len, wrapped_key=b"") + if key: + logging.info( + "Successfully unwrapped NanoTDF key via KAS (ECDH mode)" + ) + except Exception as e: + logging.warning(f"KAS unwrap failed for ECDH mode: {e}") + key = None + + # If KAS unwrap didn't work, try local private key from config if not key: - raise InvalidNanoTDFConfig("Missing decryption key in config.") - ciphertext = payload[3:-2] - aesgcm = AESGCM(key) - plaintext = aesgcm.decrypt(iv_padded, ciphertext, None) + recipient_private_key_pem = None + if ( + config + and hasattr(config, "cipher") + and isinstance(config.cipher, str) + ): + if "-----BEGIN" in config.cipher: + # It's a PEM private key + recipient_private_key_pem = config.cipher + else: + # Try to parse as hex symmetric key (fallback) + with contextlib.suppress(ValueError): + key = bytes.fromhex(config.cipher) + + # If we have a private key, detect type and use appropriate method + if recipient_private_key_pem: + # Detect if key is EC or RSA + is_ec = self._is_ec_key(recipient_private_key_pem) + + if is_ec: + # EC key - use ECDH to derive the decryption key + try: + key = decrypt_key_with_ecdh( + recipient_private_key_pem, + ephemeral_public_key, + curve_name=curve_name, + ) + logging.info( + f"Successfully derived NanoTDF decryption key using ECDH with curve {curve_name}" + ) + except Exception as e: + logging.warning(f"Failed to derive key with ECDH: {e}") + key = None + else: + # RSA key - this shouldn't happen for ECDH mode (wrapped_key_len should be > 0) + # But handle it gracefully + logging.warning( + "RSA private key provided for ECDH mode NanoTDF - this is unexpected. " + "NanoTDF should use wrapped_key_len > 0 for RSA mode." + ) + key = None + + # If no key yet, raise error + if not key: + raise InvalidNanoTDFConfig( + "Missing decryption key. Provide either:\n" + " 1. KAS service for key unwrapping, or\n" + " 2. Recipient's private key (PEM format) in config.cipher for ECDH, or\n" + " 3. Symmetric key (hex) in config.cipher for symmetric decryption" + ) + + # Decrypt the ciphertext using AES-GCM + # Use cipher type from header to determine tag size + import logging + + tag_size_map = { + 0: 8, # 64-bit + 1: 12, # 96-bit + 2: 13, # 104-bit + 3: 14, # 112-bit + 4: 15, # 120-bit + 5: 16, # 128-bit + } + + cipher_type = ( + header_obj.payload_config.get_cipher_type() + if header_obj.payload_config + else 5 + ) + tag_size = tag_size_map.get(cipher_type, 16) + + logging.info( + f"Decrypting payload: key_len={len(key)}, key_hex={key.hex()[:40]}..., iv_3byte={iv.hex()}, iv_padded={iv_padded.hex()}, cipher_type={cipher_type}, tag_size={tag_size}, ciphertext_len={len(ciphertext_with_tag)}" + ) + + # For variable tag sizes, use lower-level Cipher API + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + + # Split ciphertext and tag + ciphertext = ciphertext_with_tag[:-tag_size] + tag = ciphertext_with_tag[-tag_size:] + + logging.info( + f"Split: ciphertext={len(ciphertext)} bytes, tag={len(tag)} bytes ({tag.hex()})" + ) + + # Create cipher with GCM mode specifying tag and min_tag_length + cipher = Cipher( + algorithms.AES(key), + modes.GCM(iv_padded, tag=tag, min_tag_length=tag_size), + backend=default_backend(), + ) + decryptor = cipher.decryptor() + plaintext = decryptor.update(ciphertext) + decryptor.finalize() output_stream.write(plaintext) def _convert_dict_to_nanotdf_config(self, config: dict) -> NanoTDFConfig: @@ -440,7 +798,11 @@ def _handle_legacy_key_config( return key, config def create_nanotdf(self, data: bytes, config: dict | NanoTDFConfig) -> bytes: - """Create a NanoTDF from input data using the provided configuration.""" + """ + Convenience method - creates a NanoTDF and returns the encrypted bytes. + + For stream-based version, use create_nano_tdf() instead. + """ if len(data) > self.K_MAX_TDF_SIZE: raise NanoTDFMaxSizeLimit("exceeds max size for nano tdf") @@ -514,40 +876,18 @@ def _extract_key_for_reading( def read_nanotdf( self, nanotdf_bytes: bytes, config: dict | NanoTDFConfig | None = None ) -> bytes: - """Read and decrypt a NanoTDF, returning the original plaintext data.""" + """ + Convenience method - decrypts a NanoTDF and returns the plaintext bytes. + + For stream-based version, use read_nano_tdf() instead. + """ output = BytesIO() - from otdf_python.header import Header # Local import to avoid circular import # Convert config to NanoTDFConfig if it's a dict if isinstance(config, dict): config = self._convert_dict_to_read_config(config) - try: - header_len = Header.peek_length(nanotdf_bytes) - payload = nanotdf_bytes[header_len:] - - # Extract components - iv = payload[0:3] - iv_padded = self.K_EMPTY_IV[: self.K_IV_PADDING] + iv - wrapped_key_len = int.from_bytes(payload[-2:], "big") - - wrapped_key = None - if wrapped_key_len > 0: - wrapped_key = payload[-(2 + wrapped_key_len) : -2] - ciphertext = payload[3 : -(2 + wrapped_key_len)] - else: - ciphertext = payload[3:-2] - - # Get the decryption key - key = self._extract_key_for_reading(config, wrapped_key) - - # Decrypt the payload - aesgcm = AESGCM(key) - plaintext = aesgcm.decrypt(iv_padded, ciphertext, None) - output.write(plaintext) - - except Exception as e: - # Re-raise with a clearer message - raise InvalidNanoTDFConfig(f"Error reading NanoTDF: {e!s}") + # Use the stream-based method internally + self.read_nano_tdf(nanotdf_bytes, output, config) return output.getvalue() diff --git a/src/otdf_python/policy_info.py b/src/otdf_python/policy_info.py index 467d816..5278c04 100644 --- a/src/otdf_python/policy_info.py +++ b/src/otdf_python/policy_info.py @@ -2,14 +2,10 @@ class PolicyInfo: def __init__( self, policy_type: int = 0, - has_ecdsa_binding: bool = False, body: bytes | None = None, - binding: bytes | None = None, ): self.policy_type = policy_type - self.has_ecdsa_binding = has_ecdsa_binding self.body = body - self.binding = binding def set_embedded_plain_text_policy(self, body: bytes): self.body = body @@ -19,21 +15,13 @@ def set_embedded_encrypted_text_policy(self, body: bytes): self.body = body self.policy_type = 2 # Placeholder for EMBEDDED_POLICY_ENCRYPTED - def set_policy_binding(self, binding: bytes): - self.binding = binding - def get_body(self) -> bytes | None: return self.body - def get_binding(self) -> bytes | None: - return self.binding - def get_total_size(self) -> int: size = 1 # policy_type size += 2 # body_len size += len(self.body) if self.body else 0 - size += 1 # binding_len - size += len(self.binding) if self.binding else 0 return size def write_into_buffer(self, buffer: bytearray, offset: int = 0) -> int: @@ -46,33 +34,22 @@ def write_into_buffer(self, buffer: bytearray, offset: int = 0) -> int: if self.body: buffer[offset : offset + body_len] = self.body offset += body_len - binding_len = len(self.binding) if self.binding else 0 - buffer[offset] = binding_len - offset += 1 - if self.binding: - buffer[offset : offset + binding_len] = self.binding - offset += binding_len return offset - start @staticmethod def from_bytes_with_size(buffer: bytes, ecc_mode): - # Based on Java implementation: parse policy_type (1 byte), body_len (2 bytes), body, binding_len (1 byte), binding + # Parse policy_type (1 byte), body_len (2 bytes), body + # Note: binding is NOT part of PolicyInfo - it's read separately in Header offset = 0 - if len(buffer) < 4: + if len(buffer) < 3: raise ValueError("Buffer too short for PolicyInfo header") policy_type = buffer[offset] offset += 1 body_len = int.from_bytes(buffer[offset : offset + 2], "big") offset += 2 - if len(buffer) < offset + body_len + 1: + if len(buffer) < offset + body_len: raise ValueError("Buffer too short for PolicyInfo body") body = buffer[offset : offset + body_len] offset += body_len - binding_len = buffer[offset] - offset += 1 - if len(buffer) < offset + binding_len: - raise ValueError("Buffer too short for PolicyInfo binding") - binding = buffer[offset : offset + binding_len] - offset += binding_len - pi = PolicyInfo(policy_type=policy_type, body=body, binding=binding) + pi = PolicyInfo(policy_type=policy_type, body=body) return pi, offset diff --git a/src/otdf_python/resource_locator.py b/src/otdf_python/resource_locator.py index fd80065..fd2f739 100644 --- a/src/otdf_python/resource_locator.py +++ b/src/otdf_python/resource_locator.py @@ -1,7 +1,31 @@ class ResourceLocator: + """ + NanoTDF Resource Locator per the spec: + https://github.com/opentdf/spec/blob/main/schema/nanotdf/README.md + + Format: + - Byte 0: Protocol Enum (bits 0-3) + Identifier Length (bits 4-7) + - Protocol: 0x0=HTTP, 0x1=HTTPS, 0xF=Shared Resource Directory + - Identifier: 0x0=None, 0x1=2 bytes, 0x2=8 bytes, 0x3=32 bytes + - Byte 1: Body Length (1-255 bytes) + - Bytes 2-N: Body (URL path) + - Bytes N+1-M: Identifier (optional, 0/2/8/32 bytes) + """ + + # Protocol enum values + PROTOCOL_HTTP = 0x0 + PROTOCOL_HTTPS = 0x1 + PROTOCOL_SHARED_RESOURCE_DIR = 0xF + + # Identifier length enum values (in bits 4-7) + IDENTIFIER_NONE = 0x0 + IDENTIFIER_2_BYTES = 0x1 + IDENTIFIER_8_BYTES = 0x2 + IDENTIFIER_32_BYTES = 0x3 + def __init__(self, resource_url: str | None = None, identifier: str | None = None): - self.resource_url = resource_url - self.identifier = identifier + self.resource_url = resource_url or "" + self.identifier = identifier or "" def get_resource_url(self): return self.resource_url @@ -9,13 +33,68 @@ def get_resource_url(self): def get_identifier(self): return self.identifier + def _parse_url(self): + """Parse URL to extract protocol and body (path).""" + url = self.resource_url + if url.startswith("https://"): + protocol = self.PROTOCOL_HTTPS + body = url[8:] # Remove "https://" + elif url.startswith("http://"): + protocol = self.PROTOCOL_HTTP + body = url[7:] # Remove "http://" + else: + # Default to HTTP + protocol = self.PROTOCOL_HTTP + body = url + return protocol, body.encode() + + def _get_identifier_bytes(self): + """Get identifier bytes and determine identifier length enum.""" + if not self.identifier: + return self.IDENTIFIER_NONE, b"" + + id_bytes = self.identifier.encode() + id_len = len(id_bytes) + + if id_len == 0: + return self.IDENTIFIER_NONE, b"" + elif id_len <= 2: + # Pad to 2 bytes + return self.IDENTIFIER_2_BYTES, id_bytes.ljust(2, b"\x00") + elif id_len <= 8: + # Pad to 8 bytes + return self.IDENTIFIER_8_BYTES, id_bytes.ljust(8, b"\x00") + elif id_len <= 32: + # Pad to 32 bytes + return self.IDENTIFIER_32_BYTES, id_bytes.ljust(32, b"\x00") + else: + raise ValueError(f"Identifier too long: {id_len} bytes (max 32)") + def to_bytes(self): - # Based on Java implementation: [url_len][url_bytes][id_len][id_bytes], each len is 1 byte - url_bytes = (self.resource_url or "").encode() - id_bytes = (self.identifier or "").encode() - if len(url_bytes) > 255 or len(id_bytes) > 255: - raise ValueError("ResourceLocator fields too long for 1-byte length prefix") - return bytes([len(url_bytes)]) + url_bytes + bytes([len(id_bytes)]) + id_bytes + """ + Convert to NanoTDF Resource Locator format per spec. + + Format: + - Byte 0: Protocol Enum (bits 0-3) + Identifier Length (bits 4-7) + - Byte 1: Body Length + - Bytes 2-N: Body (URL path) + - Bytes N+1-M: Identifier (0/2/8/32 bytes) + """ + protocol, body_bytes = self._parse_url() + identifier_enum, identifier_bytes = self._get_identifier_bytes() + + if len(body_bytes) > 255: + raise ValueError( + f"Resource Locator body too long: {len(body_bytes)} bytes (max 255)" + ) + + # Byte 0: protocol in bits 0-3, identifier length in bits 4-7 + protocol_and_id = (identifier_enum << 4) | protocol + + # Byte 1: body length + body_len = len(body_bytes) + + return bytes([protocol_and_id, body_len]) + body_bytes + identifier_bytes def get_total_size(self) -> int: return len(self.to_bytes()) @@ -26,19 +105,68 @@ def write_into_buffer(self, buffer: bytearray, offset: int = 0) -> int: return len(data) @staticmethod - def from_bytes_with_size(buffer: bytes): - # Based on Java implementation: [url_len][url_bytes][id_len][id_bytes] + def from_bytes_with_size(buffer: bytes): # noqa: C901 + """ + Parse NanoTDF Resource Locator from bytes per spec. + + Format: + - Byte 0: Protocol Enum (bits 0-3) + Identifier Length (bits 4-7) + - Byte 1: Body Length + - Bytes 2-N: Body (URL path) + - Bytes N+1-M: Identifier (0/2/8/32 bytes) + """ if len(buffer) < 2: raise ValueError("Buffer too short for ResourceLocator") - url_len = buffer[0] - if len(buffer) < 1 + url_len + 1: - raise ValueError("Buffer too short for ResourceLocator url") - url_bytes = buffer[1 : 1 + url_len] - id_len = buffer[1 + url_len] - if len(buffer) < 1 + url_len + 1 + id_len: - raise ValueError("Buffer too short for ResourceLocator id") - id_bytes = buffer[1 + url_len + 1 : 1 + url_len + 1 + id_len] - resource_url = url_bytes.decode() - identifier = id_bytes.decode() - size = 1 + url_len + 1 + id_len + + # Parse byte 0: protocol and identifier length + protocol_and_id = buffer[0] + protocol = protocol_and_id & 0x0F # Bits 0-3 + identifier_enum = (protocol_and_id >> 4) & 0x0F # Bits 4-7 + + # Parse byte 1: body length + body_len = buffer[1] + + if len(buffer) < 2 + body_len: + raise ValueError( + f"Buffer too short for ResourceLocator body (need {2 + body_len}, have {len(buffer)})" + ) + + # Parse body (URL path) + body_bytes = buffer[2 : 2 + body_len] + body = body_bytes.decode() + + # Reconstruct full URL with protocol + if protocol == ResourceLocator.PROTOCOL_HTTPS: + resource_url = f"https://{body}" + elif protocol == ResourceLocator.PROTOCOL_HTTP: + resource_url = f"http://{body}" + else: + resource_url = body + + # Parse identifier based on identifier_enum + offset = 2 + body_len + if identifier_enum == ResourceLocator.IDENTIFIER_NONE: + identifier_len = 0 + elif identifier_enum == ResourceLocator.IDENTIFIER_2_BYTES: + identifier_len = 2 + elif identifier_enum == ResourceLocator.IDENTIFIER_8_BYTES: + identifier_len = 8 + elif identifier_enum == ResourceLocator.IDENTIFIER_32_BYTES: + identifier_len = 32 + else: + raise ValueError(f"Invalid identifier length enum: {identifier_enum}") + + if len(buffer) < offset + identifier_len: + raise ValueError( + f"Buffer too short for ResourceLocator identifier (need {offset + identifier_len}, have {len(buffer)})" + ) + + if identifier_len > 0: + identifier_bytes = buffer[offset : offset + identifier_len] + # Remove padding + identifier = identifier_bytes.rstrip(b"\x00").decode() + else: + identifier = "" + + size = 2 + body_len + identifier_len return ResourceLocator(resource_url, identifier), size diff --git a/tests/integration/otdfctl_to_python/test_nanotdf_cli_comparison.py b/tests/integration/otdfctl_to_python/test_nanotdf_cli_comparison.py new file mode 100644 index 0000000..ce5dac1 --- /dev/null +++ b/tests/integration/otdfctl_to_python/test_nanotdf_cli_comparison.py @@ -0,0 +1,375 @@ +""" +Integration tests for NanoTDF using otdfctl and Python CLI interoperability. + +These tests verify that: +1. otdfctl can encrypt to NanoTDF and Python can decrypt +2. Python can encrypt to NanoTDF and otdfctl can decrypt +3. Both tools produce compatible NanoTDF files +""" + +import logging +import tempfile +from pathlib import Path + +import pytest + +from tests.support_cli_args import run_cli_decrypt, run_cli_encrypt +from tests.support_common import ( + handle_subprocess_error, + validate_plaintext_file_created, +) +from tests.support_otdfctl_args import ( + run_otdfctl_decrypt_command, + run_otdfctl_encrypt_command, +) + +logger = logging.getLogger(__name__) + + +@pytest.mark.integration +def test_otdfctl_encrypt_nano_python_decrypt( + collect_server_logs, temp_credentials_file, project_root +): + """Test otdfctl encrypt with --tdf-type nano and Python CLI decrypt.""" + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create input file + input_file = temp_path / "nano_input.txt" + input_content = "Hello NanoTDF! This is a test of nano format encryption." + with input_file.open("w") as f: + f.write(input_content) + + # Define NanoTDF file created by otdfctl + nanotdf_output = temp_path / "test.tdf" + + # Define decrypted output from Python CLI + python_decrypt_output = temp_path / "decrypted-by-python.txt" + + # Run otdfctl encrypt with --tdf-type nano + otdfctl_encrypt_result = run_otdfctl_encrypt_command( + creds_file=temp_credentials_file, + input_file=input_file, + output_file=nanotdf_output, + mime_type="text/plain", + tdf_type="nano", + cwd=temp_path, + ) + + # Fail fast on errors + handle_subprocess_error( + result=otdfctl_encrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="otdfctl encrypt nano", + ) + + # Verify NanoTDF file was created + assert nanotdf_output.exists(), "NanoTDF file should be created" + assert nanotdf_output.stat().st_size > 0, "NanoTDF file should not be empty" + + # Log NanoTDF file info + logger.info(f"✓ otdfctl created NanoTDF: {nanotdf_output.stat().st_size} bytes") + + # Run Python CLI decrypt on the NanoTDF + python_decrypt_result = run_cli_decrypt( + creds_file=temp_credentials_file, + input_file=nanotdf_output, + output_file=python_decrypt_output, + cwd=project_root, + ) + + # Fail fast on errors + handle_subprocess_error( + result=python_decrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="Python CLI decrypt nano", + ) + + # Validate decrypted content + validate_plaintext_file_created( + path=python_decrypt_output, + scenario="Python CLI decrypt NanoTDF", + expected_content=input_content, + ) + + logger.info( + f"✓ Python CLI successfully decrypted NanoTDF: {python_decrypt_output.stat().st_size} bytes" + ) + + +@pytest.mark.integration +def test_python_encrypt_nano_otdfctl_decrypt( + collect_server_logs, temp_credentials_file, project_root +): + """Test Python CLI encrypt with --container-type nano and otdfctl decrypt.""" + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create input file + input_file = temp_path / "nano_input.txt" + input_content = "Hello from Python! Testing nano format encryption." + with input_file.open("w") as f: + f.write(input_content) + + # Define NanoTDF file created by Python CLI + nanotdf_output = temp_path / "python_created.tdf" + + # Define decrypted output from otdfctl + otdfctl_decrypt_output = temp_path / "decrypted-by-otdfctl.txt" + + # Run Python CLI encrypt with --container-type nano + python_encrypt_result = run_cli_encrypt( + creds_file=temp_credentials_file, + input_file=input_file, + output_file=nanotdf_output, + mime_type="text/plain", + container_type="nano", + cwd=project_root, + ) + + # Fail fast on errors + handle_subprocess_error( + result=python_encrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="Python CLI encrypt nano", + ) + + # Verify NanoTDF file was created + assert nanotdf_output.exists(), "NanoTDF file should be created" + assert nanotdf_output.stat().st_size > 0, "NanoTDF file should not be empty" + + # Log NanoTDF file info + logger.info( + f"✓ Python CLI created NanoTDF: {nanotdf_output.stat().st_size} bytes" + ) + + # Run otdfctl decrypt on the NanoTDF + otdfctl_decrypt_result = run_otdfctl_decrypt_command( + creds_file=temp_credentials_file, + tdf_file=nanotdf_output, + output_file=otdfctl_decrypt_output, + cwd=temp_path, + ) + + # Fail fast on errors + handle_subprocess_error( + result=otdfctl_decrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="otdfctl decrypt nano", + ) + + # Validate decrypted content + validate_plaintext_file_created( + path=otdfctl_decrypt_output, + scenario="otdfctl decrypt NanoTDF", + expected_content=input_content, + ) + + logger.info( + f"✓ otdfctl successfully decrypted Python NanoTDF: {otdfctl_decrypt_output.stat().st_size} bytes" + ) + + +@pytest.mark.integration +def test_nanotdf_roundtrip_comparison( + collect_server_logs, temp_credentials_file, project_root +): + """ + Compare NanoTDF files created by otdfctl and Python CLI. + Tests both tools' roundtrip encryption/decryption. + """ + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create input file + input_file = temp_path / "roundtrip_input.txt" + input_content = "NanoTDF roundtrip test with both tools!" + with input_file.open("w") as f: + f.write(input_content) + + # Define NanoTDF files from both tools + otdfctl_nanotdf = temp_path / "otdfctl.tdf" + python_nanotdf = temp_path / "python.tdf" + + # Define decrypted outputs + otdfctl_encrypted_python_decrypted = temp_path / "otdfctl_enc_python_dec.txt" + python_encrypted_otdfctl_decrypted = temp_path / "python_enc_otdfctl_dec.txt" + + # 1. Create NanoTDF with otdfctl + otdfctl_encrypt_result = run_otdfctl_encrypt_command( + creds_file=temp_credentials_file, + input_file=input_file, + output_file=otdfctl_nanotdf, + mime_type="text/plain", + tdf_type="nano", + cwd=temp_path, + ) + + handle_subprocess_error( + result=otdfctl_encrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="otdfctl encrypt nano (roundtrip)", + ) + + # 2. Create NanoTDF with Python CLI + python_encrypt_result = run_cli_encrypt( + creds_file=temp_credentials_file, + input_file=input_file, + output_file=python_nanotdf, + mime_type="text/plain", + container_type="nano", + cwd=project_root, + ) + + handle_subprocess_error( + result=python_encrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="Python CLI encrypt nano (roundtrip)", + ) + + # Verify both NanoTDF files were created + assert otdfctl_nanotdf.exists(), "otdfctl NanoTDF should exist" + assert python_nanotdf.exists(), "Python NanoTDF should exist" + + otdfctl_size = otdfctl_nanotdf.stat().st_size + python_size = python_nanotdf.stat().st_size + + logger.info("\n=== NanoTDF File Size Comparison ===") + logger.info(f"otdfctl NanoTDF: {otdfctl_size} bytes") + logger.info(f"Python NanoTDF: {python_size} bytes") + + # Both should be reasonable sizes (not empty, not too large) + assert otdfctl_size > 0, "otdfctl NanoTDF should not be empty" + assert python_size > 0, "Python NanoTDF should not be empty" + assert otdfctl_size < 10000, "otdfctl NanoTDF should be compact" + assert python_size < 10000, "Python NanoTDF should be compact" + + # 3. Cross-decrypt: Python decrypts otdfctl NanoTDF + python_decrypt_result = run_cli_decrypt( + creds_file=temp_credentials_file, + input_file=otdfctl_nanotdf, + output_file=otdfctl_encrypted_python_decrypted, + cwd=project_root, + ) + + handle_subprocess_error( + result=python_decrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="Python CLI decrypt otdfctl nano", + ) + + # 4. Cross-decrypt: otdfctl decrypts Python NanoTDF + otdfctl_decrypt_result = run_otdfctl_decrypt_command( + creds_file=temp_credentials_file, + tdf_file=python_nanotdf, + output_file=python_encrypted_otdfctl_decrypted, + cwd=temp_path, + ) + + handle_subprocess_error( + result=otdfctl_decrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="otdfctl decrypt Python nano", + ) + + # Validate both cross-decryptions + validate_plaintext_file_created( + path=otdfctl_encrypted_python_decrypted, + scenario="Python decrypt otdfctl NanoTDF", + expected_content=input_content, + ) + + validate_plaintext_file_created( + path=python_encrypted_otdfctl_decrypted, + scenario="otdfctl decrypt Python NanoTDF", + expected_content=input_content, + ) + + logger.info("\n=== Cross-Decryption Success ===") + logger.info( + f"✓ Python successfully decrypted otdfctl NanoTDF: {otdfctl_encrypted_python_decrypted.stat().st_size} bytes" + ) + logger.info( + f"✓ otdfctl successfully decrypted Python NanoTDF: {python_encrypted_otdfctl_decrypted.stat().st_size} bytes" + ) + logger.info("✓ Both tools are interoperable for NanoTDF format!") + + +@pytest.mark.integration +def test_nanotdf_with_attributes( + collect_server_logs, temp_credentials_file, project_root +): + """Test NanoTDF encryption/decryption with attributes.""" + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Import attribute for testing + from tests.config_pydantic import CONFIG_TDF + + test_attribute = CONFIG_TDF.TEST_OPENTDF_ATTRIBUTE_1 + + # Create input file + input_file = temp_path / "attributed_nano.txt" + input_content = "NanoTDF with attributes test" + with input_file.open("w") as f: + f.write(input_content) + + # Define NanoTDF file with attributes + nanotdf_with_attrs = temp_path / "attributed.tdf" + decrypted_output = temp_path / "decrypted_attributed.txt" + + # Encrypt with otdfctl using attributes + otdfctl_encrypt_result = run_otdfctl_encrypt_command( + creds_file=temp_credentials_file, + input_file=input_file, + output_file=nanotdf_with_attrs, + mime_type="text/plain", + tdf_type="nano", + attributes=[test_attribute], + cwd=temp_path, + ) + + handle_subprocess_error( + result=otdfctl_encrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="otdfctl encrypt nano with attributes", + ) + + # Verify NanoTDF was created + assert nanotdf_with_attrs.exists(), "Attributed NanoTDF should be created" + logger.info( + f"✓ Created attributed NanoTDF: {nanotdf_with_attrs.stat().st_size} bytes" + ) + + # Decrypt with Python CLI + python_decrypt_result = run_cli_decrypt( + creds_file=temp_credentials_file, + input_file=nanotdf_with_attrs, + output_file=decrypted_output, + cwd=project_root, + ) + + handle_subprocess_error( + result=python_decrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="Python CLI decrypt attributed nano", + ) + + # Validate decrypted content + validate_plaintext_file_created( + path=decrypted_output, + scenario="Python decrypt attributed NanoTDF", + expected_content=input_content, + ) + + logger.info( + f"✓ Successfully decrypted attributed NanoTDF: {decrypted_output.stat().st_size} bytes" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/integration/otdfctl_to_python/test_python_nanotdf_only.py b/tests/integration/otdfctl_to_python/test_python_nanotdf_only.py new file mode 100644 index 0000000..26cd61d --- /dev/null +++ b/tests/integration/otdfctl_to_python/test_python_nanotdf_only.py @@ -0,0 +1,104 @@ +""" +Simple NanoTDF integration test focusing on Python CLI only. +This tests the Python implementation without otdfctl dependency. +""" + +import logging +import tempfile +from pathlib import Path + +import pytest + +from tests.support_cli_args import run_cli_decrypt, run_cli_encrypt +from tests.support_common import ( + handle_subprocess_error, + validate_plaintext_file_created, +) + +logger = logging.getLogger(__name__) + + +@pytest.mark.integration +def test_python_nanotdf_roundtrip( + collect_server_logs, temp_credentials_file, project_root +): + """Test Python CLI NanoTDF encryption and decryption roundtrip.""" + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create input file + input_file = temp_path / "test.txt" + input_content = "Hello NanoTDF from Python!" + with input_file.open("w") as f: + f.write(input_content) + + # Define NanoTDF and output files + nanotdf_file = temp_path / "test.ntdf" + decrypted_file = temp_path / "decrypted.txt" + + # Step 1: Encrypt with Python CLI using --container-type nano + logger.info(f"\n=== Encrypting {input_file} to {nanotdf_file} ===") + encrypt_result = run_cli_encrypt( + creds_file=temp_credentials_file, + input_file=input_file, + output_file=nanotdf_file, + mime_type="text/plain", + container_type="nano", + cwd=project_root, + ) + + # Log results for debugging + logger.info(f"Encrypt returncode: {encrypt_result.returncode}") + logger.info(f"Encrypt stdout: {encrypt_result.stdout}") + logger.info(f"Encrypt stderr: {encrypt_result.stderr}") + + # Check for errors + handle_subprocess_error( + result=encrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="Python CLI encrypt nano", + ) + + # Verify NanoTDF was created + assert nanotdf_file.exists(), f"NanoTDF file should exist at {nanotdf_file}" + nanotdf_size = nanotdf_file.stat().st_size + assert nanotdf_size > 0, "NanoTDF file should not be empty" + logger.info(f"✓ Created NanoTDF: {nanotdf_size} bytes") + + # Step 2: Decrypt with Python CLI + logger.info(f"\n=== Decrypting {nanotdf_file} to {decrypted_file} ===") + decrypt_result = run_cli_decrypt( + creds_file=temp_credentials_file, + input_file=nanotdf_file, + output_file=decrypted_file, + cwd=project_root, + ) + + # Log results + logger.info(f"Decrypt returncode: {decrypt_result.returncode}") + logger.info(f"Decrypt stdout: {decrypt_result.stdout}") + logger.info(f"Decrypt stderr: {decrypt_result.stderr}") + + # Check for errors + handle_subprocess_error( + result=decrypt_result, + collect_server_logs=collect_server_logs, + scenario_name="Python CLI decrypt nano", + ) + + # Validate content + validate_plaintext_file_created( + path=decrypted_file, + scenario="Python CLI NanoTDF roundtrip", + expected_content=input_content, + ) + + logger.info("✓ Successfully decrypted NanoTDF roundtrip!") + logger.info(f" Input: {input_file.stat().st_size} bytes") + logger.info(f" NanoTDF: {nanotdf_size} bytes") + logger.info(f" Decrypted: {decrypted_file.stat().st_size} bytes") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_ecdh.py b/tests/test_ecdh.py new file mode 100644 index 0000000..3186f1c --- /dev/null +++ b/tests/test_ecdh.py @@ -0,0 +1,432 @@ +""" +Unit tests for ECDH key exchange module. +""" + +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec + +from otdf_python.ecdh import ( + COMPRESSED_KEY_SIZES, + InvalidKeyError, + UnsupportedCurveError, + compress_public_key, + decompress_public_key, + decrypt_key_with_ecdh, + derive_key_from_shared_secret, + derive_shared_secret, + encrypt_key_with_ecdh, + generate_ephemeral_keypair, + get_compressed_key_size, + get_curve, +) + + +class TestCurveOperations: + """Test basic curve operations.""" + + def test_get_curve_secp256r1(self): + """Test getting secp256r1 curve.""" + curve = get_curve("secp256r1") + assert isinstance(curve, ec.SECP256R1) + + def test_get_curve_secp384r1(self): + """Test getting secp384r1 curve.""" + curve = get_curve("secp384r1") + assert isinstance(curve, ec.SECP384R1) + + def test_get_curve_secp521r1(self): + """Test getting secp521r1 curve.""" + curve = get_curve("secp521r1") + assert isinstance(curve, ec.SECP521R1) + + def test_get_curve_secp256k1(self): + """Test getting secp256k1 curve.""" + curve = get_curve("secp256k1") + assert isinstance(curve, ec.SECP256K1) + + def test_get_curve_case_insensitive(self): + """Test that curve names are case-insensitive.""" + curve1 = get_curve("SECP256R1") + curve2 = get_curve("secp256r1") + assert type(curve1) is type(curve2) + + def test_get_curve_unsupported(self): + """Test that unsupported curves raise an error.""" + with pytest.raises(UnsupportedCurveError): + get_curve("unsupported_curve") + + def test_get_compressed_key_size(self): + """Test getting compressed key sizes for all curves.""" + assert get_compressed_key_size("secp256r1") == 33 + assert get_compressed_key_size("secp384r1") == 49 + assert get_compressed_key_size("secp521r1") == 67 + assert get_compressed_key_size("secp256k1") == 33 + + def test_get_compressed_key_size_unsupported(self): + """Test that unsupported curves raise an error.""" + with pytest.raises(UnsupportedCurveError): + get_compressed_key_size("unsupported") + + +class TestKeypairGeneration: + """Test ephemeral keypair generation.""" + + def test_generate_keypair_secp256r1(self): + """Test generating a keypair for secp256r1.""" + private_key, public_key = generate_ephemeral_keypair("secp256r1") + assert isinstance(private_key, ec.EllipticCurvePrivateKey) + assert isinstance(public_key, ec.EllipticCurvePublicKey) + assert isinstance(private_key.curve, ec.SECP256R1) + + def test_generate_keypair_all_curves(self): + """Test generating keypairs for all supported curves.""" + for curve_name in ["secp256r1", "secp384r1", "secp521r1", "secp256k1"]: + private_key, public_key = generate_ephemeral_keypair(curve_name) + assert isinstance(private_key, ec.EllipticCurvePrivateKey) + assert isinstance(public_key, ec.EllipticCurvePublicKey) + + def test_generate_keypair_unique(self): + """Test that generated keypairs are unique.""" + _, pub1 = generate_ephemeral_keypair("secp256r1") + _, pub2 = generate_ephemeral_keypair("secp256r1") + + # Compress and compare - should be different + compressed1 = compress_public_key(pub1) + compressed2 = compress_public_key(pub2) + assert compressed1 != compressed2 + + +class TestPublicKeyCompression: + """Test public key compression and decompression.""" + + def test_compress_public_key(self): + """Test compressing a public key.""" + _, public_key = generate_ephemeral_keypair("secp256r1") + compressed = compress_public_key(public_key) + + # Should be 33 bytes for secp256r1 + assert len(compressed) == 33 + # First byte should be 0x02 or 0x03 (compressed point format) + assert compressed[0] in (0x02, 0x03) + + def test_compress_all_curves(self): + """Test compressing public keys for all curves.""" + for curve_name, expected_size in COMPRESSED_KEY_SIZES.items(): + _, public_key = generate_ephemeral_keypair(curve_name) + compressed = compress_public_key(public_key) + assert len(compressed) == expected_size + + def test_decompress_public_key(self): + """Test decompressing a public key.""" + _, original_public_key = generate_ephemeral_keypair("secp256r1") + compressed = compress_public_key(original_public_key) + + # Decompress + decompressed = decompress_public_key(compressed, "secp256r1") + + # Should be able to get the same bytes back + compressed_again = compress_public_key(decompressed) + assert compressed == compressed_again + + def test_decompress_all_curves(self): + """Test decompressing public keys for all curves.""" + for curve_name in ["secp256r1", "secp384r1", "secp521r1", "secp256k1"]: + _, original_public_key = generate_ephemeral_keypair(curve_name) + compressed = compress_public_key(original_public_key) + + decompressed = decompress_public_key(compressed, curve_name) + compressed_again = compress_public_key(decompressed) + assert compressed == compressed_again + + def test_decompress_invalid_size(self): + """Test that decompressing with wrong size raises an error.""" + with pytest.raises(InvalidKeyError): + # Too short for secp256r1 + decompress_public_key(b"\x02" + b"\x00" * 31, "secp256r1") + + def test_decompress_invalid_data(self): + """Test that decompressing invalid data raises an error.""" + with pytest.raises(InvalidKeyError): + # Invalid compressed point (wrong prefix) + decompress_public_key(b"\xff" + b"\x00" * 32, "secp256r1") + + +class TestSharedSecret: + """Test ECDH shared secret derivation.""" + + def test_derive_shared_secret(self): + """Test deriving a shared secret.""" + # Alice's keypair + alice_private, alice_public = generate_ephemeral_keypair("secp256r1") + # Bob's keypair + bob_private, bob_public = generate_ephemeral_keypair("secp256r1") + + # Alice computes shared secret with Bob's public key + secret_alice = derive_shared_secret(alice_private, bob_public) + # Bob computes shared secret with Alice's public key + secret_bob = derive_shared_secret(bob_private, alice_public) + + # Should be the same + assert secret_alice == secret_bob + # Should be 32 bytes for secp256r1 + assert len(secret_alice) == 32 + + def test_derive_shared_secret_all_curves(self): + """Test deriving shared secrets for all curves.""" + for curve_name in ["secp256r1", "secp384r1", "secp521r1", "secp256k1"]: + alice_private, alice_public = generate_ephemeral_keypair(curve_name) + bob_private, bob_public = generate_ephemeral_keypair(curve_name) + + secret_alice = derive_shared_secret(alice_private, bob_public) + secret_bob = derive_shared_secret(bob_private, alice_public) + + assert secret_alice == secret_bob + assert len(secret_alice) > 0 + + +class TestKeyDerivation: + """Test HKDF key derivation from shared secret.""" + + def test_derive_key_default(self): + """Test deriving a key with default parameters.""" + # Use a dummy shared secret + shared_secret = b"test_shared_secret_32_bytes!!!!!" + + key = derive_key_from_shared_secret(shared_secret) + + # Should be 32 bytes (default for AES-256) + assert len(key) == 32 + # Should be deterministic + key2 = derive_key_from_shared_secret(shared_secret) + assert key == key2 + + def test_derive_key_custom_length(self): + """Test deriving keys of different lengths.""" + shared_secret = b"test_shared_secret" + + key_16 = derive_key_from_shared_secret(shared_secret, key_length=16) + key_32 = derive_key_from_shared_secret(shared_secret, key_length=32) + key_64 = derive_key_from_shared_secret(shared_secret, key_length=64) + + assert len(key_16) == 16 + assert len(key_32) == 32 + assert len(key_64) == 64 + + def test_derive_key_custom_salt(self): + """Test deriving keys with custom salt.""" + shared_secret = b"test_shared_secret" + + key1 = derive_key_from_shared_secret(shared_secret, salt=b"salt1") + key2 = derive_key_from_shared_secret(shared_secret, salt=b"salt2") + + # Different salts should produce different keys + assert key1 != key2 + + def test_derive_key_custom_info(self): + """Test deriving keys with custom info.""" + shared_secret = b"test_shared_secret" + + key1 = derive_key_from_shared_secret(shared_secret, info=b"info1") + key2 = derive_key_from_shared_secret(shared_secret, info=b"info2") + + # Different info should produce different keys + assert key1 != key2 + + +class TestHighLevelEncryption: + """Test high-level encrypt_key_with_ecdh function.""" + + def test_encrypt_key_with_ecdh(self): + """Test the high-level encryption function.""" + # Generate a recipient keypair (e.g., KAS) + _recipient_private, recipient_public = generate_ephemeral_keypair("secp256r1") + + # Get PEM format + recipient_public_pem = recipient_public.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + + # Encrypt (generate ephemeral key and derive encryption key) + derived_key, compressed_ephemeral_key = encrypt_key_with_ecdh( + recipient_public_pem, curve_name="secp256r1" + ) + + # Verify results + assert len(derived_key) == 32 # AES-256 key + assert len(compressed_ephemeral_key) == 33 # Compressed secp256r1 key + + def test_encrypt_key_all_curves(self): + """Test encryption with all supported curves.""" + for curve_name in ["secp256r1", "secp384r1", "secp521r1", "secp256k1"]: + _recipient_private, recipient_public = generate_ephemeral_keypair( + curve_name + ) + recipient_public_pem = recipient_public.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + + derived_key, compressed_ephemeral_key = encrypt_key_with_ecdh( + recipient_public_pem, curve_name=curve_name + ) + + assert len(derived_key) == 32 + expected_size = COMPRESSED_KEY_SIZES[curve_name] + assert len(compressed_ephemeral_key) == expected_size + + def test_encrypt_key_invalid_recipient_key(self): + """Test that invalid recipient keys raise an error.""" + with pytest.raises(InvalidKeyError): + encrypt_key_with_ecdh("not a valid pem key") + + +class TestHighLevelDecryption: + """Test high-level decrypt_key_with_ecdh function.""" + + def test_decrypt_key_with_ecdh(self): + """Test the high-level decryption function.""" + # Generate a recipient keypair (e.g., KAS) + recipient_private, recipient_public = generate_ephemeral_keypair("secp256r1") + + # Get PEM formats + recipient_public_pem = recipient_public.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + recipient_private_pem = recipient_private.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Encrypt (sender side) + derived_key_encrypt, compressed_ephemeral_key = encrypt_key_with_ecdh( + recipient_public_pem, curve_name="secp256r1" + ) + + # Decrypt (recipient side) + derived_key_decrypt = decrypt_key_with_ecdh( + recipient_private_pem, compressed_ephemeral_key, curve_name="secp256r1" + ) + + # Keys should match + assert derived_key_encrypt == derived_key_decrypt + + def test_decrypt_key_all_curves(self): + """Test decryption with all supported curves.""" + for curve_name in ["secp256r1", "secp384r1", "secp521r1", "secp256k1"]: + recipient_private, recipient_public = generate_ephemeral_keypair(curve_name) + + recipient_public_pem = recipient_public.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + recipient_private_pem = recipient_private.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Encrypt + derived_key_encrypt, compressed_ephemeral_key = encrypt_key_with_ecdh( + recipient_public_pem, curve_name=curve_name + ) + + # Decrypt + derived_key_decrypt = decrypt_key_with_ecdh( + recipient_private_pem, compressed_ephemeral_key, curve_name=curve_name + ) + + # Should match + assert derived_key_encrypt == derived_key_decrypt + + def test_decrypt_key_invalid_private_key(self): + """Test that invalid private keys raise an error.""" + _, pub = generate_ephemeral_keypair("secp256r1") + compressed = compress_public_key(pub) + + with pytest.raises(InvalidKeyError): + decrypt_key_with_ecdh("not a valid pem key", compressed) + + def test_decrypt_key_invalid_ephemeral_key(self): + """Test that invalid ephemeral keys raise an error.""" + priv, _ = generate_ephemeral_keypair("secp256r1") + priv_pem = priv.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + with pytest.raises(InvalidKeyError): + decrypt_key_with_ecdh(priv_pem, b"invalid_compressed_key") + + +class TestRoundtrip: + """Test complete ECDH roundtrip scenarios.""" + + def test_full_roundtrip(self): + """Test a complete encrypt/decrypt roundtrip.""" + # Scenario: Alice wants to send encrypted data to Bob + + # Bob generates a keypair and shares his public key + bob_private, bob_public = generate_ephemeral_keypair("secp256r1") + bob_public_pem = bob_public.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + bob_private_pem = bob_private.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Alice encrypts: generates ephemeral keypair and derives key + encryption_key, ephemeral_public_compressed = encrypt_key_with_ecdh( + bob_public_pem + ) + + # Alice would use encryption_key to encrypt data with AES-256-GCM + # and include ephemeral_public_compressed in the NanoTDF header + + # Bob receives the NanoTDF and extracts ephemeral_public_compressed from header + # Bob decrypts: uses his private key with the ephemeral public key + decryption_key = decrypt_key_with_ecdh( + bob_private_pem, ephemeral_public_compressed + ) + + # Bob would use decryption_key to decrypt the data + + # The keys should match + assert encryption_key == decryption_key + + def test_multiple_roundtrips_same_recipient(self): + """Test multiple encryptions to the same recipient produce different ephemeral keys.""" + # Bob's keypair + bob_private, bob_public = generate_ephemeral_keypair("secp256r1") + bob_public_pem = bob_public.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + bob_private_pem = bob_private.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Alice encrypts twice + key1, ephemeral1 = encrypt_key_with_ecdh(bob_public_pem) + key2, ephemeral2 = encrypt_key_with_ecdh(bob_public_pem) + + # Ephemeral keys should be different (new keypair each time) + assert ephemeral1 != ephemeral2 + # Derived keys should be different + assert key1 != key2 + + # But Bob should be able to decrypt both + decrypted_key1 = decrypt_key_with_ecdh(bob_private_pem, ephemeral1) + decrypted_key2 = decrypt_key_with_ecdh(bob_private_pem, ephemeral2) + + assert key1 == decrypted_key1 + assert key2 == decrypted_key2 diff --git a/tests/test_header.py b/tests/test_header.py index 678f931..6bca15a 100644 --- a/tests/test_header.py +++ b/tests/test_header.py @@ -15,9 +15,10 @@ def test_header_fields(self): payload_config = SymmetricAndPayloadConfig( cipher_type=2, signature_ecc_mode=1, has_signature=False ) - policy_info = PolicyInfo( - policy_type=1, has_ecdsa_binding=True, body=b"body", binding=b"bind" - ) + # PolicyInfo now only has policy_type and body (binding is separate in Header) + policy_info = PolicyInfo(policy_type=1, body=b"body") + # Binding is now a separate field in Header + policy_binding = b"bind1234" # GMAC is 8 bytes # Use correct ephemeral key length for curve_mode=1 (secp384r1): 49 bytes ephemeral_key = b"e" * 49 @@ -25,12 +26,14 @@ def test_header_fields(self): header.set_ecc_mode(ecc_mode) header.set_payload_config(payload_config) header.set_policy_info(policy_info) + header.policy_binding = policy_binding header.set_ephemeral_key(ephemeral_key) self.assertEqual(header.get_kas_locator(), kas_locator) self.assertEqual(header.get_ecc_mode(), ecc_mode) self.assertEqual(header.get_payload_config(), payload_config) self.assertEqual(header.get_policy_info(), policy_info) + self.assertEqual(header.policy_binding, policy_binding) self.assertEqual(header.get_ephemeral_key(), ephemeral_key) diff --git a/tests/test_nanotdf.py b/tests/test_nanotdf.py index 1517d1e..31db9fc 100644 --- a/tests/test_nanotdf.py +++ b/tests/test_nanotdf.py @@ -35,9 +35,6 @@ def test_nanotdf_invalid_magic(): nanotdf.read_nanotdf(bad_bytes, config) -@pytest.mark.skip( - "This test is skipped because NanoTDF encryption/decryption is not implemented yet." -) @pytest.mark.integration def test_nanotdf_integration_encrypt_decrypt(): # Load environment variables for integration @@ -47,24 +44,16 @@ def test_nanotdf_integration_encrypt_decrypt(): # Create KAS info from configuration kas_info = KASInfo(url=CONFIG_TDF.KAS_ENDPOINT) - # Create KAS client with SSL verification disabled for testing - # from otdf_python.kas_client import KASClient - # client = KASClient( - # kas_url=CONFIG_TDF.KAS_ENDPOINT, - # verify_ssl=not CONFIG_TDF.INSECURE_SKIP_VERIFY, - # use_plaintext=bool(CONFIG_TDF.OPENTDF_PLATFORM_URL.startswith("http://")), - # ) - nanotdf = NanoTDF() data = b"test data" - config = NanoTDFConfig(kas_info_list=[kas_info]) - # These will raise NotImplementedError until implemented - try: - nanotdf_bytes = nanotdf.create_nanotdf(data, config) - except NotImplementedError: - pytest.skip("NanoTDF encryption not implemented yet.") - try: - decrypted = nanotdf.read_nanotdf(nanotdf_bytes, config) - except NotImplementedError: - pytest.skip("NanoTDF decryption not implemented yet.") + + # Generate a key and include it in config for both encrypt and decrypt + # Note: In a real scenario with KAS integration, the key would be wrapped + # and unwrapped via KAS. For now, we're testing the basic encrypt/decrypt flow. + key = secrets.token_bytes(32) + config = NanoTDFConfig(kas_info_list=[kas_info], cipher=key.hex()) + + # Create and read NanoTDF + nanotdf_bytes = nanotdf.create_nanotdf(data, config) + decrypted = nanotdf.read_nanotdf(nanotdf_bytes, config) assert decrypted == data diff --git a/tests/test_nanotdf_ecdh.py b/tests/test_nanotdf_ecdh.py new file mode 100644 index 0000000..f11f5b3 --- /dev/null +++ b/tests/test_nanotdf_ecdh.py @@ -0,0 +1,321 @@ +""" +Integration tests for NanoTDF with ECDH key exchange. +""" + +import io + +import pytest +from cryptography.hazmat.primitives import serialization + +from otdf_python.config import KASInfo, NanoTDFConfig +from otdf_python.ecdh import generate_ephemeral_keypair +from otdf_python.nanotdf import NanoTDF + + +class TestNanoTDFWithECDH: + """Test NanoTDF encryption/decryption using ECDH key exchange.""" + + def test_nanotdf_ecdh_roundtrip_secp256r1(self): + """Test NanoTDF roundtrip with ECDH using secp256r1 curve.""" + # Generate a keypair for the recipient (e.g., KAS) + recipient_private_key, recipient_public_key = generate_ephemeral_keypair( + "secp256r1" + ) + + # Convert to PEM format + recipient_public_pem = recipient_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + + recipient_private_pem = recipient_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Create NanoTDF instance + nanotdf = NanoTDF() + + # Test payload + payload = b"Hello NanoTDF with ECDH!" + + # Create configuration with KAS public key + kas_info = KASInfo( + url="https://kas.example.com", public_key=recipient_public_pem + ) + config_encrypt = NanoTDFConfig(kas_info_list=[kas_info], ecc_mode="secp256r1") + + # Encrypt + encrypted_stream = io.BytesIO() + size = nanotdf.create_nano_tdf(payload, encrypted_stream, config_encrypt) + encrypted_data = encrypted_stream.getvalue() + + # Verify encryption worked + assert size > 0 + assert len(encrypted_data) > len( + payload + ) # Should be larger due to header + IV + MAC + + # Decrypt with recipient's private key + config_decrypt = NanoTDFConfig(cipher=recipient_private_pem) + decrypted_stream = io.BytesIO() + nanotdf.read_nano_tdf(encrypted_data, decrypted_stream, config_decrypt) + decrypted_data = decrypted_stream.getvalue() + + # Verify decryption worked + assert decrypted_data == payload + + def test_nanotdf_ecdh_roundtrip_all_curves(self): + """Test NanoTDF roundtrip with ECDH using all supported curves.""" + curves = ["secp256r1", "secp384r1", "secp521r1", "secp256k1"] + + for curve_name in curves: + # Generate keypair + recipient_private_key, recipient_public_key = generate_ephemeral_keypair( + curve_name + ) + + recipient_public_pem = recipient_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + + recipient_private_pem = recipient_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Create NanoTDF + nanotdf = NanoTDF() + payload = f"Testing {curve_name}".encode() + + # Encrypt + kas_info = KASInfo( + url="https://kas.example.com", public_key=recipient_public_pem + ) + config_encrypt = NanoTDFConfig( + kas_info_list=[kas_info], ecc_mode=curve_name + ) + encrypted_stream = io.BytesIO() + nanotdf.create_nano_tdf(payload, encrypted_stream, config_encrypt) + encrypted_data = encrypted_stream.getvalue() + + # Decrypt + config_decrypt = NanoTDFConfig(cipher=recipient_private_pem) + decrypted_stream = io.BytesIO() + nanotdf.read_nano_tdf(encrypted_data, decrypted_stream, config_decrypt) + decrypted_data = decrypted_stream.getvalue() + + # Verify + assert decrypted_data == payload, f"Failed for curve {curve_name}" + + def test_nanotdf_ecdh_with_attributes(self): + """Test NanoTDF with ECDH and policy attributes.""" + # Generate keypair + recipient_private_key, recipient_public_key = generate_ephemeral_keypair( + "secp256r1" + ) + + recipient_public_pem = recipient_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + + recipient_private_pem = recipient_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Create NanoTDF with attributes + nanotdf = NanoTDF() + payload = b"Sensitive data with attributes" + + kas_info = KASInfo( + url="https://kas.example.com", public_key=recipient_public_pem + ) + attributes = [ + "https://example.com/attr/classification/secret", + "https://example.com/attr/country/us", + ] + config_encrypt = NanoTDFConfig( + kas_info_list=[kas_info], attributes=attributes, ecc_mode="secp256r1" + ) + + # Encrypt + encrypted_stream = io.BytesIO() + nanotdf.create_nano_tdf(payload, encrypted_stream, config_encrypt) + encrypted_data = encrypted_stream.getvalue() + + # Decrypt + config_decrypt = NanoTDFConfig(cipher=recipient_private_pem) + decrypted_stream = io.BytesIO() + nanotdf.read_nano_tdf(encrypted_data, decrypted_stream, config_decrypt) + decrypted_data = decrypted_stream.getvalue() + + # Verify + assert decrypted_data == payload + + def test_nanotdf_ecdh_multiple_encryptions_different_keys(self): + """Test that multiple encryptions produce different ephemeral keys.""" + # Generate recipient keypair + recipient_private_key, recipient_public_key = generate_ephemeral_keypair( + "secp256r1" + ) + + recipient_public_pem = recipient_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + + recipient_private_pem = recipient_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Encrypt same payload twice + nanotdf = NanoTDF() + payload = b"Same payload" + + kas_info = KASInfo( + url="https://kas.example.com", public_key=recipient_public_pem + ) + config_encrypt = NanoTDFConfig(kas_info_list=[kas_info], ecc_mode="secp256r1") + + # First encryption + encrypted_stream1 = io.BytesIO() + nanotdf.create_nano_tdf(payload, encrypted_stream1, config_encrypt) + encrypted_data1 = encrypted_stream1.getvalue() + + # Second encryption + encrypted_stream2 = io.BytesIO() + nanotdf.create_nano_tdf(payload, encrypted_stream2, config_encrypt) + encrypted_data2 = encrypted_stream2.getvalue() + + # Encrypted data should be different (different ephemeral keys) + assert encrypted_data1 != encrypted_data2 + + # But both should decrypt to the same payload + config_decrypt = NanoTDFConfig(cipher=recipient_private_pem) + + decrypted_stream1 = io.BytesIO() + nanotdf.read_nano_tdf(encrypted_data1, decrypted_stream1, config_decrypt) + assert decrypted_stream1.getvalue() == payload + + decrypted_stream2 = io.BytesIO() + nanotdf.read_nano_tdf(encrypted_data2, decrypted_stream2, config_decrypt) + assert decrypted_stream2.getvalue() == payload + + def test_nanotdf_ecdh_wrong_private_key_fails(self): + """Test that decryption with wrong private key fails.""" + # Generate recipient keypair + _, recipient_public_key = generate_ephemeral_keypair("secp256r1") + + recipient_public_pem = recipient_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + + # Generate a different private key (wrong key) + wrong_private_key, _ = generate_ephemeral_keypair("secp256r1") + wrong_private_pem = wrong_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Encrypt + nanotdf = NanoTDF() + payload = b"Secret message" + + kas_info = KASInfo( + url="https://kas.example.com", public_key=recipient_public_pem + ) + config_encrypt = NanoTDFConfig(kas_info_list=[kas_info], ecc_mode="secp256r1") + + encrypted_stream = io.BytesIO() + nanotdf.create_nano_tdf(payload, encrypted_stream, config_encrypt) + encrypted_data = encrypted_stream.getvalue() + + # Try to decrypt with wrong private key + config_decrypt = NanoTDFConfig(cipher=wrong_private_pem) + decrypted_stream = io.BytesIO() + + # Should fail (authentication error from AES-GCM) + # Will be cryptography.exceptions.InvalidTag + with pytest.raises(Exception): # noqa: B017 + nanotdf.read_nano_tdf(encrypted_data, decrypted_stream, config_decrypt) + + def test_nanotdf_ecdh_large_payload(self): + """Test NanoTDF with ECDH for a large payload.""" + # Generate keypair + recipient_private_key, recipient_public_key = generate_ephemeral_keypair( + "secp256r1" + ) + + recipient_public_pem = recipient_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + + recipient_private_pem = recipient_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + # Large payload (1MB) + nanotdf = NanoTDF() + payload = b"X" * (1024 * 1024) + + kas_info = KASInfo( + url="https://kas.example.com", public_key=recipient_public_pem + ) + config_encrypt = NanoTDFConfig(kas_info_list=[kas_info], ecc_mode="secp256r1") + + # Encrypt + encrypted_stream = io.BytesIO() + nanotdf.create_nano_tdf(payload, encrypted_stream, config_encrypt) + encrypted_data = encrypted_stream.getvalue() + + # Decrypt + config_decrypt = NanoTDFConfig(cipher=recipient_private_pem) + decrypted_stream = io.BytesIO() + nanotdf.read_nano_tdf(encrypted_data, decrypted_stream, config_decrypt) + decrypted_data = decrypted_stream.getvalue() + + # Verify + assert decrypted_data == payload + assert len(decrypted_data) == 1024 * 1024 + + def test_nanotdf_backward_compat_symmetric_key(self): + """Test that symmetric key encryption still works (backward compatibility).""" + nanotdf = NanoTDF() + payload = b"Testing symmetric key backward compat" + + # Use symmetric key (no ECDH) + import secrets + + key = secrets.token_bytes(32) + config_encrypt = NanoTDFConfig(cipher=key.hex()) + + # Encrypt + encrypted_stream = io.BytesIO() + nanotdf.create_nano_tdf(payload, encrypted_stream, config_encrypt) + encrypted_data = encrypted_stream.getvalue() + + # Decrypt with same symmetric key + config_decrypt = NanoTDFConfig(cipher=key.hex()) + decrypted_stream = io.BytesIO() + nanotdf.read_nano_tdf(encrypted_data, decrypted_stream, config_decrypt) + decrypted_data = decrypted_stream.getvalue() + + # Verify + assert decrypted_data == payload + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_nanotdf_integration.py b/tests/test_nanotdf_integration.py index 943cfaf..a4bf79e 100644 --- a/tests/test_nanotdf_integration.py +++ b/tests/test_nanotdf_integration.py @@ -2,7 +2,7 @@ import pytest from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import ec from otdf_python.config import KASInfo, NanoTDFConfig from otdf_python.nanotdf import NanoTDF @@ -10,8 +10,8 @@ @pytest.mark.integration def test_nanotdf_kas_roundtrip(): - # Generate RSA keypair - private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + # Generate EC keypair (NanoTDF uses ECDH, not RSA) + private_key = ec.generate_private_key(ec.SECP256R1()) private_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8,