@@ -1717,10 +1717,18 @@ fn int_to_bytes(value: usize, length: usize) -> Vec<u8> {
17171717 bytes
17181718}
17191719
1720+ #[ allow( clippy:: enum_variant_names) ]
1721+ #[ derive( PartialEq ) ]
1722+ enum CounterLocation {
1723+ BeforeFixed ,
1724+ AfterFixed ,
1725+ MiddleFixed ,
1726+ }
1727+
17201728struct KbkdfParams {
17211729 rlen : usize ,
17221730 llen : Option < usize > ,
1723- location : pyo3 :: Py < pyo3 :: PyAny > ,
1731+ location : CounterLocation ,
17241732 label : Option < pyo3:: Py < pyo3:: types:: PyBytes > > ,
17251733 context : Option < pyo3:: Py < pyo3:: types:: PyBytes > > ,
17261734 fixed : Option < pyo3:: Py < pyo3:: types:: PyBytes > > ,
@@ -1748,23 +1756,32 @@ fn validate_kbkdf_parameters(
17481756 }
17491757
17501758 let location_bound = location. bind ( py) ;
1751- let counter_location_type = crate :: types:: KBKDF_COUNTER_LOCATION . get ( py) ?;
1752- if !location_bound. is_instance ( & counter_location_type ) ? {
1759+ let counter_location = crate :: types:: KBKDF_COUNTER_LOCATION . get ( py) ?;
1760+ if !location_bound. is_instance ( & counter_location ) ? {
17531761 return Err ( CryptographyError :: from (
17541762 pyo3:: exceptions:: PyTypeError :: new_err ( "location must be of type CounterLocation" ) ,
17551763 ) ) ;
17561764 }
17571765
1758- let counter_location_middle_fixed = crate :: types:: KBKDF_COUNTER_LOCATION
1759- . get ( py) ?
1760- . getattr ( pyo3:: intern!( py, "MiddleFixed" ) ) ?;
1761- if location_bound. eq ( & counter_location_middle_fixed) ? && break_location. is_none ( ) {
1766+ let counter_location_before_fixed =
1767+ counter_location. getattr ( pyo3:: intern!( py, "BeforeFixed" ) ) ?;
1768+ let counter_location_after_fixed = counter_location. getattr ( pyo3:: intern!( py, "AfterFixed" ) ) ?;
1769+ let rust_location = if location_bound. eq ( & counter_location_before_fixed) ? {
1770+ CounterLocation :: BeforeFixed
1771+ } else if location_bound. eq ( & counter_location_after_fixed) ? {
1772+ CounterLocation :: AfterFixed
1773+ } else {
1774+ // There are only 3 options so this is MiddleFixed
1775+ CounterLocation :: MiddleFixed
1776+ } ;
1777+
1778+ if rust_location == CounterLocation :: MiddleFixed && break_location. is_none ( ) {
17621779 return Err ( CryptographyError :: from (
17631780 pyo3:: exceptions:: PyValueError :: new_err ( "Please specify a break_location" ) ,
17641781 ) ) ;
17651782 }
17661783
1767- if break_location. is_some ( ) && !location_bound . eq ( & counter_location_middle_fixed ) ? {
1784+ if break_location. is_some ( ) && rust_location != CounterLocation :: MiddleFixed {
17681785 return Err ( CryptographyError :: from (
17691786 pyo3:: exceptions:: PyValueError :: new_err (
17701787 "break_location is ignored when location is not CounterLocation.MiddleFixed" ,
@@ -1801,7 +1818,7 @@ fn validate_kbkdf_parameters(
18011818 Ok ( KbkdfParams {
18021819 rlen,
18031820 llen,
1804- location,
1821+ location : rust_location ,
18051822 label,
18061823 context,
18071824 fixed,
@@ -1837,31 +1854,22 @@ impl KbkdfHmac {
18371854
18381855 let fixed = self . generate_fixed_input ( py) ?;
18391856
1840- let counter_location = self . params . location . bind ( py) ;
1841- let counter_location_before_fixed = crate :: types:: KBKDF_COUNTER_LOCATION
1842- . get ( py) ?
1843- . getattr ( pyo3:: intern!( py, "BeforeFixed" ) ) ?;
1844- let counter_location_after_fixed = crate :: types:: KBKDF_COUNTER_LOCATION
1845- . get ( py) ?
1846- . getattr ( pyo3:: intern!( py, "AfterFixed" ) ) ?;
1847-
1848- let ( data_before_ctr, data_after_ctr) = if counter_location
1849- . eq ( & counter_location_before_fixed) ?
1850- {
1851- ( & b"" [ ..] , & fixed[ ..] )
1852- } else if counter_location. eq ( & counter_location_after_fixed) ? {
1853- ( & fixed[ ..] , & b"" [ ..] )
1854- } else {
1855- // There are only 3 counter locations so this is MiddleFixed
1856- // We validate break_location is Some when counter_location is MiddleFixed
1857- // in the validate function
1858- let break_loc = self . params . break_location . unwrap ( ) ;
1859- if break_loc > fixed. len ( ) {
1860- return Err ( CryptographyError :: from (
1861- pyo3:: exceptions:: PyValueError :: new_err ( "break_location offset > len(fixed)" ) ,
1862- ) ) ;
1857+ let ( data_before_ctr, data_after_ctr) = match & self . params . location {
1858+ CounterLocation :: BeforeFixed => ( & b"" [ ..] , & fixed[ ..] ) ,
1859+ CounterLocation :: AfterFixed => ( & fixed[ ..] , & b"" [ ..] ) ,
1860+ CounterLocation :: MiddleFixed => {
1861+ // We validate break_location is Some when counter_location is MiddleFixed
1862+ // in the validate function
1863+ let break_loc = self . params . break_location . unwrap ( ) ;
1864+ if break_loc > fixed. len ( ) {
1865+ return Err ( CryptographyError :: from (
1866+ pyo3:: exceptions:: PyValueError :: new_err (
1867+ "break_location offset > len(fixed)" ,
1868+ ) ,
1869+ ) ) ;
1870+ }
1871+ ( & fixed[ ..break_loc] , & fixed[ break_loc..] )
18631872 }
1864- ( & fixed[ ..break_loc] , & fixed[ break_loc..] )
18651873 } ;
18661874
18671875 let mut pos = 0usize ;
0 commit comments