@@ -37,10 +37,9 @@ struct BlockDiagonalLDLT {
3737 solve (const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
3838 ThreadPool *pool) const ;
3939
40- template <class _Scalar , int _Rows, int _Cols>
41- Eigen::Matrix<_Scalar, _Rows, _Cols>
42- sqrt_solve (const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
43- ThreadPool *pool) const ;
40+ template <typename Derived>
41+ Eigen::MatrixXd sqrt_solve (const Eigen::DenseBase<Derived> &rhs,
42+ ThreadPool *pool) const ;
4443
4544 BlockDiagonal sqrt_transpose () const ;
4645
@@ -51,6 +50,8 @@ struct BlockDiagonalLDLT {
5150 Eigen::Index rows () const ;
5251
5352 Eigen::Index cols () const ;
53+
54+ bool operator ==(const BlockDiagonalLDLT &other) const ;
5455};
5556
5657struct BlockDiagonal {
@@ -141,20 +142,23 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
141142 return output;
142143}
143144
144- template <class _Scalar , int _Rows, int _Cols >
145- inline Eigen::Matrix<_Scalar, _Rows, _Cols>
146- BlockDiagonalLDLT::sqrt_solve (const Eigen::Matrix<_Scalar, _Rows, _Cols > &rhs,
145+ template <typename Derived >
146+ inline Eigen::MatrixXd
147+ BlockDiagonalLDLT::sqrt_solve (const Eigen::DenseBase<Derived > &rhs,
147148 ThreadPool *pool) const {
148149 ALBATROSS_ASSERT (cols () == rhs.rows ());
149- Eigen::Matrix<_Scalar, _Rows, _Cols> output (rows (), rhs.cols ());
150+ Eigen::MatrixXd output (rows (), rhs.cols ());
150151
151152 auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) {
152- const auto rhs_chunk = rhs.block (row, 0 , blocks[i].rows (), rhs.cols ());
153+ const auto rhs_chunk =
154+ rhs.derived ().block (row, 0 , blocks[i].rows (), rhs.cols ());
153155 output.block (row, 0 , blocks[i].rows (), rhs.cols ()) =
154156 blocks[i].sqrt_solve (rhs_chunk);
155157 };
156158
157- apply_map (block_to_row_map (), solve_and_fill_one_block, pool);
159+ // Intentionally leaving pool out here due to an unknown bug
160+ // in which the thread pool version crashes in sqrt_solve.
161+ apply_map (block_to_row_map (), solve_and_fill_one_block);
158162 return output;
159163}
160164
@@ -182,6 +186,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const {
182186 return n;
183187}
184188
189+ inline bool
190+ BlockDiagonalLDLT::operator ==(const BlockDiagonalLDLT &other) const {
191+ return blocks == other.blocks ;
192+ }
185193/*
186194 * Block Diagonal
187195 */
0 commit comments