Skip to content

Commit 6a425b5

Browse files
authored
fix WinogradConvolution3 allocation size (LeelaChessZero#864)
1 parent 6028c05 commit 6a425b5

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

src/neural/blas/network_blas.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ void BlasComputation::ComputeBlocking() {
170170
std::vector<float> res_buffer3(largest_batch_size * output_channels *
171171
kSquares);
172172

173-
WinogradConvolution3 convolve3(largest_batch_size, max_channels,
174-
output_channels);
173+
WinogradConvolution3 convolve3(largest_batch_size, max_channels);
175174

176175
std::vector<float> policy_buffer(largest_batch_size *
177176
num_policy_input_planes * kSquares);

src/neural/blas/winograd_convolution3.cc

+3-4
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,9 @@ using ConstEigenMatrixMap =
4444
#endif
4545

4646
WinogradConvolution3::WinogradConvolution3(const size_t max_batch_size,
47-
const size_t max_input_layers,
48-
const size_t max_output_layers)
49-
: V_(max_batch_size * kWinogradTile * max_input_layers * kTiles),
50-
M_(max_batch_size * kWinogradTile * max_output_layers * kTiles) {}
47+
const size_t max_channels)
48+
: V_(max_batch_size * kWinogradTile * max_channels * kTiles),
49+
M_(max_batch_size * kWinogradTile * max_channels * kTiles) {}
5150

5251
void WinogradConvolution3::Forward(const size_t batch_size,
5352
const size_t input_channels,

src/neural/blas/winograd_convolution3.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ class WinogradConvolution3 {
3939
// The instance will allocate memory resources for the
4040
// largest batch size, and the largest input and output
4141
// layers.
42-
WinogradConvolution3(const size_t max_batch_size,
43-
const size_t max_input_layers,
44-
const size_t max_output_layers);
42+
WinogradConvolution3(const size_t max_batch_size, const size_t max_channels);
4543

4644
// Forward inference, batched.
4745
void Forward(const size_t batch_size, const size_t input_channels,

0 commit comments

Comments
 (0)