diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index ee46b49a094c6..b7445ef0f3732 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -22,6 +22,8 @@ use rustc_middle::middle::exported_symbols::{SymbolExportInfo, SymbolExportLevel use rustc_session::config::{self, CrateType, Lto}; use tracing::{debug, info}; +use llvm::Linkage::*; + use crate::back::write::{ self, CodegenDiagnosticsStage, DiagnosticHandlers, bitcode_section_name, save_temp_bitcode, }; @@ -29,7 +31,7 @@ use crate::errors::{ DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro, }; use crate::llvm::AttributePlace::Function; -use crate::llvm::{self, build_string}; +use crate::llvm::{self, build_string, Linkage}; use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes}; /// We keep track of the computed LTO cache keys from the previous @@ -653,6 +655,7 @@ pub(crate) fn run_pass_manager( // We then run the llvm_optimize function a second time, to optimize the code which we generated // in the enzyme differentiation pass. let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable); + let enable_gpu = true;//config.offload.contains(&config::Offload::Enable); let stage = if thin { write::AutodiffStage::PreAD } else { @@ -667,6 +670,158 @@ pub(crate) fn run_pass_manager( write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; } + if cfg!(llvm_enzyme) && enable_gpu && !thin { + // first we need to add all the fun to the host module + // %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr } + // %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } + let cx = + SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); + if cx.get_function("gen_tgt_offload").is_some() { + let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); + let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let ti16 = cx.type_i16(); + let ti8 = cx.type_i8(); + let tarr = cx.type_array(ti32, 3); + + // coppied from LLVM + // typedef struct { + // uint64_t Reserved; + // uint16_t Version; + // uint16_t Kind; + // uint32_t Flags; + // void *Address; + // char *SymbolName; + // uint64_t Size; + // uint64_t Data; + // void *AuxAddr; + // } __tgt_offload_entry; + let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; + let kernel_elements = vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; + + cx.set_struct_body(offload_entry_ty, &entry_elements, false); + cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); + let global = cx.declare_global("my_struct_global", offload_entry_ty); + let global = cx.declare_global("my_struct_global2", kernel_arguments_ty); +//@my_struct_global = external global %struct.__tgt_offload_entry +//@my_struct_global2 = external global %struct.__tgt_kernel_arguments + dbg!(&offload_entry_ty); + dbg!(&kernel_arguments_ty); + //LLVMTypeRef elements[9] = {i64Ty, i16Ty, i16Ty, i32Ty, ptrTy, ptrTy, i64Ty, i64Ty, ptrTy}; + //LLVMStructSetBody(structTy, elements, 9, 0); + dbg!("created struct"); + for num in 0..9 { + if !cx.get_function(&format!("kernel_{num}")).is_some() { + continue; + } + + fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value{ + let ti64 = cx.type_i64(); + let size_ty = cx.type_array(ti64, vals.len() as u64); + let mut size_val = Vec::with_capacity(vals.len()); + for &val in vals { + size_val.push(cx.get_const_i64(val)); + } + let initializer = cx.const_array(ti64, &size_val); + add_unnamed_global(cx, name, initializer, PrivateLinkage) + } + + fn add_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value, l: Linkage) -> &'ll llvm::Value { + let c_name = CString::new(name).unwrap(); + let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, l); + llvm::set_initializer(llglobal, initializer); + llglobal + } + + fn add_unnamed_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value, l: Linkage) -> &'ll llvm::Value { + let llglobal = add_global(cx, name, initializer, l); + unsafe {llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global)}; + llglobal + } + + // We add a pair of sizes and maptypes per offloadable function. + // @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544] + let o_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![8u64,0,16,0]); + let o_types = add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![800u64, 544, 547, 544]); + // TODO: We should add another pair per call to offloadable functions + // @.offload_sizes.5 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384] + // @.offload_maptypes.6 = private unnamed_addr constant [2 x i64] [i64 1, i64 3] + + // Next: For each function, generate these three entries. A weak constant, + // the llvm.rodata entry name, and the omp_offloading_entries value + + // @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0 + // @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1 + let name = format!(".kernel_{num}.region_id"); + let initializer = cx.get_const_i8(0); + let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage); + + let c_entry_name = CString::new(format!("kernel_{num}")).unwrap(); + let c_val = c_entry_name.as_bytes_with_nul(); + let foo = format!(".offloading.entry_name.{num}"); + + let initializer = crate::common::bytes_in_context(cx.llcx, c_val); + let llglobal = add_unnamed_global(&cx, &foo, initializer, InternalLinkage); + llvm::set_alignment(llglobal, rustc_abi::Align::ONE); + let c_section_name = CString::new(".llvm.rodata.offloading").unwrap(); + llvm::set_section(llglobal, &c_section_name); + + + // New, TODO: cleanup + let name = format!(".offloading.entry.kernel_{num}"); + let ci64_0 = cx.get_const_i64(0); + let ci16_1 = cx.get_const_i16(1); + let elems: Vec<&llvm::Value> = vec![ci64_0, ci16_1, ci16_1, cx.get_const_i32(0), region_id, llglobal, ci64_0, ci64_0, cx.const_null(cx.type_ptr())]; + + let initializer = crate::common::named_struct(offload_entry_ty, &elems); + let c_name = CString::new(name).unwrap(); + let llglobal = llvm::add_global(cx.llmod, offload_entry_ty, &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, WeakAnyLinkage); + llvm::set_initializer(llglobal, initializer); + llvm::set_alignment(llglobal, rustc_abi::Align::ONE); + let c_section_name = CString::new(".omp_offloading_entries").unwrap(); + llvm::set_section(llglobal, &c_section_name); + // rustc + // @.offloading.entry.kernel_3 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_3.region_id, ptr @.offloading.entry_name.3, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1 + // clang + // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1 + + // + // enum Flags { + // OMP_REGISTER_REQUIRES = 0x10, + // }; + // + // typedef struct { + // void *ImageStart; + // void *ImageEnd; + // __tgt_offload_entry *EntriesBegin; + // __tgt_offload_entry *EntriesEnd; + // } __tgt_device_image; + // + // typedef struct { + // int32_t NumDeviceImages; + // __tgt_device_image *DeviceImages; + // __tgt_offload_entry *HostEntriesBegin; + // __tgt_offload_entry *HostEntriesEnd; + // } __tgt_bin_desc; + // 1. @.offload_sizes.{num} = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0] + // 2. @.offload_maptypes + // 3. @.__omp_offloading__fnc_name_ = weak constant i8 0 + // 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1 + // 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1 + } + } else { + dbg!("no marker found"); + } + } else { + dbg!("Not creating struct"); + } + if cfg!(llvm_enzyme) && enable_ad && !thin { let cx = SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); diff --git a/compiler/rustc_codegen_llvm/src/common.rs b/compiler/rustc_codegen_llvm/src/common.rs index 3cfa96393e920..8bf6a059d8195 100644 --- a/compiler/rustc_codegen_llvm/src/common.rs +++ b/compiler/rustc_codegen_llvm/src/common.rs @@ -99,14 +99,14 @@ impl<'ll, CX: Borrow>> BackendTypes for GenericCx<'ll, CX> { type DIVariable = &'ll llvm::debuginfo::DIVariable; } -impl<'ll> CodegenCx<'ll, '_> { +impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { pub(crate) fn const_array(&self, ty: &'ll Type, elts: &[&'ll Value]) -> &'ll Value { let len = u64::try_from(elts.len()).expect("LLVMConstArray2 elements len overflow"); unsafe { llvm::LLVMConstArray2(ty, elts.as_ptr(), len) } } pub(crate) fn const_bytes(&self, bytes: &[u8]) -> &'ll Value { - bytes_in_context(self.llcx, bytes) + bytes_in_context(self.llcx(), bytes) } pub(crate) fn const_get_elt(&self, v: &'ll Value, idx: u64) -> &'ll Value { @@ -119,6 +119,10 @@ impl<'ll> CodegenCx<'ll, '_> { r } } + + pub(crate) fn const_null(&self, t: &'ll Type) -> &'ll Value { + unsafe { llvm::LLVMConstNull(t) } + } } impl<'ll, 'tcx> ConstCodegenMethods for CodegenCx<'ll, 'tcx> { @@ -373,6 +377,14 @@ pub(crate) fn bytes_in_context<'ll>(llcx: &'ll llvm::Context, bytes: &[u8]) -> & } } +pub(crate) fn named_struct<'ll>( + ty: &'ll Type, + elts: &[&'ll Value], +) -> &'ll Value { + let len = c_uint::try_from(elts.len()).expect("LLVMConstStructInContext elements len overflow"); + unsafe { llvm::LLVMConstNamedStruct(ty, elts.as_ptr(), len) } +} + fn struct_in_context<'ll>( llcx: &'ll llvm::Context, elts: &[&'ll Value], diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 8d6e1d8941b72..ef29494a3737d 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -685,6 +685,21 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { unsafe { llvm::LLVMConstInt(ty, n, llvm::False) } } + pub(crate) fn get_const_i32(&self, n: u64) -> &'ll Value { + let ty = unsafe { llvm::LLVMInt32TypeInContext(self.llcx()) }; + unsafe { llvm::LLVMConstInt(ty, n, llvm::False) } + } + + pub(crate) fn get_const_i16(&self, n: u64) -> &'ll Value { + let ty = unsafe { llvm::LLVMInt16TypeInContext(self.llcx()) }; + unsafe { llvm::LLVMConstInt(ty, n, llvm::False) } + } + + pub(crate) fn get_const_i8(&self, n: u64) -> &'ll Value { + let ty = unsafe { llvm::LLVMInt8TypeInContext(self.llcx()) }; + unsafe { llvm::LLVMConstInt(ty, n, llvm::False) } + } + pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> { let name = SmallCStr::new(name); unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) } diff --git a/compiler/rustc_codegen_llvm/src/declare.rs b/compiler/rustc_codegen_llvm/src/declare.rs index 2419ec1f88854..12e7f45be41f7 100644 --- a/compiler/rustc_codegen_llvm/src/declare.rs +++ b/compiler/rustc_codegen_llvm/src/declare.rs @@ -99,6 +99,7 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { ) } } + } impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { @@ -215,7 +216,9 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { llfn } +} +impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { /// Declare a global with an intention to define it. /// /// Use this function when you intend to define a global. This function will @@ -234,13 +237,13 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { /// /// Use this function when you intend to define a global without a name. pub(crate) fn define_private_global(&self, ty: &'ll Type) -> &'ll Value { - unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod, ty) } + unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod(), ty) } } /// Gets declared value by name. pub(crate) fn get_declared_value(&self, name: &str) -> Option<&'ll Value> { debug!("get_declared_value(name={:?})", name); - unsafe { llvm::LLVMRustGetNamedValue(self.llmod, name.as_c_char_ptr(), name.len()) } + unsafe { llvm::LLVMRustGetNamedValue(self.llmod(), name.as_c_char_ptr(), name.len()) } } /// Gets defined or externally defined (AvailableExternally linkage) value by diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index e27fbf94f341d..b4d38d79e2d67 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1139,6 +1139,11 @@ unsafe extern "C" { Count: c_uint, Packed: Bool, ) -> &'a Value; + pub(crate) fn LLVMConstNamedStruct<'a>( + StructTy: &'a Type, + ConstantVals: *const &'a Value, + Count: c_uint, + ) -> &'a Value; pub(crate) fn LLVMConstVector(ScalarConstantVals: *const &Value, Size: c_uint) -> &Value; // Constant expressions diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index a41ca8ce28bce..91d16f6b99e1c 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -121,6 +121,7 @@ pub struct ModuleConfig { pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, pub autodiff: Vec, + pub offload: Vec, } impl ModuleConfig { @@ -270,6 +271,7 @@ impl ModuleConfig { emit_lifetime_markers: sess.emit_lifetime_markers(), llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]), autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]), + offload: if_regular!(sess.opts.unstable_opts.offload.clone(), vec![]), } } diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 60e1b465ba96d..2c2f82461ec6f 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -226,6 +226,13 @@ pub enum CoverageLevel { Mcdc, } +// The different settings that the `-Z offload` flag can have. +#[derive(Clone, Copy, PartialEq, Hash, Debug)] +pub enum Offload { + /// Enable the llvm offload pipeline + Enable, +} + /// The different settings that the `-Z autodiff` flag can have. #[derive(Clone, Copy, PartialEq, Hash, Debug)] pub enum AutoDiff { @@ -3061,7 +3068,7 @@ pub(crate) mod dep_tracking { }; use super::{ - AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, + AutoDiff, Offload, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn, InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail, LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName, @@ -3110,6 +3117,7 @@ pub(crate) mod dep_tracking { impl_dep_tracking_hash_via_hash!( AutoDiff, + Offload, bool, usize, NonZero, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 5b4068740a159..4f9e7b4d9a88d 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -712,6 +712,7 @@ mod desc { pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; + pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1343,6 +1344,27 @@ pub mod parse { } } + pub(crate) fn parse_offload(slot: &mut Vec, v: Option<&str>) -> bool { + let Some(v) = v else { + *slot = vec![]; + return true; + }; + let mut v: Vec<&str> = v.split(",").collect(); + v.sort_unstable(); + for &val in v.iter() { + let variant = match val { + "Enable" => Offload::Enable, + _ => { + // FIXME(ZuseZ4): print an error saying which value is not recognized + return false; + } + }; + slot.push(variant); + } + + true + } + pub(crate) fn parse_autodiff(slot: &mut Vec, v: Option<&str>) -> bool { let Some(v) = v else { *slot = vec![]; @@ -2372,6 +2394,11 @@ options! { "do not use unique names for text and data sections when -Z function-sections is used"), normalize_docs: bool = (false, parse_bool, [TRACKED], "normalize associated items in rustdoc when generating documentation"), + offload: Vec = (Vec::new(), parse_offload, [TRACKED], + "a list of offload flags to enable + Mandatory setting: + `=Enable` + Currently the only option available"), on_broken_pipe: OnBrokenPipe = (OnBrokenPipe::Default, parse_on_broken_pipe, [TRACKED], "behavior of std::io::ErrorKind::BrokenPipe (SIGPIPE)"), oom: OomStrategy = (OomStrategy::Abort, parse_oom_strategy, [TRACKED],