diff --git a/deps/key-value-storage/src/lib.rs b/deps/key-value-storage/src/lib.rs index 2fea49240f..ffde909bfe 100644 --- a/deps/key-value-storage/src/lib.rs +++ b/deps/key-value-storage/src/lib.rs @@ -27,10 +27,15 @@ pub struct SetParameters { pub overwrite: bool, } +pub enum SetResult { + Inserted, + AlreadyExists, +} + #[async_trait] pub trait KeyValueStorage: Send + Sync { /// Set a value for a key. - async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result<()>; + async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result; /// List all keys. async fn list(&self) -> Result>; diff --git a/deps/key-value-storage/src/local_fs/mod.rs b/deps/key-value-storage/src/local_fs/mod.rs index 950efe8612..d177ec9805 100644 --- a/deps/key-value-storage/src/local_fs/mod.rs +++ b/deps/key-value-storage/src/local_fs/mod.rs @@ -13,7 +13,9 @@ use serde::Deserialize; use tokio::sync::RwLock; use tracing::instrument; -use crate::{is_valid_key, KeyValueStorage, KeyValueStorageError, Result, SetParameters}; +use crate::{ + is_valid_key, KeyValueStorage, KeyValueStorageError, Result, SetParameters, SetResult, +}; /// Default file path for the local JSON file. const FILE_PATH: &str = "/opt/confidential-containers/storage/local_fs"; @@ -54,7 +56,7 @@ impl LocalFs { #[async_trait] impl KeyValueStorage for LocalFs { #[instrument(skip_all, name = "LocalFs::set", fields(key = key))] - async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result<()> { + async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result { if !is_valid_key(key) { return Err(KeyValueStorageError::SetKeyFailed { source: anyhow::anyhow!("key contains invalid characters"), @@ -75,13 +77,10 @@ impl KeyValueStorage for LocalFs { } if !parameters.overwrite && file_path.exists() { - return Err(KeyValueStorageError::SetKeyFailed { - source: anyhow::anyhow!("key already exists"), - key: key.to_string(), - }); + return Ok(SetResult::AlreadyExists); } - Ok(()) + Ok(SetResult::Inserted) } #[instrument(skip_all, name = "LocalFs::get", fields(key = key))] diff --git a/deps/key-value-storage/src/local_json/mod.rs b/deps/key-value-storage/src/local_json/mod.rs index be46b5bd8d..339edf0866 100644 --- a/deps/key-value-storage/src/local_json/mod.rs +++ b/deps/key-value-storage/src/local_json/mod.rs @@ -14,7 +14,7 @@ use serde::Deserialize; use tokio::sync::RwLock; use tracing::{debug, instrument}; -use crate::{KeyValueStorage, KeyValueStorageError, Result, SetParameters}; +use crate::{KeyValueStorage, KeyValueStorageError, Result, SetParameters, SetResult}; /// Default file path for the local JSON file. const FILE_PATH: &str = "/opt/confidential-containers/storage/local_json/key_value.json"; @@ -70,7 +70,7 @@ impl LocalJson { #[async_trait] impl KeyValueStorage for LocalJson { #[instrument(skip_all, name = "LocalJson::set", fields(key = key))] - async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result<()> { + async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result { let _ = self.lock.write().await; let file = tokio::fs::read(&self.file_path).await.map_err(|e| { KeyValueStorageError::GetKeyFailed { @@ -82,10 +82,7 @@ impl KeyValueStorage for LocalJson { .map_err(|e| KeyValueStorageError::MalformedValue { source: e.into() })?; let value_b64 = URL_SAFE.encode(value); if parameters.overwrite && items.contains_key(key) { - return Err(KeyValueStorageError::SetKeyFailed { - source: anyhow::anyhow!("key already exists"), - key: key.to_string(), - }); + return Ok(SetResult::AlreadyExists); } items.insert(key.to_string(), value_b64); @@ -101,7 +98,7 @@ impl KeyValueStorage for LocalJson { source: e.into(), key: key.to_string(), })?; - Ok(()) + Ok(SetResult::Inserted) } #[instrument(skip_all, name = "LocalJson::get", fields(key = key))] diff --git a/deps/key-value-storage/src/memory.rs b/deps/key-value-storage/src/memory.rs index 945cb4c77c..6af276a901 100644 --- a/deps/key-value-storage/src/memory.rs +++ b/deps/key-value-storage/src/memory.rs @@ -7,7 +7,7 @@ use async_trait::async_trait; use tokio::sync::RwLock; -use crate::{KeyValueStorage, KeyValueStorageError, Result, SetParameters}; +use crate::{KeyValueStorage, Result, SetParameters, SetResult}; use std::collections::HashMap; use tracing::instrument; @@ -19,7 +19,7 @@ pub struct MemoryKeyValueStorage { #[async_trait] impl KeyValueStorage for MemoryKeyValueStorage { #[instrument(skip_all, name = "MemoryKeyValueStorage::set", fields(key = key))] - async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result<()> { + async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result { if parameters.overwrite { self.items .write() @@ -27,17 +27,14 @@ impl KeyValueStorage for MemoryKeyValueStorage { .insert(key.to_string(), value.to_vec()); } else { if self.items.read().await.contains_key(key) { - return Err(KeyValueStorageError::SetKeyFailed { - source: anyhow::anyhow!("key already exists"), - key: key.to_string(), - }); + return Ok(SetResult::AlreadyExists); } self.items .write() .await .insert(key.to_string(), value.to_vec()); } - Ok(()) + Ok(SetResult::Inserted) } #[instrument(skip_all, name = "MemoryKeyValueStorage::list")] diff --git a/deps/key-value-storage/src/postgres/mod.rs b/deps/key-value-storage/src/postgres/mod.rs index 8fd736803f..d0c81fe5f2 100644 --- a/deps/key-value-storage/src/postgres/mod.rs +++ b/deps/key-value-storage/src/postgres/mod.rs @@ -16,7 +16,9 @@ use sqlx::PgPool; use sqlx::{postgres::PgPoolOptions, query, Row}; use tracing::{debug, info, instrument}; -use crate::{is_valid_key, KeyValueStorage, KeyValueStorageError, Result, SetParameters}; +use crate::{ + is_valid_key, KeyValueStorage, KeyValueStorageError, Result, SetParameters, SetResult, +}; /// The maximum number of connections to the PostgreSQL database. pub const MAX_CONNECTIONS: u32 = 5; @@ -117,7 +119,7 @@ pub struct PolicyItem { #[async_trait] impl KeyValueStorage for PostgresClient { #[instrument(skip_all, name = "PostgresClient::set")] - async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result<()> { + async fn set(&self, key: &str, value: &[u8], parameters: SetParameters) -> Result { if !is_valid_key(key) { return Err(KeyValueStorageError::SetKeyFailed { source: anyhow::anyhow!("key contains invalid characters"), @@ -154,14 +156,11 @@ impl KeyValueStorage for PostgresClient { key: key.to_string(), })?; if result.is_none() { - return Err(KeyValueStorageError::SetKeyFailed { - source: anyhow::anyhow!("key already exists"), - key: key.to_string(), - }); + return Ok(SetResult::AlreadyExists); } } - Ok(()) + Ok(SetResult::Inserted) } #[instrument(skip_all, name = "PostgresClient::list")] diff --git a/deps/policy-engine/src/error.rs b/deps/policy-engine/src/error.rs index 81624a9414..abe07c65bf 100644 --- a/deps/policy-engine/src/error.rs +++ b/deps/policy-engine/src/error.rs @@ -35,7 +35,7 @@ pub enum PolicyError { source: std::string::FromUtf8Error, }, - // Opa Related Errors + // Regorus Related Errors #[error("Failed to load policy: {0}")] LoadPolicyFailed(#[source] anyhow::Error), @@ -50,4 +50,12 @@ pub enum PolicyError { #[error("Failed to eval policy: {0}")] EvalPolicyFailed(#[source] anyhow::Error), + + #[error("Failed to add regorus extension: {name} with id {id}: {source}")] + AddRegorusExtensionFailed { + name: String, + id: u8, + #[source] + source: anyhow::Error, + }, } diff --git a/deps/policy-engine/src/lib.rs b/deps/policy-engine/src/lib.rs index b4c208fb32..ca50ab881c 100644 --- a/deps/policy-engine/src/lib.rs +++ b/deps/policy-engine/src/lib.rs @@ -19,56 +19,57 @@ pub use key_value_storage::{KeyValueStorage, KeyValueStorageConfig}; pub struct PolicyEngineConfig { /// The storage to store the policies. pub storage: KeyValueStorageConfig, +} - /// The type of policy engine to use. - /// Currently, only Rego is supported. - pub policy_type: PolicyType, +pub trait EngineTrait { + /// The suffix of the policy file. + /// Concrete policy engine backend may handle the policy in different ways. + /// For example, the policy engine may store the policy in a different format. + /// In this case, the policy engine may need to add a suffix to the policy id to distinguish the policy. + /// This is also for compatibility with the existing policy setting and getting + /// APIs. Concretely, users do not need to specify the `.rego` suffix. + fn policy_suffix() -> &'static str { + "" + } } #[derive(Clone)] -pub struct PolicyEngine { +pub struct PolicyEngine { pub storage: Arc, - pub engine: Arc, + pub engine: T, } -impl PolicyEngine { - pub async fn new(config: PolicyEngineConfig) -> Result { - let storage = config.storage.to_key_value_storage().await?; - let engine = config.policy_type.to_engine(); - Ok(Self { storage, engine }) - } - - pub async fn evaluate( - &self, - data: &str, - input: &str, - policy_id: &str, - ) -> Result { - let policy = self.get_policy(policy_id).await?; - self.engine.evaluate(data, input, &policy).await - } - +impl PolicyEngine { /// Set a policy to the backend. /// The policy is expected to be provided as string. /// Concrete policy engine backend may handle the policy in different ways. pub async fn set_policy(&self, policy_id: &str, policy: &str, overwrite: bool) -> Result<()> { let params = SetParameters { overwrite }; - self.storage - .set(policy_id, policy.as_bytes(), params) + let policy_id = format!("{}{}", policy_id, T::policy_suffix()); + let _ = self + .storage + .set(&policy_id, policy.as_bytes(), params) .await - .map_err(From::from) + .map_err(PolicyError::from)?; + Ok(()) } /// List all policies in the backend. pub async fn list_policies(&self) -> Result> { - self.storage.list().await.map_err(From::from) + let policies = self.storage.list().await?; + let policies = policies + .into_iter() + .map(|policy| policy.strip_suffix(T::policy_suffix()).map(|p|p.to_string()).ok_or(PolicyError::MalformedPolicy(anyhow::anyhow!("There is at least one policy in the storage with invalid name. The policy name should contain the policy suffix {}.", T::policy_suffix())))) + .collect::>>()?; + Ok(policies) } /// Get a policy from the backend. /// The policy is expected to be provided as string. /// Concrete policy engine backend may handle the policy in different ways. pub async fn get_policy(&self, policy_id: &str) -> Result { - let policy_str = self.storage.get(policy_id).await?; + let policy_id = format!("{}{}", policy_id, T::policy_suffix()); + let policy_str = self.storage.get(&policy_id).await?; match policy_str { Some(policy_str) => { diff --git a/deps/policy-engine/src/policy/mod.rs b/deps/policy-engine/src/policy/mod.rs index 2fba169225..5265ebe9ab 100644 --- a/deps/policy-engine/src/policy/mod.rs +++ b/deps/policy-engine/src/policy/mod.rs @@ -2,43 +2,22 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 -use std::sync::Arc; +use std::collections::HashMap; -use crate::Result; -use async_trait::async_trait; use serde::Deserialize; use serde_json::Value; -use strum::EnumString; pub mod rego; -#[async_trait] -pub trait Engine: Send + Sync { - /// The inputs to an policy engine. Inspired by OPA, we divided the inputs - /// into three parts: - /// - `data`: static data that will help to enforce the policy. - /// - `input`: dynamic data that will help to enforce the policy. - /// - `policy`: the policy to be enforced. - async fn evaluate(&self, data: &str, input: &str, policy: &str) -> Result; -} - -#[derive(Debug, EnumString, Deserialize, Clone, Default, PartialEq)] -#[strum(ascii_case_insensitive)] -pub enum PolicyType { - #[default] - Rego, -} - -#[derive(Debug)] +#[derive(Debug, Clone, Deserialize, Default, PartialEq)] +#[serde(default)] pub struct EvaluationResult { - pub rules_result: Value, + pub eval_rules_result: HashMap>, pub policy_hash: String, } -impl PolicyType { - pub fn to_engine(&self) -> Arc { - match self { - PolicyType::Rego => Arc::new(crate::policy::rego::Regorus::default()), - } - } +#[derive(Debug, Clone, Deserialize, Default, PartialEq)] +pub enum PolicyLanguage { + #[default] + Rego, } diff --git a/deps/policy-engine/src/policy/rego.rs b/deps/policy-engine/src/policy/rego.rs index 15dd8d19f9..e8e8172f90 100644 --- a/deps/policy-engine/src/policy/rego.rs +++ b/deps/policy-engine/src/policy/rego.rs @@ -2,28 +2,48 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 -use async_trait::async_trait; +use std::collections::HashMap; + +use regorus::Extension; use serde_json::Value; -use tracing::instrument; +use tracing::{info, instrument}; -use crate::{Engine, EvaluationResult, PolicyError}; +use crate::{EngineTrait, EvaluationResult, PolicyEngine, PolicyEngineConfig, PolicyError, Result}; /// The rule to evaluate the policy. /// Note that only the result of this rule will be returned. pub const EVAL_RULE: &str = "data.policy.result"; +pub struct RegorusExtension { + pub name: String, + pub id: u8, + pub extension: Box, +} + #[derive(Debug, Clone, Default)] pub struct Regorus {} -#[async_trait] -impl Engine for Regorus { +impl EngineTrait for Regorus { + fn policy_suffix() -> &'static str { + ".rego" + } +} + +impl Regorus { + /// The inputs to an policy engine. Inspired by OPA, we divided the inputs + /// into three parts: + /// - `data`: static data that will help to enforce the policy. + /// - `input`: dynamic data that will help to enforce the policy. + /// - `policy`: the policy to be enforced. #[instrument(skip_all, name = "Regorus")] - async fn evaluate( + pub async fn evaluate( &self, - data: &str, + data: Option<&str>, input: &str, policy: &str, - ) -> Result { + eval_rules: Vec<&str>, + extensions: Vec, + ) -> Result { let mut engine = regorus::Engine::new(); let policy_hash = { @@ -42,28 +62,53 @@ impl Engine for Regorus { .add_policy("".to_string(), policy.to_string()) .map_err(PolicyError::LoadPolicyFailed)?; - let data = - regorus::Value::from_json_str(data).map_err(PolicyError::JsonSerializationFailed)?; + if let Some(data) = data { + let data = regorus::Value::from_json_str(data) + .map_err(PolicyError::JsonSerializationFailed)?; - engine - .add_data(data) - .map_err(PolicyError::LoadReferenceDataFailed)?; + engine + .add_data(data) + .map_err(PolicyError::LoadReferenceDataFailed)?; + } engine .set_input_json(input) .map_err(PolicyError::SetInputDataFailed)?; - let claim_value = engine - .eval_rule(EVAL_RULE.to_string()) - .map_err(PolicyError::EvalPolicyFailed)?; - - let claim_value = claim_value - .to_json_str() - .map_err(PolicyError::JsonSerializationFailed)?; - let rules_result = serde_json::from_str::(&claim_value)?; + for extension in extensions { + engine + .add_extension(extension.name.clone(), extension.id, extension.extension) + .map_err(|e| PolicyError::AddRegorusExtensionFailed { + name: extension.name, + id: extension.id, + source: e, + })?; + } + + let eval_rules_result = eval_rules + .iter() + .map(|rule| { + let value = match engine.eval_rule(rule.to_string()) { + Ok(r) => Some(r), + // Extensions claim is optional. + Err(e) if e.to_string().contains("not a valid rule path") => { + info!("No claim {rule} found in policy."); + None + } + Err(e) => return Err(PolicyError::EvalPolicyFailed(e)), + }; + if let Some(value) = value { + let value = serde_json::to_value(value) + .map_err(|e| PolicyError::JsonSerializationFailed(e.into()))?; + Ok((rule.to_string(), Some(value))) + } else { + Ok((rule.to_string(), None)) + } + }) + .collect::>>>()?; let res = EvaluationResult { - rules_result, + eval_rules_result, policy_hash, }; @@ -71,6 +116,28 @@ impl Engine for Regorus { } } +impl PolicyEngine { + pub async fn new(config: PolicyEngineConfig) -> Result { + let storage = config.storage.to_key_value_storage().await?; + let engine = Regorus::default(); + Ok(Self { storage, engine }) + } + + pub async fn evaluate_rego( + &self, + data: Option<&str>, + input: &str, + policy_id: &str, + eval_rules: Vec<&str>, + extensions: Vec, + ) -> Result { + let policy = self.get_policy(policy_id).await?; + self.engine + .evaluate(data, input, &policy, eval_rules, extensions) + .await + } +} + #[cfg(test)] mod tests { use rstest::rstest; @@ -94,7 +161,7 @@ mod tests { #[case] policy_path: &str, #[case] expected: bool, ) { - use crate::{rego::Regorus, Engine}; + use crate::rego::Regorus; let input = format!( r#" @@ -126,7 +193,26 @@ mod tests { ); let policy = fs::read_to_string(policy_path).unwrap(); let engine = Regorus::default(); - let result = engine.evaluate(&data, &input, &policy).await.unwrap(); - assert_eq!(result.rules_result.as_bool().unwrap(), expected); + let result = engine + .evaluate( + Some(&data), + &input, + &policy, + vec!["data.policy.result"], + vec![], + ) + .await + .unwrap(); + assert_eq!( + result + .eval_rules_result + .get("data.policy.result") + .unwrap() + .as_ref() + .unwrap() + .as_bool() + .unwrap(), + expected + ); } }