Skip to content

Commit b2cd498

Browse files
committed
Update codegen and pretty tests
UI tests are pending, will depend on error messages change.
1 parent 3ab490b commit b2cd498

11 files changed

+39
-39
lines changed

tests/codegen/autodiff/batched.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
#![feature(autodiff)]
1313

14-
use std::autodiff::autodiff;
14+
use std::autodiff::autodiff_forward;
1515

16-
#[autodiff(d_square3, Forward, Dual, DualOnly)]
17-
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
18-
#[autodiff(d_square1, Forward, 4, Dual, Dual)]
16+
#[autodiff_forward(d_square3, Dual, DualOnly)]
17+
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
18+
#[autodiff_forward(d_square1, 4, Dual, Dual)]
1919
#[no_mangle]
2020
fn square(x: &f32) -> f32 {
2121
x * x

tests/codegen/autodiff/identical_fnc.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
1212
#![feature(autodiff)]
1313

14-
use std::autodiff::autodiff;
14+
use std::autodiff::autodiff_reverse;
1515

16-
#[autodiff(d_square, Reverse, Duplicated, Active)]
16+
#[autodiff_reverse(d_square, Duplicated, Active)]
1717
fn square(x: &f64) -> f64 {
1818
x * x
1919
}
2020

21-
#[autodiff(d_square2, Reverse, Duplicated, Active)]
21+
#[autodiff_reverse(d_square2, Duplicated, Active)]
2222
fn square2(x: &f64) -> f64 {
2323
x * x
2424
}

tests/codegen/autodiff/inline.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
#![feature(autodiff)]
66

7-
use std::autodiff::autodiff;
7+
use std::autodiff::autodiff_reverse;
88

9-
#[autodiff(d_square, Reverse, Duplicated, Active)]
9+
#[autodiff_reverse(d_square, Duplicated, Active)]
1010
fn square(x: &f64) -> f64 {
1111
x * x
1212
}

tests/codegen/autodiff/scalar.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
//@ needs-enzyme
44
#![feature(autodiff)]
55

6-
use std::autodiff::autodiff;
6+
use std::autodiff::autodiff_reverse;
77

8-
#[autodiff(d_square, Reverse, Duplicated, Active)]
8+
#[autodiff_reverse(d_square, Duplicated, Active)]
99
#[no_mangle]
1010
fn square(x: &f64) -> f64 {
1111
x * x

tests/codegen/autodiff/sret.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
#![feature(autodiff)]
1111

12-
use std::autodiff::autodiff;
12+
use std::autodiff::autodiff_reverse;
1313

1414
#[no_mangle]
15-
#[autodiff(df, Reverse, Active, Active, Active)]
15+
#[autodiff_reverse(df, Active, Active, Active)]
1616
fn primal(x: f32, y: f32) -> f64 {
1717
(x * x * y) as f64
1818
}

tests/pretty/autodiff/autodiff_forward.pp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
// Test that forward mode ad macros are expanded correctly.
1515

16-
use std::autodiff::autodiff;
16+
use std::autodiff::{autodiff_forward, autodiff_reverse};
1717

1818
#[rustc_autodiff]
1919
#[inline(never)]

tests/pretty/autodiff/autodiff_forward.rs

+15-15
Original file line numberDiff line numberDiff line change
@@ -7,57 +7,57 @@
77

88
// Test that forward mode ad macros are expanded correctly.
99

10-
use std::autodiff::autodiff;
10+
use std::autodiff::{autodiff_forward, autodiff_reverse};
1111

12-
#[autodiff(df1, Forward, Dual, Const, Dual)]
12+
#[autodiff_forward(df1, Dual, Const, Dual)]
1313
pub fn f1(x: &[f64], y: f64) -> f64 {
1414
unimplemented!()
1515
}
1616

17-
#[autodiff(df2, Forward, Dual, Const, Const)]
17+
#[autodiff_forward(df2, Dual, Const, Const)]
1818
pub fn f2(x: &[f64], y: f64) -> f64 {
1919
unimplemented!()
2020
}
2121

22-
#[autodiff(df3, Forward, Dual, Const, Const)]
22+
#[autodiff_forward(df3, Dual, Const, Const)]
2323
pub fn f3(x: &[f64], y: f64) -> f64 {
2424
unimplemented!()
2525
}
2626

2727
// Not the most interesting derivative, but who are we to judge
28-
#[autodiff(df4, Forward)]
28+
#[autodiff_forward(df4)]
2929
pub fn f4() {}
3030

3131
// We want to be sure that the same function can be differentiated in different ways
32-
#[autodiff(df5_rev, Reverse, Duplicated, Const, Active)]
33-
#[autodiff(df5_x, Forward, Dual, Const, Const)]
34-
#[autodiff(df5_y, Forward, Const, Dual, Const)]
32+
#[autodiff_reverse(df5_rev, Duplicated, Const, Active)]
33+
#[autodiff_forward(df5_x, Dual, Const, Const)]
34+
#[autodiff_forward(df5_y, Const, Dual, Const)]
3535
pub fn f5(x: &[f64], y: f64) -> f64 {
3636
unimplemented!()
3737
}
3838

3939
struct DoesNotImplDefault;
40-
#[autodiff(df6, Forward, Const)]
40+
#[autodiff_forward(df6, Const)]
4141
pub fn f6() -> DoesNotImplDefault {
4242
unimplemented!()
4343
}
4444

4545
// Make sure, that we add the None for the default return.
46-
#[autodiff(df7, Forward, Const)]
46+
#[autodiff_forward(df7, Const)]
4747
pub fn f7(x: f32) -> () {}
4848

49-
#[autodiff(f8_1, Forward, Dual, DualOnly)]
50-
#[autodiff(f8_2, Forward, 4, Dual, DualOnly)]
51-
#[autodiff(f8_3, Forward, 4, Dual, Dual)]
49+
#[autodiff_forward(f8_1, Dual, DualOnly)]
50+
#[autodiff_forward(f8_2, 4, Dual, DualOnly)]
51+
#[autodiff_forward(f8_3, 4, Dual, Dual)]
5252
#[no_mangle]
5353
fn f8(x: &f32) -> f32 {
5454
unimplemented!()
5555
}
5656

5757
// We want to make sure that we can use the macro for functions defined inside of functions
5858
pub fn f9() {
59-
#[autodiff(d_inner_1, Forward, Dual, DualOnly)]
60-
#[autodiff(d_inner_2, Forward, Dual, Dual)]
59+
#[autodiff_forward(d_inner_1, Dual, DualOnly)]
60+
#[autodiff_forward(d_inner_2, Dual, Dual)]
6161
fn inner(x: f32) -> f32 {
6262
x * x
6363
}

tests/pretty/autodiff/autodiff_reverse.pp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
// Test that reverse mode ad macros are expanded correctly.
1515

16-
use std::autodiff::autodiff;
16+
use std::autodiff::autodiff_reverse;
1717

1818
#[rustc_autodiff]
1919
#[inline(never)]

tests/pretty/autodiff/autodiff_reverse.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77

88
// Test that reverse mode ad macros are expanded correctly.
99

10-
use std::autodiff::autodiff;
10+
use std::autodiff::autodiff_reverse;
1111

12-
#[autodiff(df1, Reverse, Duplicated, Const, Active)]
12+
#[autodiff_reverse(df1, Duplicated, Const, Active)]
1313
pub fn f1(x: &[f64], y: f64) -> f64 {
1414
unimplemented!()
1515
}
1616

1717
// Not the most interesting derivative, but who are we to judge
18-
#[autodiff(df2, Reverse)]
18+
#[autodiff_reverse(df2)]
1919
pub fn f2() {}
2020

21-
#[autodiff(df3, Reverse, Duplicated, Const, Active)]
21+
#[autodiff_reverse(df3, Duplicated, Const, Active)]
2222
pub fn f3(x: &[f64], y: f64) -> f64 {
2323
unimplemented!()
2424
}
@@ -27,12 +27,12 @@ enum Foo { Reverse }
2727
use Foo::Reverse;
2828
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant
2929
// constructor) namespace? > It's expected to work normally.
30-
#[autodiff(df4, Reverse, Const)]
30+
#[autodiff_reverse(df4, Const)]
3131
pub fn f4(x: f32) {
3232
unimplemented!()
3333
}
3434

35-
#[autodiff(df5, Reverse, DuplicatedOnly, Duplicated)]
35+
#[autodiff_reverse(df5, DuplicatedOnly, Duplicated)]
3636
pub fn f5(x: *const f32, y: &f32) {
3737
unimplemented!()
3838
}

tests/pretty/autodiff/inherent_impl.pp

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
//@ pretty-compare-only
1212
//@ pp-exact:inherent_impl.pp
1313

14-
use std::autodiff::autodiff;
14+
use std::autodiff::autodiff_reverse;
1515

1616
struct Foo {
1717
a: f64,

tests/pretty/autodiff/inherent_impl.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
//@ pretty-compare-only
66
//@ pp-exact:inherent_impl.pp
77

8-
use std::autodiff::autodiff;
8+
use std::autodiff::autodiff_reverse;
99

1010
struct Foo {
1111
a: f64,
@@ -17,7 +17,7 @@ trait MyTrait {
1717
}
1818

1919
impl MyTrait for Foo {
20-
#[autodiff(df, Reverse, Const, Active, Active)]
20+
#[autodiff_reverse(df, Const, Active, Active)]
2121
fn f(&self, x: f64) -> f64 {
2222
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
2323
}

0 commit comments

Comments
 (0)