@@ -1715,7 +1715,7 @@ struct KbkdfHmac {
17151715enum CounterLocation {
17161716 BeforeFixed ,
17171717 AfterFixed ,
1718- MiddleFixed ,
1718+ MiddleFixed ( usize ) ,
17191719}
17201720
17211721struct KbkdfParams {
@@ -1725,7 +1725,6 @@ struct KbkdfParams {
17251725 label : Option < pyo3:: Py < pyo3:: types:: PyBytes > > ,
17261726 context : Option < pyo3:: Py < pyo3:: types:: PyBytes > > ,
17271727 fixed : Option < pyo3:: Py < pyo3:: types:: PyBytes > > ,
1728- break_location : Option < usize > ,
17291728}
17301729
17311730#[ allow( clippy:: too_many_arguments) ]
@@ -1765,16 +1764,15 @@ fn validate_kbkdf_parameters(
17651764 CounterLocation :: AfterFixed
17661765 } else {
17671766 // There are only 3 options so this is MiddleFixed
1768- CounterLocation :: MiddleFixed
1767+ if break_location. is_none ( ) {
1768+ return Err ( CryptographyError :: from (
1769+ pyo3:: exceptions:: PyValueError :: new_err ( "Please specify a break_location" ) ,
1770+ ) ) ;
1771+ }
1772+ CounterLocation :: MiddleFixed ( break_location. unwrap ( ) )
17691773 } ;
17701774
1771- if rust_location == CounterLocation :: MiddleFixed && break_location. is_none ( ) {
1772- return Err ( CryptographyError :: from (
1773- pyo3:: exceptions:: PyValueError :: new_err ( "Please specify a break_location" ) ,
1774- ) ) ;
1775- }
1776-
1777- if break_location. is_some ( ) && rust_location != CounterLocation :: MiddleFixed {
1775+ if break_location. is_some ( ) && !matches ! ( rust_location, CounterLocation :: MiddleFixed ( _) ) {
17781776 return Err ( CryptographyError :: from (
17791777 pyo3:: exceptions:: PyValueError :: new_err (
17801778 "break_location is ignored when location is not CounterLocation.MiddleFixed" ,
@@ -1815,7 +1813,6 @@ fn validate_kbkdf_parameters(
18151813 label,
18161814 context,
18171815 fixed,
1818- break_location,
18191816 } )
18201817}
18211818
@@ -1847,21 +1844,20 @@ impl KbkdfHmac {
18471844
18481845 let fixed = self . generate_fixed_input ( py) ?;
18491846
1850- let ( data_before_ctr, data_after_ctr) = match & self . params . location {
1847+ let ( data_before_ctr, data_after_ctr) = match self . params . location {
18511848 CounterLocation :: BeforeFixed => ( & b"" [ ..] , & fixed[ ..] ) ,
18521849 CounterLocation :: AfterFixed => ( & fixed[ ..] , & b"" [ ..] ) ,
1853- CounterLocation :: MiddleFixed => {
1850+ CounterLocation :: MiddleFixed ( break_location ) => {
18541851 // We validate break_location is Some when counter_location is MiddleFixed
18551852 // in the validate function
1856- let break_loc = self . params . break_location . unwrap ( ) ;
1857- if break_loc > fixed. len ( ) {
1853+ if break_location > fixed. len ( ) {
18581854 return Err ( CryptographyError :: from (
18591855 pyo3:: exceptions:: PyValueError :: new_err (
18601856 "break_location offset > len(fixed)" ,
18611857 ) ,
18621858 ) ) ;
18631859 }
1864- ( & fixed[ ..break_loc ] , & fixed[ break_loc ..] )
1860+ ( & fixed[ ..break_location ] , & fixed[ break_location ..] )
18651861 }
18661862 } ;
18671863
@@ -1892,9 +1888,14 @@ impl KbkdfHmac {
18921888 }
18931889
18941890 // llen will exist if fixed data is not provided
1895- let py_bitlength = pyo3:: types:: PyInt :: new ( py, self . length )
1896- . mul ( 8 ) ?
1897- . extract :: < pyo3:: Bound < ' _ , pyo3:: types:: PyInt > > ( ) ?;
1891+ let py_bitlength = pyo3:: types:: PyInt :: new (
1892+ py,
1893+ self . length
1894+ . checked_mul ( 8 )
1895+ . ok_or ( pyo3:: exceptions:: PyOverflowError :: new_err (
1896+ "Length too large, would cause overflow in bit length calculation" ,
1897+ ) ) ?,
1898+ ) ;
18981899 let l_val = py_uint_to_be_bytes_with_length ( py, py_bitlength, self . params . llen . unwrap ( ) ) ?;
18991900
19001901 let mut result = Vec :: new ( ) ;
0 commit comments