Skip to content

Commit 0e3599d

Browse files
committed
internal enum
1 parent 65a4e0f commit 0e3599d

File tree

1 file changed

+41
-33
lines changed

1 file changed

+41
-33
lines changed

src/rust/src/backend/kdf.rs

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,10 +1717,18 @@ fn int_to_bytes(value: usize, length: usize) -> Vec<u8> {
17171717
bytes
17181718
}
17191719

1720+
#[allow(clippy::enum_variant_names)]
1721+
#[derive(PartialEq)]
1722+
enum CounterLocation {
1723+
BeforeFixed,
1724+
AfterFixed,
1725+
MiddleFixed,
1726+
}
1727+
17201728
struct KbkdfParams {
17211729
rlen: usize,
17221730
llen: Option<usize>,
1723-
location: pyo3::Py<pyo3::PyAny>,
1731+
location: CounterLocation,
17241732
label: Option<pyo3::Py<pyo3::types::PyBytes>>,
17251733
context: Option<pyo3::Py<pyo3::types::PyBytes>>,
17261734
fixed: Option<pyo3::Py<pyo3::types::PyBytes>>,
@@ -1748,23 +1756,32 @@ fn validate_kbkdf_parameters(
17481756
}
17491757

17501758
let location_bound = location.bind(py);
1751-
let counter_location_type = crate::types::KBKDF_COUNTER_LOCATION.get(py)?;
1752-
if !location_bound.is_instance(&counter_location_type)? {
1759+
let counter_location = crate::types::KBKDF_COUNTER_LOCATION.get(py)?;
1760+
if !location_bound.is_instance(&counter_location)? {
17531761
return Err(CryptographyError::from(
17541762
pyo3::exceptions::PyTypeError::new_err("location must be of type CounterLocation"),
17551763
));
17561764
}
17571765

1758-
let counter_location_middle_fixed = crate::types::KBKDF_COUNTER_LOCATION
1759-
.get(py)?
1760-
.getattr(pyo3::intern!(py, "MiddleFixed"))?;
1761-
if location_bound.eq(&counter_location_middle_fixed)? && break_location.is_none() {
1766+
let counter_location_before_fixed =
1767+
counter_location.getattr(pyo3::intern!(py, "BeforeFixed"))?;
1768+
let counter_location_after_fixed = counter_location.getattr(pyo3::intern!(py, "AfterFixed"))?;
1769+
let rust_location = if location_bound.eq(&counter_location_before_fixed)? {
1770+
CounterLocation::BeforeFixed
1771+
} else if location_bound.eq(&counter_location_after_fixed)? {
1772+
CounterLocation::AfterFixed
1773+
} else {
1774+
// There are only 3 options so this is MiddleFixed
1775+
CounterLocation::MiddleFixed
1776+
};
1777+
1778+
if rust_location == CounterLocation::MiddleFixed && break_location.is_none() {
17621779
return Err(CryptographyError::from(
17631780
pyo3::exceptions::PyValueError::new_err("Please specify a break_location"),
17641781
));
17651782
}
17661783

1767-
if break_location.is_some() && !location_bound.eq(&counter_location_middle_fixed)? {
1784+
if break_location.is_some() && rust_location != CounterLocation::MiddleFixed {
17681785
return Err(CryptographyError::from(
17691786
pyo3::exceptions::PyValueError::new_err(
17701787
"break_location is ignored when location is not CounterLocation.MiddleFixed",
@@ -1801,7 +1818,7 @@ fn validate_kbkdf_parameters(
18011818
Ok(KbkdfParams {
18021819
rlen,
18031820
llen,
1804-
location,
1821+
location: rust_location,
18051822
label,
18061823
context,
18071824
fixed,
@@ -1837,31 +1854,22 @@ impl KbkdfHmac {
18371854

18381855
let fixed = self.generate_fixed_input(py)?;
18391856

1840-
let counter_location = self.params.location.bind(py);
1841-
let counter_location_before_fixed = crate::types::KBKDF_COUNTER_LOCATION
1842-
.get(py)?
1843-
.getattr(pyo3::intern!(py, "BeforeFixed"))?;
1844-
let counter_location_after_fixed = crate::types::KBKDF_COUNTER_LOCATION
1845-
.get(py)?
1846-
.getattr(pyo3::intern!(py, "AfterFixed"))?;
1847-
1848-
let (data_before_ctr, data_after_ctr) = if counter_location
1849-
.eq(&counter_location_before_fixed)?
1850-
{
1851-
(&b""[..], &fixed[..])
1852-
} else if counter_location.eq(&counter_location_after_fixed)? {
1853-
(&fixed[..], &b""[..])
1854-
} else {
1855-
// There are only 3 counter locations so this is MiddleFixed
1856-
// We validate break_location is Some when counter_location is MiddleFixed
1857-
// in the validate function
1858-
let break_loc = self.params.break_location.unwrap();
1859-
if break_loc > fixed.len() {
1860-
return Err(CryptographyError::from(
1861-
pyo3::exceptions::PyValueError::new_err("break_location offset > len(fixed)"),
1862-
));
1857+
let (data_before_ctr, data_after_ctr) = match &self.params.location {
1858+
CounterLocation::BeforeFixed => (&b""[..], &fixed[..]),
1859+
CounterLocation::AfterFixed => (&fixed[..], &b""[..]),
1860+
CounterLocation::MiddleFixed => {
1861+
// We validate break_location is Some when counter_location is MiddleFixed
1862+
// in the validate function
1863+
let break_loc = self.params.break_location.unwrap();
1864+
if break_loc > fixed.len() {
1865+
return Err(CryptographyError::from(
1866+
pyo3::exceptions::PyValueError::new_err(
1867+
"break_location offset > len(fixed)",
1868+
),
1869+
));
1870+
}
1871+
(&fixed[..break_loc], &fixed[break_loc..])
18631872
}
1864-
(&fixed[..break_loc], &fixed[break_loc..])
18651873
};
18661874

18671875
let mut pos = 0usize;

0 commit comments

Comments
 (0)