diff --git a/src/oso/framework/plugin/addons/signing_server/__init__.py b/src/oso/framework/plugin/addons/signing_server/__init__.py index 218c3ca..8994050 100644 --- a/src/oso/framework/plugin/addons/signing_server/__init__.py +++ b/src/oso/framework/plugin/addons/signing_server/__init__.py @@ -17,25 +17,27 @@ from __future__ import annotations + import uuid import logging -import pathlib import base64 +import sqlite3 +import pathlib -from typing import TYPE_CHECKING -from typing import Callable +from pathlib import Path +from typing import TYPE_CHECKING, Callable from pydantic import field_validator -from ..main import AddonProtocol, BaseAddonConfig - from ._key import KeyPair, KeyType from ._grep11_client import Grep11Client +from ..main import AddonProtocol, BaseAddonConfig + from oso.framework.data.types import V1_3 from oso.framework.core.logging import get_logger if TYPE_CHECKING: - from typing import Any, Callable, ClassVar, Literal + from typing import Any, ClassVar, Literal NAME: Literal["SigningServer"] = "SigningServer" @@ -43,7 +45,6 @@ def configure( framework_config: Any, addon_config: SigningServerConfig ) -> SigningServerAddon: - """Return the addon instance.""" return SigningServerAddon(framework_config, addon_config) @@ -69,7 +70,8 @@ class SigningServerConfig(BaseAddonConfig): client_cert: str client_key: str grep11_endpoint: str = "localhost" - keystore_path: str + keystore_path: str # SQLite DB file + legacy_keystore_dir: str | None = None # Old filesystem store @field_validator("ca_cert", "client_cert", "client_key", mode="before") def _decode_base64_fields(cls, v: str) -> str: @@ -97,14 +99,111 @@ class SigningServerAddon(AddonProtocol): def __init__(self, framework_config: Any, addon_config: SigningServerConfig): self._config = addon_config - self._logger = get_logger(name="signing_server") - self._keystore = pathlib.Path(self._config.keystore_path) + db_path = Path(self._config.keystore_path) + + if db_path.is_dir(): + db_file = db_path / "keystore.db" + + else: + db_file = db_path + + db_file.parent.mkdir(parents=True, exist_ok=True) + + self._conn = sqlite3.connect(str(db_file)) + with self._conn: + self._conn.execute(""" + CREATE TABLE IF NOT EXISTS keys ( + id TEXT PRIMARY KEY, + key_type TEXT NOT NULL, + private_key TEXT NOT NULL, + public_key TEXT NOT NULL + ) + """) + + if self._config.legacy_keystore_dir: + self._migrate_and_cleanup_legacy(self._config.legacy_keystore_dir) self._grep11_client = Grep11Client(self._config) self._grep11_client.health_check() + def _migrate_and_cleanup_legacy(self, legacy_dir: str): + legacy_path = pathlib.Path(legacy_dir) + + if not legacy_path.exists(): + self._logger.debug(f"No legacy keystore found at {legacy_dir}") + return + + # Collect key list + + to_migrate: list[tuple[pathlib.Path, pathlib.Path, KeyType]] = [] + + for key_type_dir in legacy_path.iterdir(): + if not key_type_dir.is_dir(): + continue + + key_type = self._get_key_type(key_type_dir.name) + + if key_type is None: + continue + + for priv_file in key_type_dir.glob("*.key"): + pub_file = priv_file.with_suffix(".pub") + + if pub_file.exists(): + to_migrate.append((priv_file, pub_file, key_type)) + + # Migrate keys to DB + + migrated_files: list[tuple[pathlib.Path, pathlib.Path]] = [] + + try: + with self._conn: + for priv_file, pub_file, key_type in to_migrate: + key_id = priv_file.stem + + # Insert if not already migrated + + if not self._conn.execute( + "SELECT 1 FROM keys WHERE id = ?", (key_id,) + ).fetchone(): + priv_bytes = priv_file.read_bytes() + pub_bytes = pub_file.read_bytes() + + self._conn.execute( + "INSERT INTO keys (id, key_type, private_key, public_key) VALUES (?, ?, ?, ?)", + (key_id, key_type.name, priv_bytes.hex(), pub_bytes.hex()), + ) + + migrated_files.append((priv_file, pub_file)) + + if migrated_files: + self._logger.info(f"Migrated {len(migrated_files)} key(s) to SQLite") + + else: + self._logger.info("No keys migrated from filesystem") + + except Exception as e: + self._logger.error(f"Migration failed: {e}") + return + + # Delete keys from filesystem + + deleted = 0 + + for priv_file, pub_file in migrated_files: + try: + priv_file.unlink() + pub_file.unlink() + deleted += 1 + + except Exception as e: + self._logger.debug(f"Failed to delete {priv_file}: {e}") + + if deleted: + self._logger.info(f"Deleted {deleted} legacy key pair(s)") + def generate_key_pair(self, key_type: KeyType) -> tuple[str, str]: """Generate a new key pair. @@ -115,12 +214,13 @@ def generate_key_pair(self, key_type: KeyType) -> tuple[str, str]: Returns ------- - tuple[str, bytes] + tuple[str, str] - key_id : str The unique identifier for the generated key. - pub_key_pem : str The public key in PEM format. """ + logging.info(f"Generating new key pair of type {key_type.name}") key_pair = self._grep11_client.generate_key_pair(key_type=key_type) @@ -149,31 +249,12 @@ def list_keys(self, key_type: KeyType) -> list[str]: list[str] List of key ids of the given key type. """ - key_id_list = [] - key_type_dir = self._keystore / key_type.name - - if key_type_dir.exists(): - if not key_type_dir.is_dir(): - raise Exception( - f"{key_type_dir} is an existing file, it should be a directory" - ) - - for key_file in key_type_dir.glob("*.key"): - key_id = key_file.stem - - if key_file.with_suffix(".pub").exists(): - key_id_list.append(key_id) - - else: - self._logger.info( - f"Corresponding public key does not exist for {key_file}" - ) - - else: - self._logger.debug(f"'{key_type.name}' dir does not exist in the key store") + cur = self._conn.execute( + "SELECT id FROM keys WHERE key_type = ?", (key_type.name,) + ) - return key_id_list + return [row[0] for row in cur.fetchall()] def get_key_pem(self, key_id: str) -> str | None: """Get the public key PEM for a given key ID. @@ -189,7 +270,8 @@ def get_key_pem(self, key_id: str) -> str | None: The PEM-encoded public key str if the key is found and conversion succeeds, otherwise None. """ - keys = self._find_keys(key_id=key_id) + + keys = self._find_keys(key_id) if not keys: self._logger.info(f"Could not find key pair for key id: '{key_id}'") @@ -197,12 +279,10 @@ def get_key_pem(self, key_id: str) -> str | None: key_type, key_pair = keys - pub_key_pem = self._grep11_client.serialized_key_to_pem( + return self._grep11_client.serialized_key_to_pem( key_type=key_type, pub_key_bytes=key_pair.PublicKey ) - return pub_key_pem - def _find_keys(self, key_id: str) -> tuple[KeyType, KeyPair] | None: """Find private and public keys for the given key ID. @@ -225,74 +305,52 @@ def _find_keys(self, key_id: str) -> tuple[KeyType, KeyPair] | None: FileNotFoundError If either the private or public key file exists but is not a valid file. """ - for priv_key_file in self._keystore.glob("*/*.key"): - file_id = priv_key_file.stem - - if key_id == file_id: - pub_key_file = priv_key_file.with_suffix(".pub") - if not priv_key_file.is_file(): - raise FileNotFoundError( - f"Private key path '{priv_key_file}' is not a valid file" - ) + row = self._conn.execute( + "SELECT key_type, private_key, public_key FROM keys WHERE id = ?", (key_id,) + ).fetchone() - if not pub_key_file.is_file(): - raise FileNotFoundError( - f"Corresponding public key for '{priv_key_file}' does not exist" - ) - - key_type_name = priv_key_file.parent.name + if not row: + return None - key_type = self._get_key_type(key_type_name=key_type_name) + key_type_name, priv_hex, pub_hex = row - if key_type is None: - self._logger.info("Key ID does not match with known key type") - self._logger.debug(f"Key ID: {key_id}") - return None + key_type = self._get_key_type(key_type_name) - key_pair = KeyPair( - PrivateKey=priv_key_file.read_bytes(), - PublicKey=pub_key_file.read_bytes(), - ) + if key_type is None: + return None - return key_type, key_pair + key_pair = KeyPair( + PrivateKey=bytes.fromhex(priv_hex), PublicKey=bytes.fromhex(pub_hex) + ) - return None + return key_type, key_pair def _get_key_type(self, key_type_name: str) -> KeyType | None: - key_type = None - for kt in KeyType: if kt.name == key_type_name: - key_type = kt + return kt - return key_type + return None def _save_key_pair(self, key_type: KeyType, key_pair: KeyPair) -> str: - key_type_dir = self._keystore / key_type.name - - if key_type_dir.exists(): - if not key_type_dir.is_dir(): - raise NotADirectoryError( - f"{key_type_dir} exists but is not a directory." - ) - else: - key_type_dir.mkdir(parents=True, exist_ok=True) - key_id = str(uuid.uuid4()) - self._logger.info(f"Writing {key_type.name} key with key ID: '{key_id}'") - - priv_key_filename = key_type_dir / f"{key_id}.key" - priv_key_filename.write_bytes(key_pair.PrivateKey) - - pub_key_filename = key_type_dir / f"{key_id}.pub" - pub_key_filename.write_bytes(key_pair.PublicKey) - - self._logger.debug( - f"Wrote priv key to {priv_key_filename} and pub key to {pub_key_filename}" + self._logger.info( + f"Saving {key_type.name} key with key ID: '{key_id}' to SQLite" ) + with self._conn: + self._conn.execute( + "INSERT INTO keys (id, key_type, private_key, public_key) VALUES (?, ?, ?, ?)", + ( + key_id, + key_type.name, + key_pair.PrivateKey.hex(), + key_pair.PublicKey.hex(), + ), + ) + return key_id def sign(self, key_id: str, data: bytes) -> str: @@ -301,19 +359,19 @@ def sign(self, key_id: str, data: bytes) -> str: Parameters ---------- key_id : str - Key ID used to find stored key, prefixed with key type OID + Key ID used to find stored key, UUID4. data : bytes Data to be signed. Returns ------- str - Signature as a string. + Signature as a hex string. """ - keys = self._find_keys(key_id=key_id) + + keys = self._find_keys(key_id) if not keys: - self._logger.info(f"Could not find key pair for key id: '{key_id}'") raise Exception(f"Could not find key pair for key id: '{key_id}'") key_type, key_pair = keys @@ -322,6 +380,33 @@ def sign(self, key_id: str, data: bytes) -> str: key_type=key_type, priv_key_bytes=key_pair.PrivateKey, data=data ) + def count_keys(self, key_type: KeyType | None = None) -> int: + """ + Return the number of keys stored in the database. + + Parameters + ---------- + key_type : KeyType | None + If provided, count only keys of this type. Otherwise, count all keys. + + Returns + ------- + int + Number of keys. + """ + + if key_type is not None: + cur = self._conn.execute( + "SELECT COUNT(*) FROM keys WHERE key_type = ?", (key_type.name,) + ) + + else: + cur = self._conn.execute("SELECT COUNT(*) FROM keys") + + row = cur.fetchone() + + return row[0] if row else 0 + def health_check(self) -> V1_3.ComponentStatus: """Check the GREP11 server health status. @@ -330,4 +415,44 @@ def health_check(self) -> V1_3.ComponentStatus: `oso.framework.data.types.ComponentStatus` OSO component status. """ + return self._grep11_client.health_check() + + def verify(self, key_id: str, data: bytes, signature: str) -> bool: + """ + Verify a signature using the public key stored in the keystore. + + Parameters + ---------- + key_id : str + The ID of the key used to generate the signature, UUID4. + data : bytes + The original data that was signed. + signature : str + The signature to verify as a hex string. + + Returns + ------- + bool + True if the signature is valid, False otherwise. + """ + + keys = self._find_keys(key_id) + + if not keys: + self._logger.info(f"Could not find key pair for key id: '{key_id}'") + return False + + key_type, key_pair = keys + + try: + return self._grep11_client.verify( + key_type=key_type, + pub_key_bytes=key_pair.PublicKey, + data=data, + signature=signature, + ) + + except Exception as e: + self._logger.error(f"Signature verification failed for key '{key_id}': {e}") + return False diff --git a/src/oso/framework/plugin/addons/signing_server/_grep11_client.py b/src/oso/framework/plugin/addons/signing_server/_grep11_client.py index baf4ca8..0d42a48 100644 --- a/src/oso/framework/plugin/addons/signing_server/_grep11_client.py +++ b/src/oso/framework/plugin/addons/signing_server/_grep11_client.py @@ -23,6 +23,7 @@ from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives import serialization +from . import SigningServerConfig from ._key import KeyPair, KeyType, SupportedMechanism from .generated import server_pb2, server_pb2_grpc @@ -31,7 +32,7 @@ class Grep11Client: - def __init__(self, signing_server_config) -> None: + def __init__(self, signing_server_config: SigningServerConfig) -> None: super().__init__() self.logger = get_logger("grep11-client") @@ -127,7 +128,7 @@ def health_check(self) -> V1_3.ComponentStatus: response = self.stub.GetMechanismList(request) assert isinstance(response, server_pb2.GetMechanismListResponse) - errors = [] + errors: list[V1_3.Error] = [] for mechanism in SupportedMechanism: if mechanism not in response.Mechs: @@ -176,6 +177,56 @@ def sign(self, key_type: KeyType, priv_key_bytes: bytes, data: bytes) -> str: return signature + def verify( + self, key_type: KeyType, pub_key_bytes: bytes, data: bytes, signature: str + ) -> bool: + """ + Verify a signature using the GREP11 server. + + Parameters + ---------- + key_type : KeyType + The type of key (should match the key used for signing). + pub_key_bytes : bytes + The public key in raw bytes. + data : bytes + The original data that was signed. + signature : str + Hex-encoded signature to verify. + + Returns + ------- + bool + True if the signature is valid, False otherwise. + """ + + self.logger.info("Performing signature verification") + self.logger.debug( + f"Verifying signature: '{signature}' for data: '{data.hex()}' with key type: '{key_type.name}'" + ) + + try: + pub_key_blob = server_pb2.KeyBlob(KeyBlobs=[pub_key_bytes]) + + verify_request = server_pb2.VerifySingleRequest( + Mech=server_pb2.Mechanism(Mechanism=key_type.value.Mechanism), + Data=data, + PubKey=pub_key_blob, + Signature=bytes.fromhex(signature), + ) + + verify_response = self.stub.VerifySingle(verify_request) + assert isinstance(verify_response, server_pb2.VerifyResponse) + + self.logger.info("Completed verification") + self.logger.debug(f"Received VerifySingleResponse: {verify_response=}") + + return True + + except Exception as e: + self.logger.debug(f"Signature verification failed: {e}") + return False + def serialized_key_to_pem(self, key_type: KeyType, pub_key_bytes: bytes) -> str: self.logger.info("Converting Public Key Blob to PEM") self.logger.debug(f"KeyType: '{key_type.name}'")