Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 133 additions & 2 deletions src/Parallel/Callback.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,45 @@

#pragma once

#include <memory>
#include <pup.h>
#include <tuple>
#include <utility>

#include "Parallel/Invoke.hpp"
#include "Utilities/PrettyType.hpp"
#include "Utilities/Serialization/CharmPupable.hpp"
#include "Utilities/Serialization/RegisterDerivedClassesWithCharm.hpp"
#include "Utilities/TypeTraits/HasEquivalence.hpp"

namespace Parallel {
namespace detail {
// Not all tuple arguments are guaranteed to have operator==, so we check the
// ones we can.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried about silently ignoring entries. What types typically occur that cause problems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problematic types were proxies and I didn't want to deal with how to evaluate equivalence of those

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. charmplusplus/charm#3848 is probably the real fix, for what it's worth, but we can't wait for that.

template <typename... Args>
bool tuple_equal(const std::tuple<Args...>& tuple_1,
const std::tuple<Args...>& tuple_2) {
bool result = true;
tmpl::for_each<tmpl::make_sequence<tmpl::size_t<0>,
tmpl::size<tmpl::list<Args...>>::value>>(
[&](const auto index_v) {
constexpr size_t index = tmpl::type_from<decltype(index_v)>::value;

if (not result) {
return;
}

if constexpr (tt::has_equivalence_v<decltype(std::get<index>(
tuple_1))>) {
result =
result and std::get<index>(tuple_1) == std::get<index>(tuple_2);
}
});

return result;
}
} // namespace detail

/// An abstract base class, whose derived class holds a function that
/// can be invoked at a later time. The function is intended to be
/// invoked only once.
Expand All @@ -30,6 +60,12 @@ class Callback : public PUP::able {
explicit Callback(CkMigrateMessage* msg) : PUP::able(msg) {}
virtual void invoke() = 0;
virtual void register_with_charm() = 0;
/*!
* \brief Returns if this callback is equal to the one passed in.
*/
virtual bool is_equal_to(const Callback& rhs) const = 0;
virtual std::string name() const = 0;
virtual std::unique_ptr<Callback> get_clone() = 0;
};

/// Wraps a call to a simple action and its arguments.
Expand Down Expand Up @@ -65,6 +101,27 @@ class SimpleActionCallback : public Callback {
register_classes_with_charm<SimpleActionCallback>();
}

bool is_equal_to(const Callback& rhs) const override {
const auto* downcast_ptr = dynamic_cast<const SimpleActionCallback*>(&rhs);
if (downcast_ptr == nullptr) {
return false;
}
return detail::tuple_equal(args_, downcast_ptr->args_);
}

std::string name() const override {
// Use pretty_type::get_name with the action since we want to differentiate
// template paremeters. Only use pretty_type::name for proxy because it'll
// likely be really long with the template parameters which is unnecessary
return "SimpleActionCallback(" + pretty_type::get_name<SimpleAction>() +
"," + pretty_type::name<Proxy>() + ")";
}

std::unique_ptr<Callback> get_clone() override {
return std::make_unique<SimpleActionCallback<SimpleAction, Proxy, Args...>>(
*this);
}

private:
std::decay_t<Proxy> proxy_{};
std::tuple<std::decay_t<Args>...> args_{};
Expand Down Expand Up @@ -93,6 +150,23 @@ class SimpleActionCallback<SimpleAction, Proxy> : public Callback {
register_classes_with_charm<SimpleActionCallback>();
}

bool is_equal_to(const Callback& rhs) const override {
const auto* downcast_ptr = dynamic_cast<const SimpleActionCallback*>(&rhs);
return downcast_ptr != nullptr;
}

std::string name() const override {
// Use pretty_type::get_name with the action since we want to differentiate
// template paremeters. Only use pretty_type::name for proxy because it'll
// likely be really long with the template parameters which is unnecessary
return "SimpleActionCallback(" + pretty_type::get_name<SimpleAction>() +
"," + pretty_type::name<Proxy>() + ")";
}

std::unique_ptr<Callback> get_clone() override {
return std::make_unique<SimpleActionCallback<SimpleAction, Proxy>>(*this);
}

private:
std::decay_t<Proxy> proxy_{};
};
Expand Down Expand Up @@ -130,6 +204,28 @@ class ThreadedActionCallback : public Callback {
register_classes_with_charm<ThreadedActionCallback>();
}

bool is_equal_to(const Callback& rhs) const override {
const auto* downcast_ptr =
dynamic_cast<const ThreadedActionCallback*>(&rhs);
if (downcast_ptr == nullptr) {
return false;
}
return detail::tuple_equal(args_, downcast_ptr->args_);
}

std::string name() const override {
// Use pretty_type::get_name with the action since we want to differentiate
// template paremeters. Only use pretty_type::name for proxy because it'll
// likely be really long with the template parameters which is unnecessary
return "ThreadedActionCallback(" + pretty_type::get_name<ThreadedAction>() +
"," + pretty_type::name<Proxy>() + ")";
}

std::unique_ptr<Callback> get_clone() override {
return std::make_unique<
ThreadedActionCallback<ThreadedAction, Proxy, Args...>>(*this);
}

private:
std::decay_t<Proxy> proxy_{};
std::tuple<std::decay_t<Args>...> args_{};
Expand Down Expand Up @@ -158,6 +254,25 @@ class ThreadedActionCallback<ThreadedAction, Proxy> : public Callback {
register_classes_with_charm<ThreadedActionCallback>();
}

bool is_equal_to(const Callback& rhs) const override {
const auto* downcast_ptr =
dynamic_cast<const ThreadedActionCallback*>(&rhs);
return downcast_ptr != nullptr;
}

std::string name() const override {
// Use pretty_type::get_name with the action since we want to differentiate
// template paremeters. Only use pretty_type::name for proxy because it'll
// likely be really long with the template parameters which is unnecessary
return "ThreadedActionCallback(" + pretty_type::get_name<ThreadedAction>() +
"," + pretty_type::name<Proxy>() + ")";
}

std::unique_ptr<Callback> get_clone() override {
return std::make_unique<ThreadedActionCallback<ThreadedAction, Proxy>>(
*this);
}

private:
std::decay_t<Proxy> proxy_{};
};
Expand All @@ -184,6 +299,22 @@ class PerformAlgorithmCallback : public Callback {
register_classes_with_charm<PerformAlgorithmCallback>();
}

bool is_equal_to(const Callback& rhs) const override {
const auto* downcast_ptr =
dynamic_cast<const PerformAlgorithmCallback*>(&rhs);
return downcast_ptr != nullptr;
}

std::string name() const override {
// Only use pretty_type::name for proxy because it'll likely be really long
// with the template parameters which is unnecessary
return "PerformAlgorithmCallback(" + pretty_type::name<Proxy>() + ")";
}

std::unique_ptr<Callback> get_clone() override {
return std::make_unique<PerformAlgorithmCallback<Proxy>>(*this);
}

private:
std::decay_t<Proxy> proxy_{};
};
Expand All @@ -194,7 +325,7 @@ template <typename Proxy>
PUP::able::PUP_ID PerformAlgorithmCallback<Proxy>::my_PUP_ID = 0;
template <typename SimpleAction, typename Proxy, typename... Args>
PUP::able::PUP_ID
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
SimpleActionCallback<SimpleAction, Proxy, Args...>::my_PUP_ID =
0; // NOLINT
template <typename SimpleAction, typename Proxy>
Expand All @@ -203,7 +334,7 @@ PUP::able::PUP_ID SimpleActionCallback<SimpleAction, Proxy>::my_PUP_ID =
0; // NOLINT
template <typename ThreadedAction, typename Proxy, typename... Args>
PUP::able::PUP_ID
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
ThreadedActionCallback<ThreadedAction, Proxy, Args...>::my_PUP_ID =
0; // NOLINT
template <typename ThreadedAction, typename Proxy>
Expand Down
73 changes: 56 additions & 17 deletions src/Parallel/GlobalCache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "Parallel/ParallelComponentHelpers.hpp"
#include "Parallel/ResourceInfo.hpp"
#include "Parallel/Tags/ResourceInfo.hpp"
#include "Utilities/Algorithm.hpp"
#include "Utilities/ErrorHandling/Assert.hpp"
#include "Utilities/ErrorHandling/Error.hpp"
#include "Utilities/Gsl.hpp"
Expand Down Expand Up @@ -77,10 +78,10 @@ CREATE_GET_TYPE_ALIAS_OR_DEFAULT(component_being_mocked)

template <typename... Tags>
auto make_mutable_cache_tag_storage(tuples::TaggedTuple<Tags...>&& input) {
return tuples::TaggedTuple<MutableCacheTag<Tags>...>(
std::make_tuple(std::move(tuples::get<Tags>(input)),
std::unordered_map<Parallel::ArrayComponentId,
std::unique_ptr<Callback>>{})...);
return tuples::TaggedTuple<MutableCacheTag<Tags>...>(std::make_tuple(
std::move(tuples::get<Tags>(input)),
std::unordered_map<Parallel::ArrayComponentId,
std::vector<std::unique_ptr<Callback>>>{})...);
}

template <typename ParallelComponent, typename ComponentList>
Expand Down Expand Up @@ -487,14 +488,34 @@ bool GlobalCache<Metavariables>::mutable_cache_item_is_ready(
optional_callback->register_with_charm();
// Second mutex is for vector of callbacks
std::mutex& mutex = tuples::get<MutexTag<tag>>(mutexes_).second;
const std::unique_ptr<Callback> clone_of_optional_callback =
optional_callback->get_clone();
{
// Scoped for lock guard
const std::lock_guard<std::mutex> lock(mutex);
std::unordered_map<Parallel::ArrayComponentId, std::unique_ptr<Callback>>&
callbacks = std::get<1>(tuples::get<tag>(mutable_global_cache_));

if (callbacks.count(array_component_id) != 1) {
callbacks[array_component_id] = std::move(optional_callback);
std::unordered_map<Parallel::ArrayComponentId,
std::vector<std::unique_ptr<Callback>>>& callbacks =
std::get<1>(tuples::get<tag>(mutable_global_cache_));

if (callbacks.contains(array_component_id)) {
// If this array component id already exists, we don't want to add
// multiple of the same callback, so we loop over the existing callbacks
// and only if none of the existing callbacks are equal to the optional
// callback do we move the optional callback into the vector
auto& vec_callbacks = callbacks.at(array_component_id);
if (alg::none_of(vec_callbacks,
[&](const std::unique_ptr<Callback>& local_callback) {
return local_callback->is_equal_to(
*optional_callback);
})) {
vec_callbacks.emplace_back(std::move(optional_callback));
}
} else {
// If we don't have this array component id, then we create the vector
// and move the optional callback into the vector
callbacks[array_component_id] =
std::vector<std::unique_ptr<Callback>>(1);
callbacks.at(array_component_id)[0] = std::move(optional_callback);
}
}

Expand Down Expand Up @@ -531,10 +552,26 @@ bool GlobalCache<Metavariables>::mutable_cache_item_is_ready(
const bool cache_item_is_ready = not callback_was_registered();
if (cache_item_is_ready) {
const std::lock_guard<std::mutex> lock(mutex);
std::unordered_map<Parallel::ArrayComponentId, std::unique_ptr<Callback>>&
callbacks = std::get<1>(tuples::get<tag>(mutable_global_cache_));

callbacks.erase(array_component_id);
std::unordered_map<Parallel::ArrayComponentId,
std::vector<std::unique_ptr<Callback>>>& callbacks =
std::get<1>(tuples::get<tag>(mutable_global_cache_));

// It's possible that no new callbacks were registered, so make sure this
// array component id still has callbacks before trying to remove them.
if (callbacks.contains(array_component_id)) {
// If this callback was a duplicate, we'll have to search through all
// callbacks to determine which to remove. If it wasn't a duplicate,
// then it'll just be the last callback in the vector.
auto& vec_callbacks = callbacks.at(array_component_id);
std::erase_if(vec_callbacks,
[&clone_of_optional_callback](const auto& t) {
return t->is_equal_to(*clone_of_optional_callback);
});

if (callbacks.at(array_component_id).empty()) {
callbacks.erase(array_component_id);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if the vector is empty after removing the element, instead of before.

}
}
}

return cache_item_is_ready;
Expand Down Expand Up @@ -573,7 +610,8 @@ void GlobalCache<Metavariables>::mutate(const std::tuple<Args...>& args) {
// Therefore, after locking it, we std::move the map of callbacks into a
// temporary map, clear the original map, and invoke the callbacks in the
// temporary map.
std::unordered_map<Parallel::ArrayComponentId, std::unique_ptr<Callback>>
std::unordered_map<Parallel::ArrayComponentId,
std::vector<std::unique_ptr<Callback>>>
callbacks{};
// Second mutex is for map of callbacks
std::mutex& mutex = tuples::get<MutexTag<tag>>(mutexes_).second;
Expand All @@ -587,9 +625,10 @@ void GlobalCache<Metavariables>::mutate(const std::tuple<Args...>& args) {
// Invoke the callbacks. Any new callbacks that are added to the
// list (if a callback calls mutable_cache_item_is_ready) will be
// saved and will not be invoked here.
for (auto& [array_component_id, callback] : callbacks) {
(void)array_component_id;
callback->invoke();
for (auto& [array_component_id, vec_callbacks] : callbacks) {
for (auto& callback : vec_callbacks) {
callback->invoke();
}
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/Parallel/ParallelComponentHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ struct MutexTag {
template <typename Tag>
struct MutableCacheTag {
using tag = Tag;
using type = std::tuple<typename Tag::type,
std::unordered_map<Parallel::ArrayComponentId,
std::unique_ptr<Callback>>>;
using type =
std::tuple<typename Tag::type,
std::unordered_map<Parallel::ArrayComponentId,
std::vector<std::unique_ptr<Callback>>>>;
};

template <typename Tag>
Expand Down
20 changes: 20 additions & 0 deletions tests/Unit/Parallel/Test_Callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,34 @@ struct RunCallbacks {
Parallel::SimpleActionCallback<MultiplyValueByFactor, decltype(proxy_2),
double>
callback_2(proxy_2, 1.5);
SPECTRE_PARALLEL_REQUIRE(
callback_0.name().find("PerformAlgorithmCallback") !=
std::string::npos);
SPECTRE_PARALLEL_REQUIRE(
(callback_1.name().find("SimpleActionCallback") != std::string::npos and
callback_1.name().find("IncrementValue") != std::string::npos));
SPECTRE_PARALLEL_REQUIRE(
(callback_2.name().find("SimpleActionCallback") != std::string::npos and
callback_2.name().find("MultiplyValueByFactor") != std::string::npos));
callback_0.invoke();
callback_1.invoke();
callback_2.invoke();
auto callback_3 = serialize_and_deserialize(callback_0);
auto callback_4 = serialize_and_deserialize(callback_1);
auto callback_5 = serialize_and_deserialize(callback_2);
SPECTRE_PARALLEL_REQUIRE(callback_0.is_equal_to(callback_3));
SPECTRE_PARALLEL_REQUIRE_FALSE(callback_0.is_equal_to(callback_4));
SPECTRE_PARALLEL_REQUIRE(callback_1.is_equal_to(callback_4));
SPECTRE_PARALLEL_REQUIRE_FALSE(callback_1.is_equal_to(callback_5));
callback_3.invoke();
callback_4.invoke();
callback_5.invoke();
const auto callback_6 = callback_0.get_clone();
const auto callback_7 = callback_1.get_clone();
const auto callback_8 = callback_2.get_clone();
SPECTRE_PARALLEL_REQUIRE(callback_0.is_equal_to(*callback_6));
SPECTRE_PARALLEL_REQUIRE(callback_1.is_equal_to(*callback_7));
SPECTRE_PARALLEL_REQUIRE(callback_2.is_equal_to(*callback_8));
std::vector<std::unique_ptr<Parallel::Callback>> callbacks;
callbacks.emplace_back(
std::make_unique<Parallel::PerformAlgorithmCallback<decltype(proxy_0)>>(
Expand All @@ -267,6 +286,7 @@ struct RunCallbacks {
callbacks.emplace_back(
std::make_unique<Parallel::SimpleActionCallback<
MultiplyValueByFactor, decltype(proxy_2), double>>(proxy_2, 2.0));
SPECTRE_PARALLEL_REQUIRE_FALSE(callback_2.is_equal_to(*callbacks.back()));

auto& nodegroup_proxy =
Parallel::get_parallel_component<TestNodegroup<Metavariables>>(cache);
Expand Down
1 change: 1 addition & 0 deletions tests/Unit/Parallel/Test_GlobalCache.ci
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ mainmodule Test_GlobalCache {
entry void run_test_three();
entry void run_test_four();
entry void run_test_five();
entry void mutate_name();
};
}
Loading
Loading