diff --git a/offload/DeviceRTL/include/DeviceTypes.h b/offload/DeviceRTL/include/DeviceTypes.h index 2e5d92380f040..111143a5578f1 100644 --- a/offload/DeviceRTL/include/DeviceTypes.h +++ b/offload/DeviceRTL/include/DeviceTypes.h @@ -136,6 +136,12 @@ struct omp_lock_t { void *Lock; }; +// see definition in openmp/runtime kmp.h +typedef enum omp_severity_t { + severity_warning = 1, + severity_fatal = 2 +} omp_severity_t; + using InterWarpCopyFnTy = void (*)(void *src, int32_t warp_num); using ShuffleReductFnTy = void (*)(void *rhsData, int16_t lane_id, int16_t lane_offset, int16_t shortCircuit); diff --git a/offload/DeviceRTL/src/Parallelism.cpp b/offload/DeviceRTL/src/Parallelism.cpp index 08ce616aee1c4..aa5e74029ec3e 100644 --- a/offload/DeviceRTL/src/Parallelism.cpp +++ b/offload/DeviceRTL/src/Parallelism.cpp @@ -45,7 +45,24 @@ using namespace ompx; namespace { -uint32_t determineNumberOfThreads(int32_t NumThreadsClause) { +void numThreadsStrictError(int32_t nt_strict, int32_t nt_severity, + const char *nt_message, int32_t requested, + int32_t actual) { + if (nt_message) + printf("%s\n", nt_message); + else + printf("The computed number of threads (%u) does not match the requested " + "number of threads (%d). Consider that it might not be supported " + "to select exactly %d threads on this target device.\n", + actual, requested, requested); + if (nt_severity == severity_fatal) + __builtin_trap(); +} + +uint32_t determineNumberOfThreads(int32_t NumThreadsClause, + int32_t nt_strict = false, + int32_t nt_severity = severity_fatal, + const char *nt_message = nullptr) { uint32_t NThreadsICV = NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads; uint32_t NumThreads = mapping::getMaxTeamThreads(); @@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) { // SPMD mode allows any number of threads, for generic mode we round down to a // multiple of WARPSIZE since it is legal to do so in OpenMP. - if (mapping::isSPMDMode()) - return NumThreads; + if (!mapping::isSPMDMode()) { + if (NumThreads < mapping::getWarpSize()) + NumThreads = 1; + else + NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1)); + } - if (NumThreads < mapping::getWarpSize()) - NumThreads = 1; - else - NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1)); + if (NumThreadsClause != -1 && nt_strict && + NumThreads != static_cast(NumThreadsClause)) + numThreadsStrictError(nt_strict, nt_severity, nt_message, NumThreadsClause, + NumThreads); return NumThreads; } @@ -82,12 +103,14 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) { extern "C" { -[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident, - int32_t num_threads, - void *fn, void **args, - const int64_t nargs) { +[[clang::always_inline]] void +__kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void **args, + const int64_t nargs, int32_t nt_strict = false, + int32_t nt_severity = severity_fatal, + const char *nt_message = nullptr) { uint32_t TId = mapping::getThreadIdInBlock(); - uint32_t NumThreads = determineNumberOfThreads(num_threads); + uint32_t NumThreads = + determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message); uint32_t PTeamSize = NumThreads == mapping::getMaxTeamThreads() ? 0 : NumThreads; // Avoid the race between the read of the `icv::Level` above and the write @@ -140,10 +163,11 @@ extern "C" { return; } -[[clang::always_inline]] void -__kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr, - int32_t num_threads, int proc_bind, void *fn, - void *wrapper_fn, void **args, int64_t nargs) { +[[clang::always_inline]] void __kmpc_parallel_51( + IdentTy *ident, int32_t, int32_t if_expr, int32_t num_threads, + int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs, + int32_t nt_strict = false, int32_t nt_severity = severity_fatal, + const char *nt_message = nullptr) { uint32_t TId = mapping::getThreadIdInBlock(); // Assert the parallelism level is zero if disabled by the user. @@ -156,6 +180,11 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr, // 3) nested parallel regions if (OMP_UNLIKELY(!if_expr || state::HasThreadState || (config::mayUseNestedParallelism() && icv::Level))) { + // OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have + // effect when parallel execution is disabled by a corresponding if clause + // attached to the parallel directive. + if (nt_strict && num_threads > 1) + numThreadsStrictError(nt_strict, nt_severity, nt_message, num_threads, 1); state::DateEnvironmentRAII DERAII(ident); ++icv::Level; invokeMicrotask(TId, 0, fn, args, nargs); @@ -169,12 +198,14 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr, // This was moved to its own routine so it could be called directly // in certain situations to avoid resource consumption of unused // logic in parallel_51. - __kmpc_parallel_spmd(ident, num_threads, fn, args, nargs); + __kmpc_parallel_spmd(ident, num_threads, fn, args, nargs, nt_strict, + nt_severity, nt_message); return; } - uint32_t NumThreads = determineNumberOfThreads(num_threads); + uint32_t NumThreads = + determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message); uint32_t MaxTeamThreads = mapping::getMaxTeamThreads(); uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads; @@ -277,6 +308,16 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr, __kmpc_end_sharing_variables(); } +[[clang::always_inline]] void __kmpc_parallel_60( + IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads, + int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs, + int32_t nt_strict = false, int32_t nt_severity = severity_fatal, + const char *nt_message = nullptr) { + return __kmpc_parallel_51(ident, id, if_expr, num_threads, proc_bind, fn, + wrapper_fn, args, nargs, nt_strict, nt_severity, + nt_message); +} + [[clang::noinline]] bool __kmpc_kernel_parallel(ParallelRegionFnTy *WorkFn) { // Work function and arguments for L1 parallel region. *WorkFn = state::ParallelRegionFn; diff --git a/openmp/runtime/src/kmp.h b/openmp/runtime/src/kmp.h index a2cacc8792b15..983e1c34f76b8 100644 --- a/openmp/runtime/src/kmp.h +++ b/openmp/runtime/src/kmp.h @@ -4666,6 +4666,7 @@ static inline int __kmp_adjust_gtid_for_hidden_helpers(int gtid) { } // Support for error directive +// See definition in offload/DeviceRTL DeviceTypes.h typedef enum kmp_severity_t { severity_warning = 1, severity_fatal = 2 diff --git a/openmp/runtime/src/kmp_runtime.cpp b/openmp/runtime/src/kmp_runtime.cpp index 417eceb8ebecc..6afea9b994de4 100644 --- a/openmp/runtime/src/kmp_runtime.cpp +++ b/openmp/runtime/src/kmp_runtime.cpp @@ -1214,6 +1214,12 @@ void __kmp_serialized_parallel(ident_t *loc, kmp_int32 global_tid) { // Reset for next parallel region this_thr->th.th_set_proc_bind = proc_bind_default; + // OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have + // effect when parallel execution is disabled by a corresponding if clause + // attached to the parallel directive. + if (this_thr->th.th_nt_strict && this_thr->th.th_set_nproc > 1) + __kmpc_error(this_thr->th.th_nt_loc, this_thr->th.th_nt_sev, + this_thr->th.th_nt_msg); // Reset num_threads for next parallel region this_thr->th.th_set_nproc = 0;