Skip to content

Commit b20c79e

Browse files
committed
upstream rustc_codegen_llvm changes for enzyme/autodiff
1 parent 3fee0f1 commit b20c79e

File tree

13 files changed

+644
-29
lines changed

13 files changed

+644
-29
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+3-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9-
use crate::expand::typetree::TypeTree;
109
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1110
use crate::ptr::P;
1211
use crate::{Ty, TyKind};
@@ -79,10 +78,6 @@ pub struct AutoDiffItem {
7978
/// The name of the function being generated
8079
pub target: String,
8180
pub attrs: AutoDiffAttrs,
82-
/// Describe the memory layout of input types
83-
pub inputs: Vec<TypeTree>,
84-
/// Describe the memory layout of the output type
85-
pub output: TypeTree,
8681
}
8782
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8883
pub struct AutoDiffAttrs {
@@ -262,22 +257,14 @@ impl AutoDiffAttrs {
262257
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
263258
}
264259

265-
pub fn into_item(
266-
self,
267-
source: String,
268-
target: String,
269-
inputs: Vec<TypeTree>,
270-
output: TypeTree,
271-
) -> AutoDiffItem {
272-
AutoDiffItem { source, target, inputs, output, attrs: self }
260+
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
261+
AutoDiffItem { source, target, attrs: self }
273262
}
274263
}
275264

276265
impl fmt::Display for AutoDiffItem {
277266
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278267
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
279-
write!(f, " with attributes: {:?}", self.attrs)?;
280-
write!(f, " with inputs: {:?}", self.inputs)?;
281-
write!(f, " with output: {:?}", self.output)
268+
write!(f, " with attributes: {:?}", self.attrs)
282269
}
283270
}

compiler/rustc_codegen_llvm/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO
5656
codegen_llvm_run_passes = failed to run LLVM passes
5757
codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err}
5858
59+
codegen_llvm_prepare_autodiff = failed to prepare autodiff: src: {$src}, target: {$target}, {$error}
60+
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare autodiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
61+
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
62+
5963
codegen_llvm_sanitizer_memtag_requires_mte =
6064
`-Zsanitizer=memtag` requires `-Ctarget-feature=+mte`
6165

compiler/rustc_codegen_llvm/src/back/lto.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,14 @@ pub(crate) fn run_pass_manager(
604604
debug!("running the pass manager");
605605
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
606606
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
607-
unsafe { write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }?;
607+
608+
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
609+
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
610+
let first_run = true;
611+
debug!("running llvm pm opt pipeline");
612+
unsafe {
613+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
614+
}
608615
debug!("lto done");
609616
Ok(())
610617
}

compiler/rustc_codegen_llvm/src/back/write.rs

+196-8
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
44
use std::sync::Arc;
55
use std::{fs, slice, str};
66

7-
use libc::{c_char, c_int, c_void, size_t};
7+
use libc::{c_char, c_int, c_uint, c_void, size_t};
88
use llvm::{
99
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols,
1010
};
11+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1112
use rustc_codegen_ssa::back::link::ensure_removed;
1213
use rustc_codegen_ssa::back::versioned_llvm_target;
1314
use rustc_codegen_ssa::back::write::{
@@ -28,7 +29,7 @@ use rustc_session::config::{
2829
use rustc_span::InnerSpan;
2930
use rustc_span::symbol::sym;
3031
use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel};
31-
use tracing::debug;
32+
use tracing::{debug, trace};
3233

3334
use crate::back::lto::ThinBuffer;
3435
use crate::back::owned_target_machine::OwnedTargetMachine;
@@ -40,8 +41,14 @@ use crate::errors::{
4041
CopyBitcode, FromLlvmDiag, FromLlvmOptimizationDiag, LlvmError, UnknownCompression,
4142
WithLlvmError, WriteBytecode,
4243
};
44+
use crate::llvm::AttributePlace::Function;
4345
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
44-
use crate::llvm::{self, DiagnosticInfo, PassManager};
46+
use crate::llvm::{
47+
self, AttributeKind, DiagnosticInfo, LLVMGetFirstFunction,
48+
LLVMGetNextFunction, LLVMGetStringAttributeAtIndex, LLVMIsEnumAttribute, LLVMIsStringAttribute,
49+
LLVMRemoveStringAttributeAtIndex, LLVMRustGetEnumAttributeAtIndex,
50+
LLVMRustRemoveEnumAttributeAtIndex, PassManager,
51+
};
4552
use crate::type_::Type;
4653
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
4754

@@ -517,9 +524,38 @@ pub(crate) unsafe fn llvm_optimize(
517524
config: &ModuleConfig,
518525
opt_level: config::OptLevel,
519526
opt_stage: llvm::OptStage,
527+
skip_size_increasing_opts: bool,
520528
) -> Result<(), FatalError> {
521-
let unroll_loops =
522-
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
529+
// Enzyme:
530+
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
531+
// source code. However, benchmarks show that optimizations increasing the code size
532+
// tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
533+
// and finally re-optimize the module, now with all optimizations available.
534+
// TODO: In a future update we could figure out how to only optimize individual functions getting
535+
// differentiated.
536+
537+
let unroll_loops;
538+
let vectorize_slp;
539+
let vectorize_loop;
540+
541+
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
542+
// optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
543+
// we should make this more granular, or at least check that the user has at least one autodiff
544+
// call in their code, to justify altering the compilation pipeline.
545+
if skip_size_increasing_opts && cfg!(llvm_enzyme) {
546+
unroll_loops = false;
547+
vectorize_slp = false;
548+
vectorize_loop = false;
549+
} else {
550+
unroll_loops =
551+
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
552+
vectorize_slp = config.vectorize_slp;
553+
vectorize_loop = config.vectorize_loop;
554+
}
555+
trace!(
556+
"Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}",
557+
unroll_loops, vectorize_slp, vectorize_loop
558+
);
523559
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
524560
let pgo_gen_path = get_pgo_gen_path(config);
525561
let pgo_use_path = get_pgo_use_path(config);
@@ -583,8 +619,8 @@ pub(crate) unsafe fn llvm_optimize(
583619
using_thin_buffers,
584620
config.merge_functions,
585621
unroll_loops,
586-
config.vectorize_slp,
587-
config.vectorize_loop,
622+
vectorize_slp,
623+
vectorize_loop,
588624
config.no_builtins,
589625
config.emit_lifetime_markers,
590626
sanitizer_options.as_ref(),
@@ -606,6 +642,115 @@ pub(crate) unsafe fn llvm_optimize(
606642
result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses))
607643
}
608644

645+
pub(crate) fn differentiate(
646+
module: &ModuleCodegen<ModuleLlvm>,
647+
cgcx: &CodegenContext<LlvmCodegenBackend>,
648+
diff_items: Vec<AutoDiffItem>,
649+
config: &ModuleConfig,
650+
) -> Result<(), FatalError> {
651+
for item in &diff_items {
652+
trace!("{}", item);
653+
}
654+
655+
let llmod = module.module_llvm.llmod();
656+
let llcx = &module.module_llvm.llcx;
657+
let diag_handler = cgcx.create_dcx();
658+
659+
// Before dumping the module, we want all the tt to become part of the module.
660+
for item in diff_items.iter() {
661+
let name = CString::new(item.source.clone()).unwrap();
662+
let fn_def: Option<&llvm::Value> =
663+
unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) };
664+
let fn_def = match fn_def {
665+
Some(x) => x,
666+
None => {
667+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
668+
src: item.source.clone(),
669+
target: item.target.clone(),
670+
error: "could not find source function".to_owned(),
671+
}));
672+
}
673+
};
674+
let target_name = CString::new(item.target.clone()).unwrap();
675+
debug!("target name: {:?}", &target_name);
676+
let fn_target: Option<&llvm::Value> =
677+
unsafe { llvm::LLVMGetNamedFunction(llmod, target_name.as_ptr()) };
678+
let fn_target = match fn_target {
679+
Some(x) => x,
680+
None => {
681+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
682+
src: item.source.clone(),
683+
target: item.target.clone(),
684+
error: "could not find target function".to_owned(),
685+
}));
686+
}
687+
};
688+
689+
crate::builder::generate_enzyme_call(llmod, llcx, fn_def, fn_target, item.attrs.clone());
690+
}
691+
692+
// We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
693+
// which Enzyme doesn't understand.
694+
unsafe {
695+
let mut f = LLVMGetFirstFunction(llmod);
696+
loop {
697+
if let Some(lf) = f {
698+
f = LLVMGetNextFunction(lf);
699+
let myhwattr = "enzyme_hw";
700+
let attr = LLVMGetStringAttributeAtIndex(
701+
lf,
702+
c_uint::MAX,
703+
myhwattr.as_ptr() as *const c_char,
704+
myhwattr.as_bytes().len() as c_uint,
705+
);
706+
if LLVMIsStringAttribute(attr) {
707+
LLVMRemoveStringAttributeAtIndex(
708+
lf,
709+
c_uint::MAX,
710+
myhwattr.as_ptr() as *const c_char,
711+
myhwattr.as_bytes().len() as c_uint,
712+
);
713+
} else {
714+
LLVMRustRemoveEnumAttributeAtIndex(
715+
lf,
716+
c_uint::MAX,
717+
AttributeKind::SanitizeHWAddress,
718+
);
719+
}
720+
} else {
721+
break;
722+
}
723+
}
724+
}
725+
726+
if let Some(opt_level) = config.opt_level {
727+
let opt_stage = match cgcx.lto {
728+
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
729+
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
730+
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
731+
_ => llvm::OptStage::PreLinkNoLTO,
732+
};
733+
// This is our second opt call, so now we run all opts,
734+
// to make sure we get the best performance.
735+
let skip_size_increasing_opts = false;
736+
trace!("running Module Optimization after differentiation");
737+
unsafe {
738+
llvm_optimize(
739+
cgcx,
740+
diag_handler.handle(),
741+
module,
742+
config,
743+
opt_level,
744+
opt_stage,
745+
skip_size_increasing_opts,
746+
)?
747+
};
748+
}
749+
trace!("done with differentiate()");
750+
751+
Ok(())
752+
}
753+
609754
// Unsafe due to LLVM calls.
610755
pub(crate) unsafe fn optimize(
611756
cgcx: &CodegenContext<LlvmCodegenBackend>,
@@ -628,14 +773,57 @@ pub(crate) unsafe fn optimize(
628773
unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) };
629774
}
630775

776+
// This code enables Enzyme to differentiate code containing Rust enums.
777+
// By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
778+
// away the enums and allows Enzyme to understand why a value can be of different types in
779+
// different code sections. We remove this attribute after Enzyme is done, to not affect the
780+
// rest of the compilation.
781+
//#[cfg(llvm_enzyme)]
782+
unsafe {
783+
let mut f = LLVMGetFirstFunction(llmod);
784+
loop {
785+
if let Some(lf) = f {
786+
f = LLVMGetNextFunction(lf);
787+
let myhwattr = "enzyme_hw";
788+
let prevattr = LLVMRustGetEnumAttributeAtIndex(
789+
lf,
790+
c_uint::MAX,
791+
AttributeKind::SanitizeHWAddress,
792+
);
793+
if LLVMIsEnumAttribute(prevattr) {
794+
let attr = llvm::CreateAttrString(llcx, myhwattr);
795+
crate::attributes::apply_to_llfn(lf, Function, &[attr]);
796+
} else {
797+
let attr = AttributeKind::SanitizeHWAddress.create_attr(llcx);
798+
crate::attributes::apply_to_llfn(lf, Function, &[attr]);
799+
}
800+
} else {
801+
break;
802+
}
803+
}
804+
}
805+
631806
if let Some(opt_level) = config.opt_level {
632807
let opt_stage = match cgcx.lto {
633808
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
634809
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
635810
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
636811
_ => llvm::OptStage::PreLinkNoLTO,
637812
};
638-
return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) };
813+
814+
// If we know that we will later run AD, then we disable vectorization and loop unrolling
815+
let skip_size_increasing_opts = cfg!(llvm_enzyme);
816+
return unsafe {
817+
llvm_optimize(
818+
cgcx,
819+
dcx,
820+
module,
821+
config,
822+
opt_level,
823+
opt_stage,
824+
skip_size_increasing_opts,
825+
)
826+
};
639827
}
640828
Ok(())
641829
}

0 commit comments

Comments
 (0)