Skip to content

Commit 06c680d

Browse files
committed
Better MRI multiple max error
- Replace manual vector search and insert with stl functions
1 parent 474a54c commit 06c680d

File tree

3 files changed

+32
-27
lines changed

3 files changed

+32
-27
lines changed

palace/models/romoperator.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ void MinimalRationalInterpolation::AddSolutionSample(double omega, const Complex
222222
z.push_back(omega);
223223
}
224224

225-
std::vector<double> MinimalRationalInterpolation::FindMaxError(int N) const
225+
std::vector<double> MinimalRationalInterpolation::FindMaxError(std::size_t N) const
226226
{
227227
// Return an estimate for argmax_z ||u(z) - V y(z)|| as argmin_z |Q(z)| with Q(z) =
228228
// sum_i q_z / (z - z_i) (denominator of the barycentric interpolation of u). The roots of
@@ -279,36 +279,40 @@ std::vector<double> MinimalRationalInterpolation::FindMaxError(int N) const
279279
// }
280280

281281
// Fall back to sampling Q on discrete points if no roots exist in [start, end].
282-
if (std::abs(z_star[0]) == 0.0)
282+
// TODO: currently we always us this. Consider other optimization above again.
283+
284+
// We could use priority queue here to keep the N lowest values. However, we don't use
285+
// std::priority_queue class since we want to have access to the vector and also binary
286+
// tree structure of heap class as rebalancing is excessive overhead for tiny size N.
287+
using q_t = std::pair<std::complex<double>, double>;
288+
std::vector<q_t> queue{};
289+
queue.reserve(N);
290+
291+
const std::size_t nr_sample = 1.0e6; // must be >= N
292+
const auto delta = (end - start) / nr_sample;
293+
for (double z_sample = start; z_sample <= end; z_sample += delta)
283294
{
284-
const auto delta = (end - start) / 1.0e6;
285-
std::vector<double> Q_star(N, mfem::infinity());
286-
while (start <= end)
295+
const double Q_sample = std::abs((q.array() / (z_map.array() - z_sample)).sum());
296+
297+
bool partial_full = (queue.size() < N);
298+
if (partial_full || Q_sample < queue.back().second)
287299
{
288-
const double Q = std::abs((q.array() / (z_map.array() - start)).sum());
289-
for (int i = 0; i < N; i++)
300+
auto it_loc = std::upper_bound(queue.begin(), queue.end(), Q_sample,
301+
[](double q, const q_t &p2) { return q < p2.second; });
302+
queue.insert(it_loc, std::make_pair(z_sample, Q_sample));
303+
if (!partial_full)
290304
{
291-
if (Q < Q_star[i])
292-
{
293-
for (int j = N - 1; j > i; j--)
294-
{
295-
z_star[j] = z_star[j - 1];
296-
Q_star[j] = Q_star[j - 1];
297-
}
298-
z_star[i] = start;
299-
Q_star[i] = Q;
300-
break;
301-
}
305+
queue.pop_back();
302306
}
303-
start += delta;
304307
}
305-
MFEM_VERIFY(
306-
N == 0 || std::abs(z_star[0]) > 0.0,
307-
fmt::format("Could not locate a maximum error in the range [{}, {}]!", start, end));
308308
}
309+
MFEM_VERIFY(queue.size() == N,
310+
fmt::format("Internal failure: queue should be size should be N={} (got {})",
311+
N, queue.size()));
312+
309313
std::vector<double> vals(z_star.size());
310-
std::transform(z_star.begin(), z_star.end(), vals.begin(),
311-
[](std::complex<double> z) { return std::real(z); });
314+
std::transform(queue.begin(), queue.end(), vals.begin(),
315+
[](const q_t &p) { return p.first.real(); });
312316
return vals;
313317
}
314318

palace/models/romoperator.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class MinimalRationalInterpolation
4040
MinimalRationalInterpolation(std::size_t max_size);
4141
void AddSolutionSample(double omega, const ComplexVector &u, MPI_Comm comm,
4242
Orthogonalization orthog_type);
43-
std::vector<double> FindMaxError(int N) const;
43+
std::vector<double> FindMaxError(std::size_t N) const;
4444

4545
const auto &GetSamplePoints() const { return z; }
4646
};
@@ -132,7 +132,7 @@ class RomOperator
132132

133133
// Compute the location(s) of the maximum error in the range of the previously sampled
134134
// parameter points.
135-
std::vector<double> FindMaxError(int excitation_idx, int N = 1) const
135+
std::vector<double> FindMaxError(int excitation_idx, std::size_t N = 1) const
136136
{
137137
return mri.at(excitation_idx).FindMaxError(N);
138138
}

test/unit/test-romoperator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ TEST_CASE("MinimalRationalInterpolation", "[romoperator]")
4646
// By symmetry highest error should be at at zero.
4747
CHECK_THAT(max_err_1[0], Catch::Matchers::WithinAbsMatcher(0.0, 1e-6));
4848

49-
// Test that elements of max_error are unique
49+
// Test that elements of max_error are unique.
50+
// TODO: get better test for multiple N.
5051
std::sort(max_err_1.begin(), max_err_1.end());
5152
CHECK(std::adjacent_find(max_err_1.begin(), max_err_1.end()) == max_err_1.end());
5253

0 commit comments

Comments
 (0)