Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jun 11, 2024
1 parent dd9a010 commit 9410339
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 72 deletions.
73 changes: 6 additions & 67 deletions src/execution/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use inkwell::context::AsContextRef;
use inkwell::intrinsics::Intrinsic;
use inkwell::module::Module;
use inkwell::passes::PassManager;
use inkwell::types::{AnyTypeEnum, BasicMetadataTypeEnum, BasicTypeEnum, FloatType, IntType};
use inkwell::types::{BasicMetadataTypeEnum, BasicTypeEnum, FloatType, IntType};
use inkwell::values::{
AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, FloatValue, FunctionValue, GlobalValue, IntValue, PointerValue
};
Expand All @@ -19,7 +19,7 @@ type RealType = f64;
use crate::ast::{Ast, AstKind};
use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
use crate::enzyme::{
CConcreteType_DT_Double, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis, EnzymeCreateForwardDiff, EnzymeFreeTypeTree, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTree, EnzymeNewTypeTreeCT, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic, FreeTypeAnalysis, IntList, LLVMOpaqueContext, LLVMOpaqueValue, CDIFFE_TYPE_DFT_CONSTANT, CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED
CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis, EnzymeCreateForwardDiff, EnzymeFreeTypeTree, EnzymeLogicRef, EnzymeNewTypeTreeCT, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic, FreeTypeAnalysis, IntList, LLVMOpaqueContext, LLVMOpaqueValue, CDIFFE_TYPE_DFT_CONSTANT, CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED
};
use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};

Expand Down Expand Up @@ -92,9 +92,6 @@ pub type GetOutFunc = unsafe extern "C" fn(
);

struct Globals<'ctx> {
enzyme_dup: GlobalValue<'ctx>,
enzyme_const: GlobalValue<'ctx>,
enzyme_dupnoneed: GlobalValue<'ctx>,
indices: GlobalValue<'ctx>,
}

Expand All @@ -115,17 +112,6 @@ impl<'ctx> Globals<'ctx> {
let indices_value = int_type.const_array(indices_array_values.as_slice());
let _int_ptr_type = int_type.ptr_type(AddressSpace::default());
let globals = Self {
enzyme_dup: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_dup"),
enzyme_const: module.add_global(
int_type,
Some(AddressSpace::default()),
"enzyme_const",
),
enzyme_dupnoneed: module.add_global(
int_type,
Some(AddressSpace::default()),
"enzyme_dupnoneed",
),
indices: module.add_global(
indices_array_type,
Some(AddressSpace::default()),
Expand Down Expand Up @@ -1673,11 +1659,6 @@ impl<'ctx> CodeGen<'ctx> {
) -> Result<FunctionValue<'ctx>> {
self.clear();

let globals = match self.globals {
Some(ref globals) => globals,
None => panic!("globals not set"),
};

// construct the gradient function
let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
let orig_fn_type_ptr = original_function
Expand Down Expand Up @@ -1709,20 +1690,7 @@ impl<'ctx> CodeGen<'ctx> {
self.fn_value_opt = Some(function);
self.builder.position_at_end(basic_block);

let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = vec![original_function
.as_global_value()
.as_pointer_value()
.into()];
let enzyme_const = self
.builder
.build_load(globals.enzyme_const.as_pointer_value(), "enzyme_const")?;
let enzyme_dup = self
.builder
.build_load(globals.enzyme_dup.as_pointer_value(), "enzyme_dup")?;
let enzyme_dupnoneed = self.builder.build_load(
globals.enzyme_dupnoneed.as_pointer_value(),
"enzyme_dupnoneed",
)?;
let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
let mut input_activity = Vec::new();
let mut arg_trees = Vec::new();
for (i, _arg) in original_function.get_param_iter().enumerate() {
Expand All @@ -1738,22 +1706,9 @@ impl<'ctx> CodeGen<'ctx> {
let new_tree = unsafe { EnzymeNewTypeTreeCT(concrete_type, self.context.as_ctx_ref() as *mut LLVMOpaqueContext) };
unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };

// pointer to double
if concrete_type == CConcreteType_DT_Pointer {
let inner_concrete_type = match _arg.get_type().into_pointer_type().get_element_type() {
AnyTypeEnum::FloatType(_) => CConcreteType_DT_Double,
_ => panic!("unsupported type"),
};
let inner_new_tree = unsafe { EnzymeNewTypeTreeCT(inner_concrete_type, self.context.as_ctx_ref() as *mut LLVMOpaqueContext) };
//unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
}
arg_trees.push(new_tree);
match args_type[i] {
CompileGradientArgType::Dup => {
// let enzyme know its an active arg
enzyme_fn_args.push(enzyme_dup.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());

Expand All @@ -1764,9 +1719,6 @@ impl<'ctx> CodeGen<'ctx> {
input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
}
CompileGradientArgType::DupNoNeed => {
// let enzyme know its an active arg we don't need
enzyme_fn_args.push(enzyme_dupnoneed.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());

Expand All @@ -1777,9 +1729,6 @@ impl<'ctx> CodeGen<'ctx> {
input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
}
CompileGradientArgType::Const => {
// let enzyme know its a constant arg
enzyme_fn_args.push(enzyme_const.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());

Expand All @@ -1789,8 +1738,8 @@ impl<'ctx> CodeGen<'ctx> {
}
// if we have void ret, this must be false;
let ret_primary_ret = false;
let ret_activity = CDIFFE_TYPE_DFT_DUP_NONEED;
let ret_tree = unsafe { EnzymeNewTypeTree() };
let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
let ret_tree = unsafe { EnzymeNewTypeTreeCT(CConcreteType_DT_Anything, self.context.as_ctx_ref() as *mut LLVMOpaqueContext) };

// always optimize
let fnc_opt_base = true;
Expand Down Expand Up @@ -1843,17 +1792,6 @@ impl<'ctx> CodeGen<'ctx> {
unsafe { EnzymeFreeTypeTree(tree) };
}

// construct enzyme function
// double df = __enzyme_fwddiff<double>((void*)f, enzyme_dup, x, dx, enzyme_dup, y, dy);
//let enzyme_fn_type = void_type.fn_type(&enzyme_fn_type, false);
//let orig_fn_name = original_function.get_name().to_str().unwrap();
//let enzyme_fn_name = format!("__enzyme_fwddiff_{}", orig_fn_name);
//let enzyme_function = self.module.add_function(
// enzyme_fn_name.as_str(),
// enzyme_fn_type,
// Some(Linkage::External),
//);

// call enzyme function
let enzyme_function =
unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
Expand All @@ -1868,6 +1806,7 @@ impl<'ctx> CodeGen<'ctx> {
Ok(function)
} else {
function.print_to_stderr();
enzyme_function.print_to_stderr();
unsafe {
function.delete();
}
Expand Down
6 changes: 1 addition & 5 deletions wrapper.h
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
#include <Enzyme/CApi.h>



void *EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI, LLVMValueRef F);
#include <Enzyme/CApi.h>

0 comments on commit 9410339

Please sign in to comment.