Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Batched GEMV Speed #113

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make clra_gemv reduce_impl column-major
Signed-off-by: Shaun Ren <shaun.ren@linux.com>
shaunren committed Jul 11, 2016
commit 3e97b6a88787791dc331f6cdc3d395d73b09d4dd
139 changes: 64 additions & 75 deletions nengo_ocl/clra_gemv.py
Original file line number Diff line number Diff line change
@@ -450,13 +450,13 @@ def reduce_impl(p, items,
raise NotImplementedError()
if p.cl_gamma is not None:
raise NotImplementedError()
if not all(s == 1 for s in p.A.stride1s):
raise NotImplementedError()
if not all(s == 1 for s in p.A.stride0s):
raise NotImplementedError('A must be in column major')

assert p.float_alpha is not None
assert p.float_gamma is not None

cl_gstructure, textconf = p.cl_geometry_and_textconf(items)
cl_gstructure, textconf = p.cl_geometry_and_textconf(items, stride=1)
max_n_dots = max([len(p.geometry[ii].dots) for ii in items])
max_reduce_len = max(max([gg.a_shape1
for gg in p.geometry[ii].dots])
@@ -466,19 +466,18 @@ def reduce_impl(p, items,
# segment means the piece of Y written by a work-group
# group_size is the number of values that we're reducing over

if len(items) < 4:
if group_size is None:
group_size = 32 # XXX
if segment_size is None:
segment_size = min(max_y_len, 2) # XXX
else:
if group_size is None:
group_size = 32 # XXX
if segment_size is None:
segment_size = min(max_y_len, 4) # XXX
# TODO autotune
if group_size is None:
group_size = 4

if segment_size is None:
if len(items) < 8:
segment_size = min(max_y_len, 8) # XXX
else:
segment_size = min(max_y_len, 32) # XXX
g_segments = int(np.ceil(float(max_y_len) / segment_size))
gsize = (group_size, g_segments * segment_size, len(items))
lsize = (group_size, segment_size, 1)
gsize = (g_segments * segment_size, group_size, len(items))
lsize = (segment_size, group_size, 1)

max_reduce_iters = int(np.ceil(float(max_reduce_len) / group_size))
textconf.update({
@@ -489,8 +488,8 @@ def reduce_impl(p, items,
'group_size': group_size,
'local_count': group_size * segment_size,
'max_reduce_len': max_reduce_len,
'N_cutoff': max_reduce_iters * group_size,
'max_n_dots': max_n_dots,
'log2_group_size': int(np.ceil(np.log2(group_size))),
})
if 0:
for k, v in textconf.items():
@@ -509,19 +508,18 @@ def reduce_impl(p, items,
const __global ${Y_in.cl_buf.ctype} *Y_in_data,
__global ${Y.cl_buf.ctype} *Y_data)
{
const int i = get_local_id(0);
const int j = get_local_id(1);
const int m = get_global_id(0);

__local int lstructure[${n_structure_vars}];
% if segment_size > 1:
// we'll cache X in shared memory so we load it only once
// for the whole segment
__local ${X.cl_buf.ctype} lX[${group_size}];
% endif

//Scratch space for the dot products
__local ${Y.cl_buf.ctype}
partialDotProduct[${segment_size}][${group_size}];
sums[${group_size}][${segment_size}];
__local ${Y.cl_buf.ctype}
y_sum_pre[${segment_size}];
const int local_idx = get_local_id(0)
+ get_local_id(1) * get_local_size(0);
const int local_idx = i + j * get_local_size(0);

// load structure
% if local_count < n_structure_vars:
@@ -532,7 +530,7 @@ def reduce_impl(p, items,
lstructure[ii] = gstructure[
get_global_id(2) * ${structure_vars_stride} + ii];
}
% else :
% else:
if (local_idx < ${n_structure_vars})
{
lstructure[local_idx] = gstructure[
@@ -541,25 +539,25 @@ def reduce_impl(p, items,
% endif
barrier(CLK_LOCAL_MEM_FENCE);

if ((get_local_id(0) == 0) && (get_global_id(1) < ${y_len}))
if (j == 0 && m < ${y_len})
{
% if float_beta is not None and float_beta != 0 :
y_sum_pre[get_local_id(1)] = ${float_beta}
* Y_in_data[${y_in_starts} + get_global_id(1)];
y_sum_pre[i] = ${float_beta}
* Y_in_data[${y_in_starts} + m];
% elif cl_beta is not None:
y_sum_pre[get_local_id(1)] = betas[${bb}]
* Y_in_data[${y_in_starts} + get_global_id(1)];
y_sum_pre[i] = betas[${bb}]
* Y_in_data[${y_in_starts} + m];
% else :
y_sum_pre[get_local_id(1)] = 0;
y_sum_pre[i] = 0;
% endif

% if float_gamma is not None and float_gamma != 0:
y_sum_pre[get_local_id(1)] += ${float_gamma};
y_sum_pre[i] += ${float_gamma};
% endif
// printf("betaY + gamma=%f\\n", y_sum_pre[get_local_id(1)]);
// printf("betaY + gamma=%f\\n", y_sum_pre[i]);
}

partialDotProduct[get_local_id(1)][get_local_id(0)] = 0;
sums[j][i] = 0;
% if max_n_dots > 1:
for (int ii = 0;
ii < ${n_dot_products};
@@ -570,57 +568,48 @@ def reduce_impl(p, items,
% endif


for (int nn = get_local_id(0);
nn < ${N_cutoff};
nn += get_local_size(0))
{
// segment_size = ${segment_size}
% if (segment_size == 1):
if ((nn < ${N_i}) && (get_global_id(1) < ${y_len}))
{
partialDotProduct[get_local_id(1)][get_local_id(0)] +=
A_data[${a_starts} + get_global_id(1) * ${a_s0} + nn]
* X_data[${x_starts} + nn];
}
% else:
barrier(CLK_LOCAL_MEM_FENCE);
if ((get_local_id(1) == 0) && (nn < ${N_i}))
{
lX[get_local_id(0)] = X_data[${x_starts} + nn];
}
barrier(CLK_LOCAL_MEM_FENCE);
if ((nn < ${N_i}) && (get_global_id(1) < ${y_len}))

__global ${A.cl_buf.ctype}* a = A_data + ${a_starts} + m;

% if segment_size >= 4:
// we'll cache X in shared memory so we load it only once
// for the whole segment
__local ${X.cl_buf.ctype} lX[${max_reduce_len}];
for (int k = local_idx; k < ${N_i}; k += ${local_count})
lX[k] = X_data[${x_starts} + k];
barrier(CLK_LOCAL_MEM_FENCE);
% endif

if (m < ${y_len}) {
for (int k = get_global_id(1); k < ${N_i}; k += get_global_size(1))
{
partialDotProduct[get_local_id(1)][get_local_id(0)] +=
A_data[${a_starts} + get_global_id(1) * ${a_s0} + nn]
* lX[get_local_id(0)];
// segment_size = ${segment_size}
% if segment_size < 4:
sums[j][i] += a[${a_s1} * k] * X_data[${x_starts} + k];
% else:
sums[j][i] += a[${a_s1} * k] * lX[k];
% endif
}
% endif
}

% if (max_n_dots > 1):
}
% endif

// -- Parallel reduction long work-group dimension 0
for (uint stride = 1;
stride < get_local_size(0);
stride *= 2)
{
// -- Parallel reduction
% for ks in range(log2_group_size - 1, -1, -1):
barrier(CLK_LOCAL_MEM_FENCE);

uint index = 2 * stride * get_local_id(0);
if (index + stride < get_local_size(0))
{
partialDotProduct[get_local_id(1)][index] +=
partialDotProduct[get_local_id(1)][index + stride];
if (j < ${2**ks}) {
sums[j][i] += sums[j + ${2**ks}][i];
% if ks == 0:
if (m < ${y_len}) {
Y_data[${y_offset} + m] = y_sum_pre[i]
+ ${float_alpha} * sums[0][i];
}
% endif
}
}
// barrier(CLK_LOCAL_MEM_FENCE);
if ((get_local_id(0) == 0) && (get_global_id(1) < ${y_len})) {
Y_data[${y_offset} + get_global_id(1)] = y_sum_pre[get_local_id(1)]
+ ${float_alpha} * partialDotProduct[get_local_id(1)][0];
}
% endfor

}
"""