Skip to content

Commit ae20296

Browse files
authored
amrex::callNoinline: Call given function without inline (#4606)
This works on lambdas, functors, normal functions. But it does not work on overloaded functions like std::sin. If needed, one could however wrap functions like std::sin inside a lambda function. It also does not work with normal functions for SYCL and one would have to wrap it inside a lambda. Here is the motivation behind this PR. In this impactx PR (BLAST-ImpactX/impactx#964), a GPU kernel uses 8 amrex::Parser's. The CUDA CI fails if more than one job is used in build. Apparently the kernel is too big because all those parser functions are inlined. This PR provides a way to reduce the size by forcing noinline.
1 parent 9fdcf3c commit ae20296

File tree

7 files changed

+158
-4
lines changed

7 files changed

+158
-4
lines changed

Src/Base/AMReX_TypeTraits.H

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ namespace amrex
8787
struct IsParticleContainer : public std::is_base_of<ParticleContainerBase, T>::type {};
8888
#endif
8989

90+
template <class T, class Enable = void>
91+
struct DefinitelyNotHostRunnable : std::false_type {};
92+
9093
#ifdef AMREX_USE_GPU
9194

9295
template <class T, class Enable = void>
@@ -95,9 +98,6 @@ namespace amrex
9598
template <class T, class Enable = void>
9699
struct MaybeHostDeviceRunnable : std::true_type {};
97100

98-
template <class T, class Enable = void>
99-
struct DefinitelyNotHostRunnable : std::false_type {};
100-
101101
#if defined(AMREX_USE_CUDA) && defined(__NVCC__)
102102

103103
template <class T>

Src/Base/AMReX_Utility.H

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <AMReX_GpuQualifiers.H>
1717
#include <AMReX_FileSystem.H>
1818
#include <AMReX_String.H>
19+
#include <AMReX_TypeTraits.H>
1920

2021
#include <cfloat>
2122
#include <chrono>
@@ -242,6 +243,19 @@ namespace amrex
242243
int limit = 100,
243244
std::ostringstream ss = {});
244245

246+
/**
247+
* \brief Call given function without inline.
248+
*
249+
* This works on lambdas, functors, normal functions. But it does not
250+
* work with overloaded functions like std::sin. If needed, one could
251+
* however wrap functions like std::sin inside a lambda function. It
252+
* also does not work with normal functions for SYCL and one would have
253+
* to wrap it inside a lambda.
254+
*/
255+
template <typename F, typename... T>
256+
AMREX_GPU_HOST_DEVICE AMREX_NO_INLINE
257+
auto callNoinline (F const& f, T&&... arg)
258+
-> decltype(std::declval<F>()(std::declval<T>()...)); // needed for nvcc
245259
}
246260

247261
template <typename T>
@@ -506,5 +520,17 @@ std::string amrex::ToString(const T& t, const char* symbol_begin, const char* sy
506520
return ss.str();
507521
}
508522

523+
template <typename F, typename... T>
524+
AMREX_GPU_HOST_DEVICE AMREX_NO_INLINE
525+
auto amrex::callNoinline (F const& f, T&&... arg)
526+
-> decltype(std::declval<F>()(std::declval<T>()...)) // needed for nvcc
527+
{
528+
AMREX_IF_ON_HOST((
529+
if constexpr (!amrex::DefinitelyNotHostRunnable<F>::value) {
530+
return f(std::forward<T>(arg)...);
531+
}
532+
))
533+
AMREX_IF_ON_DEVICE(( return f(std::forward<T>(arg)...); ))
534+
}
509535

510536
#endif /*BL_UTILITY_H*/

Tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ else()
125125
#
126126
# List of subdirectories to search for CMakeLists.
127127
#
128-
set( AMREX_TESTS_SUBDIRS Amr AsyncOut CLZ CommType CTOParFor DeviceGlobal Enum
128+
set( AMREX_TESTS_SUBDIRS Amr AsyncOut CallNoinline CLZ CommType CTOParFor DeviceGlobal Enum
129129
MultiBlock MultiPeriod ParmParse Parser Parser2 ParserUserFn Reinit
130130
RoundoffDomain SmallMatrix)
131131

Tests/CallNoinline/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
foreach(D IN LISTS AMReX_SPACEDIM)
2+
set(_sources main.cpp)
3+
set(_input_files)
4+
5+
setup_test(${D} _sources _input_files)
6+
7+
unset(_sources)
8+
unset(_input_files)
9+
endforeach()

Tests/CallNoinline/GNUmakefile

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
AMREX_HOME ?= ../../amrex
2+
3+
DEBUG = FALSE
4+
5+
DIM = 3
6+
7+
COMP = gcc
8+
9+
USE_MPI = FALSE
10+
USE_OMP = FALSE
11+
USE_CUDA = FALSE
12+
USE_HIP = FALSE
13+
USE_SYCL = FALSE
14+
15+
BL_NO_FORT = TRUE
16+
17+
TINY_PROFILE = FALSE
18+
19+
include $(AMREX_HOME)/Tools/GNUMake/Make.defs
20+
21+
include ./Make.package
22+
include $(AMREX_HOME)/Src/Base/Make.package
23+
24+
include $(AMREX_HOME)/Tools/GNUMake/Make.rules

Tests/CallNoinline/Make.package

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CEXE_sources += main.cpp

Tests/CallNoinline/main.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#include <AMReX.H>
2+
#include <AMReX_Gpu.H>
3+
#include <AMReX_Utility.H>
4+
5+
#include <cmath>
6+
7+
using namespace amrex;
8+
9+
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
10+
void halve (double& x)
11+
{
12+
x *= 0.5;
13+
}
14+
15+
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
16+
double f (double x, double y)
17+
{
18+
return std::cos(x*y);
19+
}
20+
21+
struct S {
22+
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
23+
double operator() (double a, double b, double c) const {
24+
return std::cos(a*b*c);
25+
}
26+
};
27+
28+
struct P {
29+
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
30+
double operator() (double a, double b, double c) const {
31+
return std::cos(a*b*c);
32+
}
33+
};
34+
35+
#if defined(__GNUC__) && !defined(__clang__)
36+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
37+
#endif
38+
39+
int main(int argc, char* argv[])
40+
{
41+
amrex::Initialize(argc,argv);
42+
{
43+
double pi = amrex::Math::pi<double>();
44+
double one = 1.0;
45+
46+
auto g = [] AMREX_GPU_DEVICE (double a, double b, double c)
47+
{
48+
return std::sin(a*b*c);
49+
};
50+
51+
S s{};
52+
P p{};
53+
54+
amrex::Gpu::HostVector<double> ones(2);
55+
amrex::Gpu::HostVector<double> zeroes(3);
56+
auto* p1 = ones.data();
57+
auto* p0 = zeroes.data();
58+
59+
amrex::ParallelFor(1, [=] AMREX_GPU_DEVICE (int)
60+
{
61+
p1[0] = callNoinline([] (double a) { return std::sin(a); }, pi*one*0.5);
62+
p1[1] = callNoinline(g, pi, one, 0.5);
63+
64+
auto half = one;
65+
#ifdef AMREX_USE_SYCL
66+
callNoinline([] (double& x) { halve(x); }, half);
67+
p0[0] = callNoinline([] (double a, double b) { return f(a,b); }, pi, half);
68+
#else
69+
callNoinline(halve, half);
70+
p0[0] = callNoinline(f, pi, half);
71+
#endif
72+
auto half2 = one;
73+
callNoinline([] (double& a) { a *= 0.5; }, half2);
74+
p0[1] = callNoinline(s, pi, one, half2);
75+
p0[2] = callNoinline(p, pi, one, half2);
76+
});
77+
Gpu::streamSynchronize();
78+
79+
zeroes.push_back(callNoinline(f, pi, 0.5));
80+
zeroes.push_back(callNoinline(P{}, pi, one, 0.5));
81+
82+
amrex::Print() << "ones: " << amrex::ToString(ones) << "\n"
83+
<< "zeroes: " << amrex::ToString(zeroes) << "\n";
84+
85+
for (auto x : ones) {
86+
AMREX_ALWAYS_ASSERT(almostEqual(x, 1.0, 10));
87+
}
88+
89+
for (auto x : zeroes) {
90+
AMREX_ALWAYS_ASSERT(std::abs(x) < 1.e-15);
91+
}
92+
}
93+
amrex::Finalize();
94+
}

0 commit comments

Comments
 (0)