diff --git a/src/decoding.rs b/src/decoding.rs index 8d87f03..7d87dda 100644 --- a/src/decoding.rs +++ b/src/decoding.rs @@ -213,12 +213,10 @@ fn verify_signature<'a>( return Err(new_error(ErrorKind::MissingAlgorithm)); } - if validation.validate_signature { - for alg in &validation.algorithms { - if key.family != alg.family() { - return Err(new_error(ErrorKind::InvalidAlgorithm)); - } - } + if validation.validate_signature + && !validation.algorithms.iter().any(|alg| alg.family() == key.family) + { + return Err(new_error(ErrorKind::InvalidAlgorithm)); } let (signature, message) = expect_two!(token.rsplitn(2, '.')); @@ -229,6 +227,10 @@ fn verify_signature<'a>( return Err(new_error(ErrorKind::InvalidAlgorithm)); } + if header.alg.family() != key.family { + return Err(new_error(ErrorKind::InvalidAlgorithm)); + } + if validation.validate_signature && !verify(signature, message.as_bytes(), key, header.alg)? { return Err(new_error(ErrorKind::InvalidSignature)); } diff --git a/tests/hmac.rs b/tests/hmac.rs index 47ca448..2e1b4ce 100644 --- a/tests/hmac.rs +++ b/tests/hmac.rs @@ -86,6 +86,16 @@ fn decode_token() { claims.unwrap(); } +#[test] +#[wasm_bindgen_test] +fn decode_token_with_multiple_algorithms_allowed() { + let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.9r56oF7ZliOBlOAyiOFperTGxBtPykRQiWNFxhDCW98"; + let mut validation = Validation::new(Algorithm::HS256); + validation.algorithms.push(Algorithm::RS256); + let claims = decode::(token, &DecodingKey::from_secret(b"secret"), &validation); + claims.unwrap(); +} + #[test] #[wasm_bindgen_test] #[should_panic(expected = "InvalidToken")] @@ -126,6 +136,50 @@ fn decode_token_wrong_algorithm() { claims.unwrap(); } +#[test] +#[wasm_bindgen_test] +#[should_panic(expected = "MissingAlgorithm")] +fn decode_missing_algorithm() { + let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUifQ.I1BvFoHe94AFf09O6tDbcSB8-jp8w6xZqmyHIwPeSdY"; + let mut validation = Validation::new(Algorithm::HS256); + validation.algorithms = vec![]; + let claims = decode::(token, &DecodingKey::from_secret(b"secret"), &validation); + claims.unwrap(); +} + +#[test] +#[wasm_bindgen_test] +#[should_panic(expected = "InvalidAlgorithm")] +fn decode_mismatched_key_algorithm() { + let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUifQ.I1BvFoHe94AFf09O6tDbcSB8-jp8w6xZqmyHIwPeSdY"; + let mut validation = Validation::new(Algorithm::HS256); + validation.algorithms.push(Algorithm::RS256); + let key = &DecodingKey::from_rsa_components("aGk", "aGk").unwrap(); + let claims = decode::(token, &key, &validation); + claims.unwrap(); +} + +#[test] +#[wasm_bindgen_test] +#[should_panic(expected = "InvalidAlgorithm")] +fn decode_invalid_header_algorithm() { + let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUifQ.I1BvFoHe94AFf09O6tDbcSB8-jp8w6xZqmyHIwPeSdY"; + let validation = Validation::new(Algorithm::RS256); + let key = &DecodingKey::from_rsa_components("aGk", "aGk").unwrap(); + let claims = decode::(token, &key, &validation); + claims.unwrap(); +} + +#[test] +#[wasm_bindgen_test] +#[should_panic(expected = "InvalidAlgorithm")] +fn wrong_decoding_key_family() { + let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUifQ.I1BvFoHe94AFf09O6tDbcSB8-jp8w6xZqmyHIwPeSdY"; + let validation = Validation::new(Algorithm::RS256); + let claims = decode::(token, &DecodingKey::from_secret(b"secret"), &validation); + claims.unwrap(); +} + #[test] #[wasm_bindgen_test] #[should_panic(expected = "InvalidAlgorithm")]