Skip to content

Commit ac46d2a

Browse files
committed
feedback
1 parent 4427698 commit ac46d2a

File tree

2 files changed

+25
-31
lines changed

2 files changed

+25
-31
lines changed

src/rust/src/asn1.rs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ pub(crate) fn py_uint_to_big_endian_bytes<'p>(
7171
py: pyo3::Python<'p>,
7272
v: pyo3::Bound<'p, pyo3::types::PyInt>,
7373
) -> pyo3::PyResult<PyBackedBytes> {
74-
reject_negative_integer(&v)?;
7574
// Round the length up so that we prefix an extra \x00. This ensures that
7675
// integers that'd have the high bit set in their first octet are not
7776
// encoded as negative in DER.
@@ -83,22 +82,16 @@ pub(crate) fn py_uint_to_big_endian_bytes<'p>(
8382
py_uint_to_be_bytes_with_length(py, v, length)
8483
}
8584

86-
fn reject_negative_integer(v: &pyo3::Bound<'_, pyo3::types::PyInt>) -> pyo3::PyResult<()> {
87-
if v.lt(0)? {
88-
Err(pyo3::exceptions::PyValueError::new_err(
89-
"Negative integers are not supported",
90-
))
91-
} else {
92-
Ok(())
93-
}
94-
}
95-
9685
pub(crate) fn py_uint_to_be_bytes_with_length<'p>(
9786
py: pyo3::Python<'p>,
9887
v: pyo3::Bound<'p, pyo3::types::PyInt>,
9988
length: usize,
10089
) -> pyo3::PyResult<PyBackedBytes> {
101-
reject_negative_integer(&v)?;
90+
if v.lt(0)? {
91+
return Err(pyo3::exceptions::PyValueError::new_err(
92+
"Negative integers are not supported",
93+
));
94+
}
10295
Ok(
10396
v.call_method1(pyo3::intern!(py, "to_bytes"), (length, "big"))?
10497
.extract()?,

src/rust/src/backend/kdf.rs

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,7 +1715,7 @@ struct KbkdfHmac {
17151715
enum CounterLocation {
17161716
BeforeFixed,
17171717
AfterFixed,
1718-
MiddleFixed,
1718+
MiddleFixed(usize),
17191719
}
17201720

17211721
struct KbkdfParams {
@@ -1725,7 +1725,6 @@ struct KbkdfParams {
17251725
label: Option<pyo3::Py<pyo3::types::PyBytes>>,
17261726
context: Option<pyo3::Py<pyo3::types::PyBytes>>,
17271727
fixed: Option<pyo3::Py<pyo3::types::PyBytes>>,
1728-
break_location: Option<usize>,
17291728
}
17301729

17311730
#[allow(clippy::too_many_arguments)]
@@ -1765,16 +1764,15 @@ fn validate_kbkdf_parameters(
17651764
CounterLocation::AfterFixed
17661765
} else {
17671766
// There are only 3 options so this is MiddleFixed
1768-
CounterLocation::MiddleFixed
1767+
if break_location.is_none() {
1768+
return Err(CryptographyError::from(
1769+
pyo3::exceptions::PyValueError::new_err("Please specify a break_location"),
1770+
));
1771+
}
1772+
CounterLocation::MiddleFixed(break_location.unwrap())
17691773
};
17701774

1771-
if rust_location == CounterLocation::MiddleFixed && break_location.is_none() {
1772-
return Err(CryptographyError::from(
1773-
pyo3::exceptions::PyValueError::new_err("Please specify a break_location"),
1774-
));
1775-
}
1776-
1777-
if break_location.is_some() && rust_location != CounterLocation::MiddleFixed {
1775+
if break_location.is_some() && !matches!(rust_location, CounterLocation::MiddleFixed(_)) {
17781776
return Err(CryptographyError::from(
17791777
pyo3::exceptions::PyValueError::new_err(
17801778
"break_location is ignored when location is not CounterLocation.MiddleFixed",
@@ -1815,7 +1813,6 @@ fn validate_kbkdf_parameters(
18151813
label,
18161814
context,
18171815
fixed,
1818-
break_location,
18191816
})
18201817
}
18211818

@@ -1847,21 +1844,20 @@ impl KbkdfHmac {
18471844

18481845
let fixed = self.generate_fixed_input(py)?;
18491846

1850-
let (data_before_ctr, data_after_ctr) = match &self.params.location {
1847+
let (data_before_ctr, data_after_ctr) = match self.params.location {
18511848
CounterLocation::BeforeFixed => (&b""[..], &fixed[..]),
18521849
CounterLocation::AfterFixed => (&fixed[..], &b""[..]),
1853-
CounterLocation::MiddleFixed => {
1850+
CounterLocation::MiddleFixed(break_location) => {
18541851
// We validate break_location is Some when counter_location is MiddleFixed
18551852
// in the validate function
1856-
let break_loc = self.params.break_location.unwrap();
1857-
if break_loc > fixed.len() {
1853+
if break_location > fixed.len() {
18581854
return Err(CryptographyError::from(
18591855
pyo3::exceptions::PyValueError::new_err(
18601856
"break_location offset > len(fixed)",
18611857
),
18621858
));
18631859
}
1864-
(&fixed[..break_loc], &fixed[break_loc..])
1860+
(&fixed[..break_location], &fixed[break_location..])
18651861
}
18661862
};
18671863

@@ -1892,9 +1888,14 @@ impl KbkdfHmac {
18921888
}
18931889

18941890
// llen will exist if fixed data is not provided
1895-
let py_bitlength = pyo3::types::PyInt::new(py, self.length)
1896-
.mul(8)?
1897-
.extract::<pyo3::Bound<'_, pyo3::types::PyInt>>()?;
1891+
let py_bitlength = pyo3::types::PyInt::new(
1892+
py,
1893+
self.length
1894+
.checked_mul(8)
1895+
.ok_or(pyo3::exceptions::PyOverflowError::new_err(
1896+
"Length too large, would cause overflow in bit length calculation",
1897+
))?,
1898+
);
18981899
let l_val = py_uint_to_be_bytes_with_length(py, py_bitlength, self.params.llen.unwrap())?;
18991900

19001901
let mut result = Vec::new();

0 commit comments

Comments
 (0)