Skip to content

gpu offload host code generation #142097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 152 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ use rustc_middle::middle::exported_symbols::{SymbolExportInfo, SymbolExportLevel
use rustc_session::config::{self, CrateType, Lto};
use tracing::{debug, info};

use llvm::Linkage::*;

use crate::back::write::{
self, CodegenDiagnosticsStage, DiagnosticHandlers, bitcode_section_name, save_temp_bitcode,
};
use crate::errors::{
DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, build_string};
use crate::llvm::{self, build_string, Linkage};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};

/// We keep track of the computed LTO cache keys from the previous
Expand Down Expand Up @@ -653,6 +655,7 @@ pub(crate) fn run_pass_manager(
// We then run the llvm_optimize function a second time, to optimize the code which we generated
// in the enzyme differentiation pass.
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
let enable_gpu = true;//config.offload.contains(&config::Offload::Enable);
let stage = if thin {
write::AutodiffStage::PreAD
} else {
Expand All @@ -667,6 +670,154 @@ pub(crate) fn run_pass_manager(
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}

if cfg!(llvm_enzyme) && enable_gpu && !thin {
// first we need to add all the fun to the host module
// %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
// %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
if cx.get_function("gen_tgt_offload").is_some() {
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
let ti32 = cx.type_i32();
let ti16 = cx.type_i16();
let ti8 = cx.type_i8();
let tarr = cx.type_array(ti32, 3);

let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
let kernel_elements = vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];

cx.set_struct_body(offload_entry_ty, &entry_elements, false);
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
let global = cx.declare_global("my_struct_global", offload_entry_ty);
let global = cx.declare_global("my_struct_global2", kernel_arguments_ty);
dbg!(&offload_entry_ty);
dbg!(&kernel_arguments_ty);
//LLVMTypeRef elements[9] = {i64Ty, i16Ty, i16Ty, i32Ty, ptrTy, ptrTy, i64Ty, i64Ty, ptrTy};
//LLVMStructSetBody(structTy, elements, 9, 0);
dbg!("created struct");
for num in 0..5 {
if !cx.get_function(&format!("kernel_{num}")).is_some() {
continue;
}

fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value{
let ti64 = cx.type_i64();
let size_ty = cx.type_array(ti64, vals.len() as u64);
let mut size_val = Vec::with_capacity(vals.len());
for &val in vals {
size_val.push(cx.get_const_i64(val));
}
let initializer = cx.const_array(ti64, &size_val);
add_global(cx, name, initializer, PrivateLinkage)
}

fn add_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value, l: Linkage) -> &'ll llvm::Value {
let c_name = CString::new(name).unwrap();
let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
llvm::set_global_constant(llglobal, true);
unsafe {llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global)};
llvm::set_linkage(llglobal, l);
llvm::set_initializer(llglobal, initializer);
llglobal
}

// We add a pair of sizes and maptypes per offloadable function.
// @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
let o_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![8u64,0,16,0]);
let o_types = add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![800u64, 544, 547, 544]);
// TODO: We should add another pair per call to offloadable functions
// @.offload_sizes.5 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
// @.offload_maptypes.6 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]

// Next: For each function, generate these three entries. A weak constant,
// the llvm.rodata entry name, and the omp_offloading_entries value

// @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0
let name = format!(".kernel_{num}.region_id");
let initializer = cx.get_const_i8(0);
add_global(&cx, &name, initializer, WeakAnyLinkage);

let entry_name = format!("kernel_{num}");
let c_entry_name = CString::new(entry_name).unwrap();
let c_val = c_entry_name.as_bytes_with_nul();
let foo = format!(".offloading.entry_name.{num}");

let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
let llglobal = add_global(&cx, &foo, initializer, InternalLinkage);
llvm::set_alignment(llglobal, rustc_abi::Align::ONE);
let c_section_name = CString::new(".llvm.rodata.offloading").unwrap();
llvm::set_section(llglobal, &c_section_name);
// @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
// @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1

// typedef struct {
// uint64_t Reserved;
// uint16_t Version;
// uint16_t Kind;
// uint32_t Flags;
// void *Address;
// char *SymbolName;
// uint64_t Size;
// uint64_t Data;
// void *AuxAddr;
// } __tgt_offload_entry;
//
// enum Flags {
// OMP_REGISTER_REQUIRES = 0x10,
// };
//
// typedef struct {
// void *ImageStart;
// void *ImageEnd;
// __tgt_offload_entry *EntriesBegin;
// __tgt_offload_entry *EntriesEnd;
// } __tgt_device_image;
//
// typedef struct {
// int32_t NumDeviceImages;
// __tgt_device_image *DeviceImages;
// __tgt_offload_entry *HostEntriesBegin;
// __tgt_offload_entry *HostEntriesEnd;
// } __tgt_bin_desc;
// 1. @.offload_sizes.{num} = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
// 2. @.offload_maptypes
// 3. @.__omp_offloading_<hash>_fnc_name_<hash> = weak constant i8 0
// 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
// 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
}
// @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0
// @.offload_sizes = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
// @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
// @.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13.region_id = weak constant i8 0
// @.offload_sizes.1 = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
// @.offload_maptypes.2 = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
// @.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19.region_id = weak constant i8 0
// @.offload_sizes.3 = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
// @.offload_maptypes.4 = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
// @.offload_sizes.5 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
// @.offload_maptypes.6 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
// @_ZSt4cout = external global %"class.std::basic_ostream", align 8
// @.str = private unnamed_addr constant [3 x i8] c"hi\00", align 1
// @.offload_sizes.7 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
// @.offload_maptypes.8 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
// @.str.9 = private unnamed_addr constant [3 x i8] c"ho\00", align 1
// @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
// @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
// @.offloading.entry_name.10 = internal unnamed_addr constant [67 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13\00", section ".llvm.rodata.offloading", align 1
// @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13.region_id, ptr @.offloading.entry_name.10, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
// @.offloading.entry_name.11 = internal unnamed_addr constant [69 x i8] c"__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19\00", section ".llvm.rodata.offloading", align 1
// @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19.region_id, ptr @.offloading.entry_name.11, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
// @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr @_GLOBAL__sub_I_zaxpy.cpp, ptr null }]
} else {
dbg!("no marker found");
}
} else {
dbg!("Not creating struct");
}

if cfg!(llvm_enzyme) && enable_ad && !thin {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_llvm/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BackendTypes for GenericCx<'ll, CX> {
type DIVariable = &'ll llvm::debuginfo::DIVariable;
}

impl<'ll> CodegenCx<'ll, '_> {
impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
pub(crate) fn const_array(&self, ty: &'ll Type, elts: &[&'ll Value]) -> &'ll Value {
let len = u64::try_from(elts.len()).expect("LLVMConstArray2 elements len overflow");
unsafe { llvm::LLVMConstArray2(ty, elts.as_ptr(), len) }
}

pub(crate) fn const_bytes(&self, bytes: &[u8]) -> &'ll Value {
bytes_in_context(self.llcx, bytes)
bytes_in_context(self.llcx(), bytes)
}

pub(crate) fn const_get_elt(&self, v: &'ll Value, idx: u64) -> &'ll Value {
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_codegen_llvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
}

pub(crate) fn get_const_i8(&self, n: u64) -> &'ll Value {
let ty = unsafe { llvm::LLVMInt8TypeInContext(self.llcx()) };
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
}

pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
let name = SmallCStr::new(name);
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }
Expand Down
7 changes: 5 additions & 2 deletions compiler/rustc_codegen_llvm/src/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
)
}
}

}

impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
Expand Down Expand Up @@ -215,7 +216,9 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {

llfn
}
}

impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
/// Declare a global with an intention to define it.
///
/// Use this function when you intend to define a global. This function will
Expand All @@ -234,13 +237,13 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
///
/// Use this function when you intend to define a global without a name.
pub(crate) fn define_private_global(&self, ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod, ty) }
unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod(), ty) }
}

/// Gets declared value by name.
pub(crate) fn get_declared_value(&self, name: &str) -> Option<&'ll Value> {
debug!("get_declared_value(name={:?})", name);
unsafe { llvm::LLVMRustGetNamedValue(self.llmod, name.as_c_char_ptr(), name.len()) }
unsafe { llvm::LLVMRustGetNamedValue(self.llmod(), name.as_c_char_ptr(), name.len()) }
}

/// Gets defined or externally defined (AvailableExternally linkage) value by
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub struct ModuleConfig {
pub emit_lifetime_markers: bool,
pub llvm_plugins: Vec<String>,
pub autodiff: Vec<config::AutoDiff>,
pub offload: Vec<config::Offload>,
}

impl ModuleConfig {
Expand Down Expand Up @@ -270,6 +271,7 @@ impl ModuleConfig {
emit_lifetime_markers: sess.emit_lifetime_markers(),
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
offload: if_regular!(sess.opts.unstable_opts.offload.clone(), vec![]),
}
}

Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_session/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ pub enum CoverageLevel {
Mcdc,
}

// The different settings that the `-Z offload` flag can have.
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
pub enum Offload {
/// Enable the llvm offload pipeline
Enable,
}

/// The different settings that the `-Z autodiff` flag can have.
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
pub enum AutoDiff {
Expand Down Expand Up @@ -3061,7 +3068,7 @@ pub(crate) mod dep_tracking {
};

use super::{
AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
AutoDiff, Offload, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn,
InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail,
LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName,
Expand Down Expand Up @@ -3110,6 +3117,7 @@ pub(crate) mod dep_tracking {

impl_dep_tracking_hash_via_hash!(
AutoDiff,
Offload,
bool,
usize,
NonZero<usize>,
Expand Down
27 changes: 27 additions & 0 deletions compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ mod desc {
pub(crate) const parse_list_with_polarity: &str =
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
pub(crate) const parse_number: &str = "a number";
Expand Down Expand Up @@ -1343,6 +1344,27 @@ pub mod parse {
}
}

pub(crate) fn parse_offload(slot: &mut Vec<Offload>, v: Option<&str>) -> bool {
let Some(v) = v else {
*slot = vec![];
return true;
};
let mut v: Vec<&str> = v.split(",").collect();
v.sort_unstable();
for &val in v.iter() {
let variant = match val {
"Enable" => Offload::Enable,
_ => {
// FIXME(ZuseZ4): print an error saying which value is not recognized
return false;
}
};
slot.push(variant);
}

true
}

pub(crate) fn parse_autodiff(slot: &mut Vec<AutoDiff>, v: Option<&str>) -> bool {
let Some(v) = v else {
*slot = vec![];
Expand Down Expand Up @@ -2372,6 +2394,11 @@ options! {
"do not use unique names for text and data sections when -Z function-sections is used"),
normalize_docs: bool = (false, parse_bool, [TRACKED],
"normalize associated items in rustdoc when generating documentation"),
offload: Vec<crate::config::Offload> = (Vec::new(), parse_offload, [TRACKED],
"a list of offload flags to enable
Mandatory setting:
`=Enable`
Currently the only option available"),
on_broken_pipe: OnBrokenPipe = (OnBrokenPipe::Default, parse_on_broken_pipe, [TRACKED],
"behavior of std::io::ErrorKind::BrokenPipe (SIGPIPE)"),
oom: OomStrategy = (OomStrategy::Abort, parse_oom_strategy, [TRACKED],
Expand Down
Loading