Skip to content

Commit 79670d4

Browse files
author
Robin Skånberg
committed
gto fixes.
1 parent 902c40c commit 79670d4

File tree

6 files changed

+501
-112
lines changed

6 files changed

+501
-112
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ set(GTO_SHADER_FILES
101101
src/shaders/eval_alie.comp
102102
src/shaders/eval_gto.comp
103103
src/shaders/eval_gto_density.comp
104-
src/shaders/segment_and_attribute_to_group.comp
104+
src/shaders/eval_gto_density_grad.comp
105+
src/shaders/voronoi_segment.comp
105106
)
106107

107108
set(MD_DEFINES MD_GL_SPLINE_SUBDIVISION_COUNT=${MD_GL_SPLINE_SUBDIVISION_COUNT})

src/md_gto.c

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ static GLuint get_alie_program(void) {
9393
return program;
9494
}
9595

96-
static GLuint get_vol_segment_to_groups_program(void) {
96+
static GLuint get_voronoi_segment_program(void) {
9797
static GLuint program = 0;
9898
if (!program) {
9999
GLuint shader = glCreateShader(GL_COMPUTE_SHADER);
100-
if (md_gl_shader_compile(shader, (str_t){(const char*)segment_and_attribute_to_group_comp, segment_and_attribute_to_group_comp_size}, 0, 0)) {
100+
if (md_gl_shader_compile(shader, (str_t){(const char*)voronoi_segment_comp, voronoi_segment_comp_size}, 0, 0)) {
101101
GLuint prog = glCreateProgram();
102102
if (md_gl_program_attach_and_link(prog, &shader, 1)) {
103103
program = prog;
@@ -247,10 +247,10 @@ void md_gto_grid_evaluate_ALIE_GPU(uint32_t vol_tex, const md_grid_t* vol_grid,
247247
gto_grid_evaluate_orb_GPU(vol_tex, vol_grid, orb, mode, program);
248248
}
249249

250-
void md_gto_segment_and_attribute_to_groups_GPU(float* out_group_values, size_t cap_groups, uint32_t vol_tex, const md_grid_t* grid, const float* point_xyzr, const uint32_t* point_group_idx, size_t num_points) {
251-
ASSERT(out_group_values);
250+
void md_gto_voronoi_segment_GPU(float* out_values, const float* point_xyzr, size_t num_points, uint32_t vol_tex, const md_grid_t* grid) {
251+
ASSERT(out_values);
252252
ASSERT(point_xyzr);
253-
ASSERT(point_group_idx);
253+
ASSERT(grid);
254254

255255
GLenum format = 0;
256256
if (glGetTextureLevelParameteriv) {
@@ -273,37 +273,30 @@ void md_gto_segment_and_attribute_to_groups_GPU(float* out_group_values, size_t
273273

274274
md_gl_debug_push("SEGMENT VOL TO GROUP");
275275

276-
GLintptr ssbo_group_value_offset = 0;
277-
GLsizeiptr ssbo_group_value_size = sizeof(float) * 16;
276+
GLintptr ssbo_point_value_offset = 0;
277+
GLsizeiptr ssbo_point_value_size = sizeof(float) * 16;
278278

279-
GLintptr ssbo_point_xyzr_offset = ALIGN_TO(ssbo_group_value_offset + ssbo_group_value_size, 256);
279+
GLintptr ssbo_point_xyzr_offset = ALIGN_TO(ssbo_point_value_offset + ssbo_point_value_size, 256);
280280
GLsizeiptr ssbo_point_xyzr_size = sizeof(float) * 4 * num_points;
281281

282-
GLintptr ssbo_point_group_offset = ALIGN_TO(ssbo_point_xyzr_offset + ssbo_point_xyzr_size, 256);
283-
GLsizeiptr ssbo_point_group_size = sizeof(uint32_t) * num_points;
284-
285-
size_t total_size = ALIGN_TO(ssbo_point_group_offset + ssbo_point_group_size, 256);
286-
282+
size_t total_size = ALIGN_TO(ssbo_point_xyzr_offset + ssbo_point_xyzr_size, 256);
287283
GLuint ssbo = get_buffer(total_size);
288284

289285
glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
290286

291-
// Clear first 16 bytes which represents the result (group_values)
292-
glClearBufferSubData(GL_SHADER_STORAGE_BUFFER, GL_R32F, ssbo_group_value_offset, ssbo_group_value_size, GL_RED, GL_FLOAT, NULL);
287+
// Clear first portion of buffer holding point values
288+
glClearBufferSubData(GL_SHADER_STORAGE_BUFFER, GL_R32F, ssbo_point_value_offset, ssbo_point_value_size, GL_RED, GL_FLOAT, NULL);
293289
// Fill next portion of buffer with point xyzr
294290
glBufferSubData(GL_SHADER_STORAGE_BUFFER, ssbo_point_xyzr_offset, ssbo_point_xyzr_size, point_xyzr);
295-
// Fill last portion of buffer with point indices
296-
glBufferSubData(GL_SHADER_STORAGE_BUFFER, ssbo_point_group_offset, ssbo_point_group_size, point_group_idx);
297291

298292
glBindBufferRange(GL_SHADER_STORAGE_BUFFER, 0, ssbo, ssbo_point_xyzr_offset, ssbo_point_xyzr_size);
299-
glBindBufferRange(GL_SHADER_STORAGE_BUFFER, 1, ssbo, ssbo_point_group_offset, ssbo_point_group_size);
300-
glBindBufferRange(GL_SHADER_STORAGE_BUFFER, 2, ssbo, ssbo_group_value_offset, ssbo_group_value_size);
293+
glBindBufferRange(GL_SHADER_STORAGE_BUFFER, 1, ssbo, ssbo_point_value_offset, ssbo_point_value_size);
301294

302295
glBindImageTexture(0, vol_tex, 0, GL_TRUE, 0, GL_READ_ONLY, format);
303296

304297
glMemoryBarrier(GL_BUFFER_UPDATE_BARRIER_BIT);
305298

306-
GLuint program = get_vol_segment_to_groups_program();
299+
GLuint program = get_voronoi_segment_program();
307300
glUseProgram(program);
308301

309302
float world_to_model[4][4];
@@ -321,18 +314,16 @@ void md_gto_segment_and_attribute_to_groups_GPU(float* out_group_values, size_t
321314
DIV_UP(grid->dim[1], 8),
322315
DIV_UP(grid->dim[2], 8),
323316
};
324-
325317
glDispatchCompute(num_groups[0], num_groups[1], num_groups[2]);
326318

327319
glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
328320

329-
uint32_t temp_group_values[16];
330-
glGetBufferSubData(GL_SHADER_STORAGE_BUFFER, ssbo_group_value_offset, ssbo_group_value_size, temp_group_values);
331-
332-
for (size_t i = 0; i < MIN(cap_groups, 16); ++i) {
333-
double value = temp_group_values[i] / QUANTIZATION_SCALE_FACTOR;
334-
out_group_values[i] = (float)value;
321+
uint32_t* temp_values = (uint32_t*)md_temp_push(sizeof(uint32_t) * num_points);
322+
glGetBufferSubData(GL_SHADER_STORAGE_BUFFER, ssbo_point_value_offset, ssbo_point_value_size, temp_values);
323+
for (size_t i = 0; i < num_points; ++i) {
324+
out_values[i] = (float)(temp_values[i] / QUANTIZATION_SCALE_FACTOR);
335325
}
326+
md_temp_pop(sizeof(uint32_t) * num_points);
336327

337328
glUseProgram(0);
338329
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
@@ -485,7 +476,7 @@ void md_gto_grid_evaluate_matrix_GPU(uint32_t vol_tex, const md_grid_t* grid, co
485476
GLuint64 elapsedTime = 0;
486477
glGetQueryObjectui64v(query, GL_QUERY_RESULT, &elapsedTime); // nanoseconds
487478

488-
MD_LOG_DEBUG("GTO Density evaluation GPU time: %.3f ms", elapsedTime / 1e6);
479+
MD_LOG_DEBUG("GTO Density evaluation of [%i,%i,%i] GPU time: %.3f ms", grid->dim[0], grid->dim[1], grid->dim[2], elapsedTime / 1e6);
489480

490481
glUseProgram(0);
491482

src/md_gto.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,13 @@ void md_gto_grid_evaluate_orb_GPU(uint32_t vol_tex, const md_grid_t* vol_grid, c
9191
void md_gto_grid_evaluate_ALIE_GPU(uint32_t vol_tex, const md_grid_t* vol_grid, const md_orbital_data_t* orb, md_gto_eval_mode_t mode);
9292

9393
// This is malplaced at the moment, but this is for the moment, the best match in where to place the functionality
94-
// Performs voronoi segmentation of the supplied volume to points with a supplied radius and accumulates the value of each voxel into the corresponding group of the closest point
95-
// - out_group_values: Destination array holding the group values that are written to
96-
// - cap_groups: Capacity of group array
97-
// - vol_tex: The texture handle to the volume
98-
// - vol_grid: The grid defining the volume
94+
// Performs voronoi segmentation of the supplied volume to points with a supplied radius and accumulates the value of each voxel into the corresponding point
95+
// - out_values: Destination array holding the point values that are written to, should have length 'num_points'
9996
// - point_xyzr: Point coordinates + radius, packed xyzrxyzrxyzr
100-
// - point_group_idx: Point group index [0, num_groups-1]
10197
// - num_points: Number of points
102-
void md_gto_segment_and_attribute_to_groups_GPU(float* out_group_values, size_t cap_groups, uint32_t vol_tex, const md_grid_t* vol_grid, const float* point_xyzr, const uint32_t* point_group_idx, size_t num_points);
98+
// - vol_tex: The texture handle to the volume
99+
// - vol_grid: The grid defining the volume
100+
void md_gto_voronoi_segment_GPU(float* out_values, const float* point_xyzr, size_t num_points, uint32_t vol_tex, const md_grid_t* vol_grid);
103101

104102
// Evaluates GTOs over a grid
105103
// - out_grid_values: The grid to write the evaluated values to, should have length 'grid->dim[0] * grid->dim[1] * grid->dim[2]'

src/shaders/eval_gto_density.comp

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
#version 430 core
22

3-
#extension GL_KHR_shader_subgroup_basic : enable
4-
#extension GL_KHR_shader_subgroup_ballot : enable
5-
#extension GL_KHR_shader_subgroup_arithmetic : enable
6-
73
struct PGTO {
84
float coeff;
95
float alpha;
@@ -48,7 +44,7 @@ layout(std140, binding = 0) uniform UniformBlock {
4844

4945
layout(binding = 0) writeonly restrict uniform image3D out_vol;
5046

51-
#if 0
47+
#if 1
5248
float safe_pow(float base, uint exponent) {
5349
switch(exponent) {
5450
case 0u: return 1.0;
@@ -60,7 +56,11 @@ float safe_pow(float base, uint exponent) {
6056
return b2 * b2;
6157
}
6258
default: {
63-
return 1.0;
59+
float v = 1.0;
60+
for (uint e = 0u; e < exponent; ++e) {
61+
v *= base;
62+
}
63+
return v;
6464
}
6565
}
6666
}
@@ -96,7 +96,11 @@ uint pack_offset_length(uint offset, uint length) {
9696
return (length << 24u) | offset;
9797
}
9898

99-
#define WG_SIZE 512
99+
#define WG_X 8
100+
#define WG_Y 8
101+
#define WG_Z 8
102+
103+
#define WG_SIZE (WG_X * WG_Y * WG_Z)
100104
#define TILE_SIZE 32
101105
#define TILE_AREA (TILE_SIZE*TILE_SIZE)
102106
#define MAX_SCREENED_CGTOS 2048
@@ -119,26 +123,28 @@ uint get_index(uint i, uint j, uint N) {
119123
return row_offset + (col - row);
120124
}
121125

122-
// Populate the D GSM from the global D_matrix
126+
// Full D_tile from global D_matrix
123127
void fill_D_tile(uint tile_i, uint tile_j) {
124-
uint tid = gl_LocalInvocationIndex;
125-
if (tid >= TILE_SIZE) return;
128+
const uint tid = gl_LocalInvocationIndex;
126129

127130
uint baseI = tile_i * TILE_SIZE;
128131
uint baseJ = tile_j * TILE_SIZE;
129132

130-
uint idxI = baseI + tid;
131-
uint gi = (idxI < MAX_SCREENED_CGTOS) ? screened_cgtos[idxI] : INVALID_CGTO_IDX;
133+
for (uint index = tid; index < TILE_AREA; index += WG_SIZE) {
134+
uint row = index / TILE_SIZE;
135+
uint col = index % TILE_SIZE;
136+
137+
uint idxI = baseI + row;
138+
uint gi = (idxI < MAX_SCREENED_CGTOS) ? screened_cgtos[idxI] : INVALID_CGTO_IDX;
132139

133-
for (uint col = 0; col < TILE_SIZE; ++col) {
134140
uint idxJ = baseJ + col;
135141
uint gj = (idxJ < MAX_SCREENED_CGTOS) ? screened_cgtos[idxJ] : INVALID_CGTO_IDX;
136142

137143
float value = 0.0;
138144
if (gi != INVALID_CGTO_IDX && gj != INVALID_CGTO_IDX) {
139145
value = D_matrix[get_index(gi, gj, D_matrix_dim)];
140146
}
141-
D_tile[tid][col] = value;
147+
D_tile[row][col] = value;
142148
}
143149
}
144150

@@ -162,29 +168,27 @@ void fill_cgtos_tile(uint tile_number, vec3 model_aabb_min, vec3 model_aabb_max)
162168
uint src_end = cgto_offset[global_cgto_idx + 1u];
163169
uint src_len = src_end - src_beg;
164170

165-
if (src_beg != src_end) {
166-
// Reserve contiguous space for this CGTO's PGTOs
167-
uint dst = atomicAdd(num_pgtos, src_len);
168-
169-
cgto_center = cgto_xyzr[global_cgto_idx].xyz;
170-
cgto_pgto_off = dst;
171-
172-
vec3 model_xyz = vec3(world_to_model * vec4(cgto_center, 1.0));
173-
vec3 d = clamp(model_xyz, model_aabb_min, model_aabb_max) - model_xyz;
174-
float d2 = dot(d, d);
171+
// Reserve contiguous space for this CGTO's PGTOs
172+
uint dst = atomicAdd(num_pgtos, src_len);
175173

176-
// Copy PGTOs
177-
for (uint k = 0; k < src_len; ++k) {
178-
float r = pgto_radius[src_beg + k];
179-
// Cull based on radius
180-
if (d2 < r * r) {
181-
PGTO g;
182-
g.coeff = pgto_coeff[src_beg + k];
183-
g.alpha = pgto_alpha[src_beg + k];
184-
g.ijkl = pgto_ijkl[src_beg + k];
185-
pgtos_tile[cgto_pgto_off + cgto_pgto_len] = g;
186-
cgto_pgto_len++;
187-
}
174+
cgto_center = cgto_xyzr[global_cgto_idx].xyz;
175+
cgto_pgto_off = dst;
176+
177+
vec3 model_xyz = vec3(world_to_model * vec4(cgto_center, 1.0));
178+
vec3 d = clamp(model_xyz, model_aabb_min, model_aabb_max) - model_xyz;
179+
float d2 = dot(d, d);
180+
181+
// Copy PGTOs
182+
for (uint k = 0; k < src_len; ++k) {
183+
float r = pgto_radius[src_beg + k];
184+
// Cull based on radius
185+
if (d2 < r * r) {
186+
PGTO g;
187+
g.coeff = pgto_coeff[src_beg + k];
188+
g.alpha = pgto_alpha[src_beg + k];
189+
g.ijkl = pgto_ijkl[src_beg + k];
190+
pgtos_tile[cgto_pgto_off + cgto_pgto_len] = g;
191+
cgto_pgto_len++;
188192
}
189193
}
190194
}
@@ -205,16 +209,33 @@ void eval_phis(out float out_phi[TILE_SIZE], vec3 coord) {
205209
uint pgto_len = cgtos_tile[i].pgto_len;
206210

207211
vec3 center = cgtos_tile[i].coord;
208-
vec3 d = coord - center;
212+
vec3 d = coord - center;
209213
float r2 = dot(d, d);
210-
214+
215+
// Precompute powers once per CGTO
216+
float dx2 = d.x * d.x;
217+
float dy2 = d.y * d.y;
218+
float dz2 = d.z * d.z;
219+
211220
float phi = 0.0;
212221
for (uint j = pgto_off; j < pgto_off + pgto_len; ++j) {
213222
PGTO pgto = pgtos_tile[j];
214223
uvec4 ijkl = unpack_ijkl(pgto.ijkl);
215-
float fx = safe_pow(d.x, ijkl.x);
216-
float fy = safe_pow(d.y, ijkl.y);
217-
float fz = safe_pow(d.z, ijkl.z);
224+
// Use ternary for common cases (no divergence, compiles to select)
225+
float fx = (ijkl.x == 0u) ? 1.0 :
226+
(ijkl.x == 1u) ? d.x :
227+
(ijkl.x == 2u) ? dx2 :
228+
(ijkl.x == 3u) ? dx2 * d.x : dx2 * dx2;
229+
230+
float fy = (ijkl.y == 0u) ? 1.0 :
231+
(ijkl.y == 1u) ? d.y :
232+
(ijkl.y == 2u) ? dy2 :
233+
(ijkl.y == 3u) ? dy2 * d.y : dy2 * dy2;
234+
235+
float fz = (ijkl.z == 0u) ? 1.0 :
236+
(ijkl.z == 1u) ? d.z :
237+
(ijkl.z == 2u) ? dz2 :
238+
(ijkl.z == 3u) ? dz2 * d.z : dz2 * dz2;
218239
phi += pgto.coeff * fx * fy * fz * exp(-pgto.alpha * r2);
219240
}
220241
out_phi[i] = phi;
@@ -225,9 +246,10 @@ float symmetric_contraction(float phi[TILE_SIZE], float D[TILE_SIZE][TILE_SIZE])
225246
float result = 0.0;
226247
for (uint i = 0; i < TILE_SIZE; ++i) {
227248
float ai = phi[i];
228-
result += D[i][i] * ai * ai; // Diagonal
249+
result = fma(D[i][i] * ai, ai, result); // Diagonal
250+
ai = 2.0 * ai;
229251
for (uint j = i + 1; j < TILE_SIZE; ++j) {
230-
result += 2.0 * D[i][j] * ai * phi[j]; // Off-diagonal
252+
result = fma(D[i][j] * ai, phi[j], result); // Off-diagonal
231253
}
232254
}
233255
return result;
@@ -236,15 +258,15 @@ float symmetric_contraction(float phi[TILE_SIZE], float D[TILE_SIZE][TILE_SIZE])
236258
float contraction(float phi_mu[TILE_SIZE], float phi_nu[TILE_SIZE], float D[TILE_SIZE][TILE_SIZE]) {
237259
float result = 0.0;
238260
for (uint i = 0; i < TILE_SIZE; ++i) {
239-
float ai = phi_mu[i];
261+
float ai = 2.0 * phi_mu[i];
240262
for (uint j = 0; j < TILE_SIZE; ++j) {
241-
result += 2.0 * D[i][j] * ai * phi_nu[j];
263+
result = fma(D[i][j] * ai, phi_nu[j], result); // Off-diagonal
242264
}
243265
}
244266
return result;
245267
}
246268

247-
layout (local_size_x = 8, local_size_y = 8, local_size_z = 8) in;
269+
layout (local_size_x = WG_X, local_size_y = WG_Y, local_size_z = WG_Z) in;
248270
void main() {
249271
uint tid = gl_LocalInvocationIndex;
250272
if (tid == 0) {
@@ -256,7 +278,6 @@ void main() {
256278
vec3 model_aabb_max = vec3((gl_WorkGroupID.xyz + uvec3(1,1,1)) * gl_WorkGroupSize.xyz) * step.xyz;
257279
// Stage 1: Screening. Prune CGTOs to identify which are relevant for region
258280
{
259-
// Stream matches directly; avoid large per-thread stacks and subgroup prefix sums
260281
for (uint i = tid; i < D_matrix_dim; i += WG_SIZE) {
261282
vec4 cgto = cgto_xyzr[i];
262283
if (cgto.w == 0.0) continue;
@@ -291,6 +312,7 @@ void main() {
291312
screened_cgtos[i] = INVALID_CGTO_IDX;
292313
}
293314
}
315+
barrier();
294316
}
295317

296318
float phi_tile_mu[TILE_SIZE]; // evaluated φ_μ(r) in registers
@@ -330,8 +352,6 @@ void main() {
330352
rho += symmetric_contraction(phi_tile_mu, D_tile);
331353

332354
for (uint tile_j = tile_i + 1; tile_j < num_tiles; ++tile_j) {
333-
barrier();
334-
335355
// OFF DIAGONAL TILE
336356
if (tid == 0) {
337357
num_pgtos = 0;

0 commit comments

Comments
 (0)