11#version 430 core
22
3- #extension GL_KHR_shader_subgroup_basic : enable
4- #extension GL_KHR_shader_subgroup_ballot : enable
5- #extension GL_KHR_shader_subgroup_arithmetic : enable
6-
73struct PGTO {
84 float coeff;
95 float alpha;
@@ -48,7 +44,7 @@ layout(std140, binding = 0) uniform UniformBlock {
4844
4945layout(binding = 0) writeonly restrict uniform image3D out_vol;
5046
51- #if 0
47+ #if 1
5248float safe_pow(float base, uint exponent) {
5349 switch(exponent) {
5450 case 0u: return 1.0;
@@ -60,7 +56,11 @@ float safe_pow(float base, uint exponent) {
6056 return b2 * b2;
6157 }
6258 default: {
63- return 1.0;
59+ float v = 1.0;
60+ for (uint e = 0u; e < exponent; ++e) {
61+ v *= base;
62+ }
63+ return v;
6464 }
6565 }
6666}
@@ -96,7 +96,11 @@ uint pack_offset_length(uint offset, uint length) {
9696 return (length << 24u) | offset;
9797}
9898
99- #define WG_SIZE 512
99+ #define WG_X 8
100+ #define WG_Y 8
101+ #define WG_Z 8
102+
103+ #define WG_SIZE (WG_X * WG_Y * WG_Z)
100104#define TILE_SIZE 32
101105#define TILE_AREA (TILE_SIZE*TILE_SIZE)
102106#define MAX_SCREENED_CGTOS 2048
@@ -119,26 +123,28 @@ uint get_index(uint i, uint j, uint N) {
119123 return row_offset + (col - row);
120124}
121125
122- // Populate the D GSM from the global D_matrix
126+ // Full D_tile from global D_matrix
123127void fill_D_tile(uint tile_i, uint tile_j) {
124- uint tid = gl_LocalInvocationIndex;
125- if (tid >= TILE_SIZE) return;
128+ const uint tid = gl_LocalInvocationIndex;
126129
127130 uint baseI = tile_i * TILE_SIZE;
128131 uint baseJ = tile_j * TILE_SIZE;
129132
130- uint idxI = baseI + tid;
131- uint gi = (idxI < MAX_SCREENED_CGTOS) ? screened_cgtos[idxI] : INVALID_CGTO_IDX;
133+ for (uint index = tid; index < TILE_AREA; index += WG_SIZE) {
134+ uint row = index / TILE_SIZE;
135+ uint col = index % TILE_SIZE;
136+
137+ uint idxI = baseI + row;
138+ uint gi = (idxI < MAX_SCREENED_CGTOS) ? screened_cgtos[idxI] : INVALID_CGTO_IDX;
132139
133- for (uint col = 0; col < TILE_SIZE; ++col) {
134140 uint idxJ = baseJ + col;
135141 uint gj = (idxJ < MAX_SCREENED_CGTOS) ? screened_cgtos[idxJ] : INVALID_CGTO_IDX;
136142
137143 float value = 0.0;
138144 if (gi != INVALID_CGTO_IDX && gj != INVALID_CGTO_IDX) {
139145 value = D_matrix[get_index(gi, gj, D_matrix_dim)];
140146 }
141- D_tile[tid ][col] = value;
147+ D_tile[row ][col] = value;
142148 }
143149}
144150
@@ -162,29 +168,27 @@ void fill_cgtos_tile(uint tile_number, vec3 model_aabb_min, vec3 model_aabb_max)
162168 uint src_end = cgto_offset[global_cgto_idx + 1u];
163169 uint src_len = src_end - src_beg;
164170
165- if (src_beg != src_end) {
166- // Reserve contiguous space for this CGTO's PGTOs
167- uint dst = atomicAdd(num_pgtos, src_len);
168-
169- cgto_center = cgto_xyzr[global_cgto_idx].xyz;
170- cgto_pgto_off = dst;
171-
172- vec3 model_xyz = vec3(world_to_model * vec4(cgto_center, 1.0));
173- vec3 d = clamp(model_xyz, model_aabb_min, model_aabb_max) - model_xyz;
174- float d2 = dot(d, d);
171+ // Reserve contiguous space for this CGTO's PGTOs
172+ uint dst = atomicAdd(num_pgtos, src_len);
175173
176- // Copy PGTOs
177- for (uint k = 0; k < src_len; ++k) {
178- float r = pgto_radius[src_beg + k];
179- // Cull based on radius
180- if (d2 < r * r) {
181- PGTO g;
182- g.coeff = pgto_coeff[src_beg + k];
183- g.alpha = pgto_alpha[src_beg + k];
184- g.ijkl = pgto_ijkl[src_beg + k];
185- pgtos_tile[cgto_pgto_off + cgto_pgto_len] = g;
186- cgto_pgto_len++;
187- }
174+ cgto_center = cgto_xyzr[global_cgto_idx].xyz;
175+ cgto_pgto_off = dst;
176+
177+ vec3 model_xyz = vec3(world_to_model * vec4(cgto_center, 1.0));
178+ vec3 d = clamp(model_xyz, model_aabb_min, model_aabb_max) - model_xyz;
179+ float d2 = dot(d, d);
180+
181+ // Copy PGTOs
182+ for (uint k = 0; k < src_len; ++k) {
183+ float r = pgto_radius[src_beg + k];
184+ // Cull based on radius
185+ if (d2 < r * r) {
186+ PGTO g;
187+ g.coeff = pgto_coeff[src_beg + k];
188+ g.alpha = pgto_alpha[src_beg + k];
189+ g.ijkl = pgto_ijkl[src_beg + k];
190+ pgtos_tile[cgto_pgto_off + cgto_pgto_len] = g;
191+ cgto_pgto_len++;
188192 }
189193 }
190194 }
@@ -205,16 +209,33 @@ void eval_phis(out float out_phi[TILE_SIZE], vec3 coord) {
205209 uint pgto_len = cgtos_tile[i].pgto_len;
206210
207211 vec3 center = cgtos_tile[i].coord;
208- vec3 d = coord - center;
212+ vec3 d = coord - center;
209213 float r2 = dot(d, d);
210-
214+
215+ // Precompute powers once per CGTO
216+ float dx2 = d.x * d.x;
217+ float dy2 = d.y * d.y;
218+ float dz2 = d.z * d.z;
219+
211220 float phi = 0.0;
212221 for (uint j = pgto_off; j < pgto_off + pgto_len; ++j) {
213222 PGTO pgto = pgtos_tile[j];
214223 uvec4 ijkl = unpack_ijkl(pgto.ijkl);
215- float fx = safe_pow(d.x, ijkl.x);
216- float fy = safe_pow(d.y, ijkl.y);
217- float fz = safe_pow(d.z, ijkl.z);
224+ // Use ternary for common cases (no divergence, compiles to select)
225+ float fx = (ijkl.x == 0u) ? 1.0 :
226+ (ijkl.x == 1u) ? d.x :
227+ (ijkl.x == 2u) ? dx2 :
228+ (ijkl.x == 3u) ? dx2 * d.x : dx2 * dx2;
229+
230+ float fy = (ijkl.y == 0u) ? 1.0 :
231+ (ijkl.y == 1u) ? d.y :
232+ (ijkl.y == 2u) ? dy2 :
233+ (ijkl.y == 3u) ? dy2 * d.y : dy2 * dy2;
234+
235+ float fz = (ijkl.z == 0u) ? 1.0 :
236+ (ijkl.z == 1u) ? d.z :
237+ (ijkl.z == 2u) ? dz2 :
238+ (ijkl.z == 3u) ? dz2 * d.z : dz2 * dz2;
218239 phi += pgto.coeff * fx * fy * fz * exp(-pgto.alpha * r2);
219240 }
220241 out_phi[i] = phi;
@@ -225,9 +246,10 @@ float symmetric_contraction(float phi[TILE_SIZE], float D[TILE_SIZE][TILE_SIZE])
225246 float result = 0.0;
226247 for (uint i = 0; i < TILE_SIZE; ++i) {
227248 float ai = phi[i];
228- result += D[i][i] * ai * ai; // Diagonal
249+ result = fma(D[i][i] * ai, ai, result); // Diagonal
250+ ai = 2.0 * ai;
229251 for (uint j = i + 1; j < TILE_SIZE; ++j) {
230- result += 2.0 * D[i][j] * ai * phi[j]; // Off-diagonal
252+ result = fma( D[i][j] * ai, phi[j], result); // Off-diagonal
231253 }
232254 }
233255 return result;
@@ -236,15 +258,15 @@ float symmetric_contraction(float phi[TILE_SIZE], float D[TILE_SIZE][TILE_SIZE])
236258float contraction(float phi_mu[TILE_SIZE], float phi_nu[TILE_SIZE], float D[TILE_SIZE][TILE_SIZE]) {
237259 float result = 0.0;
238260 for (uint i = 0; i < TILE_SIZE; ++i) {
239- float ai = phi_mu[i];
261+ float ai = 2.0 * phi_mu[i];
240262 for (uint j = 0; j < TILE_SIZE; ++j) {
241- result += 2.0 * D[i][j] * ai * phi_nu[j];
263+ result = fma( D[i][j] * ai, phi_nu[j], result); // Off-diagonal
242264 }
243265 }
244266 return result;
245267}
246268
247- layout (local_size_x = 8 , local_size_y = 8 , local_size_z = 8 ) in;
269+ layout (local_size_x = WG_X , local_size_y = WG_Y , local_size_z = WG_Z ) in;
248270void main() {
249271 uint tid = gl_LocalInvocationIndex;
250272 if (tid == 0) {
@@ -256,7 +278,6 @@ void main() {
256278 vec3 model_aabb_max = vec3((gl_WorkGroupID.xyz + uvec3(1,1,1)) * gl_WorkGroupSize.xyz) * step.xyz;
257279 // Stage 1: Screening. Prune CGTOs to identify which are relevant for region
258280 {
259- // Stream matches directly; avoid large per-thread stacks and subgroup prefix sums
260281 for (uint i = tid; i < D_matrix_dim; i += WG_SIZE) {
261282 vec4 cgto = cgto_xyzr[i];
262283 if (cgto.w == 0.0) continue;
@@ -291,6 +312,7 @@ void main() {
291312 screened_cgtos[i] = INVALID_CGTO_IDX;
292313 }
293314 }
315+ barrier();
294316 }
295317
296318 float phi_tile_mu[TILE_SIZE]; // evaluated φ_μ(r) in registers
@@ -330,8 +352,6 @@ void main() {
330352 rho += symmetric_contraction(phi_tile_mu, D_tile);
331353
332354 for (uint tile_j = tile_i + 1; tile_j < num_tiles; ++tile_j) {
333- barrier();
334-
335355 // OFF DIAGONAL TILE
336356 if (tid == 0) {
337357 num_pgtos = 0;
0 commit comments