Skip to content

Commit

Permalink
simplify autodiff - less generic annotations (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
strasdat authored Feb 2, 2025
1 parent 2fabf71 commit 7f185c7
Show file tree
Hide file tree
Showing 23 changed files with 352 additions and 738 deletions.
19 changes: 19 additions & 0 deletions crates/sophus_autodiff/src/dual/dual_batch_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,25 @@ where
}
}

impl<const ROWS: usize, const COLS: usize, const BATCH: usize>
IsScalarFieldDualMatrix<DualBatchScalar<BATCH, 1, 1>, ROWS, COLS, BATCH>
for DualBatchMatrix<ROWS, COLS, BATCH, 1, 1>
where
BatchScalarF64<BATCH>: IsCoreScalar,
LaneCount<BATCH>: SupportedLaneCount,
{
fn scalarfield_derivative(&self) -> BatchMatF64<ROWS, COLS, BATCH> {
let mut out = BatchMatF64::<ROWS, COLS, BATCH>::zeros();

for i in 0..ROWS {
for j in 0..COLS {
out.set_elem([i, j], self.inner[(i, j)].derivative()[(0, 0)]);
}
}
out
}
}

impl<
const ROWS: usize,
const COLS: usize,
Expand Down
16 changes: 16 additions & 0 deletions crates/sophus_autodiff/src/dual/dual_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use num_traits::Zero;
use super::matrix::MatrixValuedDerivative;
use crate::{
dual::{
matrix::IsScalarFieldDualMatrix,
DualScalar,
DualVector,
},
Expand Down Expand Up @@ -78,6 +79,21 @@ impl<const ROWS: usize, const COLS: usize, const DM: usize, const DN: usize>
}
}

impl<const ROWS: usize, const COLS: usize> IsScalarFieldDualMatrix<DualScalar<1, 1>, ROWS, COLS, 1>
for DualMatrix<ROWS, COLS, 1, 1>
{
fn scalarfield_derivative(&self) -> MatF64<ROWS, COLS> {
let mut out = MatF64::<ROWS, COLS>::zeros();

for i in 0..ROWS {
for j in 0..COLS {
out.set_elem([i, j], self.inner[(i, j)].derivative()[(0, 0)]);
}
}
out
}
}

impl<const ROWS: usize, const COLS: usize, const DM: usize, const DN: usize> PartialEq
for DualMatrix<ROWS, COLS, DM, DN>
{
Expand Down
112 changes: 57 additions & 55 deletions crates/sophus_autodiff/src/dual/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,24 @@ pub trait IsDualMatrix<
>: IsMatrix<S, ROWS, COLS, BATCH, DM, DN>
{
/// Create a new dual matrix from a real matrix for auto-differentiation with respect to self
///
/// Typically this is not called directly, but through using a map auto-differentiation call:
///
/// - ScalarValueMatrixMap::fw_autodiff(...);
/// - VectorValuedMatrixMap::fw_autodiff(...);
/// - MatrixValuedMatrixMap::fw_autodiff(...);
fn var(val: S::RealMatrix<ROWS, COLS>) -> Self;

/// Get the derivative
fn derivative(self) -> MatrixValuedDerivative<S::RealScalar, ROWS, COLS, BATCH, DM, DN>;
}

/// Trait for scalar dual numbers
pub trait IsScalarFieldDualMatrix<
S: IsDualScalar<BATCH, 1, 1>,
const ROWS: usize,
const COLS: usize,
const BATCH: usize,
>: IsDualMatrix<S, ROWS, COLS, BATCH, 1, 1>
{
/// Get the derivative
fn scalarfield_derivative(&self) -> S::RealMatrix<ROWS, COLS>;
}

#[test]
fn dual_matrix_tests() {
#[cfg(feature = "simd")]
Expand All @@ -84,16 +90,18 @@ fn dual_matrix_tests() {
#[cfg(test)]
impl Test for $scalar {
fn run() {
let m_2x4 = <$scalar as IsScalar<$batch,0,0>>::Matrix::<2, 4>::from_f64_array2([
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
]);
let m_4x1 = <$scalar as IsScalar<$batch,0,0>>::Matrix::<4, 1>::from_f64_array2([
[1.0],
[2.0],
[3.0],
[4.0],
]);
let m_2x4 =
<$scalar as IsScalar<$batch, 0, 0>>::Matrix::<2, 4>::from_f64_array2([
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
]);
let m_4x1 =
<$scalar as IsScalar<$batch, 0, 0>>::Matrix::<4, 1>::from_f64_array2([
[1.0],
[2.0],
[3.0],
[4.0],
]);

fn mat_mul_fn<
S: IsScalar<BATCH, DM, DN>,
Expand All @@ -106,21 +114,18 @@ fn dual_matrix_tests() {
) -> S::Matrix<2, 1> {
x.mat_mul(y)
}
let finite_diff =
MatrixValuedMatrixMap::<$scalar, $batch,0,0>::sym_diff_quotient(
|x| mat_mul_fn::<$scalar, $batch,0,0>(x, m_4x1),
m_2x4,
EPS_F64,
);
let auto_grad = MatrixValuedMatrixMap::<$dual_scalar_2_4, $batch, 2, 4>::fw_autodiff(
|x| {
mat_mul_fn::<$dual_scalar_2_4, $batch, 2, 4>(
x,
<$dual_scalar_2_4 as IsScalar<$batch, 2, 4>>::Matrix::from_real_matrix(m_4x1),
)
},
let finite_diff = MatrixValuedMatrixMap::<$scalar, $batch>::sym_diff_quotient(
|x| mat_mul_fn::<$scalar, $batch, 0, 0>(x, m_4x1),
m_2x4,
EPS_F64,
);
let auto_grad = mat_mul_fn::<$dual_scalar_2_4, $batch, 2, 4>(
<$dual_scalar_2_4>::matrix_var(m_2x4),
<$dual_scalar_2_4 as IsScalar<$batch, 2, 4>>::Matrix::from_real_matrix(
m_4x1,
),
)
.derivative();

for i in 0..2 {
for j in 0..1 {
Expand All @@ -133,19 +138,17 @@ fn dual_matrix_tests() {
}

let finite_diff = MatrixValuedMatrixMap::sym_diff_quotient(
|x| mat_mul_fn::<$scalar, $batch,0,0>(m_2x4, x),
|x| mat_mul_fn::<$scalar, $batch, 0, 0>(m_2x4, x),
m_4x1,
EPS_F64,
);
let auto_grad = MatrixValuedMatrixMap::<$dual_scalar_4_1, $batch, 4, 1>::fw_autodiff(
|x| {
mat_mul_fn::<$dual_scalar_4_1, $batch, 4, 1>(
<$dual_scalar_4_1 as IsScalar<$batch, 4, 1>>::Matrix::from_real_matrix(m_2x4),
x,
)
},
m_4x1,
);
let auto_grad = mat_mul_fn::<$dual_scalar_4_1, $batch, 4, 1>(
<$dual_scalar_4_1 as IsScalar<$batch, 4, 1>>::Matrix::from_real_matrix(
m_2x4,
),
<$dual_scalar_4_1>::matrix_var(m_4x1),
)
.derivative();

for i in 0..2 {
for j in 0..1 {
Expand All @@ -168,23 +171,23 @@ fn dual_matrix_tests() {
x.mat_mul(&x)
}

let m_4x4 = <$scalar as IsScalar<$batch,0,0>>::Matrix::<4, 4>::from_f64_array2([
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
]);

let finite_diff =
MatrixValuedMatrixMap::<$scalar, $batch,0,0>::sym_diff_quotient(
mat_mul2_fn::<$scalar, $batch,0,0>,
m_4x4,
EPS_F64,
);
let auto_grad = MatrixValuedMatrixMap::<$dual_scalar_4_4, $batch,4, 4>::fw_autodiff(
mat_mul2_fn::<$dual_scalar_4_4, $batch, 4 ,4>,
let m_4x4 =
<$scalar as IsScalar<$batch, 0, 0>>::Matrix::<4, 4>::from_f64_array2([
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
]);

let finite_diff = MatrixValuedMatrixMap::<$scalar, $batch>::sym_diff_quotient(
mat_mul2_fn::<$scalar, $batch, 0, 0>,
m_4x4,
EPS_F64,
);
let auto_grad = mat_mul2_fn::<$dual_scalar_4_4, $batch, 4, 4>(
<$dual_scalar_4_4>::matrix_var(m_4x4),
)
.derivative();

for i in 0..2 {
for j in 0..4 {
Expand All @@ -195,7 +198,6 @@ fn dual_matrix_tests() {
);
}
}

}
}
};
Expand Down
Loading

0 comments on commit 7f185c7

Please sign in to comment.