diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index 8ccc37ef98a74..a5f2f0efa2c3b 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6801,6 +6801,11 @@ class MappableExprsHandler { llvm::OpenMPIRBuilder::MapNonContiguousArrayTy; using MapExprsArrayTy = SmallVector; using MapValueDeclsArrayTy = SmallVector; + using MapData = + std::tuple, + bool /*IsImplicit*/, const ValueDecl *, const Expr *>; + using MapDataArrayTy = SmallVector; /// This structure contains combined information generated for mappable /// clauses, including base pointers, pointers, sizes, map types, user-defined @@ -8496,6 +8501,7 @@ class MappableExprsHandler { const StructRangeInfoTy &PartialStruct, bool IsMapThis, llvm::OpenMPIRBuilder &OMPBuilder, const ValueDecl *VD = nullptr, + unsigned OffsetForMemberOfFlag = 0, bool NotTargetParams = true) const { if (CurTypes.size() == 1 && ((CurTypes.back() & OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) != @@ -8583,8 +8589,8 @@ class MappableExprsHandler { // All other current entries will be MEMBER_OF the combined entry // (except for PTR_AND_OBJ entries which do not have a placeholder value // 0xFFFF in the MEMBER_OF field). - OpenMPOffloadMappingFlags MemberOfFlag = - OMPBuilder.getMemberOfFlag(CombinedInfo.BasePointers.size() - 1); + OpenMPOffloadMappingFlags MemberOfFlag = OMPBuilder.getMemberOfFlag( + OffsetForMemberOfFlag + CombinedInfo.BasePointers.size() - 1); for (auto &M : CurTypes) OMPBuilder.setCorrectMemberOfFlag(M, MemberOfFlag); } @@ -8727,11 +8733,13 @@ class MappableExprsHandler { } } - /// Generate the base pointers, section pointers, sizes, map types, and - /// mappers associated to a given capture (all included in \a CombinedInfo). - void generateInfoForCapture(const CapturedStmt::Capture *Cap, - llvm::Value *Arg, MapCombinedInfoTy &CombinedInfo, - StructRangeInfoTy &PartialStruct) const { + /// For a capture that has an associated clause, generate the base pointers, + /// section pointers, sizes, map types, and mappers (all included in + /// \a CurCaptureVarInfo). + void generateInfoForCaptureFromClauseInfo( + const CapturedStmt::Capture *Cap, llvm::Value *Arg, + MapCombinedInfoTy &CurCaptureVarInfo, llvm::OpenMPIRBuilder &OMPBuilder, + unsigned OffsetForMemberOfFlag) const { assert(!Cap->capturesVariableArrayType() && "Not expecting to generate map info for a variable array type!"); @@ -8749,26 +8757,22 @@ class MappableExprsHandler { // pass the pointer by value. If it is a reference to a declaration, we just // pass its value. if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) { - CombinedInfo.Exprs.push_back(VD); - CombinedInfo.BasePointers.emplace_back(Arg); - CombinedInfo.DevicePtrDecls.emplace_back(VD); - CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer); - CombinedInfo.Pointers.push_back(Arg); - CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast( + CurCaptureVarInfo.Exprs.push_back(VD); + CurCaptureVarInfo.BasePointers.emplace_back(Arg); + CurCaptureVarInfo.DevicePtrDecls.emplace_back(VD); + CurCaptureVarInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer); + CurCaptureVarInfo.Pointers.push_back(Arg); + CurCaptureVarInfo.Sizes.push_back(CGF.Builder.CreateIntCast( CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty, /*isSigned=*/true)); - CombinedInfo.Types.push_back( + CurCaptureVarInfo.Types.push_back( OpenMPOffloadMappingFlags::OMP_MAP_LITERAL | OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM); - CombinedInfo.Mappers.push_back(nullptr); + CurCaptureVarInfo.Mappers.push_back(nullptr); return; } - using MapData = - std::tuple, bool, - const ValueDecl *, const Expr *>; - SmallVector DeclComponentLists; + MapDataArrayTy DeclComponentLists; // For member fields list in is_device_ptr, store it in // DeclComponentLists for generating components info. static const OpenMPMapModifierKind Unknown = OMPC_MAP_MODIFIER_unknown; @@ -8826,6 +8830,51 @@ class MappableExprsHandler { return (HasPresent && !HasPresentR) || (HasAllocs && !HasAllocsR); }); + auto GenerateInfoForComponentLists = + [&](ArrayRef DeclComponentLists, + bool IsEligibleForTargetParamFlag) { + MapCombinedInfoTy CurInfoForComponentLists; + StructRangeInfoTy PartialStruct; + + if (DeclComponentLists.empty()) + return; + + generateInfoForCaptureFromComponentLists( + VD, DeclComponentLists, CurInfoForComponentLists, PartialStruct, + IsEligibleForTargetParamFlag, + /*AreBothBasePtrAndPteeMapped=*/HasMapBasePtr && HasMapArraySec); + + // If there is an entry in PartialStruct it means we have a + // struct with individual members mapped. Emit an extra combined + // entry. + if (PartialStruct.Base.isValid()) { + CurCaptureVarInfo.append(PartialStruct.PreliminaryMapData); + emitCombinedEntry( + CurCaptureVarInfo, CurInfoForComponentLists.Types, + PartialStruct, Cap->capturesThis(), OMPBuilder, nullptr, + OffsetForMemberOfFlag, + /*NotTargetParams*/ !IsEligibleForTargetParamFlag); + } + + // Return if we didn't add any entries. + if (CurInfoForComponentLists.BasePointers.empty()) + return; + + CurCaptureVarInfo.append(CurInfoForComponentLists); + }; + + GenerateInfoForComponentLists(DeclComponentLists, + /*IsEligibleForTargetParamFlag=*/true); + } + + /// Generate the base pointers, section pointers, sizes, map types, and + /// mappers associated to \a DeclComponentLists for a given capture + /// \a VD (all included in \a CurComponentListInfo). + void generateInfoForCaptureFromComponentLists( + const ValueDecl *VD, ArrayRef DeclComponentLists, + MapCombinedInfoTy &CurComponentListInfo, StructRangeInfoTy &PartialStruct, + bool IsListEligibleForTargetParamFlag, + bool AreBothBasePtrAndPteeMapped = false) const { // Find overlapping elements (including the offset from the base element). llvm::SmallDenseMap< const MapData *, @@ -8949,7 +8998,7 @@ class MappableExprsHandler { // Associated with a capture, because the mapping flags depend on it. // Go through all of the elements with the overlapped elements. - bool IsFirstComponentList = true; + bool AddTargetParamFlag = IsListEligibleForTargetParamFlag; MapCombinedInfoTy StructBaseCombinedInfo; for (const auto &Pair : OverlappedData) { const MapData &L = *Pair.getFirst(); @@ -8964,11 +9013,11 @@ class MappableExprsHandler { ArrayRef OverlappedComponents = Pair.getSecond(); generateInfoForComponentList( - MapType, MapModifiers, {}, Components, CombinedInfo, - StructBaseCombinedInfo, PartialStruct, IsFirstComponentList, - IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper, + MapType, MapModifiers, {}, Components, CurComponentListInfo, + StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag, IsImplicit, + /*GenerateAllInfoForClauses*/ false, Mapper, /*ForDeviceAddr=*/false, VD, VarRef, OverlappedComponents); - IsFirstComponentList = false; + AddTargetParamFlag = false; } // Go through other elements without overlapped elements. for (const MapData &L : DeclComponentLists) { @@ -8983,12 +9032,12 @@ class MappableExprsHandler { auto It = OverlappedData.find(&L); if (It == OverlappedData.end()) generateInfoForComponentList( - MapType, MapModifiers, {}, Components, CombinedInfo, - StructBaseCombinedInfo, PartialStruct, IsFirstComponentList, + MapType, MapModifiers, {}, Components, CurComponentListInfo, + StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag, IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper, /*ForDeviceAddr=*/false, VD, VarRef, - /*OverlappedElements*/ {}, HasMapBasePtr && HasMapArraySec); - IsFirstComponentList = false; + /*OverlappedElements*/ {}, AreBothBasePtrAndPteeMapped); + AddTargetParamFlag = false; } } @@ -9467,7 +9516,6 @@ static void genMapInfoForCaptures( CE = CS.capture_end(); CI != CE; ++CI, ++RI, ++CV) { MappableExprsHandler::MapCombinedInfoTy CurInfo; - MappableExprsHandler::StructRangeInfoTy PartialStruct; // VLA sizes are passed to the outlined region by copy and do not have map // information associated. @@ -9488,13 +9536,18 @@ static void genMapInfoForCaptures( } else { // If we have any information in the map clause, we use it, otherwise we // just do a default mapping. - MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct); + MEHandler.generateInfoForCaptureFromClauseInfo( + CI, *CV, CurInfo, OMPBuilder, + /*OffsetForMemberOfFlag=*/CombinedInfo.BasePointers.size()); + if (!CI->capturesThis()) MappedVarSet.insert(CI->getCapturedVar()); else MappedVarSet.insert(nullptr); - if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid()) + + if (CurInfo.BasePointers.empty()) MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo); + // Generate correct mapping for variables captured by reference in // lambdas. if (CI->capturesVariable()) @@ -9502,7 +9555,7 @@ static void genMapInfoForCaptures( CurInfo, LambdaPointers); } // We expect to have at least an element of information for this capture. - assert((!CurInfo.BasePointers.empty() || PartialStruct.Base.isValid()) && + assert(!CurInfo.BasePointers.empty() && "Non-existing map pointer for capture!"); assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() && CurInfo.BasePointers.size() == CurInfo.Sizes.size() && @@ -9510,15 +9563,6 @@ static void genMapInfoForCaptures( CurInfo.BasePointers.size() == CurInfo.Mappers.size() && "Inconsistent map information sizes!"); - // If there is an entry in PartialStruct it means we have a struct with - // individual members mapped. Emit an extra combined entry. - if (PartialStruct.Base.isValid()) { - CombinedInfo.append(PartialStruct.PreliminaryMapData); - MEHandler.emitCombinedEntry(CombinedInfo, CurInfo.Types, PartialStruct, - CI->capturesThis(), OMPBuilder, nullptr, - /*NotTargetParams*/ false); - } - // We need to append the results of this capture to what we already have. CombinedInfo.append(CurInfo); }