Skip to content

Commit 9b78a1c

Browse files
authored
Add jacobian op (#2706)
* Add enzyme.jacobian op * Remove JVP_Apply and VJP_Apply ops * unfix fmt
1 parent ad6c5f6 commit 9b78a1c

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,17 @@ def AutoDiffOp : Enzyme_Op<"autodiff",
161161
}];
162162
}
163163

164+
def JacobianOp : Enzyme_Op<"jacobian",
165+
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
166+
let summary = "Compute Jacobian for a function";
167+
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width, DefaultValuedAttr<BoolAttr, "false">:$strong_zero);
168+
let results = (outs Variadic<AnyType>:$outputs);
169+
170+
let assemblyFormat = [{
171+
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
172+
}];
173+
}
174+
164175
def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> {
165176
let summary = "Perform reverse mode AD on a child region";
166177
let arguments = (ins Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width, DefaultValuedAttr<BoolAttr, "false">:$strong_zero, OptionalAttr<StrAttr>:$fn);

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ ForwardDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
162162
return success();
163163
}
164164

165+
LogicalResult JacobianOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
166+
auto global =
167+
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
168+
if (!global)
169+
return emitOpError("'")
170+
<< getFn() << "' does not reference a valid global funcOp";
171+
172+
return success();
173+
}
174+
165175
//===----------------------------------------------------------------------===//
166176
// ForwardDiffOp
167177
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)