Skip to content

[NFC][Clang][OpenMP] Refactor mapinfo generation for captured vars #146891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 86 additions & 42 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6801,6 +6801,11 @@ class MappableExprsHandler {
llvm::OpenMPIRBuilder::MapNonContiguousArrayTy;
using MapExprsArrayTy = SmallVector<MappingExprInfo, 4>;
using MapValueDeclsArrayTy = SmallVector<const ValueDecl *, 4>;
using MapData =
std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
OpenMPMapClauseKind, ArrayRef<OpenMPMapModifierKind>,
bool /*IsImplicit*/, const ValueDecl *, const Expr *>;
using MapDataArrayTy = SmallVector<MapData, 4>;

/// This structure contains combined information generated for mappable
/// clauses, including base pointers, pointers, sizes, map types, user-defined
Expand Down Expand Up @@ -8496,6 +8501,7 @@ class MappableExprsHandler {
const StructRangeInfoTy &PartialStruct, bool IsMapThis,
llvm::OpenMPIRBuilder &OMPBuilder,
const ValueDecl *VD = nullptr,
unsigned OffsetForMemberOfFlag = 0,
bool NotTargetParams = true) const {
if (CurTypes.size() == 1 &&
((CurTypes.back() & OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
Expand Down Expand Up @@ -8583,8 +8589,8 @@ class MappableExprsHandler {
// All other current entries will be MEMBER_OF the combined entry
// (except for PTR_AND_OBJ entries which do not have a placeholder value
// 0xFFFF in the MEMBER_OF field).
OpenMPOffloadMappingFlags MemberOfFlag =
OMPBuilder.getMemberOfFlag(CombinedInfo.BasePointers.size() - 1);
OpenMPOffloadMappingFlags MemberOfFlag = OMPBuilder.getMemberOfFlag(
OffsetForMemberOfFlag + CombinedInfo.BasePointers.size() - 1);
for (auto &M : CurTypes)
OMPBuilder.setCorrectMemberOfFlag(M, MemberOfFlag);
}
Expand Down Expand Up @@ -8727,11 +8733,13 @@ class MappableExprsHandler {
}
}

/// Generate the base pointers, section pointers, sizes, map types, and
/// mappers associated to a given capture (all included in \a CombinedInfo).
void generateInfoForCapture(const CapturedStmt::Capture *Cap,
llvm::Value *Arg, MapCombinedInfoTy &CombinedInfo,
StructRangeInfoTy &PartialStruct) const {
/// For a capture that has an associated clause, generate the base pointers,
/// section pointers, sizes, map types, and mappers (all included in
/// \a CurCaptureVarInfo).
void generateInfoForCaptureFromClauseInfo(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original function is now split into two. The first one here creates the DeclComponentLists, and the second one handles these DeclComponentLists.

const CapturedStmt::Capture *Cap, llvm::Value *Arg,
MapCombinedInfoTy &CurCaptureVarInfo, llvm::OpenMPIRBuilder &OMPBuilder,
unsigned OffsetForMemberOfFlag) const {
assert(!Cap->capturesVariableArrayType() &&
"Not expecting to generate map info for a variable array type!");

Expand All @@ -8749,26 +8757,22 @@ class MappableExprsHandler {
// pass the pointer by value. If it is a reference to a declaration, we just
// pass its value.
if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) {
CombinedInfo.Exprs.push_back(VD);
CombinedInfo.BasePointers.emplace_back(Arg);
CombinedInfo.DevicePtrDecls.emplace_back(VD);
CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
CombinedInfo.Pointers.push_back(Arg);
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
CurCaptureVarInfo.Exprs.push_back(VD);
CurCaptureVarInfo.BasePointers.emplace_back(Arg);
CurCaptureVarInfo.DevicePtrDecls.emplace_back(VD);
CurCaptureVarInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
CurCaptureVarInfo.Pointers.push_back(Arg);
CurCaptureVarInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
/*isSigned=*/true));
CombinedInfo.Types.push_back(
CurCaptureVarInfo.Types.push_back(
OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM);
CombinedInfo.Mappers.push_back(nullptr);
CurCaptureVarInfo.Mappers.push_back(nullptr);
return;
}

using MapData =
std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
OpenMPMapClauseKind, ArrayRef<OpenMPMapModifierKind>, bool,
const ValueDecl *, const Expr *>;
SmallVector<MapData, 4> DeclComponentLists;
MapDataArrayTy DeclComponentLists;
// For member fields list in is_device_ptr, store it in
// DeclComponentLists for generating components info.
static const OpenMPMapModifierKind Unknown = OMPC_MAP_MODIFIER_unknown;
Expand Down Expand Up @@ -8826,6 +8830,51 @@ class MappableExprsHandler {
return (HasPresent && !HasPresentR) || (HasAllocs && !HasAllocsR);
});

auto GenerateInfoForComponentLists =
[&](ArrayRef<MapData> DeclComponentLists,
bool IsEligibleForTargetParamFlag) {
MapCombinedInfoTy CurInfoForComponentLists;
StructRangeInfoTy PartialStruct;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving PartialStruct here allows us to make it local to each set of DeclComponentLists, so when GenerateInfoForComponentLists is called multiple times with different lists, we'll get different combined entries and member-of flags for each of those lists.

For now, this is called only once with the full set of DeclComponentLists from line 8867.


if (DeclComponentLists.empty())
return;

generateInfoForCaptureFromComponentLists(
VD, DeclComponentLists, CurInfoForComponentLists, PartialStruct,
IsEligibleForTargetParamFlag,
/*AreBothBasePtrAndPteeMapped=*/HasMapBasePtr && HasMapArraySec);

// If there is an entry in PartialStruct it means we have a
// struct with individual members mapped. Emit an extra combined
// entry.
if (PartialStruct.Base.isValid()) {
CurCaptureVarInfo.append(PartialStruct.PreliminaryMapData);
emitCombinedEntry(
CurCaptureVarInfo, CurInfoForComponentLists.Types,
PartialStruct, Cap->capturesThis(), OMPBuilder, nullptr,
OffsetForMemberOfFlag,
/*NotTargetParams*/ !IsEligibleForTargetParamFlag);
}

// Return if we didn't add any entries.
if (CurInfoForComponentLists.BasePointers.empty())
return;

CurCaptureVarInfo.append(CurInfoForComponentLists);
};

GenerateInfoForComponentLists(DeclComponentLists,
/*IsEligibleForTargetParamFlag=*/true);
}

/// Generate the base pointers, section pointers, sizes, map types, and
/// mappers associated to \a DeclComponentLists for a given capture
/// \a VD (all included in \a CurComponentListInfo).
void generateInfoForCaptureFromComponentLists(
const ValueDecl *VD, ArrayRef<MapData> DeclComponentLists,
MapCombinedInfoTy &CurComponentListInfo, StructRangeInfoTy &PartialStruct,
bool IsListEligibleForTargetParamFlag,
bool AreBothBasePtrAndPteeMapped = false) const {
// Find overlapping elements (including the offset from the base element).
llvm::SmallDenseMap<
const MapData *,
Expand Down Expand Up @@ -8949,7 +8998,7 @@ class MappableExprsHandler {

// Associated with a capture, because the mapping flags depend on it.
// Go through all of the elements with the overlapped elements.
bool IsFirstComponentList = true;
bool AddTargetParamFlag = IsListEligibleForTargetParamFlag;
MapCombinedInfoTy StructBaseCombinedInfo;
for (const auto &Pair : OverlappedData) {
const MapData &L = *Pair.getFirst();
Expand All @@ -8964,11 +9013,11 @@ class MappableExprsHandler {
ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
OverlappedComponents = Pair.getSecond();
generateInfoForComponentList(
MapType, MapModifiers, {}, Components, CombinedInfo,
StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
MapType, MapModifiers, {}, Components, CurComponentListInfo,
StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag, IsImplicit,
/*GenerateAllInfoForClauses*/ false, Mapper,
/*ForDeviceAddr=*/false, VD, VarRef, OverlappedComponents);
IsFirstComponentList = false;
AddTargetParamFlag = false;
}
// Go through other elements without overlapped elements.
for (const MapData &L : DeclComponentLists) {
Expand All @@ -8983,12 +9032,12 @@ class MappableExprsHandler {
auto It = OverlappedData.find(&L);
if (It == OverlappedData.end())
generateInfoForComponentList(
MapType, MapModifiers, {}, Components, CombinedInfo,
StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
MapType, MapModifiers, {}, Components, CurComponentListInfo,
StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag,
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
/*ForDeviceAddr=*/false, VD, VarRef,
/*OverlappedElements*/ {}, HasMapBasePtr && HasMapArraySec);
IsFirstComponentList = false;
/*OverlappedElements*/ {}, AreBothBasePtrAndPteeMapped);
AddTargetParamFlag = false;
}
}

Expand Down Expand Up @@ -9467,7 +9516,6 @@ static void genMapInfoForCaptures(
CE = CS.capture_end();
CI != CE; ++CI, ++RI, ++CV) {
MappableExprsHandler::MapCombinedInfoTy CurInfo;
MappableExprsHandler::StructRangeInfoTy PartialStruct;

// VLA sizes are passed to the outlined region by copy and do not have map
// information associated.
Expand All @@ -9488,37 +9536,33 @@ static void genMapInfoForCaptures(
} else {
// If we have any information in the map clause, we use it, otherwise we
// just do a default mapping.
MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct);
MEHandler.generateInfoForCaptureFromClauseInfo(
CI, *CV, CurInfo, OMPBuilder,
/*OffsetForMemberOfFlag=*/CombinedInfo.BasePointers.size());

if (!CI->capturesThis())
MappedVarSet.insert(CI->getCapturedVar());
else
MappedVarSet.insert(nullptr);
if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid())

if (CurInfo.BasePointers.empty())
MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo);

// Generate correct mapping for variables captured by reference in
// lambdas.
if (CI->capturesVariable())
MEHandler.generateInfoForLambdaCaptures(CI->getCapturedVar(), *CV,
CurInfo, LambdaPointers);
}
// We expect to have at least an element of information for this capture.
assert((!CurInfo.BasePointers.empty() || PartialStruct.Base.isValid()) &&
assert(!CurInfo.BasePointers.empty() &&
"Non-existing map pointer for capture!");
assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() &&
CurInfo.BasePointers.size() == CurInfo.Sizes.size() &&
CurInfo.BasePointers.size() == CurInfo.Types.size() &&
CurInfo.BasePointers.size() == CurInfo.Mappers.size() &&
"Inconsistent map information sizes!");

// If there is an entry in PartialStruct it means we have a struct with
// individual members mapped. Emit an extra combined entry.
if (PartialStruct.Base.isValid()) {
CombinedInfo.append(PartialStruct.PreliminaryMapData);
MEHandler.emitCombinedEntry(CombinedInfo, CurInfo.Types, PartialStruct,
CI->capturesThis(), OMPBuilder, nullptr,
/*NotTargetParams*/ false);
}

// We need to append the results of this capture to what we already have.
CombinedInfo.append(CurInfo);
}
Expand Down
Loading