Skip to content

Commit 2ac676c

Browse files
Gang Y Chenigcbot
authored andcommitted
refactor the private-base code
also add comments
1 parent a992f25 commit 2ac676c

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/PrivateMemory/PrivateMemoryResolution.cpp

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,52 +1091,54 @@ bool PrivateMemoryResolution::resolveAllocaInstructions(bool privateOnStack)
10911091
Instruction* simdSize = entryBuilder.CreateCall(simdSizeFunc, llvm::None, VALUE_NAME("simdSize"));
10921092

10931093
Value* privateBase = nullptr;
1094+
ADDRESS_SPACE scratchMemoryAddressSpace = ADDRESS_SPACE_PRIVATE;
10941095
if (modMD->compOpt.UseScratchSpacePrivateMemory)
10951096
{
1096-
Value* r0Val = implicitArgs.getImplicitArgValue(*m_currFunction, ImplicitArg::R0, m_pMdUtils);
1097-
Value* r0_5 = entryBuilder.CreateExtractElement(r0Val, ConstantInt::get(typeInt32, 5), VALUE_NAME("r0.5"));
1098-
privateBase = entryBuilder.CreateAnd(r0_5, ConstantInt::get(typeInt32, 0xFFFFFC00), VALUE_NAME("privateBase"));
1097+
if (Ctx.platform.hasScratchSurface())
1098+
{
1099+
// when we use per-thread scratch-surface with SSH bindless
1100+
// R0_5[32:10] is the offset of the surface-state for scratch
1101+
// surface slot#0, NOT the offset into the surface.
1102+
privateBase = entryBuilder.getInt32(0);
1103+
}
1104+
else
1105+
{ // the old mechanism
1106+
Value* r0Val = implicitArgs.getImplicitArgValue(*m_currFunction, ImplicitArg::R0, m_pMdUtils);
1107+
Value* r0_5 = entryBuilder.CreateExtractElement(r0Val, ConstantInt::get(typeInt32, 5), VALUE_NAME("r0.5"));
1108+
privateBase = entryBuilder.CreateAnd(r0_5, ConstantInt::get(typeInt32, 0xFFFFFC00), VALUE_NAME("privateBase"));
1109+
}
10991110
}
1100-
1101-
ADDRESS_SPACE scratchMemoryAddressSpace = ADDRESS_SPACE_PRIVATE;
1102-
if (Ctx.platform.hasScratchSurface())
1111+
else
11031112
{
1104-
if (modMD->compOpt.UseScratchSpacePrivateMemory)
1113+
scratchMemoryAddressSpace = ADDRESS_SPACE_GLOBAL;
1114+
modMD->compOpt.UseStatelessforPrivateMemory = true;
1115+
1116+
const uint32_t dwordSizeInBits = 32;
1117+
const uint32_t pointerSizeInDwords = Ctx.getRegisterPointerSizeInBits(scratchMemoryAddressSpace) / dwordSizeInBits;
1118+
IGC_ASSERT(pointerSizeInDwords <= 2);
1119+
llvm::Type* resultType = entryBuilder.getInt32Ty();
1120+
if (pointerSizeInDwords > 1)
11051121
{
1106-
privateBase = entryBuilder.getInt32(0);
1122+
resultType = IGCLLVM::FixedVectorType::get(resultType, 2);
11071123
}
1108-
else if (nullptr == privateBase)
1124+
Function* pFunc = GenISAIntrinsic::getDeclaration(
1125+
m_currFunction->getParent(),
1126+
GenISAIntrinsic::GenISA_RuntimeValue,
1127+
resultType);
1128+
privateBase = entryBuilder.CreateCall(pFunc, entryBuilder.getInt32(modMD->MinNOSPushConstantSize - pointerSizeInDwords));
1129+
if (privateBase->getType()->isVectorTy())
11091130
{
1110-
scratchMemoryAddressSpace = ADDRESS_SPACE_GLOBAL;
1111-
modMD->compOpt.UseStatelessforPrivateMemory = true;
1112-
1113-
const uint32_t dwordSizeInBits = 32;
1114-
const uint32_t pointerSizeInDwords = Ctx.getRegisterPointerSizeInBits(scratchMemoryAddressSpace) / dwordSizeInBits;
1115-
IGC_ASSERT(pointerSizeInDwords <= 2);
1116-
llvm::Type* resultType = entryBuilder.getInt32Ty();
1117-
if (pointerSizeInDwords > 1)
1118-
{
1119-
resultType = IGCLLVM::FixedVectorType::get(resultType, 2);
1120-
}
1121-
Function* pFunc = GenISAIntrinsic::getDeclaration(
1122-
m_currFunction->getParent(),
1123-
GenISAIntrinsic::GenISA_RuntimeValue,
1124-
resultType);
1125-
privateBase = entryBuilder.CreateCall(pFunc, entryBuilder.getInt32(modMD->MinNOSPushConstantSize - pointerSizeInDwords));
1126-
if (privateBase->getType()->isVectorTy())
1127-
{
1128-
privateBase = entryBuilder.CreateBitCast(privateBase, entryBuilder.getInt64Ty());
1129-
}
1131+
privateBase = entryBuilder.CreateBitCast(privateBase, entryBuilder.getInt64Ty());
1132+
}
11301133

1131-
ConstantInt* totalPrivateMemPerWIValue = ConstantInt::get(typeInt32, totalPrivateMemPerWI);
1132-
Value* totalPrivateMemPerThread = entryBuilder.CreateMul(simdSize, totalPrivateMemPerWIValue, VALUE_NAME("totalPrivateMemPerThread"));
1134+
ConstantInt* totalPrivateMemPerWIValue = ConstantInt::get(typeInt32, totalPrivateMemPerWI);
1135+
Value* totalPrivateMemPerThread = entryBuilder.CreateMul(simdSize, totalPrivateMemPerWIValue, VALUE_NAME("totalPrivateMemPerThread"));
11331136

1134-
Function* pHWTIDFunc = GenISAIntrinsic::getDeclaration(m_currFunction->getParent(), GenISAIntrinsic::GenISA_hw_thread_id_alloca, Type::getInt32Ty(C));
1135-
llvm::Value* threadId = entryBuilder.CreateCall(pHWTIDFunc);
1136-
llvm::Value* perThreadOffset = entryBuilder.CreateMul(threadId, totalPrivateMemPerThread, VALUE_NAME("perThreadOffset"));
1137-
perThreadOffset = entryBuilder.CreateZExt(perThreadOffset, privateBase->getType());
1138-
privateBase = entryBuilder.CreateAdd(privateBase, perThreadOffset);
1139-
}
1137+
Function* pHWTIDFunc = GenISAIntrinsic::getDeclaration(m_currFunction->getParent(), GenISAIntrinsic::GenISA_hw_thread_id_alloca, Type::getInt32Ty(C));
1138+
llvm::Value* threadId = entryBuilder.CreateCall(pHWTIDFunc);
1139+
llvm::Value* perThreadOffset = entryBuilder.CreateMul(threadId, totalPrivateMemPerThread, VALUE_NAME("perThreadOffset"));
1140+
perThreadOffset = entryBuilder.CreateZExt(perThreadOffset, privateBase->getType());
1141+
privateBase = entryBuilder.CreateAdd(privateBase, perThreadOffset);
11401142
}
11411143

11421144
for (auto pAI : allocaInsts)

0 commit comments

Comments
 (0)