@@ -286,6 +286,119 @@ impl<'tcx> ConvertVisitor<'tcx> {
286
286
)
287
287
}
288
288
289
+ mir_op:: RewriteKind :: MallocSafe {
290
+ ref zero_ty,
291
+ elem_size,
292
+ single,
293
+ }
294
+ | mir_op:: RewriteKind :: CallocSafe {
295
+ ref zero_ty,
296
+ elem_size,
297
+ single,
298
+ } => {
299
+ // `malloc(n)` -> `Box::new(z)` or similar
300
+ assert ! ( matches!( hir_rw, Rewrite :: Identity ) ) ;
301
+ let zeroize_expr = generate_zeroize_expr ( zero_ty) ;
302
+ let mut stmts = match * rw {
303
+ mir_op:: RewriteKind :: MallocSafe { .. } => vec ! [
304
+ Rewrite :: Let ( vec![ ( "byte_len" . into( ) , self . get_subexpr( ex, 0 ) ) ] ) ,
305
+ Rewrite :: Let1 (
306
+ "n" . into( ) ,
307
+ Box :: new( format_rewrite!( "byte_len as usize / {elem_size}" ) ) ,
308
+ ) ,
309
+ ] ,
310
+ mir_op:: RewriteKind :: CallocSafe { .. } => vec ! [
311
+ Rewrite :: Let ( vec![
312
+ ( "count" . into( ) , self . get_subexpr( ex, 0 ) ) ,
313
+ ( "size" . into( ) , self . get_subexpr( ex, 1 ) ) ,
314
+ ] ) ,
315
+ format_rewrite!( "assert_eq!(size, {elem_size})" ) ,
316
+ Rewrite :: Let1 ( "n" . into( ) , Box :: new( format_rewrite!( "count as usize" ) ) ) ,
317
+ ] ,
318
+ _ => unreachable ! ( ) ,
319
+ } ;
320
+ let expr = if single {
321
+ stmts. push ( Rewrite :: Text ( "assert_eq!(n, 1)" . into ( ) ) ) ;
322
+ format_rewrite ! ( "Box::new({})" , zeroize_expr)
323
+ } else {
324
+ stmts. push ( Rewrite :: Let1 (
325
+ "mut v" . into ( ) ,
326
+ Box :: new ( Rewrite :: Text ( "Vec::with_capacity(n)" . into ( ) ) ) ,
327
+ ) ) ;
328
+ stmts. push ( format_rewrite ! (
329
+ "for i in 0..n {{\n v.push({});\n }}" ,
330
+ zeroize_expr,
331
+ ) ) ;
332
+ Rewrite :: Text ( "v.into_boxed_slice()" . into ( ) )
333
+ } ;
334
+ Rewrite :: Block ( stmts, Some ( Box :: new ( expr) ) )
335
+ }
336
+
337
+ mir_op:: RewriteKind :: FreeSafe { single : _ } => {
338
+ // `free(p)` -> `drop(p)`
339
+ assert ! ( matches!( hir_rw, Rewrite :: Identity ) ) ;
340
+ Rewrite :: Call ( "std::mem::drop" . to_string ( ) , vec ! [ self . get_subexpr( ex, 0 ) ] )
341
+ }
342
+
343
+ mir_op:: RewriteKind :: ReallocSafe {
344
+ ref zero_ty,
345
+ elem_size,
346
+ src_single,
347
+ dest_single,
348
+ } => {
349
+ // `realloc(p, n)` -> `Box::new(...)`
350
+ assert ! ( matches!( hir_rw, Rewrite :: Identity ) ) ;
351
+ let zeroize_expr = generate_zeroize_expr ( zero_ty) ;
352
+ let mut stmts = vec ! [
353
+ Rewrite :: Let ( vec![
354
+ ( "src_ptr" . into( ) , self . get_subexpr( ex, 0 ) ) ,
355
+ ( "dest_byte_len" . into( ) , self . get_subexpr( ex, 1 ) ) ,
356
+ ] ) ,
357
+ Rewrite :: Let1 (
358
+ "dest_n" . into( ) ,
359
+ Box :: new( format_rewrite!( "dest_byte_len as usize / {elem_size}" ) ) ,
360
+ ) ,
361
+ ] ;
362
+ if dest_single {
363
+ stmts. push ( Rewrite :: Text ( "assert_eq!(dest_n, 1)" . into ( ) ) ) ;
364
+ }
365
+ let expr = match ( src_single, dest_single) {
366
+ ( false , false ) => {
367
+ stmts. push ( Rewrite :: Let1 (
368
+ "mut dest_ptr" . into ( ) ,
369
+ Box :: new ( Rewrite :: Text ( "Vec::from(src_ptr)" . into ( ) ) ) ,
370
+ ) ) ;
371
+ stmts. push ( format_rewrite ! (
372
+ "dest_ptr.resize_with(dest_n, || {})" ,
373
+ zeroize_expr,
374
+ ) ) ;
375
+ Rewrite :: Text ( "dest_ptr.into_boxed_slice()" . into ( ) )
376
+ }
377
+ ( false , true ) => {
378
+ format_rewrite ! (
379
+ "src_ptr.into_iter().next().unwrap_or_else(|| {})" ,
380
+ zeroize_expr
381
+ )
382
+ }
383
+ ( true , false ) => {
384
+ stmts. push ( Rewrite :: Let1 (
385
+ "mut dest_ptr" . into ( ) ,
386
+ Box :: new ( Rewrite :: Text ( "Vec::with_capacity(dest_n)" . into ( ) ) ) ,
387
+ ) ) ;
388
+ stmts. push ( Rewrite :: Text (
389
+ "if dest_n >= 1 { dest_ptr.push(*src_ptr); }" . into ( ) ,
390
+ ) ) ;
391
+ stmts. push ( format_rewrite ! (
392
+ "dest_ptr.resize_with(dest_n, || {})" ,
393
+ zeroize_expr,
394
+ ) ) ;
395
+ Rewrite :: Text ( "dest_ptr.into_boxed_slice()" . into ( ) )
396
+ }
397
+ ( true , true ) => Rewrite :: Text ( "src_ptr" . into ( ) ) ,
398
+ } ;
399
+ Rewrite :: Block ( stmts, Some ( Box :: new ( expr) ) )
400
+ }
401
+
289
402
mir_op:: RewriteKind :: CellGet => {
290
403
// `*x` to `Cell::get(x)`
291
404
assert ! ( matches!( hir_rw, Rewrite :: Identity ) ) ;
@@ -566,7 +679,7 @@ fn generate_zeroize_code(zero_ty: &ZeroizeType, lv: &str) -> String {
566
679
match * zero_ty {
567
680
ZeroizeType :: Int => format ! ( "{lv} = 0" ) ,
568
681
ZeroizeType :: Bool => format ! ( "{lv} = false" ) ,
569
- ZeroizeType :: Iterable ( ref elem_zero_ty) => format ! (
682
+ ZeroizeType :: Array ( ref elem_zero_ty) => format ! (
570
683
"
571
684
{{
572
685
for elem in {lv}.iter_mut() {{
@@ -576,7 +689,7 @@ fn generate_zeroize_code(zero_ty: &ZeroizeType, lv: &str) -> String {
576
689
" ,
577
690
generate_zeroize_code( elem_zero_ty, "(*elem)" )
578
691
) ,
579
- ZeroizeType :: Struct ( ref fields) => {
692
+ ZeroizeType :: Struct ( _ , ref fields) => {
580
693
eprintln ! ( "zeroize: {} fields on {lv}: {fields:?}" , fields. len( ) ) ;
581
694
let mut s = String :: new ( ) ;
582
695
writeln ! ( s, "{{" ) . unwrap ( ) ;
@@ -594,6 +707,27 @@ fn generate_zeroize_code(zero_ty: &ZeroizeType, lv: &str) -> String {
594
707
}
595
708
}
596
709
710
+ /// Generate an expression to produce a zeroized version of a value.
711
+ fn generate_zeroize_expr ( zero_ty : & ZeroizeType ) -> String {
712
+ match * zero_ty {
713
+ ZeroizeType :: Int => format ! ( "0" ) ,
714
+ ZeroizeType :: Bool => format ! ( "false" ) ,
715
+ ZeroizeType :: Array ( ref elem_zero_ty) => format ! (
716
+ "std::array::from_fn(|| {})" ,
717
+ generate_zeroize_expr( elem_zero_ty)
718
+ ) ,
719
+ ZeroizeType :: Struct ( ref name, ref fields) => {
720
+ let mut s = String :: new ( ) ;
721
+ write ! ( s, "{} {{\n " , name) . unwrap ( ) ;
722
+ for ( name, field_zero_ty) in fields {
723
+ write ! ( s, "{}: {},\n " , name, generate_zeroize_expr( field_zero_ty) , ) . unwrap ( ) ;
724
+ }
725
+ write ! ( s, "}}\n " ) . unwrap ( ) ;
726
+ s
727
+ }
728
+ }
729
+ }
730
+
597
731
fn take_prefix_while < ' a , T > ( slice : & mut & ' a [ T ] , mut pred : impl FnMut ( & ' a T ) -> bool ) -> & ' a [ T ] {
598
732
let i = slice. iter ( ) . position ( |x| !pred ( x) ) . unwrap_or ( slice. len ( ) ) ;
599
733
let ( a, b) = slice. split_at ( i) ;
@@ -614,14 +748,14 @@ pub fn convert_cast_rewrite(kind: &mir_op::RewriteKind, hir_rw: Rewrite) -> Rewr
614
748
Rewrite :: Ref ( Box :: new ( elem) , mutbl_from_bool ( mutbl) )
615
749
}
616
750
617
- mir_op:: RewriteKind :: MutToImm => {
618
- // `p` -> `&*p`
751
+ mir_op:: RewriteKind :: Reborrow { mutbl } => {
752
+ // `p` -> `&*p` / `&mut *p`
619
753
let hir_rw = match fold_mut_to_imm ( hir_rw) {
620
754
Ok ( folded_rw) => return folded_rw,
621
755
Err ( rw) => rw,
622
756
} ;
623
757
let place = Rewrite :: Deref ( Box :: new ( hir_rw) ) ;
624
- Rewrite :: Ref ( Box :: new ( place) , hir :: Mutability :: Not )
758
+ Rewrite :: Ref ( Box :: new ( place) , mutbl_from_bool ( mutbl ) )
625
759
}
626
760
627
761
mir_op:: RewriteKind :: OptionUnwrap => {
@@ -661,6 +795,33 @@ pub fn convert_cast_rewrite(kind: &mir_op::RewriteKind, hir_rw: Rewrite) -> Rewr
661
795
Rewrite :: MethodCall ( ref_method, Box :: new ( hir_rw) , vec ! [ ] )
662
796
}
663
797
798
+ mir_op:: RewriteKind :: DynOwnedUnwrap => {
799
+ Rewrite :: MethodCall ( "unwrap" . to_string ( ) , Box :: new ( hir_rw) , vec ! [ ] )
800
+ }
801
+ mir_op:: RewriteKind :: DynOwnedTake => {
802
+ // `p` -> `mem::replace(&mut p, Err(()))`
803
+ Rewrite :: Call (
804
+ "std::mem::replace" . to_string ( ) ,
805
+ vec ! [
806
+ Rewrite :: Ref ( Box :: new( hir_rw) , hir:: Mutability :: Mut ) ,
807
+ Rewrite :: Text ( "Err(())" . into( ) ) ,
808
+ ] ,
809
+ )
810
+ }
811
+ mir_op:: RewriteKind :: DynOwnedWrap => {
812
+ Rewrite :: Call ( "std::result::Result::<_, ()>::Ok" . to_string ( ) , vec ! [ hir_rw] )
813
+ }
814
+
815
+ mir_op:: RewriteKind :: DynOwnedDowngrade { mutbl } => {
816
+ let ref_method = if mutbl {
817
+ "as_deref_mut" . into ( )
818
+ } else {
819
+ "as_deref" . into ( )
820
+ } ;
821
+ let hir_rw = Rewrite :: MethodCall ( ref_method, Box :: new ( hir_rw) , vec ! [ ] ) ;
822
+ Rewrite :: MethodCall ( "unwrap" . into ( ) , Box :: new ( hir_rw) , vec ! [ ] )
823
+ }
824
+
664
825
mir_op:: RewriteKind :: CastRefToRaw { mutbl } => {
665
826
// `addr_of!(*p)` is cleaner than `p as *const _`; we don't know the pointee
666
827
// type here, so we can't emit `p as *const T`.
0 commit comments