Skip to content

Commit 00c89d6

Browse files
phdum-ahughcars
authored andcommitted
WIP: Add tests
1 parent 91a745f commit 00c89d6

File tree

4 files changed

+57
-6
lines changed

4 files changed

+57
-6
lines changed

palace/models/romoperator.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,14 @@ inline void ProlongatePROMSolution(std::size_t n, const std::vector<Vector> &V,
189189

190190
MinimalRationalInterpolation::MinimalRationalInterpolation(std::size_t max_size)
191191
{
192+
z.reserve(max_size);
192193
Q.resize(max_size, ComplexVector());
193194
}
194195

195196
void MinimalRationalInterpolation::AddSolutionSample(double omega, const ComplexVector &u,
196-
const SpaceOperator &space_op,
197+
MPI_Comm comm,
197198
Orthogonalization orthog_type)
198199
{
199-
MPI_Comm comm = space_op.GetComm();
200-
201200
// Compute the coefficients for the minimal rational interpolation of the state u used
202201
// as an error indicator. The complex-valued snapshot matrix U = [{u_i, (iω) u_i}] is
203202
// stored by its QR decomposition.
@@ -479,7 +478,7 @@ void RomOperator::UpdatePROM(const ComplexVector &u, std::string_view node_label
479478
void RomOperator::UpdateMRI(int excitation_idx, double omega, const ComplexVector &u)
480479
{
481480
BlockTimer bt(Timer::CONSTRUCT_PROM);
482-
mri.at(excitation_idx).AddSolutionSample(omega, u, space_op, orthog_type);
481+
mri.at(excitation_idx).AddSolutionSample(omega, u, space_op.GetComm(), orthog_type);
483482
}
484483

485484
void RomOperator::SolvePROM(int excitation_idx, double omega, ComplexVector &u)

palace/models/romoperator.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class MinimalRationalInterpolation
3838

3939
public:
4040
MinimalRationalInterpolation(std::size_t max_size);
41-
void AddSolutionSample(double omega, const ComplexVector &u,
42-
const SpaceOperator &space_op, Orthogonalization orthog_type);
41+
void AddSolutionSample(double omega, const ComplexVector &u, MPI_Comm comm,
42+
Orthogonalization orthog_type);
4343
std::vector<double> FindMaxError(int N) const;
4444

4545
const auto &GetSamplePoints() const { return z; }

test/unit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_executable(unit-tests
2121
${CMAKE_CURRENT_SOURCE_DIR}/test-libceed.cpp
2222
${CMAKE_CURRENT_SOURCE_DIR}/test-postoperator.cpp
2323
${CMAKE_CURRENT_SOURCE_DIR}/test-postoperatorcsv.cpp
24+
${CMAKE_CURRENT_SOURCE_DIR}/test-romoperator.cpp
2425
${CMAKE_CURRENT_SOURCE_DIR}/test-tablecsv.cpp
2526
)
2627
target_link_libraries(unit-tests PRIVATE ${LIB_TARGET_NAME} Catch2::Catch2)

test/unit/test-romoperator.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include <iterator>
2+
#include <fmt/format.h>
3+
#include <catch2/catch_approx.hpp>
4+
#include <catch2/catch_test_macros.hpp>
5+
#include <catch2/benchmark/catch_benchmark_all.hpp>
6+
#include <catch2/generators/catch_generators_all.hpp>
7+
#include <catch2/matchers/catch_matchers_floating_point.hpp>
8+
#include "models/romoperator.hpp"
9+
10+
using namespace palace;
11+
12+
// auto complex_circle_sample_points(int nr_sample_points, double radius = 5.5)
13+
// {
14+
// double end_point_linscale = double(nr_sample_points - 1) / nr_sample_points;
15+
// Eigen::ArrayXcd zj_sample =
16+
// Eigen::ArrayXcd::LinSpaced(nr_sample_points, 0, 2 * M_PI * end_point_linscale);
17+
// zj_sample =
18+
// zj_sample.unaryExpr([radius](std::complex<double> z)
19+
// { return radius * std::exp(std::complex<double>(0., 1.) * z);
20+
// });
21+
// return zj_sample;
22+
// }
23+
24+
TEST_CASE("MinimalRationalInterpolation", "[romoperator]")
25+
{
26+
auto *comm = Mpi::World();
27+
28+
auto fn_tan_shift = [](double z)
29+
{ return std::tan(0.5 * M_PI * (z - std::complex<double>(1., 1.))); };
30+
31+
// Test scalar case: 2 sample points for 2 x 2 vector
32+
MinimalRationalInterpolation mri_1(2);
33+
34+
CHECK(mri_1.GetSamplePoints() == std::vector<double>{});
35+
CHECK_THROWS(mri_1.FindMaxError(1));
36+
37+
for (double x_sample : {-1.0, 1.0})
38+
{
39+
ComplexVector c_vec(1);
40+
c_vec = fn_tan_shift(x_sample) / double(Mpi::Size(comm));
41+
42+
mri_1.AddSolutionSample(x_sample, c_vec, comm, Orthogonalization::MGS);
43+
}
44+
45+
CHECK(mri_1.GetSamplePoints().size() == 2);
46+
CHECK(mri_1.GetSamplePoints() == std::vector<double>{-1.0, 1.0});
47+
// By symmetry of poles max erro should be at zero.
48+
auto max_err_1 = mri_1.FindMaxError(1);
49+
REQUIRE(max_err_1.size() == 1);
50+
CHECK_THAT(max_err_1[0], Catch::Matchers::WithinAbsMatcher(0., 1e-6));
51+
}

0 commit comments

Comments
 (0)