diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index c080c3fac87d4..deb2fba1cd796 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -223,21 +223,19 @@ def ScheduleModifier : OpenMP_I32EnumAttr< def ScheduleModifierAttr : OpenMP_EnumAttr; //===----------------------------------------------------------------------===// -// target_region_flags enum. +// target_exec_mode enum. //===----------------------------------------------------------------------===// -def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">; -def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>; -def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>; -def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>; - -def TargetRegionFlags : OpenMP_BitEnumAttr< - "TargetRegionFlags", - "target region property flags", [ - TargetRegionFlagsNone, - TargetRegionFlagsGeneric, - TargetRegionFlagsSpmd, - TargetRegionFlagsTripCount +def TargetExecModeBare : I32EnumAttrCase<"bare", 0>; +def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>; +def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>; + +def TargetExecMode : OpenMP_I32EnumAttr< + "TargetExecMode", + "target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [ + TargetExecModeBare, + TargetExecModeGeneric, + TargetExecModeSpmd, ]>; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index be114ea4fb631..6569905c5fae4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1517,13 +1517,17 @@ def TargetOp : OpenMP_Op<"target", traits = [ /// operations, the top level one will be the one captured. Operation *getInnermostCapturedOmpOp(); - /// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the - /// contents of the target region. + /// Infers the kernel type (Bare, Generic or SPMD) based on the contents of + /// the target region. /// /// \param capturedOp result of a still valid (no modifications made to any /// nested operations) previous call to `getInnermostCapturedOmpOp()`. - static ::mlir::omp::TargetRegionFlags - getKernelExecFlags(Operation *capturedOp); + /// \param hostEvalTripCount output argument to store whether this kernel + /// wraps a loop whose bounds must be evaluated on the host prior to + /// launching it. + static ::mlir::omp::TargetExecMode + getKernelExecFlags(Operation *capturedOp, + bool *hostEvalTripCount = nullptr); }] # clausesExtraClassDeclaration; let assemblyFormat = clausesAssemblyFormat # [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c1c1767ef90b0..c3c17006fe571 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() { return emitError("target containing multiple 'omp.teams' nested ops"); // Check that host_eval values are only used in legal ways. + bool hostEvalTripCount; Operation *capturedOp = getInnermostCapturedOmpOp(); - TargetRegionFlags execFlags = getKernelExecFlags(capturedOp); + TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount); for (Value hostEvalArg : cast(getOperation()).getHostEvalBlockArgs()) { for (Operation *user : hostEvalArg.getUsers()) { @@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() { "and 'thread_limit' in 'omp.teams'"; } if (auto parallelOp = dyn_cast(user)) { - if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && + if (execMode == TargetExecMode::spmd && parallelOp->isAncestor(capturedOp) && hostEvalArg == parallelOp.getNumThreads()) continue; @@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() { "'omp.parallel' when representing target SPMD"; } if (auto loopNestOp = dyn_cast(user)) { - if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) && - loopNestOp.getOperation() == capturedOp && + if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp && (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) @@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { }); } -TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { +TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp, + bool *hostEvalTripCount) { + // TODO: Support detection of bare kernel mode. // A non-null captured op is only valid if it resides inside of a TargetOp // and is the result of calling getInnermostCapturedOmpOp() on it. TargetOp targetOp = @@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) && "unexpected captured op"); + if (hostEvalTripCount) + *hostEvalTripCount = false; + // If it's not capturing a loop, it's a default target region. if (!isa_and_present(capturedOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; // Get the innermost non-simd loop wrapper. SmallVector loopWrappers; @@ -2130,79 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { auto numWrappers = std::distance(innermostWrapper, loopWrappers.end()); if (numWrappers != 1 && numWrappers != 2) - return TargetRegionFlags::generic; + return TargetExecMode::generic; // Detect target-teams-distribute-parallel-wsloop[-simd]. if (numWrappers == 2) { if (!isa(innermostWrapper)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; innermostWrapper = std::next(innermostWrapper); if (!isa(innermostWrapper)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; Operation *teamsOp = parallelOp->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; - if (teamsOp->getParentOp() == targetOp.getOperation()) - return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + if (teamsOp->getParentOp() == targetOp.getOperation()) { + if (hostEvalTripCount) + *hostEvalTripCount = true; + return TargetExecMode::spmd; + } } // Detect target-teams-distribute[-simd] and target-teams-loop. else if (isa(innermostWrapper)) { Operation *teamsOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; if (teamsOp->getParentOp() != targetOp.getOperation()) - return TargetRegionFlags::generic; + return TargetExecMode::generic; + + if (hostEvalTripCount) + *hostEvalTripCount = true; if (isa(innermostWrapper)) - return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; - - // Find single immediately nested captured omp.parallel and add spmd flag - // (generic-spmd case). - // - // TODO: This shouldn't have to be done here, as it is too easy to break. - // The openmp-opt pass should be updated to be able to promote kernels like - // this from "Generic" to "Generic-SPMD". However, the use of the - // `kmpc_distribute_static_loop` family of functions produced by the - // OMPIRBuilder for these kernels prevents that from working. - Dialect *ompDialect = targetOp->getDialect(); - Operation *nestedCapture = findCapturedOmpOp( - capturedOp, /*checkSingleMandatoryExec=*/false, - [&](Operation *sibling) { - return sibling && (ompDialect != sibling->getDialect() || - sibling->hasTrait()); - }); - - TargetRegionFlags result = - TargetRegionFlags::generic | TargetRegionFlags::trip_count; - - if (!nestedCapture) - return result; - - while (nestedCapture->getParentOp() != capturedOp) - nestedCapture = nestedCapture->getParentOp(); - - return isa(nestedCapture) ? result | TargetRegionFlags::spmd - : result; + return TargetExecMode::spmd; + + return TargetExecMode::generic; } // Detect target-parallel-wsloop[-simd]. else if (isa(innermostWrapper)) { Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; if (parallelOp->getParentOp() == targetOp.getOperation()) - return TargetRegionFlags::spmd; + return TargetExecMode::spmd; } - return TargetRegionFlags::generic; + return TargetExecMode::generic; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index eb96cb211fdd5..d49cc38cd7925 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5354,17 +5354,18 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, } // Update kernel bounds structure for the `OpenMPIRBuilder` to use. - omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp); - assert( - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic | - omp::TargetRegionFlags::spmd) && - "invalid kernel flags"); - attrs.ExecFlags = - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic) - ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) - ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD - : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC - : llvm::omp::OMP_TGT_EXEC_MODE_SPMD; + omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp); + switch (execMode) { + case omp::TargetExecMode::bare: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE; + break; + case omp::TargetExecMode::generic: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC; + break; + case omp::TargetExecMode::spmd: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD; + break; + } attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; @@ -5414,8 +5415,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, if (numThreads) attrs.MaxThreads = moduleTranslation.lookupValue(numThreads); - if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp), - omp::TargetRegionFlags::trip_count)) { + bool hostEvalTripCount; + targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount); + if (hostEvalTripCount) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); attrs.LoopTripCount = nullptr; diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir index 9bb2b40a43def..fd190a7b95f66 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir @@ -87,7 +87,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo } } -// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]] +// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]] // DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata" // DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy { // DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},