Skip to content

Commit 7b695a3

Browse files
authored
fix: Remove PostComputeError wrapper (#9)
1 parent d351e5d commit 7b695a3

File tree

5 files changed

+56
-62
lines changed

5 files changed

+56
-62
lines changed

src/api/worker_api.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ use serde::Serialize;
2929
/// ```
3030
#[derive(Serialize, Debug)]
3131
pub struct ExitMessage<'a> {
32-
#[serde(rename = "cause")]
3332
pub cause: &'a ReplicateStatusCause,
3433
}
3534

@@ -171,15 +170,15 @@ mod tests {
171170
let causes = [
172171
(
173172
ReplicateStatusCause::PostComputeInvalidTeeSignature,
174-
"PostComputeInvalidTeeSignature",
173+
"POST_COMPUTE_INVALID_TEE_SIGNATURE",
175174
),
176175
(
177176
ReplicateStatusCause::PostComputeWorkerAddressMissing,
178-
"PostComputeWorkerAddressMissing",
177+
"POST_COMPUTE_WORKER_ADDRESS_MISSING",
179178
),
180179
(
181180
ReplicateStatusCause::PostComputeFailedUnknownIssue,
182-
"PostComputeFailedUnknownIssue",
181+
"POST_COMPUTE_FAILED_UNKNOWN_ISSUE",
183182
),
184183
];
185184

src/compute/app_runner.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::api::worker_api::{ExitMessage, WorkerApiClient};
22
use crate::compute::{
3-
errors::{PostComputeError, ReplicateStatusCause},
3+
errors::ReplicateStatusCause,
44
signer::get_challenge,
55
utils::env_utils::{TeeSessionEnvironmentVariable, get_env_var_or_error},
66
};
@@ -14,7 +14,7 @@ use std::error::Error;
1414
/// the post-compute workflow.
1515
pub trait PostComputeRunnerInterface {
1616
fn run_post_compute(&self, chain_task_id: &str) -> Result<(), Box<dyn Error>>;
17-
fn get_challenge(&self, chain_task_id: &str) -> Result<String, PostComputeError>;
17+
fn get_challenge(&self, chain_task_id: &str) -> Result<String, ReplicateStatusCause>;
1818
fn send_exit_cause(
1919
&self,
2020
authorization: &str,
@@ -41,11 +41,11 @@ impl DefaultPostComputeRunner {
4141
}
4242

4343
impl PostComputeRunnerInterface for DefaultPostComputeRunner {
44-
fn run_post_compute(&self, chain_task_id: &str) -> Result<(), Box<dyn Error>> {
44+
fn run_post_compute(&self, _chain_task_id: &str) -> Result<(), Box<dyn Error>> {
4545
Err("run_post_compute not implemented yet".into())
4646
}
4747

48-
fn get_challenge(&self, chain_task_id: &str) -> Result<String, PostComputeError> {
48+
fn get_challenge(&self, chain_task_id: &str) -> Result<String, ReplicateStatusCause> {
4949
get_challenge(chain_task_id)
5050
}
5151

@@ -95,7 +95,7 @@ pub fn start_with_runner<R: PostComputeRunnerInterface>(runner: &R) -> i32 {
9595
println!("Tee worker post-compute started");
9696
let chain_task_id: String = match get_env_var_or_error(
9797
TeeSessionEnvironmentVariable::IEXEC_TASK_ID,
98-
ReplicateStatusCause::PostComputeChainTaskIdMissing,
98+
ReplicateStatusCause::PostComputeTaskIdMissing,
9999
) {
100100
Ok(id) => id,
101101
Err(e) => {
@@ -114,9 +114,9 @@ pub fn start_with_runner<R: PostComputeRunnerInterface>(runner: &R) -> i32 {
114114
}
115115
Err(error) => {
116116
let exit_cause: &ReplicateStatusCause;
117-
match error.downcast_ref::<PostComputeError>() {
117+
match error.downcast_ref::<ReplicateStatusCause>() {
118118
Some(post_compute_error) => {
119-
exit_cause = post_compute_error.exit_cause();
119+
exit_cause = post_compute_error;
120120
error!(
121121
"TEE post-compute failed with exit cause [errorMessage:{}]",
122122
&exit_cause
@@ -224,19 +224,17 @@ mod tests {
224224
if self.run_post_compute_success {
225225
Ok(())
226226
} else if let Some(cause) = &self.error_cause {
227-
Err(Box::new(PostComputeError::new(cause.clone())))
227+
Err(Box::new(cause.clone()))
228228
} else {
229229
Err("Mock error".into())
230230
}
231231
}
232232

233-
fn get_challenge(&self, _chain_task_id: &str) -> Result<String, PostComputeError> {
233+
fn get_challenge(&self, _chain_task_id: &str) -> Result<String, ReplicateStatusCause> {
234234
if self.get_challenge_success {
235235
Ok("mock_challenge".to_string())
236236
} else {
237-
Err(PostComputeError::new(
238-
ReplicateStatusCause::PostComputeTeeChallengePrivateKeyMissing,
239-
))
237+
Err(ReplicateStatusCause::PostComputeTeeChallengePrivateKeyMissing)
240238
}
241239
}
242240

src/compute/errors.rs

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,19 @@ use serde::{Deserialize, Serialize};
22
use thiserror::Error;
33

44
#[derive(Debug, PartialEq, Clone, Error, Serialize, Deserialize)]
5+
#[serde(rename_all(serialize = "SCREAMING_SNAKE_CASE"))]
6+
#[allow(clippy::enum_variant_names)]
57
pub enum ReplicateStatusCause {
6-
#[error("Failed to verify TeeEnclaveChallenge signature (exiting)")]
7-
PostComputeInvalidTeeSignature,
8+
#[error("Task ID related environment variable is missing")]
9+
PostComputeTaskIdMissing,
10+
#[error("Unexpected error occurred")]
11+
PostComputeFailedUnknownIssue,
812
#[error("Invalid enclave challenge private key")]
913
PostComputeInvalidEnclaveChallengePrivateKey,
10-
#[error("Worker address related environment variable is missing")]
11-
PostComputeWorkerAddressMissing,
14+
#[error("Invalid TEE signature")]
15+
PostComputeInvalidTeeSignature,
1216
#[error("Tee challenge private key related environment variable is missing")]
1317
PostComputeTeeChallengePrivateKeyMissing,
14-
#[error("Chain task ID related environment variable is missing")]
15-
PostComputeChainTaskIdMissing,
16-
#[error("Unexpected error occured")]
17-
PostComputeFailedUnknownIssue,
18-
}
19-
20-
#[derive(Debug, Error, Clone)]
21-
#[error("PostCompute failed: {exit_cause}")]
22-
pub struct PostComputeError {
23-
pub exit_cause: ReplicateStatusCause,
24-
}
25-
26-
impl PostComputeError {
27-
pub fn new(cause: ReplicateStatusCause) -> Self {
28-
Self { exit_cause: cause }
29-
}
30-
31-
pub fn exit_cause(&self) -> &ReplicateStatusCause {
32-
&self.exit_cause
33-
}
18+
#[error("Worker address related environment variable is missing")]
19+
PostComputeWorkerAddressMissing,
3420
}

src/compute/signer.rs

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::compute::{
2-
errors::{PostComputeError, ReplicateStatusCause::*},
2+
errors::ReplicateStatusCause,
33
utils::{
44
env_utils::{TeeSessionEnvironmentVariable, get_env_var_or_error},
55
hash_utils::{concatenate_and_hash, hex_string_to_byte_array},
@@ -22,7 +22,7 @@ use alloy_signer_local::PrivateKeySigner;
2222
/// # Returns
2323
///
2424
/// * `Ok(String)` - The signature as a hexadecimal string if successful
25-
/// * `Err(PostComputeError)` - An error if the private key is invalid or if signing fails
25+
/// * `Err(ReplicateStatusCause)` - An error if the private key is invalid or if signing fails
2626
///
2727
/// # Errors
2828
///
@@ -44,14 +44,14 @@ use alloy_signer_local::PrivateKeySigner;
4444
pub fn sign_enclave_challenge(
4545
message_hash: &str,
4646
enclave_challenge_private_key: &str,
47-
) -> Result<String, PostComputeError> {
47+
) -> Result<String, ReplicateStatusCause> {
4848
let signer: PrivateKeySigner = enclave_challenge_private_key
4949
.parse::<PrivateKeySigner>()
50-
.map_err(|_| PostComputeError::new(PostComputeInvalidEnclaveChallengePrivateKey))?;
50+
.map_err(|_| ReplicateStatusCause::PostComputeInvalidEnclaveChallengePrivateKey)?;
5151

5252
let signature: Signature = signer
5353
.sign_message_sync(&hex_string_to_byte_array(message_hash))
54-
.map_err(|_| PostComputeError::new(PostComputeInvalidTeeSignature))?;
54+
.map_err(|_| ReplicateStatusCause::PostComputeInvalidTeeSignature)?;
5555

5656
Ok(signature.to_string())
5757
}
@@ -69,7 +69,7 @@ pub fn sign_enclave_challenge(
6969
/// # Returns
7070
///
7171
/// * `Ok(String)` - The challenge signature as a hexadecimal string if successful
72-
/// * `Err(PostComputeError)` - An error if required environment variables are missing or if signing fails
72+
/// * `Err(ReplicateStatusCause)` - An error if required environment variables are missing or if signing fails
7373
///
7474
/// # Errors
7575
///
@@ -97,14 +97,14 @@ pub fn sign_enclave_challenge(
9797
/// Err(e) => eprintln!("Error generating challenge: {:?}", e),
9898
/// }
9999
/// ```
100-
pub fn get_challenge(chain_task_id: &str) -> Result<String, PostComputeError> {
100+
pub fn get_challenge(chain_task_id: &str) -> Result<String, ReplicateStatusCause> {
101101
let worker_address: String = get_env_var_or_error(
102102
TeeSessionEnvironmentVariable::SIGN_WORKER_ADDRESS,
103-
PostComputeWorkerAddressMissing,
103+
ReplicateStatusCause::PostComputeWorkerAddressMissing,
104104
)?;
105105
let tee_challenge_private_key: String = get_env_var_or_error(
106106
TeeSessionEnvironmentVariable::SIGN_TEE_CHALLENGE_PRIVATE_KEY,
107-
PostComputeTeeChallengePrivateKeyMissing,
107+
ReplicateStatusCause::PostComputeTeeChallengePrivateKeyMissing,
108108
)?;
109109
let message_hash: String = concatenate_and_hash(&[chain_task_id, &worker_address]);
110110
sign_enclave_challenge(&message_hash, &tee_challenge_private_key)
@@ -113,7 +113,6 @@ pub fn get_challenge(chain_task_id: &str) -> Result<String, PostComputeError> {
113113
#[cfg(test)]
114114
mod tests {
115115
use super::*;
116-
use crate::compute::utils::env_utils::TeeSessionEnvironmentVariable::*;
117116
use temp_env::with_vars;
118117

119118
const CHAIN_TASK_ID: &str = "0x123456789abcdef";
@@ -141,7 +140,7 @@ mod tests {
141140
assert!(
142141
matches!(
143142
result,
144-
Err(ref err) if err.exit_cause == PostComputeInvalidEnclaveChallengePrivateKey
143+
Err(err) if err == ReplicateStatusCause::PostComputeInvalidEnclaveChallengePrivateKey
145144
),
146145
"Should return missing TEE challenge private key error"
147146
);
@@ -151,9 +150,12 @@ mod tests {
151150
fn should_get_challenge() {
152151
with_vars(
153152
vec![
154-
(SIGN_WORKER_ADDRESS.name(), Some(WORKER_ADDRESS)),
155153
(
156-
SIGN_TEE_CHALLENGE_PRIVATE_KEY.name(),
154+
TeeSessionEnvironmentVariable::SIGN_WORKER_ADDRESS.name(),
155+
Some(WORKER_ADDRESS),
156+
),
157+
(
158+
TeeSessionEnvironmentVariable::SIGN_TEE_CHALLENGE_PRIVATE_KEY.name(),
157159
Some(ENCLAVE_CHALLENGE_PRIVATE_KEY),
158160
),
159161
],
@@ -181,9 +183,12 @@ mod tests {
181183
fn should_fail_on_missing_worker_address_env_var() {
182184
with_vars(
183185
vec![
184-
(SIGN_WORKER_ADDRESS.name(), None),
185186
(
186-
SIGN_TEE_CHALLENGE_PRIVATE_KEY.name(),
187+
TeeSessionEnvironmentVariable::SIGN_WORKER_ADDRESS.name(),
188+
None,
189+
),
190+
(
191+
TeeSessionEnvironmentVariable::SIGN_TEE_CHALLENGE_PRIVATE_KEY.name(),
187192
Some(ENCLAVE_CHALLENGE_PRIVATE_KEY),
188193
),
189194
],
@@ -192,7 +197,7 @@ mod tests {
192197
assert!(
193198
matches!(
194199
result,
195-
Err(ref err) if err.exit_cause == PostComputeWorkerAddressMissing
200+
Err(err) if err == ReplicateStatusCause::PostComputeWorkerAddressMissing
196201
),
197202
"Should return missing worker address error"
198203
);
@@ -204,15 +209,21 @@ mod tests {
204209
fn should_fail_on_missing_private_key_env_var() {
205210
with_vars(
206211
vec![
207-
(SIGN_WORKER_ADDRESS.name(), Some(WORKER_ADDRESS)),
208-
(SIGN_TEE_CHALLENGE_PRIVATE_KEY.name(), None),
212+
(
213+
TeeSessionEnvironmentVariable::SIGN_WORKER_ADDRESS.name(),
214+
Some(WORKER_ADDRESS),
215+
),
216+
(
217+
TeeSessionEnvironmentVariable::SIGN_TEE_CHALLENGE_PRIVATE_KEY.name(),
218+
None,
219+
),
209220
],
210221
|| {
211222
let result = get_challenge(CHAIN_TASK_ID);
212223
assert!(
213224
matches!(
214225
result,
215-
Err(ref err) if err.exit_cause == PostComputeTeeChallengePrivateKeyMissing
226+
Err(err) if err == ReplicateStatusCause::PostComputeTeeChallengePrivateKeyMissing
216227
),
217228
"Should return missing private key error"
218229
);

src/compute/utils/env_utils.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::compute::errors::{PostComputeError, ReplicateStatusCause};
1+
use crate::compute::errors::ReplicateStatusCause;
22
use std::env;
33

44
pub enum TeeSessionEnvironmentVariable {
@@ -24,9 +24,9 @@ impl TeeSessionEnvironmentVariable {
2424
pub fn get_env_var_or_error(
2525
env_var: TeeSessionEnvironmentVariable,
2626
status_cause_if_missing: ReplicateStatusCause,
27-
) -> Result<String, PostComputeError> {
27+
) -> Result<String, ReplicateStatusCause> {
2828
match env::var(env_var.name()) {
2929
Ok(value) if !value.is_empty() => Ok(value),
30-
_ => Err(PostComputeError::new(status_cause_if_missing)),
30+
_ => Err(status_cause_if_missing),
3131
}
3232
}

0 commit comments

Comments
 (0)