diff --git a/keras/backend/mxnet_backend.py b/keras/backend/mxnet_backend.py index 206121bd7533..ce08f74a440f 100644 --- a/keras/backend/mxnet_backend.py +++ b/keras/backend/mxnet_backend.py @@ -2908,8 +2908,14 @@ def pool2d(x, pool_size, strides=(1, 1), if dim_ordering == 'default': dim_ordering = image_dim_ordering() x = _preprocess_convnd_input(x, dim_ordering) + + padding = (0, 0) + if border_mode == 'same': + padding, _, out_size = _preprocess_border_mode(border_mode, x.shape, pool_size, strides, (1,1)) + border_mode = 'valid' + s = mx.sym.Pooling(data=x.symbol, kernel=pool_size, pool_type=pool_mode, pooling_convention=border_mode, - stride=strides) + pad=padding, stride=strides) out = _postprocess_convnd_output(KerasSymbol(s), dim_ordering) return out @@ -2932,8 +2938,14 @@ def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid', if dim_ordering == 'default': dim_ordering = image_dim_ordering() x = _preprocess_convnd_input(x, dim_ordering) + + padding = (0, 0, 0) + if border_mode == 'same': + padding, _, _ = _preprocess_border_mode(border_mode, x.shape, pool_size, strides, (1,1,1)) + border_mode = 'valid' + s = mx.sym.Pooling(data=x.symbol, kernel=pool_size, pool_type=pool_mode, pooling_convention=border_mode, - stride=strides) + pad=padding, stride=strides) out = _postprocess_convnd_output(KerasSymbol(s), dim_ordering) return out