Skip to content

Commit

Permalink
Various simplifications in cppdlr/utils.hpp
Browse files Browse the repository at this point in the history
  • Loading branch information
Wentzell committed Jan 30, 2025
1 parent e59f710 commit 14263a6
Showing 1 changed file with 48 additions and 89 deletions.
137 changes: 48 additions & 89 deletions c++/cppdlr/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
#include <nda/nda.hpp>
#include <nda/blas.hpp>


namespace cppdlr {
using dcomplex = std::complex<double>;

/**
* Calculate the squared norm of a vector
*
* @param v The input vector
* @return x The squared norm of the vector
*/
double normsq(nda::MemoryVector auto const &v) { return nda::real(nda::blas::dotc(v, v)); }

/**
* Class constructor for barycheb: barycentric Lagrange interpolation at
Expand Down Expand Up @@ -116,10 +122,10 @@ namespace cppdlr {
// Compute norms of rows of input matrix, and rescale eps tolerance
auto norms = nda::vector<double>(m);
double epssq = eps * eps;
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }

// Begin pivoted double Gram-Schmidt procedure
int jpiv = 0, jj = 0;
int jpiv = 0;
double nrm = 0;
auto piv = nda::arange(m);
auto tmp = nda::vector<S>(n);
Expand All @@ -137,38 +143,29 @@ namespace cppdlr {
}

// Swap current row with chosen pivot row
tmp = aa(j, _);
aa(j, _) = aa(jpiv, _);
aa(jpiv, _) = tmp;

nrm = norms(j);
norms(j) = norms(jpiv);
norms(jpiv) = nrm;

jj = piv(j);
piv(j) = piv(jpiv);
piv(jpiv) = jj;
deep_swap(aa(j, _), aa(jpiv, _));
std::swap(norms(j), norms(jpiv));
std::swap(piv(j), piv(jpiv));

// Orthogonalize current rows (now the chosen pivot row) against all
// previously chosen rows
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }

// Get norm of current row
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
//nrm = nda::norm(aa(j, _));
nrm = normsq(aa(j, _));

// Terminate if sufficiently small, and return previously selected rows
// (not including current row)
if (nrm <= epssq) { return {aa(nda::range(0, j), _), norms(nda::range(0, j)), piv(nda::range(0, j))}; };

// Normalize current row
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
aa(j, _) /= sqrt(nrm);

// Orthogonalize remaining rows against current row
for (int k = j + 1; k < m; ++k) {
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
norms(k) = normsq(aa(k, _));
}
}

Expand Down Expand Up @@ -211,22 +208,21 @@ namespace cppdlr {
if (m % 2 != 0) { throw std::runtime_error("Input matrix must have even number of rows."); }

// Copy input data, re-ordering rows to make symmetric rows adjacent.
auto aa = typename T::regular_type(m, n);
auto aa = typename T::regular_type(m, n);
aa(nda::range(0, m, 2), _) = a(nda::range(0, m / 2), _);
aa(nda::range(1, m, 2), _) = a(nda::range(m - 1, m / 2 - 1, -1), _);

// Compute norms of rows of input matrix, and rescale eps tolerance
auto norms = nda::vector<double>(m);
double epssq = eps * eps;
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }

// Begin pivoted double Gram-Schmidt procedure
int jpiv = 0, jj = 0;
double nrm = 0;
auto piv = nda::arange(0, m);
int jpiv = 0;
double nrm = 0;
auto piv = nda::arange(0, m);
piv(nda::range(0, m, 2)) = nda::arange(0, m / 2); // Re-order pivots to match re-ordered input matrix
piv(nda::range(1, m, 2)) = nda::arange(m - 1, m / 2 - 1, -1);
auto tmp = nda::vector<S>(n);

if (maxrnk % 2 != 0) { // If n < m and n is odd, decrease maxrnk to maintain symmetry
maxrnk -= 1;
Expand All @@ -245,61 +241,46 @@ namespace cppdlr {
}

// Swap current row pair with chosen pivot row pair
tmp = aa(j, _);
aa(j, _) = aa(jpiv, _);
aa(jpiv, _) = tmp;
tmp = aa(j + 1, _);
aa(j + 1, _) = aa(jpiv + 1, _);
aa(jpiv + 1, _) = tmp;

nrm = norms(j);
norms(j) = norms(jpiv);
norms(jpiv) = nrm;
nrm = norms(j + 1);
norms(j + 1) = norms(jpiv + 1);
norms(jpiv + 1) = nrm;

jj = piv(j);
piv(j) = piv(jpiv);
piv(jpiv) = jj;
jj = piv(j + 1);
piv(j + 1) = piv(jpiv + 1);
piv(jpiv + 1) = jj;
deep_swap(aa(j, _), aa(jpiv, _));
deep_swap(aa(j + 1, _), aa(jpiv + 1, _));
std::swap(norms(j), norms(jpiv));
std::swap(norms(j + 1), norms(jpiv + 1));
std::swap(piv(j), piv(jpiv));
std::swap(piv(j + 1), piv(jpiv + 1));

// Orthogonalize current row (now the first chosen pivot row) against all
// previously chosen rows
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }

// Get norm of current row
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
nrm = normsq(aa(j, _));

// Terminate if sufficiently small, and return previously selected rows
// (not including current row)
if (nrm <= epssq) { return {aa(nda::range(0, j), _), norms(nda::range(0, j)), piv(nda::range(0, j))}; };

// Normalize current row
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
aa(j, _) /= sqrt(nrm);

// Orthogonalize remaining rows against current row
for (int k = j + 1; k < m; ++k) {
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
norms(k) = normsq(aa(k, _));
}

// Orthogonalize current row (now the second chosen pivot row) against all
// previously chosen rows
for (int k = 0; k < j + 1; ++k) { aa(j + 1, _) = aa(j + 1, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j + 1, _)); }

// Normalize current row
nrm = nda::real(nda::blas::dotc(aa(j + 1, _), aa(j + 1, _)));
aa(j + 1, _) = aa(j + 1, _) * (1 / sqrt(nrm));
aa(j + 1, _) /= sqrt(normsq(aa(j + 1, _)));

// Orthogonalize remaining rows against current row
for (int k = j + 2; k < m; ++k) {
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
aa(k, _) = aa(k, _) - aa(j + 1, _) * nda::blas::dotc(aa(j + 1, _), aa(k, _));
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
norms(k) = normsq(aa(k, _));
}
}

Expand Down Expand Up @@ -352,18 +333,18 @@ namespace cppdlr {
aa(nda::range(0, m, 2), _) = a(nda::range(0, m / 2), _);
aa(nda::range(1, m, 2), _) = a(nda::range(m - 1, m / 2 - 1, -1), _);
} else {
aa(0, _) = a((m - 1) / 2, _);
aa(0, _) = a((m - 1) / 2, _);
aa(nda::range(1, m, 2), _) = a(nda::range(0, (m - 1) / 2), _);
aa(nda::range(2, m, 2), _) = a(nda::range(m - 1, (m - 1) / 2, -1), _);
//aa(m - 1, _) = a((m - 1) / 2, _);
}

// Compute norms of rows of input matrix
auto norms = nda::vector<double>(m);
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }

// Begin pivoted double Gram-Schmidt procedure
int jpiv = 0, jj = 0;
int jpiv = 0;
double nrm = 0;
auto piv = nda::arange(0, m);
if (m % 2 == 0) {
Expand All @@ -375,23 +356,17 @@ namespace cppdlr {
piv(nda::range(2, m, 2)) = nda::arange(m - 1, (m - 1) / 2, -1);
//piv(m - 1) = (m - 1) / 2;
}
auto tmp = nda::vector<S>(n);

// If m odd, first choose middle row (now last row) as first pivot

if (m % 2 == 1) {
//int j = 0; // Index of current row
//jpiv = 0; // Index of pivot row

// Normalize
nrm = nda::real(nda::blas::dotc(aa(0, _), aa(0, _)));
aa(0, _) = aa(0, _) * (1 / sqrt(nrm));
//aa(0, _) /= sqrt(nda::real(nda::blas::dotc(aa(0, _), aa(0, _))));
aa(0, _) /= sqrt(normsq(aa(0, _)));

// Orthogonalize remaining rows against current row
for (int k = 1; k < m; ++k) {
aa(k, _) = aa(k, _) - aa(0, _) * nda::blas::dotc(aa(0, _), aa(k, _));
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
norms(k) = normsq(aa(k, _));
}
}

Expand All @@ -410,53 +385,37 @@ namespace cppdlr {
}

// Swap current row pair with chosen pivot row pair
tmp = aa(j, _);
aa(j, _) = aa(jpiv, _);
aa(jpiv, _) = tmp;
tmp = aa(j + 1, _);
aa(j + 1, _) = aa(jpiv + 1, _);
aa(jpiv + 1, _) = tmp;

nrm = norms(j);
norms(j) = norms(jpiv);
norms(jpiv) = nrm;
nrm = norms(j + 1);
norms(j + 1) = norms(jpiv + 1);
norms(jpiv + 1) = nrm;

jj = piv(j);
piv(j) = piv(jpiv);
piv(jpiv) = jj;
jj = piv(j + 1);
piv(j + 1) = piv(jpiv + 1);
piv(jpiv + 1) = jj;
deep_swap(aa(j, _), aa(jpiv, _));
deep_swap(aa(j + 1, _), aa(jpiv + 1, _));
std::swap(norms(j), norms(jpiv));
std::swap(norms(j + 1), norms(jpiv + 1));
std::swap(piv(j), piv(jpiv));
std::swap(piv(j + 1), piv(jpiv + 1));

// Orthogonalize current row (now the first chosen pivot row) against all
// previously chosen rows
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }

// Normalize current row
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
aa(j, _) /= sqrt(normsq(aa(j, _)));

// Orthogonalize remaining rows against current row
for (int k = j + 1; k < m; ++k) {
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
norms(k) = normsq(aa(k, _));
}

// Orthogonalize current row (now the second chosen pivot row) against all
// previously chosen rows
for (int k = 0; k < j + 1; ++k) { aa(j + 1, _) = aa(j + 1, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j + 1, _)); }

// Normalize current row
nrm = nda::real(nda::blas::dotc(aa(j + 1, _), aa(j + 1, _)));
aa(j + 1, _) = aa(j + 1, _) * (1 / sqrt(nrm));
aa(j + 1, _) /= sqrt(normsq(aa(j + 1, _)));

// Orthogonalize remaining rows against current row
for (int k = j + 2; k < m; ++k) {
aa(k, _) = aa(k, _) - aa(j + 1, _) * nda::blas::dotc(aa(j + 1, _), aa(k, _));
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
norms(k) = normsq(aa(k, _));
}
}

Expand Down Expand Up @@ -551,7 +510,7 @@ namespace cppdlr {
* @return Contraction of the inner dimensions of \p a and \p b
*/
template <nda::MemoryArray Ta, nda::MemoryArray Tb, nda::Scalar Sa = nda::get_value_t<Ta>, nda::Scalar Sb = nda::get_value_t<Tb>,
nda::Scalar S = typename std::common_type<Sa, Sb>::type>
nda::Scalar S = std::common_type_t<Sa, Sb>>
nda::array<S, Ta::rank + Tb::rank - 2> arraymult(Ta const &a, Tb const &b) {

// Get ranks of input arrays
Expand Down

0 comments on commit 14263a6

Please sign in to comment.