Skip to content

Commit 7f4aa5a

Browse files
committed
fix reshape (partial 1d position convention)
1 parent aa73de6 commit 7f4aa5a

4 files changed

Lines changed: 67 additions & 55 deletions

File tree

enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,21 @@ GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder,
361361
builder, loc, scalarType,
362362
DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
363363

364-
SmallVector<Value> autodiffInputs{position, gradSeed};
364+
bool isCustomLogpdf = ctx.hasCustomLogpdf();
365+
auto flatType = RankedTensorType::get({ctx.positionSize}, elemType);
366+
Value autodiffPosition = position;
367+
auto autodiffPositionType = positionType;
368+
auto autodiffGradType = positionType;
369+
if (isCustomLogpdf) {
370+
autodiffPosition =
371+
enzyme::ReshapeOp::create(builder, loc, flatType, position);
372+
autodiffPositionType = flatType;
373+
autodiffGradType = flatType;
374+
}
375+
376+
SmallVector<Value> autodiffInputs{autodiffPosition, gradSeed};
365377
auto autodiffOp = enzyme::AutoDiffRegionOp::create(
366-
builder, loc, TypeRange{scalarType, rng.getType(), positionType},
378+
builder, loc, TypeRange{scalarType, rng.getType(), autodiffGradType},
367379
autodiffInputs,
368380
builder.getArrayAttr({enzyme::ActivityAttr::get(
369381
builder.getContext(), enzyme::Activity::enzyme_active)}),
@@ -376,12 +388,12 @@ GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder,
376388
nullptr);
377389

378390
Block *autodiffBlock = builder.createBlock(&autodiffOp.getBody());
379-
autodiffBlock->addArgument(positionType, loc);
391+
autodiffBlock->addArgument(autodiffPositionType, loc);
380392

381393
builder.setInsertionPointToStart(autodiffBlock);
382394
Value qArg = autodiffBlock->getArgument(0);
383395

384-
if (ctx.hasCustomLogpdf()) {
396+
if (isCustomLogpdf) {
385397
SmallVector<Value> callArgs;
386398
callArgs.push_back(qArg);
387399
callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
@@ -425,9 +437,14 @@ GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder,
425437

426438
builder.setInsertionPointAfter(autodiffOp);
427439

440+
Value grad = autodiffOp.getResult(2);
441+
if (isCustomLogpdf) {
442+
grad = enzyme::ReshapeOp::create(builder, loc, positionType, grad);
443+
}
444+
428445
return {
429446
autodiffOp.getResult(0), // U
430-
autodiffOp.getResult(2), // grad
447+
grad, // grad
431448
autodiffOp.getResult(1) // rng
432449
};
433450
}
@@ -691,8 +708,10 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng,
691708

692709
if (ctx.hasCustomLogpdf()) {
693710
q0 = initialPosition;
711+
auto flatType = RankedTensorType::get({ctx.positionSize}, elemType);
712+
auto q0Flat = enzyme::ReshapeOp::create(builder, loc, flatType, q0);
694713
SmallVector<Value> callArgs;
695-
callArgs.push_back(q0);
714+
callArgs.push_back(q0Flat);
696715
callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
697716
auto callOp = func::CallOp::create(builder, loc, ctx.logpdfFn,
698717
TypeRange{scalarType}, callArgs);
@@ -735,13 +754,24 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng,
735754
}
736755

737756
// 4. Compute initial gradient at q0
757+
bool isCustomLogpdf = ctx.hasCustomLogpdf();
758+
auto flatType = RankedTensorType::get({ctx.positionSize}, elemType);
759+
Value autodiffQ0 = q0;
760+
auto autodiffQ0Type = positionType;
761+
auto autodiffGradType = positionType;
762+
if (isCustomLogpdf) {
763+
autodiffQ0 = enzyme::ReshapeOp::create(builder, loc, flatType, q0);
764+
autodiffQ0Type = flatType;
765+
autodiffGradType = flatType;
766+
}
767+
738768
auto gradSeedInit = arith::ConstantOp::create(
739769
builder, loc, scalarType,
740770
DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
741-
SmallVector<Value> autodiffInputs{q0, gradSeedInit};
771+
SmallVector<Value> autodiffInputs{autodiffQ0, gradSeedInit};
742772
auto autodiffInit = enzyme::AutoDiffRegionOp::create(
743773
builder, loc,
744-
TypeRange{scalarType, rngForAutodiff.getType(), positionType},
774+
TypeRange{scalarType, rngForAutodiff.getType(), autodiffGradType},
745775
autodiffInputs,
746776
builder.getArrayAttr({enzyme::ActivityAttr::get(
747777
builder.getContext(), enzyme::Activity::enzyme_active)}),
@@ -754,12 +784,12 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng,
754784
nullptr);
755785

756786
Block *autodiffInitBlock = builder.createBlock(&autodiffInit.getBody());
757-
autodiffInitBlock->addArgument(positionType, loc);
787+
autodiffInitBlock->addArgument(autodiffQ0Type, loc);
758788

759789
builder.setInsertionPointToStart(autodiffInitBlock);
760790
auto q0Arg = autodiffInitBlock->getArgument(0);
761791

762-
if (ctx.hasCustomLogpdf()) {
792+
if (isCustomLogpdf) {
763793
SmallVector<Value> callArgs;
764794
callArgs.push_back(q0Arg);
765795
callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
@@ -803,8 +833,10 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng,
803833
}
804834
builder.setInsertionPointAfter(autodiffInit);
805835

806-
// (U, rng, grad)
807-
auto grad0 = autodiffInit.getResult(2);
836+
Value grad0 = autodiffInit.getResult(2);
837+
if (isCustomLogpdf) {
838+
grad0 = enzyme::ReshapeOp::create(builder, loc, positionType, grad0);
839+
}
808840

809841
return {q0, U0, grad0, rngForSampleKernel};
810842
}

enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,8 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
655655
logpdfFnAttr = mcmcOp.getLogpdfFnAttr();
656656
fnInputs.assign(inputs.begin() + 1, inputs.end());
657657
auto initialPos = mcmcOp.getInitialPosition();
658-
positionSize =
659-
cast<RankedTensorType>(initialPos.getType()).getShape()[1];
658+
auto initPosType = cast<RankedTensorType>(initialPos.getType());
659+
positionSize = initPosType.getNumElements();
660660
selection = mcmcOp.getSelectionAttr();
661661
allAddresses = mcmcOp.getAllAddressesAttr();
662662
} else {

enzyme/test/MLIR/ProbProg/mcmc_custom_logpdf.mlir

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %eopt --probprog %s | FileCheck %s
22

33
module {
4-
func.func @logpdf(%x : tensor<1x2xf64>) -> tensor<f64> {
5-
%sum_sq = enzyme.dot %x, %x {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0, 1>, rhs_contracting_dimensions = array<i64: 0, 1>} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<f64>
4+
func.func @logpdf(%x : tensor<2xf64>) -> tensor<f64> {
5+
%sum_sq = enzyme.dot %x, %x {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
66
%neg_half = arith.constant dense<-5.000000e-01> : tensor<f64>
77
%result = arith.mulf %neg_half, %sum_sq : tensor<f64>
88
return %result : tensor<f64>
@@ -65,9 +65,9 @@ module {
6565
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
6666
}
6767

68-
func.func @shifted_logpdf(%x : tensor<1x2xf64>, %mu : tensor<1x2xf64>) -> tensor<f64> {
69-
%diff = arith.subf %x, %mu : tensor<1x2xf64>
70-
%sum_sq = enzyme.dot %diff, %diff {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0, 1>, rhs_contracting_dimensions = array<i64: 0, 1>} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<f64>
68+
func.func @shifted_logpdf(%x : tensor<2xf64>, %mu : tensor<2xf64>) -> tensor<f64> {
69+
%diff = arith.subf %x, %mu : tensor<2xf64>
70+
%sum_sq = enzyme.dot %diff, %diff {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
7171
%neg_half = arith.constant dense<-5.000000e-01> : tensor<f64>
7272
%result = arith.mulf %neg_half, %sum_sq : tensor<f64>
7373
return %result : tensor<f64>
@@ -80,12 +80,7 @@ module {
8080
// CHECK: func.call @shifted_logpdf
8181
// CHECK-NEXT: %[[NEG:.+]] = arith.negf
8282
// CHECK-NEXT: enzyme.yield
83-
// CHECK: enzyme.for_loop
84-
// CHECK: enzyme.autodiff_region
85-
// CHECK: func.call @shifted_logpdf
86-
// CHECK-NEXT: %{{.+}} = arith.negf
87-
// CHECK-NEXT: enzyme.yield
88-
func.func @nuts_shifted_logpdf(%rng : tensor<2xui64>, %mu : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
83+
func.func @nuts_shifted_logpdf(%rng : tensor<2xui64>, %mu : tensor<2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
8984
%init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64>
9085
%step_size = arith.constant dense<0.1> : tensor<f64>
9186
%res:8 = "enzyme.mcmc"(%rng, %mu, %step_size, %init_pos) {
@@ -97,7 +92,7 @@ module {
9792
num_warmup = 0,
9893
num_samples = 1,
9994
operand_segment_sizes = array<i32: 2, 0, 0, 1, 1, 0, 0>
100-
} : (tensor<2xui64>, tensor<1x2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
95+
} : (tensor<2xui64>, tensor<2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
10196
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
10297
}
10398

@@ -108,12 +103,7 @@ module {
108103
// CHECK: func.call @shifted_logpdf
109104
// CHECK-NEXT: %{{.+}} = arith.negf
110105
// CHECK-NEXT: enzyme.yield
111-
// CHECK: enzyme.for_loop
112-
// CHECK: enzyme.autodiff_region
113-
// CHECK: func.call @shifted_logpdf
114-
// CHECK-NEXT: %{{.+}} = arith.negf
115-
// CHECK-NEXT: enzyme.yield
116-
func.func @hmc_shifted_logpdf(%rng : tensor<2xui64>, %mu : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
106+
func.func @hmc_shifted_logpdf(%rng : tensor<2xui64>, %mu : tensor<2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
117107
%init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64>
118108
%step_size = arith.constant dense<0.1> : tensor<f64>
119109
%res:8 = "enzyme.mcmc"(%rng, %mu, %step_size, %init_pos) {
@@ -125,16 +115,16 @@ module {
125115
num_warmup = 0,
126116
num_samples = 1,
127117
operand_segment_sizes = array<i32: 2, 0, 0, 1, 1, 0, 0>
128-
} : (tensor<2xui64>, tensor<1x2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
118+
} : (tensor<2xui64>, tensor<2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
129119
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
130120
}
131121

132-
func.func @anisotropic_logpdf(%x : tensor<1x2xf64>, %mu : tensor<1x2xf64>, %precision : tensor<1x2xf64>) -> tensor<f64> {
133-
%diff = arith.subf %x, %mu : tensor<1x2xf64>
134-
%diff_sq = arith.mulf %diff, %diff : tensor<1x2xf64>
135-
%weighted = arith.mulf %precision, %diff_sq : tensor<1x2xf64>
136-
%ones = arith.constant dense<1.0> : tensor<1x2xf64>
137-
%sum = enzyme.dot %ones, %weighted {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0, 1>, rhs_contracting_dimensions = array<i64: 0, 1>} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<f64>
122+
func.func @anisotropic_logpdf(%x : tensor<2xf64>, %mu : tensor<2xf64>, %precision : tensor<2xf64>) -> tensor<f64> {
123+
%diff = arith.subf %x, %mu : tensor<2xf64>
124+
%diff_sq = arith.mulf %diff, %diff : tensor<2xf64>
125+
%weighted = arith.mulf %precision, %diff_sq : tensor<2xf64>
126+
%ones = arith.constant dense<1.0> : tensor<2xf64>
127+
%sum = enzyme.dot %ones, %weighted {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
138128
%neg_half = arith.constant dense<-5.000000e-01> : tensor<f64>
139129
%result = arith.mulf %neg_half, %sum : tensor<f64>
140130
return %result : tensor<f64>
@@ -147,12 +137,7 @@ module {
147137
// CHECK: func.call @anisotropic_logpdf
148138
// CHECK-NEXT: %[[NEG:.+]] = arith.negf
149139
// CHECK-NEXT: enzyme.yield
150-
// CHECK: enzyme.for_loop
151-
// CHECK: enzyme.autodiff_region
152-
// CHECK: func.call @anisotropic_logpdf
153-
// CHECK-NEXT: %{{.+}} = arith.negf
154-
// CHECK-NEXT: enzyme.yield
155-
func.func @nuts_anisotropic_logpdf(%rng : tensor<2xui64>, %mu : tensor<1x2xf64>, %precision : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
140+
func.func @nuts_anisotropic_logpdf(%rng : tensor<2xui64>, %mu : tensor<2xf64>, %precision : tensor<2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
156141
%init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64>
157142
%step_size = arith.constant dense<0.1> : tensor<f64>
158143
%res:8 = "enzyme.mcmc"(%rng, %mu, %precision, %step_size, %init_pos) {
@@ -164,7 +149,7 @@ module {
164149
num_warmup = 0,
165150
num_samples = 1,
166151
operand_segment_sizes = array<i32: 3, 0, 0, 1, 1, 0, 0>
167-
} : (tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
152+
} : (tensor<2xui64>, tensor<2xf64>, tensor<2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
168153
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
169154
}
170155

@@ -175,12 +160,7 @@ module {
175160
// CHECK: func.call @anisotropic_logpdf
176161
// CHECK-NEXT: %{{.+}} = arith.negf
177162
// CHECK-NEXT: enzyme.yield
178-
// CHECK: enzyme.for_loop
179-
// CHECK: enzyme.autodiff_region
180-
// CHECK: func.call @anisotropic_logpdf
181-
// CHECK-NEXT: %{{.+}} = arith.negf
182-
// CHECK-NEXT: enzyme.yield
183-
func.func @hmc_anisotropic_logpdf(%rng : tensor<2xui64>, %mu : tensor<1x2xf64>, %precision : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
163+
func.func @hmc_anisotropic_logpdf(%rng : tensor<2xui64>, %mu : tensor<2xf64>, %precision : tensor<2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
184164
%init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64>
185165
%step_size = arith.constant dense<0.1> : tensor<f64>
186166
%res:8 = "enzyme.mcmc"(%rng, %mu, %precision, %step_size, %init_pos) {
@@ -192,7 +172,7 @@ module {
192172
num_warmup = 0,
193173
num_samples = 1,
194174
operand_segment_sizes = array<i32: 3, 0, 0, 1, 1, 0, 0>
195-
} : (tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
175+
} : (tensor<2xui64>, tensor<2xf64>, tensor<2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
196176
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
197177
}
198178
}

enzyme/test/MLIR/ProbProg/mcmc_strong_zero.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %eopt --probprog %s | FileCheck %s
22

33
module {
4-
func.func @logpdf(%x : tensor<1x2xf64>) -> tensor<f64> {
5-
%sum_sq = enzyme.dot %x, %x {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0, 1>, rhs_contracting_dimensions = array<i64: 0, 1>} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<f64>
4+
func.func @logpdf(%x : tensor<2xf64>) -> tensor<f64> {
5+
%sum_sq = enzyme.dot %x, %x {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
66
%neg_half = arith.constant dense<-5.000000e-01> : tensor<f64>
77
%result = arith.mulf %neg_half, %sum_sq : tensor<f64>
88
return %result : tensor<f64>

0 commit comments

Comments
 (0)