Skip to content

Commit 35b7c42

Browse files
committed
backup code cleanups
1 parent eda4968 commit 35b7c42

File tree

5 files changed

+848
-920
lines changed

5 files changed

+848
-920
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 36 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
#![allow(unused_imports)]
2-
#![allow(unused_variables)]
3-
use std::ffi::{CStr, CString};
1+
use std::ffi::CString;
42
use std::io::{self, Write};
53
use std::path::{Path, PathBuf};
64
use std::sync::Arc;
75
use std::{fs, slice, str};
86

97
use libc::{c_char, c_int, c_uint, c_void, size_t};
108
use llvm::{
11-
IntPredicate, LLVMGetNextBasicBlock, LLVMRustDISetInstMetadata,
9+
IntPredicate,
1210
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols,
1311
};
1412
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode};
@@ -47,27 +45,26 @@ use crate::errors::{
4745
};
4846
use crate::llvm::diagnostic::OptimizationDiagnosticKind;
4947
use crate::llvm::{
50-
self, enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, BasicBlock,
48+
self, enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind,
5149
CreateEnzymeLogic, CreateTypeAnalysis, DiagnosticInfo, EnzymeLogicRef, EnzymeTypeAnalysisRef,
52-
FreeTypeAnalysis, LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMBuildCall2,
50+
FreeTypeAnalysis, LLVMAppendBasicBlockInContext, LLVMBuildCall2,
5351
LLVMBuildCondBr, LLVMBuildExtractValue, LLVMBuildICmp, LLVMBuildRet, LLVMBuildRetVoid,
5452
LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext,
55-
LLVMCreateStringAttribute, LLVMDeleteFunction, LLVMDisposeBuilder, LLVMDumpModule,
56-
LLVMGetBasicBlockTerminator, LLVMGetFirstBasicBlock, LLVMGetFirstFunction,
57-
LLVMGetModuleContext, LLVMGetNextFunction, LLVMGetParams, LLVMGetReturnType,
53+
LLVMCreateStringAttribute, LLVMDisposeBuilder, LLVMDumpModule,
54+
LLVMGetFirstBasicBlock, LLVMGetFirstFunction,
55+
LLVMGetNextFunction, LLVMGetParams, LLVMGetReturnType,
5856
LLVMGetStringAttributeAtIndex, LLVMGlobalGetValueType, LLVMIsEnumAttribute,
5957
LLVMIsStringAttribute, LLVMMetadataAsValue, LLVMPositionBuilderAtEnd,
60-
LLVMRemoveStringAttributeAtIndex, LLVMReplaceAllUsesWith, LLVMRustAddEnumAttributeAtIndex,
61-
LLVMRustAddFunctionAttributes, LLVMRustDIGetInstMetadata, LLVMRustDIGetInstMetadataOfTy,
62-
LLVMRustEraseBBFromParent, LLVMRustEraseInstBefore, LLVMRustEraseInstFromParent,
58+
LLVMRemoveStringAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex,
59+
LLVMRustAddFunctionAttributes, LLVMRustDIGetInstMetadata,
60+
LLVMRustEraseInstBefore, LLVMRustEraseInstFromParent,
6361
LLVMRustGetEnumAttributeAtIndex, LLVMRustGetFunctionType, LLVMRustGetLastInstruction,
64-
LLVMRustGetTerminator, LLVMRustHasDbgMetadata, LLVMRustHasMetadata,
65-
LLVMRustRemoveEnumAttributeAtIndex, LLVMRustRemoveFncAttr,
66-
LLVMRustgetFirstNonPHIOrDbgOrLifetime, LLVMSetValueName2, LLVMVerifyFunction,
62+
LLVMRustGetTerminator, LLVMRustHasMetadata,
63+
LLVMRustRemoveEnumAttributeAtIndex,
64+
LLVMVerifyFunction,
6765
LLVMVoidTypeInContext, PassManager, Value,
6866
};
6967
use crate::type_::Type;
70-
use crate::typetree::to_enzyme_typetree;
7168
use crate::{base, common, llvm_util, DiffTypeTree, LlvmCodegenBackend, ModuleLlvm};
7269

7370
pub fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> FatalError {
@@ -669,50 +666,27 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
669666

670667
// DESIGN:
671668
// Today we have our placeholder function, and our Enzyme generated one.
672-
// We create a wrapper function and delete the placeholder body.
673-
// We then call the wrapper from the placeholder.
669+
// We create a wrapper function and delete the placeholder body. You can see the
670+
// placeholder by running `cargo expand` on an autodiff invocation. We call the wrapper
671+
// from the placeholder. This function is a bit longer, because it matches the Rust level
672+
// autodiff macro with LLVM level Enzyme autodiff expectations.
674673
//
675-
// Soon, we won't delete the whole placeholder, but just the loop,
676-
// and the two inline asm sections. For now we can still call the wrapper.
677-
// In the future we call our Enzyme generated function directly and unwrap the return
678-
// struct in our original placeholder.
679-
//
680-
// define internal double @_ZN2ad3bar17ha38374e821680177E(ptr align 8 %0, ptr align 8 %1, double %2) unnamed_addr #17 !dbg !13678 {
681-
// %4 = alloca double, align 8
682-
// %5 = alloca ptr, align 8
683-
// %6 = alloca ptr, align 8
684-
// %7 = alloca { ptr, double }, align 8
685-
// store ptr %0, ptr %6, align 8
686-
// call void @llvm.dbg.declare(metadata ptr %6, metadata !13682, metadata !DIExpression()), !dbg !13685
687-
// store ptr %1, ptr %5, align 8
688-
// call void @llvm.dbg.declare(metadata ptr %5, metadata !13683, metadata !DIExpression()), !dbg !13685
689-
// store double %2, ptr %4, align 8
690-
// call void @llvm.dbg.declare(metadata ptr %4, metadata !13684, metadata !DIExpression()), !dbg !13686
691-
// call void asm sideeffect alignstack inteldialect "NOP", "~{dirflag},~{fpsr},~{flags},~{memory}"(), !dbg !13687, !srcloc !23
692-
// %8 = call double @_ZN2ad3foo17h95b548a9411653b2E(ptr align 8 %0), !dbg !13687
693-
// %9 = call double @_ZN4core4hint9black_box17h7bd67a41b0f12bdfE(double %8), !dbg !13687
694-
// store ptr %1, ptr %7, align 8, !dbg !13687
695-
// %10 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687
696-
// store double %2, ptr %10, align 8, !dbg !13687
697-
// %11 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 0, !dbg !13687
698-
// %12 = load ptr, ptr %11, align 8, !dbg !13687, !nonnull !23, !align !1047, !noundef !23
699-
// %13 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687
700-
// %14 = load double, ptr %13, align 8, !dbg !13687, !noundef !23
701-
// %15 = call { ptr, double } @_ZN4core4hint9black_box17h669f3b22afdcb487E(ptr align 8 %12, double %14), !dbg !13687
702-
// %16 = extractvalue { ptr, double } %15, 0, !dbg !13687
703-
// %17 = extractvalue { ptr, double } %15, 1, !dbg !13687
704-
// br label %18, !dbg !13687
705-
//
706-
//18: ; preds = %18, %3
707-
// br label %18, !dbg !13687
674+
// Think of computing the derivative with respect to &[f32] by marking it as duplicated.
675+
// The user will then pass an extra &mut [f32] and we want add the derivative to that.
676+
// On LLVM/Enzyme level, &[f32] however becomes `ptr, i64` and we mark ptr as duplicated,
677+
// and i64 (len) as const. Enzyme will then expect `ptr, ptr, i64` as arguments. See how the
678+
// second i64 from the mut slice isn't used? That's why we add a safety check to assert
679+
// that the second (mut) slice is at least as long as the first (const) slice. Otherwise,
680+
// Enzyme would write out of bounds if the first (const) slice is longer than the second.
708681

709682
unsafe fn create_call<'a>(
710683
tgt: &'a Value,
711684
src: &'a Value,
712-
rev_mode: bool,
713685
llmod: &'a llvm::Module,
714686
llcx: &llvm::Context,
715-
size_positions: &[usize],
687+
// FIXME: Instead of recomputing the positions as we do it below, we should
688+
// start using this list of positions that indicate length integers.
689+
_size_positions: &[usize],
716690
ad: &[AutoDiff],
717691
) {
718692
unsafe {
@@ -756,9 +730,10 @@ unsafe fn create_call<'a>(
756730
inner_pos += 1;
757731
outer_pos += 1;
758732
} else {
759-
// out: (ptr, <>int1, ptr, int2)
733+
// out: rust: (&[f32], &mut [f32])
734+
// out: llvm: (ptr, <>int1, ptr, int2)
760735
// inner: (ptr, <>ptr, int)
761-
// goal: (ptr, ptr, int1), skipping int2
736+
// goal: call (ptr, ptr, int1), skipping int2
762737
// we are here: <>
763738
assert!(llvm::LLVMRustGetTypeKind(outer_arg_ty) == llvm::TypeKind::Integer);
764739
assert!(llvm::LLVMRustGetTypeKind(inner_arg_ty) == llvm::TypeKind::Pointer);
@@ -872,17 +847,17 @@ unsafe fn create_call<'a>(
872847
);
873848

874849
// Add dummy dbg info to our newly generated call, if we have any.
875-
let inst = LLVMRustgetFirstNonPHIOrDbgOrLifetime(bb).unwrap();
876850
let md_ty = llvm::LLVMGetMDKindIDInContext(
877851
llcx,
878852
"dbg".as_ptr() as *const c_char,
879853
"dbg".len() as c_uint,
880854
);
881855

856+
882857
if LLVMRustHasMetadata(last_inst, md_ty) {
883858
let md = LLVMRustDIGetInstMetadata(last_inst);
884859
let md_val = LLVMMetadataAsValue(llcx, md);
885-
let md2 = llvm::LLVMSetMetadata(struct_ret, md_ty, md_val);
860+
let _md2 = llvm::LLVMSetMetadata(struct_ret, md_ty, md_val);
886861
} else {
887862
trace!("No dbg info");
888863
}
@@ -938,8 +913,8 @@ unsafe fn get_panic_name(llmod: &llvm::Module) -> CString {
938913
// For now we only check if shadow arguments are large enough. In this case we look for Rust panic
939914
// functions in the module and call it. Due to hashing we can't hardcode the panic function name.
940915
// Note: This worked even for panic=abort tests so seems solid enough for now.
941-
// TODO: Pick a panic function which allows displaying an errormessage.
942-
// TODO: We probably want to keep a handle at higher level and pass it down instead of searching.
916+
// FIXME: Pick a panic function which allows displaying an error message.
917+
// FIXME: We probably want to keep a handle at higher level and pass it down instead of searching.
943918
unsafe fn add_panic_msg_to_global<'a>(
944919
llmod: &'a llvm::Module,
945920
llcx: &'a llvm::Context,
@@ -961,7 +936,7 @@ unsafe fn add_panic_msg_to_global<'a>(
961936
let i8_array_type = LLVMArrayType2(LLVMInt8TypeInContext(llcx), msg_len as u64);
962937

963938
// Create the string constant
964-
let string_const_val =
939+
let _string_const_val =
965940
LLVMConstStringInContext2(llcx, cmsg.as_ptr() as *const i8, msg_len as usize, 0);
966941

967942
// Create the array initializer
@@ -1098,8 +1073,7 @@ pub(crate) unsafe fn enzyme_ad(
10981073

10991074
let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));
11001075

1101-
let rev_mode = item.attrs.mode == DiffMode::Reverse;
1102-
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions, ad);
1076+
create_call(target_fnc, res, llmod, llcx, &size_positions, ad);
11031077
// TODO: implement drop for wrapper type?
11041078
FreeTypeAnalysis(type_analysis);
11051079
}
@@ -1133,10 +1107,6 @@ pub(crate) unsafe fn differentiate(
11331107

11341108
// Before dumping the module, we want all the tt to become part of the module.
11351109
for (i, item) in diff_items.iter().enumerate() {
1136-
let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
1137-
let llvm_data_layout =
1138-
std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes())
1139-
.expect("got a non-UTF8 data-layout from LLVM");
11401110
let tt: FncTree = FncTree { args: item.inputs.clone(), ret: item.output.clone() };
11411111
let name = CString::new(item.source.clone()).unwrap();
11421112
let fn_def: &llvm::Value =

0 commit comments

Comments
 (0)