@@ -753,3 +753,367 @@ fn main(@builtin(subgroup_size) size : u32,
753753 )
754754 ) ;
755755 } ) ;
756+
757+ const skipValue = 0xffff0000 ;
758+
759+ /**
760+ * Checks subgroup_id consistency
761+ *
762+ * @param outputData An array of vec4u
763+ * * 0: comparison of subgroup_id among subgroup
764+ * * 1: comparison of subgroup_id < num_subgroups
765+ * * 2: subgroup_id (for first member) or skipValue
766+ * * 3: unused
767+ * @param wgSize Invocations in the workgroup
768+ * @param numWGs Number of workgroups
769+ */
770+ function checkSubgroupIdConsistency (
771+ outputData : Uint32Array ,
772+ wgSize : number ,
773+ numWGs : number
774+ ) : Error | undefined {
775+ for ( let wg = 0 ; wg < numWGs ; wg ++ ) {
776+ // Max wgSize is 256 and min subgroup size is 4
777+ const seen = new Array ( Math . ceil ( wgSize / 4 ) ) ;
778+ seen . fill ( 0 ) ;
779+ for ( let inv = 0 ; inv < wgSize ; inv ++ ) {
780+ const gid = wg * wgSize + inv ;
781+ const outputIdx = gid * 4 ;
782+ const compare = outputData [ outputIdx ] ;
783+ const in_range = outputData [ outputIdx + 1 ] ;
784+ const sid = outputData [ outputIdx + 2 ] ;
785+
786+ if ( compare !== 1 ) {
787+ return new Error (
788+ `Invocation ${ gid } : not all invocations in subgroup have same subgroup_id: ${ compare } `
789+ ) ;
790+ }
791+ if ( in_range !== 1 ) {
792+ return new Error (
793+ `Invocation ${ gid } : subgroup_id out of range of num_subgroups: ${ in_range } `
794+ ) ;
795+ }
796+
797+ if ( sid !== skipValue ) {
798+ if ( seen [ sid ] !== 0 ) {
799+ return new Error ( `Invocation ${ gid } : subgroup_id reused among different subgroups` ) ;
800+ }
801+ seen [ sid ] = 1 ;
802+ }
803+ }
804+
805+ const firstZero = seen . findIndex ( ele => ele === 0 ) ;
806+ const lastOne = seen . findLastIndex ( ele => ele === 1 ) ;
807+ if ( firstZero !== - 1 && firstZero < lastOne ) {
808+ return new Error ( `Subgroup id values are not densely packed: missing ${ firstZero } ` ) ;
809+ }
810+ }
811+
812+ return undefined ;
813+ }
814+
815+ g . test ( 'subgroup_id' )
816+ . desc (
817+ 'Tests subgroup_id values. No mapping between local_invocation_index and subgroup_id can be relied upon.'
818+ )
819+ . params ( u =>
820+ u
821+ . combine ( 'sizes' , kWGSizes )
822+ . beginSubcases ( )
823+ . combine ( 'numWGs' , [ 1 , 2 ] as const )
824+ . combine ( 'lid' , [
825+ [ 0 , 1 , 2 ] ,
826+ [ 0 , 2 , 1 ] ,
827+ [ 1 , 0 , 2 ] ,
828+ [ 1 , 2 , 0 ] ,
829+ [ 2 , 0 , 1 ] ,
830+ [ 2 , 1 , 0 ] ,
831+ ] as const )
832+ )
833+ . fn ( async t => {
834+ t . skipIfDeviceDoesNotHaveFeature ( 'subgroups' as GPUFeatureName ) ;
835+ t . skipIfLanguageFeatureNotSupported ( 'subgroup_id' ) ;
836+ const wgx = t . params . sizes [ 0 ] ;
837+ const wgy = t . params . sizes [ 1 ] ;
838+ const wgz = t . params . sizes [ 2 ] ;
839+ const lid = t . params . lid ;
840+ const wgThreads = wgx * wgy * wgz ;
841+
842+ // Compatibility mode has lower workgroup limits.
843+ const {
844+ maxComputeInvocationsPerWorkgroup,
845+ maxComputeWorkgroupSizeX,
846+ maxComputeWorkgroupSizeY,
847+ maxComputeWorkgroupSizeZ,
848+ } = t . device . limits ;
849+ t . skipIf (
850+ maxComputeInvocationsPerWorkgroup < wgThreads ||
851+ maxComputeWorkgroupSizeX < wgx ||
852+ maxComputeWorkgroupSizeY < wgy ||
853+ maxComputeWorkgroupSizeZ < wgz ,
854+ 'Workgroup size too large'
855+ ) ;
856+
857+ const wgsl = `
858+ enable subgroups;
859+ requires subgroup_id;
860+
861+ const stride = ${ wgThreads } ;
862+
863+ ${ genLID ( lid [ 0 ] , lid [ 1 ] , lid [ 2 ] , t . params . sizes ) }
864+
865+ @group(0) @binding(0)
866+ var<storage, read_write> output : array<vec4u>;
867+
868+ @compute @workgroup_size(${ wgx } , ${ wgy } , ${ wgz } )
869+ fn main(@builtin(local_invocation_id) local_id : vec3u,
870+ @builtin(workgroup_id) wgid : vec3u,
871+ @builtin(subgroup_id) sid : u32,
872+ @builtin(num_subgroups) num_subgroups : u32) {
873+ // Remapped local id.
874+ let lid = getLID(local_id);
875+
876+ let gid = lid + stride * wgid.x;
877+
878+ // Is the subgroup_id equivalent for all members?
879+ let broadcast_id = subgroupBroadcastFirst(sid);
880+ let compare = subgroupAll(broadcast_id == sid);
881+
882+ // Is subgroup_id in the range of num_subgroups?
883+ let in_range = sid < num_subgroups;
884+
885+ var out_sid = ${ skipValue } u;
886+ if subgroupElect() {
887+ out_sid = sid;
888+ }
889+
890+ output[gid] = vec4u(
891+ select(0u, 1u, compare),
892+ select(0u, 1u, in_range),
893+ out_sid,
894+ 0);
895+ }
896+ ` ;
897+
898+ const numInvocations = wgThreads * t . params . numWGs ;
899+ const numUints = 4 * numInvocations ;
900+ const placeholderValue = 999 ;
901+ const outputBuffer = t . makeBufferWithContents (
902+ new Uint32Array ( [ ...iterRange ( numUints , x => placeholderValue ) ] ) ,
903+ GPUBufferUsage . STORAGE | GPUBufferUsage . COPY_SRC | GPUBufferUsage . COPY_DST
904+ ) ;
905+ t . trackForCleanup ( outputBuffer ) ;
906+
907+ const pipeline = t . device . createComputePipeline ( {
908+ layout : 'auto' ,
909+ compute : {
910+ module : t . device . createShaderModule ( {
911+ code : wgsl ,
912+ } ) ,
913+ entryPoint : 'main' ,
914+ } ,
915+ } ) ;
916+ const bg = t . device . createBindGroup ( {
917+ layout : pipeline . getBindGroupLayout ( 0 ) ,
918+ entries : [
919+ {
920+ binding : 0 ,
921+ resource : {
922+ buffer : outputBuffer ,
923+ } ,
924+ } ,
925+ ] ,
926+ } ) ;
927+
928+ const encoder = t . device . createCommandEncoder ( ) ;
929+ const pass = encoder . beginComputePass ( ) ;
930+ pass . setPipeline ( pipeline ) ;
931+ pass . setBindGroup ( 0 , bg ) ;
932+ pass . dispatchWorkgroups ( t . params . numWGs , 1 , 1 ) ;
933+ pass . end ( ) ;
934+ t . queue . submit ( [ encoder . finish ( ) ] ) ;
935+
936+ const outputReadback = await t . readGPUBufferRangeTyped ( outputBuffer , {
937+ srcByteOffset : 0 ,
938+ type : Uint32Array ,
939+ typedLength : numUints ,
940+ method : 'copy' ,
941+ } ) ;
942+ const outputData : Uint32Array = outputReadback . data ;
943+
944+ t . expectOK ( checkSubgroupIdConsistency ( outputData , wgThreads , t . params . numWGs ) ) ;
945+ } ) ;
946+
947+ /**
948+ * Checks num_subgroups consistency
949+ *
950+ * @param countData An array with numWGs elements containing the counted number of subgroups
951+ * @param outputData An array numWGs * wgSize elements containing the value of num_subgroups
952+ * @param wgSize Number of invocations in the workgroup
953+ * @param numWGs Number of workgroups
954+ */
955+ function checkNumSubgroupsConsistency (
956+ countData : Uint32Array ,
957+ outputData : Uint32Array ,
958+ wgSize : number ,
959+ numWGs : number
960+ ) : Error | undefined {
961+ for ( let wg = 0 ; wg < numWGs ; wg ++ ) {
962+ const count = countData [ wg ] ;
963+ const slice = outputData . slice ( wg * wgSize , ( wg + 1 ) * wgSize ) ;
964+ const index = slice . findIndex ( ele => ele !== count ) ;
965+ if ( index !== - 1 ) {
966+ return new Error ( `Workgroup ${ wg } : inconsistent num_subgroups:
967+ - expected: ${ count }
968+ - got: ${ slice [ index ] } ` ) ;
969+ }
970+ }
971+
972+ return undefined ;
973+ }
974+
975+ g . test ( 'num_subgroups' )
976+ . desc ( 'Tests num_subgroups values.' )
977+ . params ( u =>
978+ u
979+ . combine ( 'sizes' , kWGSizes )
980+ . beginSubcases ( )
981+ . combine ( 'numWGs' , [ 1 , 2 ] as const )
982+ . combine ( 'lid' , [
983+ [ 0 , 1 , 2 ] ,
984+ [ 0 , 2 , 1 ] ,
985+ [ 1 , 0 , 2 ] ,
986+ [ 1 , 2 , 0 ] ,
987+ [ 2 , 0 , 1 ] ,
988+ [ 2 , 1 , 0 ] ,
989+ ] as const )
990+ )
991+ . fn ( async t => {
992+ t . skipIfDeviceDoesNotHaveFeature ( 'subgroups' as GPUFeatureName ) ;
993+ t . skipIfLanguageFeatureNotSupported ( 'subgroup_id' ) ;
994+ const wgx = t . params . sizes [ 0 ] ;
995+ const wgy = t . params . sizes [ 1 ] ;
996+ const wgz = t . params . sizes [ 2 ] ;
997+ const lid = t . params . lid ;
998+ const wgThreads = wgx * wgy * wgz ;
999+
1000+ // Compatibility mode has lower workgroup limits.
1001+ const {
1002+ maxComputeInvocationsPerWorkgroup,
1003+ maxComputeWorkgroupSizeX,
1004+ maxComputeWorkgroupSizeY,
1005+ maxComputeWorkgroupSizeZ,
1006+ } = t . device . limits ;
1007+ t . skipIf (
1008+ maxComputeInvocationsPerWorkgroup < wgThreads ||
1009+ maxComputeWorkgroupSizeX < wgx ||
1010+ maxComputeWorkgroupSizeY < wgy ||
1011+ maxComputeWorkgroupSizeZ < wgz ,
1012+ 'Workgroup size too large'
1013+ ) ;
1014+
1015+ const wgsl = `
1016+ enable subgroups;
1017+ requires subgroup_id;
1018+
1019+ const stride = ${ wgThreads } ;
1020+
1021+ ${ genLID ( lid [ 0 ] , lid [ 1 ] , lid [ 2 ] , t . params . sizes ) }
1022+
1023+ @group(0) @binding(0)
1024+ var<storage, read_write> numSubgroups : array<u32>;
1025+
1026+ @group(0) @binding(1)
1027+ var<storage, read_write> output : array<u32>;
1028+
1029+ var<workgroup> count : atomic<u32>;
1030+
1031+ @compute @workgroup_size(${ wgx } , ${ wgy } , ${ wgz } )
1032+ fn main(@builtin(local_invocation_id) local_id : vec3u,
1033+ @builtin(workgroup_id) wgid : vec3u,
1034+ @builtin(subgroup_id) sid : u32,
1035+ @builtin(num_subgroups) num_subgroups : u32) {
1036+ // Remapped local id.
1037+ let lid = getLID(local_id);
1038+
1039+ let gid = lid + stride * wgid.x;
1040+
1041+ if subgroupElect() {
1042+ atomicAdd(&count, 1);
1043+ }
1044+
1045+ workgroupBarrier();
1046+
1047+ if lid == 0 {
1048+ numSubgroups[wgid.x] = atomicLoad(&count);
1049+ }
1050+
1051+ output[gid] = num_subgroups;
1052+ }
1053+ ` ;
1054+
1055+ const numInvocations = wgThreads * t . params . numWGs ;
1056+ const placeholderValue = 999 ;
1057+ const countBuffer = t . makeBufferWithContents (
1058+ new Uint32Array ( [ ...iterRange ( t . params . numWGs , x => placeholderValue ) ] ) ,
1059+ GPUBufferUsage . STORAGE | GPUBufferUsage . COPY_SRC | GPUBufferUsage . COPY_DST
1060+ ) ;
1061+ t . trackForCleanup ( countBuffer ) ;
1062+ const outputBuffer = t . makeBufferWithContents (
1063+ new Uint32Array ( [ ...iterRange ( numInvocations * 4 , x => placeholderValue ) ] ) ,
1064+ GPUBufferUsage . STORAGE | GPUBufferUsage . COPY_SRC | GPUBufferUsage . COPY_DST
1065+ ) ;
1066+ t . trackForCleanup ( outputBuffer ) ;
1067+
1068+ const pipeline = t . device . createComputePipeline ( {
1069+ layout : 'auto' ,
1070+ compute : {
1071+ module : t . device . createShaderModule ( {
1072+ code : wgsl ,
1073+ } ) ,
1074+ entryPoint : 'main' ,
1075+ } ,
1076+ } ) ;
1077+ const bg = t . device . createBindGroup ( {
1078+ layout : pipeline . getBindGroupLayout ( 0 ) ,
1079+ entries : [
1080+ {
1081+ binding : 0 ,
1082+ resource : {
1083+ buffer : countBuffer ,
1084+ } ,
1085+ } ,
1086+ {
1087+ binding : 1 ,
1088+ resource : {
1089+ buffer : outputBuffer ,
1090+ } ,
1091+ } ,
1092+ ] ,
1093+ } ) ;
1094+
1095+ const encoder = t . device . createCommandEncoder ( ) ;
1096+ const pass = encoder . beginComputePass ( ) ;
1097+ pass . setPipeline ( pipeline ) ;
1098+ pass . setBindGroup ( 0 , bg ) ;
1099+ pass . dispatchWorkgroups ( t . params . numWGs , 1 , 1 ) ;
1100+ pass . end ( ) ;
1101+ t . queue . submit ( [ encoder . finish ( ) ] ) ;
1102+
1103+ const countReadback = await t . readGPUBufferRangeTyped ( countBuffer , {
1104+ srcByteOffset : 0 ,
1105+ type : Uint32Array ,
1106+ typedLength : t . params . numWGs ,
1107+ method : 'copy' ,
1108+ } ) ;
1109+ const countData : Uint32Array = countReadback . data ;
1110+ const outputReadback = await t . readGPUBufferRangeTyped ( outputBuffer , {
1111+ srcByteOffset : 0 ,
1112+ type : Uint32Array ,
1113+ typedLength : numInvocations ,
1114+ method : 'copy' ,
1115+ } ) ;
1116+ const outputData : Uint32Array = outputReadback . data ;
1117+
1118+ t . expectOK ( checkNumSubgroupsConsistency ( countData , outputData , wgThreads , t . params . numWGs ) ) ;
1119+ } ) ;
0 commit comments