Skip to content

Commit 9176242

Browse files
committed
added typetree support for memcpy
1 parent bd70be1 commit 9176242

File tree

15 files changed

+123
-18
lines changed

15 files changed

+123
-18
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
244244
scratch_align,
245245
bx.const_usize(copy_bytes),
246246
MemFlags::empty(),
247+
None,
247248
);
248249
bx.lifetime_end(llscratch, scratch_size);
249250
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -1118,11 +1119,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11181119
src_align: Align,
11191120
size: &'ll Value,
11201121
flags: MemFlags,
1122+
tt: Option<FncTree>,
11211123
) {
11221124
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11231125
let size = self.intcast(size, self.type_isize(), false);
11241126
let is_volatile = flags.contains(MemFlags::VOLATILE);
1125-
unsafe {
1127+
let memcpy = unsafe {
11261128
llvm::LLVMRustBuildMemCpy(
11271129
self.llbuilder,
11281130
dst,
@@ -1131,7 +1133,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11311133
src_align.bytes() as c_uint,
11321134
size,
11331135
is_volatile,
1134-
);
1136+
)
1137+
};
1138+
1139+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1140+
// a memcpy during autodiff, it needs to know the structure of the data being
1141+
// copied to properly track derivatives. For example, copying an array of floats
1142+
// vs. copying a struct with mixed types requires different derivative handling.
1143+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1144+
if let Some(tt) = tt {
1145+
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11351146
}
11361147
}
11371148

compiler/rustc_codegen_llvm/src/va_arg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
735735
src_align,
736736
bx.const_u32(layout.layout.size().bytes() as u32),
737737
MemFlags::empty(),
738+
None,
738739
);
739740
tmp
740741
} else {

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
16231623
align,
16241624
bx.const_usize(copy_bytes),
16251625
MemFlags::empty(),
1626+
None,
16261627
);
16271628
// ...and then load it with the ABI type.
16281629
llval = load_cast(bx, cast, llscratch, scratch_align);

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
3030
if allow_overlap {
3131
bx.memmove(dst, align, src, align, size, flags);
3232
} else {
33-
bx.memcpy(dst, align, src, align, size, flags);
33+
bx.memcpy(dst, align, src, align, size, flags, None);
3434
}
3535
}
3636

compiler/rustc_codegen_ssa/src/mir/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
9090
let align = pointee_layout.align;
9191
let dst = dst_val.immediate();
9292
let src = src_val.immediate();
93-
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
93+
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
9494
}
9595
mir::StatementKind::FakeRead(..)
9696
| mir::StatementKind::Retag { .. }

compiler/rustc_codegen_ssa/src/traits/builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ pub trait BuilderMethods<'a, 'tcx>:
424424
src_align: Align,
425425
size: Self::Value,
426426
flags: MemFlags,
427+
tt: Option<rustc_ast::expand::typetree::FncTree>,
427428
);
428429
fn memmove(
429430
&mut self,
@@ -480,7 +481,7 @@ pub trait BuilderMethods<'a, 'tcx>:
480481
temp.val.store_with_flags(self, dst.with_type(layout), flags);
481482
} else if !layout.is_zst() {
482483
let bytes = self.const_usize(layout.size.bytes());
483-
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
484+
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
484485
}
485486
}
486487

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2286,7 +2286,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22862286
let child = typetree_from_ty(tcx, inner_ty);
22872287
return TypeTree(vec![Type {
22882288
offset: -1,
2289-
size: 8, // TODO(KMJ-007): Get actual pointer size from target
2289+
size: tcx.data_layout.pointer_size().bytes_usize(),
22902290
kind: Kind::Pointer,
22912291
child,
22922292
}]);

tests/codegen-llvm/autodiff/typetree.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ fn main() {
3030
let output_ = d_simple(&x, &mut df_dx, 1.0);
3131
assert_eq!(output, output_);
3232
assert_eq!(2.0, df_dx);
33-
}
33+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; Check that enzyme_type attributes are present in the LLVM IR function definition
2+
; This verifies our TypeTree system correctly attaches metadata for Enzyme
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
5+
6+
; Check that llvm.memcpy exists (either call or declare)
7+
CHECK: {{(call|declare).*}}@llvm.memcpy
8+

0 commit comments

Comments
 (0)