Skip to content

Commit

Permalink
Unsteady sampling (#46)
Browse files Browse the repository at this point in the history
* MultiBlockSolver::Solve takes sample generator as input.

* MultiBlockSolver::SaveSnapshots takes sample generator.

* ROM linear elements

* templates for ROMTensorElement and ROMEQPElement
  • Loading branch information
dreamer2368 authored Jun 19, 2024
1 parent cdf4280 commit ec9f027
Show file tree
Hide file tree
Showing 20 changed files with 569 additions and 445 deletions.
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,9 @@ set(scaleupROMObj_SOURCES
include/block_smoother.hpp
src/block_smoother.cpp

# include/navier_solver.hpp
# src/navier_solver.cpp
include/rom_element_collection.hpp
src/rom_element_collection.cpp

include/unsteady_ns_solver.hpp
src/unsteady_ns_solver.cpp

Expand Down
4 changes: 2 additions & 2 deletions include/advdiff_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ friend class ParameterizedProblem;
void BuildDomainOperators() override;

// Component-wise assembly
void BuildCompROMLinElems(Array<FiniteElementSpace *> &fes_comp) override;
void BuildCompROMLinElems() override;

bool Solve() override;
bool Solve(SampleGenerator *sample_generator = NULL) override;

void SetFlowAtSubdomain(std::function<void(const Vector &, double, Vector &)> F, const int m=-1);

Expand Down
8 changes: 4 additions & 4 deletions include/linelast_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class LinElastSolver : public MultiBlockSolver
// system-specific.
virtual void AssembleInterfaceMatrices();

virtual bool Solve();
bool Solve(SampleGenerator *sample_generator = NULL) override;

virtual void SetupBCVariables() override;
virtual void SetupIC(std::function<void(const Vector &, double, Vector &)> F);
Expand All @@ -93,9 +93,9 @@ class LinElastSolver : public MultiBlockSolver
virtual void SetupDomainBCOperators();

// Component-wise assembly
virtual void BuildCompROMLinElems(Array<FiniteElementSpace *> &fes_comp);
virtual void BuildBdrROMLinElems(Array<FiniteElementSpace *> &fes_comp);
virtual void BuildItfaceROMLinElems(Array<FiniteElementSpace *> &fes_comp);
void BuildCompROMLinElems() override;
void BuildBdrROMLinElems() override;
void BuildItfaceROMLinElems() override;

virtual void ProjectOperatorOnReducedBasis();

Expand Down
44 changes: 17 additions & 27 deletions include/multiblock_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "rom_handler.hpp"
#include "hdf5_utils.hpp"
#include "sample_generator.hpp"
#include "rom_element_collection.hpp"

// By convention we only use mfem namespace as default, not CAROM.
using namespace mfem;
Expand Down Expand Up @@ -112,13 +113,8 @@ friend class ParameterizedProblem;
bool separate_variable_basis = false;

// Used for bottom-up building, only with ComponentTopologyHandler.
// For now, assumes ROM basis represents the entire vector solution.
Array<MatrixBlocks *> comp_mats; // Size(num_components);
// boundary condition is enforced via forcing term.
Array<Array<MatrixBlocks *> *> bdr_mats;
Array<MatrixBlocks *> port_mats; // reference ports.
// DenseTensor objects from nonlinear operators
// will be defined per each derived MultiBlockSolver.
Array<FiniteElementSpace *> comp_fes;
ROMLinearElement *rom_elems = NULL;

public:
MultiBlockSolver();
Expand Down Expand Up @@ -202,31 +198,23 @@ friend class ParameterizedProblem;

// Component-wise assembly
void GetComponentFESpaces(Array<FiniteElementSpace *> &comp_fes);
virtual void AllocateROMLinElems();
// virtual void AllocateROMLinElems();

void BuildROMLinElems();
virtual void BuildCompROMLinElems(Array<FiniteElementSpace *> &fes_comp) = 0;
virtual void BuildBdrROMLinElems(Array<FiniteElementSpace *> &fes_comp) = 0;
virtual void BuildCompROMLinElems() = 0;
virtual void BuildBdrROMLinElems() = 0;
// TODO(kevin): part of this can be transferred to InterfaceForm.
virtual void BuildItfaceROMLinElems(Array<FiniteElementSpace *> &fes_comp) = 0;

void SaveROMLinElems(const std::string &filename);
// Save ROM Elements in a hdf5-format file specified with file_id.
// TODO: add more arguments to support more general data structures of ROM Elements.
virtual void SaveCompBdrROMLinElems(hid_t &file_id);
void SaveBdrROMLinElems(hid_t &comp_grp_id, const int &comp_idx);
void SaveItfaceROMLinElems(hid_t &file_id);

void LoadROMLinElems(const std::string &filename);
// Load ROM Elements in a hdf5-format file specified with file_id.
// TODO: add more arguments to support more general data structures of ROM Elements.
virtual void LoadCompBdrROMLinElems(hid_t &file_id);
void LoadBdrROMLinElems(hid_t &comp_grp_id, const int &comp_idx);
void LoadItfaceROMLinElems(hid_t &file_id);
virtual void BuildItfaceROMLinElems() = 0;

void SaveROMLinElems(const std::string &filename)
{ assert(rom_elems); rom_elems->Save(filename); }

void LoadROMLinElems(const std::string &filename)
{ assert(rom_elems); rom_elems->Load(filename); }

void AssembleROMMat();

virtual bool Solve() = 0;
virtual bool Solve(SampleGenerator *sample_generator = NULL) = 0;

virtual void InitVisualization(const std::string& output_dir = "");
virtual void InitUnifiedParaview(const std::string &file_prefix);
Expand Down Expand Up @@ -271,7 +259,9 @@ friend class ParameterizedProblem;
void InitROMHandler();
void GetBasisTags(std::vector<std::string> &basis_tags);

virtual void PrepareSnapshots(BlockVector* &U_snapshots, std::vector<std::string> &basis_tags);
virtual BlockVector* PrepareSnapshots(std::vector<std::string> &basis_tags);
void SaveSnapshots(SampleGenerator *sample_generator);

virtual void LoadReducedBasis() { rom_handler->LoadReducedBasis(); }
virtual void ProjectOperatorOnReducedBasis() = 0;
virtual void ProjectRHSOnReducedBasis();
Expand Down
8 changes: 4 additions & 4 deletions include/poisson_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ friend class ParameterizedProblem;
virtual void AssembleInterfaceMatrices();

// Component-wise assembly
virtual void BuildCompROMLinElems(Array<FiniteElementSpace *> &fes_comp);
virtual void BuildBdrROMLinElems(Array<FiniteElementSpace *> &fes_comp);
virtual void BuildItfaceROMLinElems(Array<FiniteElementSpace *> &fes_comp);
virtual void BuildCompROMLinElems() override;
void BuildBdrROMLinElems() override;
void BuildItfaceROMLinElems() override;

virtual bool Solve();
virtual bool Solve(SampleGenerator *sample_generator = NULL);

virtual void ProjectOperatorOnReducedBasis();

Expand Down
118 changes: 118 additions & 0 deletions include/rom_element_collection.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright 2023 Lawrence Livermore National Security, LLC. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: MIT

#ifndef SCALEUPROM_ROM_ELEMENT_COLLECTION_HPP
#define SCALEUPROM_ROM_ELEMENT_COLLECTION_HPP

#include "topology_handler.hpp"
#include "rom_nonlinearform.hpp"
#include "rom_interfaceform.hpp"
#include "mfem.hpp"
#include "hdf5_utils.hpp"

// By convention we only use mfem namespace as default, not CAROM.
using namespace mfem;

class ROMElementCollection
{
protected:
const int num_var;
const int num_comp;
const int num_ref_ports;
const bool separate_variable;

TopologyHandler *topol_handler = NULL; // not owned
Array<FiniteElementSpace *> fes; // not owned

public:
ROMElementCollection(TopologyHandler *topol_handler_, const Array<FiniteElementSpace *> &fes_,
const bool separate_variable_)
: topol_handler(topol_handler_), fes(fes_),
num_comp(topol_handler_->GetNumComponents()),
num_var(fes_.Size() / topol_handler_->GetNumComponents()),
num_ref_ports(topol_handler_->GetNumRefPorts()),
separate_variable(separate_variable_)
{
assert(num_comp * num_var == fes.Size());
assert(topol_handler->GetType() == TopologyHandlerMode::COMPONENT);
}

virtual ~ROMElementCollection() {}

virtual void Save(const std::string &filename) = 0;
virtual void Load(const std::string &filename) = 0;
};

class ROMLinearElement : public ROMElementCollection
{
public:
Array<MatrixBlocks *> comp; // Size(num_components);
// boundary condition is enforced via forcing term.
Array<Array<MatrixBlocks *> *> bdr;
Array<MatrixBlocks *> port; // reference ports.

public:
ROMLinearElement(TopologyHandler *topol_handler_,
const Array<FiniteElementSpace *> &fes_,
const bool separate_variable_);

virtual ~ROMLinearElement();

void Save(const std::string &filename) override;
void Load(const std::string &filename) override;

private:
void SaveCompBdrElems(hid_t &file_id);
void SaveBdrElems(hid_t &comp_grp_id, const int &comp_idx);
void SaveItfaceElems(hid_t &file_id);

void LoadCompBdrElems(hid_t &file_id);
void LoadBdrElems(hid_t &comp_grp_id, const int &comp_idx);
void LoadItfaceElems(hid_t &file_id);
};

class ROMTensorElement : public ROMElementCollection
{
public:
Array<DenseTensor *> comp; // Size(num_components);

/* boundary/interface is not implemented yet.. should consider */
// Array<Array<DenseTensor *> *> bdr;
// Array<DenseTensor *> port; // reference ports.

public:
ROMTensorElement(TopologyHandler *topol_handler_,
const Array<FiniteElementSpace *> &fes_,
const bool separate_variable_)
: ROMElementCollection(topol_handler_, fes_, separate_variable_) {}

virtual ~ROMTensorElement();

void Save(const std::string &filename) override {}
void Load(const std::string &filename) override {}

};

class ROMEQPElement : public ROMElementCollection
{
public:
Array<ROMNonlinearForm *> comp; // Size(num_components);
// boundary condition is enforced via forcing term.
Array<Array<ROMNonlinearForm *> *> bdr;
Array<ROMInterfaceForm *> port; // reference ports.

public:
ROMEQPElement(TopologyHandler *topol_handler_,
const Array<FiniteElementSpace *> &fes_,
const bool separate_variable_)
: ROMElementCollection(topol_handler_, fes_, separate_variable_) {}

virtual ~ROMEQPElement() {}

void Save(const std::string &filename) override {}
void Load(const std::string &filename) override {}

};

#endif
4 changes: 2 additions & 2 deletions include/sample_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ class SampleGenerator
std::vector<std::string> basis_tags;
std::map<std::string, int> basis_tag2idx;

/* snapshot pairs per interface port, for nonlinear EQP */
/* snapshot pairs per interface port, for nonlinear interface EQP */
std::vector<PortTag> port_tags;
std::map<PortTag, int> port_tag2idx;
Array<Array<int> *> port_colidxs;
Array<Array2D<int> *> port_colidxs;

public:
SampleGenerator(MPI_Comm comm);
Expand Down
3 changes: 1 addition & 2 deletions include/steady_ns_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ friend class SteadyNSOperator;
// component ROM element for nonlinear convection.
Array<DenseTensor *> comp_tensors, subdomain_tensors;
Array<ROMNonlinearForm *> comp_eqps, subdomain_eqps;
Array<FiniteElementSpace *> comp_fes; // pointers to existing fespace, no need to delete

Solver *J_solver = NULL;
GMRESSolver *J_gmres = NULL;
Expand All @@ -161,7 +160,7 @@ friend class SteadyNSOperator;
void SaveROMOperator(const std::string input_prefix="") override;
void LoadROMOperatorFromFile(const std::string input_prefix="") override;

bool Solve() override;
bool Solve(SampleGenerator *sample_generator = NULL) override;

void ProjectOperatorOnReducedBasis() override;

Expand Down
8 changes: 4 additions & 4 deletions include/stokes_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ friend class ParameterizedProblem;
virtual void SetupPressureMassMatrix();

// Component-wise assembly
virtual void BuildCompROMLinElems(Array<FiniteElementSpace *> &fes_comp);
virtual void BuildBdrROMLinElems(Array<FiniteElementSpace *> &fes_comp);
virtual void BuildItfaceROMLinElems(Array<FiniteElementSpace *> &fes_comp);
void BuildCompROMLinElems() override;
void BuildBdrROMLinElems() override;
void BuildItfaceROMLinElems() override;

virtual bool Solve();
virtual bool Solve(SampleGenerator *sample_generator = NULL);
virtual void Solve_obsolete();

virtual void LoadReducedBasis() override;
Expand Down
2 changes: 1 addition & 1 deletion include/unsteady_ns_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ friend class SteadyNSOperator;
void LoadROMOperatorFromFile(const std::string input_prefix="") override
{ mfem_error("UnsteadyNSSolver::LoadROMOperatorFromFile is not implemented yet!\n"); }

bool Solve() override;
bool Solve(SampleGenerator *sample_generator = NULL) override;

using MultiBlockSolver::SaveVisualization;
void SaveVisualization(const int step, const double time) override;
Expand Down
20 changes: 11 additions & 9 deletions src/advdiff_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,18 @@ void AdvDiffSolver::BuildDomainOperators()
}
}

void AdvDiffSolver::BuildCompROMLinElems(Array<FiniteElementSpace *> &fes_comp)
void AdvDiffSolver::BuildCompROMLinElems()
{
mfem_error("AdvDiffSolver::BuildCompROMLinElems is not implemented yet!\n");

assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());
assert(rom_elems);

const int num_comp = fes_comp.Size();
assert(comp_mats.Size() == num_comp);

for (int c = 0; c < num_comp; c++)
for (int c = 0; c < topol_handler->GetNumComponents(); c++)
{
Mesh *comp = topol_handler->GetComponentMesh(c);
BilinearForm a_comp(fes_comp[c]);
BilinearForm a_comp(comp_fes[c]);

a_comp.AddDomainIntegrator(new DiffusionIntegrator);
if (full_dg)
Expand All @@ -72,12 +70,12 @@ void AdvDiffSolver::BuildCompROMLinElems(Array<FiniteElementSpace *> &fes_comp)
a_comp.Finalize();

// Poisson equation has only one solution variable.
comp_mats[c]->SetSize(1, 1);
(*comp_mats[c])(0, 0) = rom_handler->ProjectToRefBasis(c, c, &(a_comp.SpMat()));
rom_elems->comp[c]->SetSize(1, 1);
(*rom_elems->comp[c])(0, 0) = rom_handler->ProjectToRefBasis(c, c, &(a_comp.SpMat()));
}
}

bool AdvDiffSolver::Solve()
bool AdvDiffSolver::Solve(SampleGenerator *sample_generator)
{
// If using direct solver, returns always true.
bool converged = true;
Expand Down Expand Up @@ -151,6 +149,10 @@ bool AdvDiffSolver::Solve()
delete solver;
}

/* save solution if sample generator is provided */
if (converged && sample_generator)
SaveSnapshots(sample_generator);

return converged;
}

Expand Down
Loading

0 comments on commit ec9f027

Please sign in to comment.