diff --git a/otdf-python-proto/scripts/generate_connect_proto.py b/otdf-python-proto/scripts/generate_connect_proto.py index c5602a6..ff0d0d8 100644 --- a/otdf-python-proto/scripts/generate_connect_proto.py +++ b/otdf-python-proto/scripts/generate_connect_proto.py @@ -26,7 +26,7 @@ def check_dependencies() -> bool: try: subprocess.run(check_cmd, shell=True, capture_output=True, check=True) print(f"✓ {name} is available") - except (subprocess.CalledProcessError, FileNotFoundError): + except (subprocess.CalledProcessError, FileNotFoundError): # noqa: PERF203 missing.append(name) print(f"✗ {name} is missing") @@ -104,7 +104,7 @@ def copy_opentdf_proto_files(proto_gen_dir: Path) -> bool: copied_files += 1 - except Exception as e: + except Exception as e: # noqa: PERF203 print(f" Warning: Failed to copy {relative_path}: {e}") print(f"Found and copied {copied_files} proto files from repository") @@ -235,7 +235,7 @@ def _fix_ignore_if_default_value(proto_files_dir): print(f"Updated {proto_file.name} to use IGNORE_IF_ZERO_VALUE") - except Exception as e: + except Exception as e: # noqa: PERF203 print(f"Error updating {proto_file.name}: {e}") diff --git a/src/otdf_python/cli.py b/src/otdf_python/cli.py index f57ed83..516d766 100644 --- a/src/otdf_python/cli.py +++ b/src/otdf_python/cli.py @@ -3,7 +3,7 @@ OpenTDF Python CLI A command-line interface for encrypting and decrypting files using OpenTDF. -Provides encrypt, decrypt, and inspect commands similar to the TypeScript CLI. +Provides encrypt, decrypt, and inspect commands similar to the otdfctl CLI. """ import argparse @@ -146,7 +146,7 @@ def build_sdk(args) -> SDK: else: raise CLIError( "CRITICAL", - "Authentication required: provide --with-client-creds-file, --client-id and --client-secret, or --auth", + "Authentication required: provide --with-client-creds-file OR --client-id and --client-secret", ) if hasattr(args, "plaintext") and args.plaintext: @@ -479,11 +479,6 @@ def create_parser() -> argparse.ArgumentParser: ) auth_group.add_argument("--client-id", help="OAuth client ID") auth_group.add_argument("--client-secret", help="OAuth client secret") - # Keep --auth for backward compatibility - auth_group.add_argument( - "--auth", - help="Combined OAuth credentials (clientId:clientSecret) - deprecated, use --with-client-creds-file", - ) # Security options security_group = parser.add_argument_group("Security") diff --git a/src/otdf_python/crypto_utils.py b/src/otdf_python/crypto_utils.py index 7e2bcc2..2b80e79 100644 --- a/src/otdf_python/crypto_utils.py +++ b/src/otdf_python/crypto_utils.py @@ -75,8 +75,3 @@ def get_rsa_private_key_from_pem(pem_data: str) -> rsa.RSAPrivateKey: if not isinstance(private_key, rsa.RSAPrivateKey): raise ValueError("Not an RSA private key") return private_key - - -# Aliases for compatibility -rsa_private_key_from_pem = CryptoUtils.get_rsa_private_key_from_pem -rsa_public_key_from_pem = CryptoUtils.get_rsa_public_key_from_pem diff --git a/src/otdf_python/kas_client.py b/src/otdf_python/kas_client.py index 33f7360..9be78c9 100644 --- a/src/otdf_python/kas_client.py +++ b/src/otdf_python/kas_client.py @@ -141,10 +141,11 @@ def _handle_existing_scheme(self, parsed) -> str: def _create_signed_request_jwt(self, policy_json, client_public_key, key_access): # noqa: C901 """ Create a signed JWT for the rewrap request. - The JWT is signed with the DPoP private key, matching Java SDK implementation exactly. + The JWT is signed with the DPoP private key, mimicking the Java SDK implementation. """ # Convert the KeyAccess to a dict that matches Java SDK structure # Handle both ManifestKeyAccess (new camelCase and old snake_case) and simple KeyAccess (for tests) + # TODO: This can probably be simplified to only camelCase # Ensure wrappedKey is a base64-encoded string # Note: wrappedKey from manifest is already base64-encoded @@ -248,8 +249,8 @@ def _create_signed_request_jwt(self, policy_json, client_public_key, key_access) }, } ], - "keyAccess": key_access_dict, # Keep legacy field for backward compatibility - "policy": policy_base64, # Keep legacy field for backward compatibility + "keyAccess": key_access_dict, + "policy": policy_base64, } # Convert to JSON string @@ -532,7 +533,7 @@ def _parse_and_decrypt_response(self, response): encrypted_key = b64decode(entity_wrapped_key) return self.decryptor.decrypt(encrypted_key) - def unwrap(self, key_access, policy_json, session_key_type=None): + def unwrap(self, key_access, policy_json, session_key_type=None) -> bytes: """ Unwrap a key using Connect RPC. @@ -544,7 +545,7 @@ def unwrap(self, key_access, policy_json, session_key_type=None): Returns: Unwrapped key bytes """ - # Default to RSA if not specified (for backward compatibility) + # Default to RSA if not specified if session_key_type is None: session_key_type = RSA_KEY_TYPE @@ -561,9 +562,9 @@ def unwrap(self, key_access, policy_json, session_key_type=None): ) # Call Connect RPC unwrap - return self._unwrap_with_connect_rpc(key_access, signed_token, policy_json) + return self._unwrap_with_connect_rpc(key_access, signed_token) - def _unwrap_with_connect_rpc(self, key_access, signed_token, policy_json): + def _unwrap_with_connect_rpc(self, key_access, signed_token) -> bytes: """ Connect RPC method for unwrapping keys. """ diff --git a/src/otdf_python/manifest.py b/src/otdf_python/manifest.py index b86fc77..1d771da 100644 --- a/src/otdf_python/manifest.py +++ b/src/otdf_python/manifest.py @@ -151,7 +151,8 @@ def _root_sig(rs): return ManifestRootSignature(**rs) def _integrity(i): - # Handle both old snake_case and new camelCase formats for backward compatibility + # Handle both snake_case and camelCase fields + # TODO: This can probably be simplified to only camelCase return ManifestIntegrityInformation( rootSignature=_root_sig( i.get("rootSignature", i.get("root_signature")) @@ -174,7 +175,8 @@ def _key_access(k): return ManifestKeyAccess(**k) def _enc_info(e): - # Handle both old snake_case and new camelCase formats + # Handle both snake_case and camelCase fields + # TODO: This can probably be simplified to only camelCase return ManifestEncryptionInformation( type=e.get("type", e.get("key_access_type", "split")), policy=e["policy"], diff --git a/src/otdf_python/nanotdf.py b/src/otdf_python/nanotdf.py index bea366e..9e896fa 100644 --- a/src/otdf_python/nanotdf.py +++ b/src/otdf_python/nanotdf.py @@ -371,10 +371,10 @@ def read_nano_tdf( if not kas_private_key and not kas_mock_unwrap: raise InvalidNanoTDFConfig("Missing kas_private_key for unwrap.") if kas_mock_unwrap: - # Use the SDK.KAS mock unwrap_nanotdf logic - from otdf_python.sdk import SDK + # Use the KAS mock unwrap_nanotdf logic + from otdf_python.sdk import KAS - key = SDK.KAS().unwrap_nanotdf( + key = KAS().unwrap_nanotdf( curve=None, header=None, kas_url=None, @@ -395,7 +395,6 @@ def read_nano_tdf( plaintext = aesgcm.decrypt(iv_padded, ciphertext, None) output_stream.write(plaintext) - # Legacy method names for backward compatibility def _convert_dict_to_nanotdf_config(self, config: dict) -> NanoTDFConfig: """Convert a dictionary config to a NanoTDFConfig object.""" converted_config = NanoTDFConfig() diff --git a/src/otdf_python/sdk.py b/src/otdf_python/sdk.py index f835a22..543bb5e 100644 --- a/src/otdf_python/sdk.py +++ b/src/otdf_python/sdk.py @@ -42,6 +42,155 @@ class Interceptor: ... # Can be dict in Python implementation class TrustManager: ... +class KAS(AbstractContextManager): + """ + KAS (Key Access Service) interface to define methods related to key access and management. + """ + + def get_public_key(self, kas_info: Any) -> Any: + """ + Retrieves the public key from the KAS for RSA operations. + If the public key is cached, returns the cached value. + Otherwise, makes a request to the KAS. + + Args: + kas_info: KASInfo object containing the URL and algorithm + + Returns: + Updated KASInfo object with KID and PublicKey populated + + Raises: + SDKException: If there's an error retrieving the public key + """ + # Delegate to the underlying KAS client which handles authentication properly + return self._kas_client.get_public_key(kas_info) + + def __init__( + self, + platform_url=None, + token_source=None, + sdk_ssl_verify=True, + use_plaintext=False, + auth_headers: dict | None = None, + ): + """ + Initialize the KAS client + + Args: + platform_url: URL of the platform + token_source: Function that returns an authentication token + sdk_ssl_verify: Whether to verify SSL certificates + use_plaintext: Whether to use plaintext HTTP connections instead of HTTPS + auth_headers: Dictionary of authentication headers to include in requests + """ + from .kas_client import KASClient + + self._kas_client = KASClient( + kas_url=platform_url, + token_source=token_source, + verify_ssl=sdk_ssl_verify, + use_plaintext=use_plaintext, + ) + # Store the parameters for potential use + self._sdk_ssl_verify = sdk_ssl_verify + self._use_plaintext = use_plaintext + self._auth_headers = auth_headers + + def get_ec_public_key(self, kas_info: Any, curve: Any) -> Any: + """ + Retrieves the EC public key from the KAS. + + Args: + kas_info: KASInfo object containing the URL + curve: The EC curve to use + + Returns: + Updated KASInfo object with KID and PublicKey populated + """ + # Set algorithm to "ec:" + from copy import copy + + kas_info_copy = copy(kas_info) + kas_info_copy.algorithm = f"ec:{curve}" + return self.get_public_key_from_kas(kas_info_copy) + + def get_public_key_from_kas(self, kas_info: Any) -> Any: + """ + Retrieves the public key from the KAS for RSA operations. + Wrapper around the KAS client's get_public_key method. + + Args: + kas_info: KASInfo object containing the URL and algorithm + + Returns: + Updated KASInfo object with KID and PublicKey populated + """ + return self._kas_client.get_public_key(kas_info) + + def unwrap(self, key_access: Any, policy: str, session_key_type: Any) -> bytes: + """ + Unwraps the key using the KAS. + + Args: + key_access: KeyAccess object containing the wrapped key + policy: Policy JSON string + session_key_type: Type of session key (RSA, EC) + + Returns: + Unwrapped key as bytes + """ + print("Pause") + return self._kas_client.unwrap(key_access, policy, session_key_type) + + def unwrap_nanotdf( + self, + curve: Any, + header: str, + kas_url: str, + wrapped_key: bytes | None = None, + kas_private_key: str | None = None, + mock: bool = False, + ) -> bytes: + """ + Unwraps the NanoTDF key using the KAS. If mock=True, performs local unwrap using the private key (for tests). + + Args: + curve: EC curve used + header: NanoTDF header + kas_url: URL of the KAS + wrapped_key: Optional wrapped key bytes (for mock mode) + kas_private_key: Optional KAS private key (for mock mode) + mock: If True, unwrap locally using provided private key + + Returns: + Unwrapped key as bytes + """ + if mock and wrapped_key and kas_private_key: + from .asym_decryption import AsymDecryption + + asym = AsymDecryption(private_key_pem=kas_private_key) + return asym.decrypt(wrapped_key) + + # This would be implemented using nanotdf-specific logic + raise NotImplementedError("KAS unwrap_nanotdf not implemented.") + + def get_key_cache(self) -> Any: + """ + Returns the KAS key cache. + + Returns: + The KAS key cache object + """ + return self._kas_client.get_key_cache() + + def close(self): + """Closes resources associated with the KAS interface""" + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + class SDK(AbstractContextManager): def new_tdf_config( self, attributes: list[str] | None = None, **kwargs @@ -111,153 +260,6 @@ def new_tdf_config( Provides various services for TDF/NanoTDF operations and platform API calls. """ - class KAS(AbstractContextManager): - """ - KAS (Key Access Service) interface to define methods related to key access and management. - """ - - def get_public_key(self, kas_info: Any) -> Any: - """ - Retrieves the public key from the KAS for RSA operations. - If the public key is cached, returns the cached value. - Otherwise, makes a request to the KAS. - - Args: - kas_info: KASInfo object containing the URL and algorithm - - Returns: - Updated KASInfo object with KID and PublicKey populated - - Raises: - SDKException: If there's an error retrieving the public key - """ - # Delegate to the underlying KAS client which handles authentication properly - return self._kas_client.get_public_key(kas_info) - - def __init__( - self, - platform_url=None, - token_source=None, - sdk_ssl_verify=True, - use_plaintext=False, - auth_headers: dict | None = None, - ): - """ - Initialize the KAS client - - Args: - platform_url: URL of the platform - token_source: Function that returns an authentication token - sdk_ssl_verify: Whether to verify SSL certificates - use_plaintext: Whether to use plaintext HTTP connections instead of HTTPS - auth_headers: Dictionary of authentication headers to include in requests - """ - from .kas_client import KASClient - - self._kas_client = KASClient( - kas_url=platform_url, - token_source=token_source, - verify_ssl=sdk_ssl_verify, - use_plaintext=use_plaintext, - ) - # Store the parameters for potential use - self._sdk_ssl_verify = sdk_ssl_verify - self._use_plaintext = use_plaintext - self._auth_headers = auth_headers - - def get_ec_public_key(self, kas_info: Any, curve: Any) -> Any: - """ - Retrieves the EC public key from the KAS. - - Args: - kas_info: KASInfo object containing the URL - curve: The EC curve to use - - Returns: - Updated KASInfo object with KID and PublicKey populated - """ - # Set algorithm to "ec:" - from copy import copy - - kas_info_copy = copy(kas_info) - kas_info_copy.algorithm = f"ec:{curve}" - return self.get_public_key_from_kas(kas_info_copy) - - def get_public_key_from_kas(self, kas_info: Any) -> Any: - """ - Retrieves the public key from the KAS for RSA operations. - Wrapper around the KAS client's get_public_key method. - - Args: - kas_info: KASInfo object containing the URL and algorithm - - Returns: - Updated KASInfo object with KID and PublicKey populated - """ - return self._kas_client.get_public_key(kas_info) - - def unwrap(self, key_access: Any, policy: str, session_key_type: Any) -> bytes: - """ - Unwraps the key using the KAS. - - Args: - key_access: KeyAccess object containing the wrapped key - policy: Policy JSON string - session_key_type: Type of session key (RSA, EC) - - Returns: - Unwrapped key as bytes - """ - return self._kas_client.unwrap(key_access, policy, session_key_type) - - def unwrap_nanotdf( - self, - curve: Any, - header: str, - kas_url: str, - wrapped_key: bytes | None = None, - kas_private_key: str | None = None, - mock: bool = False, - ) -> bytes: - """ - Unwraps the NanoTDF key using the KAS. If mock=True, performs local unwrap using the private key (for tests). - - Args: - curve: EC curve used - header: NanoTDF header - kas_url: URL of the KAS - wrapped_key: Optional wrapped key bytes (for mock mode) - kas_private_key: Optional KAS private key (for mock mode) - mock: If True, unwrap locally using provided private key - - Returns: - Unwrapped key as bytes - """ - if mock and wrapped_key and kas_private_key: - from .asym_decryption import AsymDecryption - - asym = AsymDecryption(private_key_pem=kas_private_key) - return asym.decrypt(wrapped_key) - - # This would be implemented using nanotdf-specific logic - raise NotImplementedError("KAS unwrap_nanotdf not implemented.") - - def get_key_cache(self) -> Any: - """ - Returns the KAS key cache. - - Returns: - The KAS key cache object - """ - return self._kas_client.get_key_cache() - - def close(self): - """Closes resources associated with the KAS interface""" - pass - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - class Services(AbstractContextManager): """ The Services interface provides access to various platform service clients and KAS. @@ -287,25 +289,12 @@ def kas_registry(self) -> KeyAccessServerRegistryServiceClientInterface: """Returns the KAS registry service client""" raise NotImplementedError - def kas(self) -> "SDK.KAS": + def kas(self) -> KAS: """ - Returns the KAS interface. - - Returns: - KAS: The KAS interface implementation + Returns the KAS client for key access operations. + This should be implemented to return an instance of KAS. """ - # Return a KAS implementation with the SDK's platform URL and settings - # This is where we would get the platform URL and token source from the SDK - from .sdk_builder import SDKBuilder - - platform_url = SDKBuilder.get_platform_url() - - # Create the KAS implementation with the platform URL and use_plaintext setting from SDK - kas_impl = SDK.KAS( - platform_url=platform_url, - use_plaintext=getattr(self, "_use_plaintext", False), - ) - return kas_impl + raise NotImplementedError def close(self): """Closes resources associated with the services""" diff --git a/src/otdf_python/sdk_builder.py b/src/otdf_python/sdk_builder.py index df4e935..55019a8 100644 --- a/src/otdf_python/sdk_builder.py +++ b/src/otdf_python/sdk_builder.py @@ -10,7 +10,7 @@ import httpx from dataclasses import dataclass -from otdf_python.sdk import SDK +from otdf_python.sdk import SDK, KAS from otdf_python.sdk_exceptions import AutoConfigureException # Configure logging @@ -37,14 +37,14 @@ class SDKBuilder: # Class variable to store the latest platform URL _platform_url = None - def __init__(self): + def __init__(self, auth_token: str | None = None): self.platform_endpoint: str | None = None self.issuer_endpoint: str | None = None self.oauth_config: OAuthConfig | None = None self.use_plaintext: bool = False self.insecure_skip_verify: bool = False self.ssl_context: ssl.SSLContext | None = None - self.auth_token: str | None = None + self.auth_token: str | None = auth_token self.cert_paths: list[str] = [] @staticmethod @@ -390,7 +390,7 @@ def __init__(self, builder_instance): self._auth_headers = auth_interceptor if auth_interceptor else {} self._builder = builder_instance - def kas(self) -> "SDK.KAS": + def kas(self) -> KAS: """ Returns the KAS interface with SSL verification settings. """ @@ -406,7 +406,7 @@ def token_source(): return self._builder._get_token_from_client_credentials() return None - kas_impl = SDK.KAS( + kas_impl = KAS( platform_url=platform_url, token_source=token_source, sdk_ssl_verify=self._ssl_verify, diff --git a/src/otdf_python/tdf.py b/src/otdf_python/tdf.py index ed6d18f..ad8e578 100644 --- a/src/otdf_python/tdf.py +++ b/src/otdf_python/tdf.py @@ -1,10 +1,14 @@ -from typing import BinaryIO +from typing import BinaryIO, TYPE_CHECKING import io import os import hashlib import hmac import base64 import zipfile + +if TYPE_CHECKING: + from otdf_python.kas_client import KASClient + from otdf_python.manifest import ( Manifest, ManifestSegment, @@ -196,7 +200,8 @@ def _enforce_policy(self, manifest: Manifest, config: TDFReaderConfig): # noqa: required_attrs = set() if "body" in policy_dict: - # Check for both dataAttributes (new camelCase) and data_attributes (old snake_case) for backward compatibility + # Handle both snake_case and camelCase fields + # TODO: This can probably be simplified to only camelCase data_attrs = policy_dict["body"].get( "dataAttributes" ) or policy_dict["body"].get("data_attributes") @@ -240,7 +245,7 @@ def _unwrap_key(self, key_access_objs, private_key_pem): raise ValueError("No matching KAS private key could unwrap any payload key") return key - def _unwrap_key_with_kas(self, key_access_objs, policy_b64): + def _unwrap_key_with_kas(self, key_access_objs, policy_b64) -> bytes: """ Unwraps the key using the KAS service (production method) """ @@ -248,7 +253,7 @@ def _unwrap_key_with_kas(self, key_access_objs, policy_b64): if not self.services: raise ValueError("SDK services required for KAS operations") - kas_client = ( + kas_client: KASClient = ( self.services.kas() ) # The 'kas_client' should be typed as KASClient @@ -262,7 +267,7 @@ def _unwrap_key_with_kas(self, key_access_objs, policy_b64): # Try each key access object for ka in key_access_objs: try: - # Pass the manifest key access object directly to match Java SDK + # Pass the manifest key access object directly key_access = ka # Determine session key type from key_access properties @@ -281,7 +286,7 @@ def _unwrap_key_with_kas(self, key_access_objs, policy_b64): if key: return key - except Exception as e: + except Exception as e: # noqa: PERF203 import logging logging.warning(f"Error unwrapping key with KAS: {e}") diff --git a/tests/config_pydantic.py b/tests/config_pydantic.py index f189f48..730a5c6 100644 --- a/tests/config_pydantic.py +++ b/tests/config_pydantic.py @@ -73,6 +73,9 @@ class ConfigureTdf(BaseSettings): TEST_OPENTDF_ATTRIBUTE_1: str = "https://example.com/attr/attr1/value/value1" TEST_OPENTDF_ATTRIBUTE_2: str = "https://example.com/attr/attr1/value/value2" + TEST_USER_ID: str = "sample-user" + TEST_USER_PASSWORD: str = "sample-password" + class ConfigureTesting(BaseSettings): """ diff --git a/tests/integration/support_sdk.py b/tests/integration/support_sdk.py new file mode 100644 index 0000000..8212bb4 --- /dev/null +++ b/tests/integration/support_sdk.py @@ -0,0 +1,101 @@ +from otdf_python.sdk_builder import SDKBuilder +from otdf_python.sdk import SDK +from tests.config_pydantic import CONFIG_TDF +import httpx + + +def get_sdk() -> SDK: + if CONFIG_TDF.OPENTDF_PLATFORM_URL.startswith("http://"): + sdk = ( + SDKBuilder() + .set_platform_endpoint(CONFIG_TDF.OPENTDF_PLATFORM_URL) + .set_issuer_endpoint(CONFIG_TDF.OPENTDF_KEYCLOAK_HOST) + .client_secret( + CONFIG_TDF.OPENTDF_CLIENT_ID, + CONFIG_TDF.OPENTDF_CLIENT_SECRET, + ) + .use_insecure_plaintext_connection(True) + .use_insecure_skip_verify(CONFIG_TDF.INSECURE_SKIP_VERIFY) + .build() + ) + elif CONFIG_TDF.OPENTDF_PLATFORM_URL.startswith("https://"): + sdk = ( + SDKBuilder() + .set_platform_endpoint(CONFIG_TDF.OPENTDF_PLATFORM_URL) + .set_issuer_endpoint(CONFIG_TDF.OPENTDF_KEYCLOAK_HOST) + .client_secret( + CONFIG_TDF.OPENTDF_CLIENT_ID, + CONFIG_TDF.OPENTDF_CLIENT_SECRET, + ) + .use_insecure_skip_verify(CONFIG_TDF.INSECURE_SKIP_VERIFY) + .build() + ) + else: + raise ValueError( + f"Invalid platform URL: {CONFIG_TDF.OPENTDF_PLATFORM_URL}. " + "It must start with 'http://' or 'https://'." + ) + + return sdk + + +def get_sdk_for_pe() -> SDK: + user_token: str = get_user_access_token( + CONFIG_TDF.OIDC_OP_TOKEN_ENDPOINT, + CONFIG_TDF.TEST_USER_ID, + CONFIG_TDF.TEST_USER_PASSWORD, + ) + + if CONFIG_TDF.OPENTDF_PLATFORM_URL.startswith("http://"): + sdk = ( + SDKBuilder(auth_token=user_token) + .set_platform_endpoint(CONFIG_TDF.OPENTDF_PLATFORM_URL) + .set_issuer_endpoint(CONFIG_TDF.OPENTDF_KEYCLOAK_HOST) + .use_insecure_plaintext_connection(True) + .use_insecure_skip_verify(CONFIG_TDF.INSECURE_SKIP_VERIFY) + .build() + ) + elif CONFIG_TDF.OPENTDF_PLATFORM_URL.startswith("https://"): + sdk = ( + SDKBuilder() + .set_platform_endpoint(CONFIG_TDF.OPENTDF_PLATFORM_URL) + .set_issuer_endpoint(CONFIG_TDF.OPENTDF_KEYCLOAK_HOST) + .use_insecure_skip_verify(CONFIG_TDF.INSECURE_SKIP_VERIFY) + .build() + ) + else: + raise ValueError( + f"Invalid platform URL: {CONFIG_TDF.OPENTDF_PLATFORM_URL}. " + "It must start with 'http://' or 'https://'." + ) + + return sdk + + +def get_user_access_token( + token_endpoint, + pe_username, + pe_password, +): + """ + When using this function, ensure that: + + 1. The client has "fine-grained access control" enabled (in the Advanced tab for the client in Keycloak). + 2. The client is allowed to use "Direct access grants" (in the Settings tab for the client in Keycloak). + + """ + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + data = { + "grant_type": "password", + "client_id": CONFIG_TDF.OPENTDF_CLIENT_ID, + "client_secret": CONFIG_TDF.OPENTDF_CLIENT_SECRET, + "username": pe_username, + "password": pe_password, + } + + with httpx.Client(verify=False) as client: + response = client.post(token_endpoint, headers=headers, data=data) + response.raise_for_status() + token_data = response.json() + return token_data.get("access_token") diff --git a/tests/integration/test_pe_interaction.py b/tests/integration/test_pe_interaction.py new file mode 100644 index 0000000..b748246 --- /dev/null +++ b/tests/integration/test_pe_interaction.py @@ -0,0 +1,104 @@ +""" +Integration test: Single attribute encryption/decryption using SDK and otdfctl +""" + +import logging +import os +import tempfile +from pathlib import Path +import pytest + +from otdf_python.sdk import SDK +from tests.config_pydantic import CONFIG_TDF +from otdf_python.sdk_exceptions import SDKException +from tests.integration.support_sdk import get_sdk_for_pe + +# Test files (adjust paths as needed) +DECRYPTED_FILE_OTDFCTL = "decrypted_otdfctl.txt" + +_test_attributes = [CONFIG_TDF.TEST_OPENTDF_ATTRIBUTE_1] +logger = logging.getLogger(__name__) + + +def decrypt(input_path: Path, output_path: Path, sdk: SDK): + # # Determine output + # with open(output_path, "wb") as output_file: + with open(input_path, "rb") as infile, open(output_path, "wb") as outfile: + try: + logger.debug("Decrypting TDF") + # tdf_reader = sdk.load_tdf(infile.read(), reader_config) + tdf_reader = sdk.load_tdf_without_config(infile.read()) + # Access payload directly from TDFReader + payload_bytes = tdf_reader.payload + outfile.write(payload_bytes) + logger.info("Successfully decrypted TDF") + + except Exception as e: + logger.error(f"Decryption failed: {e}") + # Clean up the output file if there was an error + output_path.unlink(missing_ok=True) + raise SDKException("Decryption failed") + + +@pytest.mark.skip(reason="Skipping until PE environment issues are resolved") +@pytest.mark.integration +def test_single_attribute_encryption_decryption(): + # Encrypt with SDK using a single attribute + sdk = get_sdk_for_pe() + + with tempfile.TemporaryDirectory() as tmpDir: + print("Created temporary directory", tmpDir) + some_plaintext_file = Path(tmpDir) / "new-file.txt" + some_plaintext_file.write_text("Hello world") + + INPUT_FILE = some_plaintext_file + + config = sdk.new_tdf_config( + attributes=_test_attributes, + ) + + input_path = Path(INPUT_FILE) + + output_path = input_path.with_suffix(input_path.suffix + ".tdf") + with open(input_path, "rb") as infile, open(output_path, "wb") as outfile: + sdk.create_tdf(infile.read(), config, output_stream=outfile) + + TDF_FILE = output_path + + assert TDF_FILE.exists() + + # Decrypt with SDK + DECRYPTED_FILE_SDK = Path(tmpDir) / "decrypted.txt" + DECRYPTED_FILE_SDK.touch() # Ensure the file exists + + decrypt(TDF_FILE, DECRYPTED_FILE_SDK, sdk) + with open(INPUT_FILE, "rb") as f1, open(DECRYPTED_FILE_SDK, "rb") as f2: + assert f1.read() == f2.read(), "SDK decrypted output does not match input" + + # # Decrypt with otdfctl + # otdfctl_cmd = [ + # "otdfctl", + # "decrypt", + # "--kas-url", + # kas_info["url"], + # "--kas-public-key", + # kas_info["public_key"], + # "--kas-token", + # kas_info["token"], + # "--attribute", + # _test_attributes, + # "-i", + # TDF_FILE, + # "-o", + # DECRYPTED_FILE_OTDFCTL, + # ] + # subprocess.run(otdfctl_cmd, check=True) + # with open(INPUT_FILE, "rb") as f1, open(DECRYPTED_FILE_OTDFCTL, "rb") as f2: + # assert f1.read() == f2.read(), ( + # "otdfctl decrypted output does not match input" + # ) + + # Clean up + for f in [TDF_FILE, DECRYPTED_FILE_SDK, DECRYPTED_FILE_OTDFCTL]: + if os.path.exists(f): + os.remove(f) diff --git a/tests/test_sdk_mock.py b/tests/test_sdk_mock.py index 61ad98c..ec507f8 100644 --- a/tests/test_sdk_mock.py +++ b/tests/test_sdk_mock.py @@ -1,5 +1,6 @@ from otdf_python.sdk import ( SDK, + KAS, AttributesServiceClientInterface, NamespaceServiceClientInterface, SubjectMappingServiceClientInterface, @@ -9,7 +10,7 @@ ) -class MockKAS(SDK.KAS): +class MockKAS(KAS): def get_public_key(self, kas_info): return "mock-public-key" diff --git a/tests/test_validate_otdf_python.py b/tests/test_validate_otdf_python.py index fad6e4e..ee702e7 100644 --- a/tests/test_validate_otdf_python.py +++ b/tests/test_validate_otdf_python.py @@ -10,55 +10,19 @@ import tempfile import logging from pathlib import Path -from otdf_python.sdk_builder import SDKBuilder -from otdf_python.sdk import SDK from otdf_python.tdf import TDFReaderConfig -from tests.config_pydantic import CONFIG_TDF import pytest +from tests.integration.support_sdk import get_sdk + # Set up detailed logging logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") _test_attributes = [] -def _get_sdk() -> SDK: - if CONFIG_TDF.OPENTDF_PLATFORM_URL.startswith("http://"): - sdk = ( - SDKBuilder() - .set_platform_endpoint(CONFIG_TDF.OPENTDF_PLATFORM_URL) - .set_issuer_endpoint(CONFIG_TDF.OPENTDF_KEYCLOAK_HOST) - .client_secret( - CONFIG_TDF.OPENTDF_CLIENT_ID, - CONFIG_TDF.OPENTDF_CLIENT_SECRET, - ) - .use_insecure_plaintext_connection(True) - .use_insecure_skip_verify(CONFIG_TDF.INSECURE_SKIP_VERIFY) - .build() - ) - elif CONFIG_TDF.OPENTDF_PLATFORM_URL.startswith("https://"): - sdk = ( - SDKBuilder() - .set_platform_endpoint(CONFIG_TDF.OPENTDF_PLATFORM_URL) - .set_issuer_endpoint(CONFIG_TDF.OPENTDF_KEYCLOAK_HOST) - .client_secret( - CONFIG_TDF.OPENTDF_CLIENT_ID, - CONFIG_TDF.OPENTDF_CLIENT_SECRET, - ) - .use_insecure_skip_verify(CONFIG_TDF.INSECURE_SKIP_VERIFY) - .build() - ) - else: - raise ValueError( - f"Invalid platform URL: {CONFIG_TDF.OPENTDF_PLATFORM_URL}. " - "It must start with 'http://' or 'https://'." - ) - - return sdk - - def _get_sdk_and_tdf_config() -> tuple: - sdk = _get_sdk() + sdk = get_sdk() # Let the SDK create the default KAS info from the platform URL # This will automatically append /kas to the platform URL @@ -83,7 +47,7 @@ def encrypt_file(input_path: Path) -> Path: def decrypt_file(encrypted_path: Path) -> Path: """Decrypt a file and return the path to the decrypted file.""" - sdk = _get_sdk() + sdk = get_sdk() output_path = encrypted_path.with_suffix(".decrypted") with open(encrypted_path, "rb") as infile, open(output_path, "wb") as outfile: @@ -102,7 +66,7 @@ def decrypt_file(encrypted_path: Path) -> Path: def verify_encrypt_str() -> None: print("Validating string encryption (local TDF)") try: - sdk = _get_sdk() + sdk = get_sdk() payload = b"Hello from Python"