Skip to content

Commit 0339725

Browse files
bowenxue-inteligcbot
authored andcommitted
Fix bug with WaveAllJointReduction
Fix bug that caused last layer destination to be uniform, resulting in incorrect registers used in .asm post joint reduction
1 parent 12a25d4 commit 0339725

File tree

3 files changed

+22
-25
lines changed

3 files changed

+22
-25
lines changed

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14490,10 +14490,19 @@ void EmitPass::emitReductionTree( e_opcode op, VISA_Type type, CVariable* src, C
1449014490
for( unsigned int i = 0; i < numIterations; i++ )
1449114491
{
1449214492
// Get alias for src0, src1, and dst based on offsets and SIMD size
14493-
auto* layerSrc0 = m_currShader->GetNewAlias( src, type, i * 2 * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );
14494-
auto* layerSrc1 = m_currShader->GetNewAlias( src, type, ( i * 2 * layerMaxSimdLanes + src1Offset ) * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );
14495-
auto* layerDst = m_currShader->GetNewAlias( src, type, i * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );
14496-
14493+
auto* layerSrc0 = m_currShader->GetNewAlias( src, type, i * 2 * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes, false );
14494+
auto* layerSrc1 = m_currShader->GetNewAlias( src, type, ( i * 2 * layerMaxSimdLanes + src1Offset ) * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes, false );
14495+
CVariable* layerDst;
14496+
if( (srcElementCount >> 1 <= dst->GetNumberElement()) && (i + 1 == numIterations ))
14497+
{
14498+
// Final layer, use destination of WaveAll vector intrinsic inst (passed in with correct offset)
14499+
layerDst = dst;
14500+
}
14501+
else
14502+
{
14503+
// Use src as workspace to store intermediate values
14504+
layerDst = m_currShader->GetNewAlias( src, type, i * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes, false );
14505+
}
1449714506
if( !int64EmulationNeeded )
1449814507
{
1449914508
m_encoder->SetNoMask();
@@ -14522,13 +14531,6 @@ void EmitPass::emitReductionTree( e_opcode op, VISA_Type type, CVariable* src, C
1452214531
srcElementCount >>= 1;
1452314532
reductionElementCount >>= 1;
1452414533
}
14525-
14526-
// copy fully reduced elements from src to dst
14527-
auto* finalLayerDst = m_currShader->GetNewAlias( src, type, 0, dst->GetNumberElement() );
14528-
m_encoder->SetNoMask();
14529-
m_encoder->SetSimdSize( lanesToSIMDMode( dst->GetNumberElement() ) );
14530-
m_encoder->Copy( dst, finalLayerDst );
14531-
m_encoder->Push();
1453214534
}
1453314535

1453414536
// Recursive function that emits one or more joint reduction trees based on the joint output width
@@ -14542,8 +14544,8 @@ void EmitPass::emitReductionTrees( e_opcode op, VISA_Type type, SIMDMode simdMod
1454214544
// Do full tree reduction
1454314545
unsigned int reductionElements = src->GetNumberElement() / dst->GetNumberElement();
1454414546
unsigned int groupReductionElementCount = reductionElements * simdLanes;
14545-
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, startIdx * reductionElements * m_encoder->GetCISADataTypeSize( type ), groupReductionElementCount );
14546-
CVariable* dstAlias = m_currShader->GetNewAlias( dst, type, startIdx * m_encoder->GetCISADataTypeSize( type ), simdLanes);
14547+
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, startIdx * reductionElements * m_encoder->GetCISADataTypeSize( type ), groupReductionElementCount, false );
14548+
CVariable* dstAlias = m_currShader->GetNewAlias( dst, type, startIdx * m_encoder->GetCISADataTypeSize( type ), simdLanes, false);
1454714549
emitReductionTree( op, type, srcAlias, dstAlias );
1454814550
// Start new recursive tree if any elements are left
1454914551
if ( numGroups > simdLanes )
@@ -23010,13 +23012,13 @@ void EmitPass::emitWaveAll(llvm::GenIntrinsicInst* inst)
2301023012
for( uint16_t i = 0; i < dst->GetNumberElement(); i++ )
2301123013
{
2301223014
// Prepare reduceSrc
23013-
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
23014-
CVariable* reduceSrcAlias = m_currShader->GetNewAlias( reduceSrc, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
23015+
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ), false);
23016+
CVariable* reduceSrcAlias = m_currShader->GetNewAlias( reduceSrc, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ), false );
2301523017
ScanReducePrepareSrc( type, identity, false, false, srcAlias, reduceSrcAlias );
2301623018

2301723019
// Prepare reduceSrcSecondHalf
23018-
CVariable* srcSecondHalfAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
23019-
CVariable* reduceSrcSecondHalfAlias = m_currShader->GetNewAlias( reduceSrcSecondHalf, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
23020+
CVariable* srcSecondHalfAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ), false );
23021+
CVariable* reduceSrcSecondHalfAlias = m_currShader->GetNewAlias( reduceSrcSecondHalf, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ), false);
2302023022
ScanReducePrepareSrc( type, identity, false, true, srcSecondHalfAlias, reduceSrcSecondHalfAlias );
2302123023

2302223024
// Emit correct operations

IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-dual-simd16-group4.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ define void @CSMain(i32 %runtime_value_0, i32 %runtime_value_1, i32 %runtime_val
8484
; layer 3
8585
; CHECK: add (M1_NM, 8) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<4;2,1> reduceSrc_waveAllSrc0(0,2)<4;2,1>
8686
; layer 4
87-
; CHECK: add (M1_NM, 4) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1>
88-
; copy to dest
89-
; CHECK: mov (M1_NM, 1) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<1;1,0>
87+
; CHECK: add (M1_NM, 4) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1>
9088
%waveAllJoint = call <4 x i32> @llvm.genx.GenISA.WaveAll.v4i32.i8.i32(<4 x i32> %waveAllSrc3, i8 0, i32 0)
9189
%res_a = extractelement <4 x i32> %waveAllJoint, i32 0
9290
%res_b = extractelement <4 x i32> %waveAllJoint, i32 1

IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-simd32-group17.ll

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,13 @@ define void @CSMain(i32 %runtime_value_0, i32 %runtime_value_1, i32 %runtime_val
144144
; layer 4
145145
; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<4;2,1> reduceSrc_waveAllSrc0(0,2)<4;2,1>
146146
; layer 5
147-
; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1>
148-
; copy to dest
149-
; CHECK: mov (M1_NM, 1) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<1;1,0>
147+
; CHECK: add (M1_NM, 16) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1>
150148
; Joint Reduction Tree (1-wide, leftover from splitting the 17-wide vector into 16 and 1, almost identical to existing non-joint reduction tree generated from scalar WaveAll intrinsic further below)
151149
; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<32;16,1> reduceSrc_waveAllSrc0(33,0)<32;16,1>
152150
; CHECK: add (M1_NM, 8) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<16;8,1> reduceSrc_waveAllSrc0(32,8)<16;8,1>
153151
; CHECK: add (M1_NM, 4) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<8;4,1> reduceSrc_waveAllSrc0(32,4)<8;4,1>
154152
; CHECK: add (M1_NM, 2) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<4;2,1> reduceSrc_waveAllSrc0(32,2)<4;2,1>
155-
; CHECK: add (M1_NM, 1) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<2;1,1> reduceSrc_waveAllSrc0(32,1)<2;1,1>
156-
; CHECK: mov (M1_NM, 1) waveAllJoint(1,0)<1> reduceSrc_waveAllSrc0(32,0)<1;1,0>
153+
; CHECK: add (M1_NM, 1) waveAllJoint(1,0)<1> reduceSrc_waveAllSrc0(32,0)<2;1,1> reduceSrc_waveAllSrc0(32,1)<2;1,1>
157154
%waveAllJoint = call <17 x i32> @llvm.genx.GenISA.WaveAll.v17i32.i8.i32(<17 x i32> %waveAllSrc16, i8 0, i32 0)
158155
%res_f = call i32 @llvm.genx.GenISA.WaveAll.i32.i8.i32(i32 %f, i8 0, i32 0)
159156
%res_add_0 = extractelement <17 x i32> %waveAllJoint, i32 0

0 commit comments

Comments
 (0)