11// RUN: %eopt --probprog %s | FileCheck %s
22
33module {
4- func.func @logpdf (%x : tensor <1 x 2 x f64 >) -> tensor <f64 > {
5- %sum_sq = enzyme.dot %x , %x {lhs_batching_dimensions = array<i64 >, rhs_batching_dimensions = array<i64 >, lhs_contracting_dimensions = array<i64 : 0 , 1 >, rhs_contracting_dimensions = array<i64 : 0 , 1 >} : (tensor <1 x 2 x f64 >, tensor <1 x 2 x f64 >) -> tensor <f64 >
4+ func.func @logpdf (%x : tensor <2 x f64 >) -> tensor <f64 > {
5+ %sum_sq = enzyme.dot %x , %x {lhs_batching_dimensions = array<i64 >, rhs_batching_dimensions = array<i64 >, lhs_contracting_dimensions = array<i64 : 0 >, rhs_contracting_dimensions = array<i64 : 0 >} : (tensor <2 x f64 >, tensor <2 x f64 >) -> tensor <f64 >
66 %neg_half = arith.constant dense <-5.000000e-01 > : tensor <f64 >
77 %result = arith.mulf %neg_half , %sum_sq : tensor <f64 >
88 return %result : tensor <f64 >
@@ -65,9 +65,9 @@ module {
6565 return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
6666 }
6767
68- func.func @shifted_logpdf (%x : tensor <1 x 2 x f64 >, %mu : tensor <1 x 2 x f64 >) -> tensor <f64 > {
69- %diff = arith.subf %x , %mu : tensor <1 x 2 x f64 >
70- %sum_sq = enzyme.dot %diff , %diff {lhs_batching_dimensions = array<i64 >, rhs_batching_dimensions = array<i64 >, lhs_contracting_dimensions = array<i64 : 0 , 1 >, rhs_contracting_dimensions = array<i64 : 0 , 1 >} : (tensor <1 x 2 x f64 >, tensor <1 x 2 x f64 >) -> tensor <f64 >
68+ func.func @shifted_logpdf (%x : tensor <2 x f64 >, %mu : tensor <2 x f64 >) -> tensor <f64 > {
69+ %diff = arith.subf %x , %mu : tensor <2 x f64 >
70+ %sum_sq = enzyme.dot %diff , %diff {lhs_batching_dimensions = array<i64 >, rhs_batching_dimensions = array<i64 >, lhs_contracting_dimensions = array<i64 : 0 >, rhs_contracting_dimensions = array<i64 : 0 >} : (tensor <2 x f64 >, tensor <2 x f64 >) -> tensor <f64 >
7171 %neg_half = arith.constant dense <-5.000000e-01 > : tensor <f64 >
7272 %result = arith.mulf %neg_half , %sum_sq : tensor <f64 >
7373 return %result : tensor <f64 >
@@ -80,12 +80,7 @@ module {
8080 // CHECK: func.call @shifted_logpdf
8181 // CHECK-NEXT: %[[NEG:.+]] = arith.negf
8282 // CHECK-NEXT: enzyme.yield
83- // CHECK: enzyme.for_loop
84- // CHECK: enzyme.autodiff_region
85- // CHECK: func.call @shifted_logpdf
86- // CHECK-NEXT: %{{.+}} = arith.negf
87- // CHECK-NEXT: enzyme.yield
88- func.func @nuts_shifted_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
83+ func.func @nuts_shifted_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
8984 %init_pos = arith.constant dense <[[0.5 , -0.5 ]]> : tensor <1 x2 xf64 >
9085 %step_size = arith.constant dense <0.1 > : tensor <f64 >
9186 %res:8 = " enzyme.mcmc" (%rng , %mu , %step_size , %init_pos ) {
@@ -97,7 +92,7 @@ module {
9792 num_warmup = 0 ,
9893 num_samples = 1 ,
9994 operand_segment_sizes = array <i32 : 2 , 0 , 0 , 1 , 1 , 0 , 0 >
100- } : (tensor <2 xui64 >, tensor <1 x 2 x f64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <f64 >, tensor <1 x2 xf64 >)
95+ } : (tensor <2 xui64 >, tensor <2 x f64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <f64 >, tensor <1 x2 xf64 >)
10196 return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
10297 }
10398
@@ -108,12 +103,7 @@ module {
108103 // CHECK: func.call @shifted_logpdf
109104 // CHECK-NEXT: %{{.+}} = arith.negf
110105 // CHECK-NEXT: enzyme.yield
111- // CHECK: enzyme.for_loop
112- // CHECK: enzyme.autodiff_region
113- // CHECK: func.call @shifted_logpdf
114- // CHECK-NEXT: %{{.+}} = arith.negf
115- // CHECK-NEXT: enzyme.yield
116- func.func @hmc_shifted_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
106+ func.func @hmc_shifted_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
117107 %init_pos = arith.constant dense <[[0.5 , -0.5 ]]> : tensor <1 x2 xf64 >
118108 %step_size = arith.constant dense <0.1 > : tensor <f64 >
119109 %res:8 = " enzyme.mcmc" (%rng , %mu , %step_size , %init_pos ) {
@@ -125,16 +115,16 @@ module {
125115 num_warmup = 0 ,
126116 num_samples = 1 ,
127117 operand_segment_sizes = array <i32 : 2 , 0 , 0 , 1 , 1 , 0 , 0 >
128- } : (tensor <2 xui64 >, tensor <1 x 2 x f64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <f64 >, tensor <1 x2 xf64 >)
118+ } : (tensor <2 xui64 >, tensor <2 x f64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <f64 >, tensor <1 x2 xf64 >)
129119 return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
130120 }
131121
132- func.func @anisotropic_logpdf (%x : tensor <1 x 2 x f64 >, %mu : tensor <1 x 2 x f64 >, %precision : tensor <1 x 2 x f64 >) -> tensor <f64 > {
133- %diff = arith.subf %x , %mu : tensor <1 x 2 x f64 >
134- %diff_sq = arith.mulf %diff , %diff : tensor <1 x 2 x f64 >
135- %weighted = arith.mulf %precision , %diff_sq : tensor <1 x 2 x f64 >
136- %ones = arith.constant dense <1.0 > : tensor <1 x 2 x f64 >
137- %sum = enzyme.dot %ones , %weighted {lhs_batching_dimensions = array<i64 >, rhs_batching_dimensions = array<i64 >, lhs_contracting_dimensions = array<i64 : 0 , 1 >, rhs_contracting_dimensions = array<i64 : 0 , 1 >} : (tensor <1 x 2 x f64 >, tensor <1 x 2 x f64 >) -> tensor <f64 >
122+ func.func @anisotropic_logpdf (%x : tensor <2 x f64 >, %mu : tensor <2 x f64 >, %precision : tensor <2 x f64 >) -> tensor <f64 > {
123+ %diff = arith.subf %x , %mu : tensor <2 x f64 >
124+ %diff_sq = arith.mulf %diff , %diff : tensor <2 x f64 >
125+ %weighted = arith.mulf %precision , %diff_sq : tensor <2 x f64 >
126+ %ones = arith.constant dense <1.0 > : tensor <2 x f64 >
127+ %sum = enzyme.dot %ones , %weighted {lhs_batching_dimensions = array<i64 >, rhs_batching_dimensions = array<i64 >, lhs_contracting_dimensions = array<i64 : 0 >, rhs_contracting_dimensions = array<i64 : 0 >} : (tensor <2 x f64 >, tensor <2 x f64 >) -> tensor <f64 >
138128 %neg_half = arith.constant dense <-5.000000e-01 > : tensor <f64 >
139129 %result = arith.mulf %neg_half , %sum : tensor <f64 >
140130 return %result : tensor <f64 >
@@ -147,12 +137,7 @@ module {
147137 // CHECK: func.call @anisotropic_logpdf
148138 // CHECK-NEXT: %[[NEG:.+]] = arith.negf
149139 // CHECK-NEXT: enzyme.yield
150- // CHECK: enzyme.for_loop
151- // CHECK: enzyme.autodiff_region
152- // CHECK: func.call @anisotropic_logpdf
153- // CHECK-NEXT: %{{.+}} = arith.negf
154- // CHECK-NEXT: enzyme.yield
155- func.func @nuts_anisotropic_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <1 x2 xf64 >, %precision : tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
140+ func.func @nuts_anisotropic_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <2 xf64 >, %precision : tensor <2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
156141 %init_pos = arith.constant dense <[[0.5 , -0.5 ]]> : tensor <1 x2 xf64 >
157142 %step_size = arith.constant dense <0.1 > : tensor <f64 >
158143 %res:8 = " enzyme.mcmc" (%rng , %mu , %precision , %step_size , %init_pos ) {
@@ -164,7 +149,7 @@ module {
164149 num_warmup = 0 ,
165150 num_samples = 1 ,
166151 operand_segment_sizes = array <i32 : 3 , 0 , 0 , 1 , 1 , 0 , 0 >
167- } : (tensor <2 xui64 >, tensor <1 x 2 x f64 >, tensor <1 x 2 x f64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <f64 >, tensor <1 x2 xf64 >)
152+ } : (tensor <2 xui64 >, tensor <2 x f64 >, tensor <2 x f64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <f64 >, tensor <1 x2 xf64 >)
168153 return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
169154 }
170155
@@ -175,12 +160,7 @@ module {
175160 // CHECK: func.call @anisotropic_logpdf
176161 // CHECK-NEXT: %{{.+}} = arith.negf
177162 // CHECK-NEXT: enzyme.yield
178- // CHECK: enzyme.for_loop
179- // CHECK: enzyme.autodiff_region
180- // CHECK: func.call @anisotropic_logpdf
181- // CHECK-NEXT: %{{.+}} = arith.negf
182- // CHECK-NEXT: enzyme.yield
183- func.func @hmc_anisotropic_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <1 x2 xf64 >, %precision : tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
163+ func.func @hmc_anisotropic_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <2 xf64 >, %precision : tensor <2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
184164 %init_pos = arith.constant dense <[[0.5 , -0.5 ]]> : tensor <1 x2 xf64 >
185165 %step_size = arith.constant dense <0.1 > : tensor <f64 >
186166 %res:8 = " enzyme.mcmc" (%rng , %mu , %precision , %step_size , %init_pos ) {
@@ -192,7 +172,7 @@ module {
192172 num_warmup = 0 ,
193173 num_samples = 1 ,
194174 operand_segment_sizes = array <i32 : 3 , 0 , 0 , 1 , 1 , 0 , 0 >
195- } : (tensor <2 xui64 >, tensor <1 x 2 x f64 >, tensor <1 x 2 x f64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <f64 >, tensor <1 x2 xf64 >)
175+ } : (tensor <2 xui64 >, tensor <2 x f64 >, tensor <2 x f64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <f64 >, tensor <1 x2 xf64 >)
196176 return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
197177 }
198178}
0 commit comments