Skip to content

Commit

Permalink
Merge branch 'v13' into bp-8062-v13-cepstrum
Browse files Browse the repository at this point in the history
  • Loading branch information
emcastillo authored Jan 24, 2024
2 parents 4edd71e + e83d9d1 commit 2100b47
Show file tree
Hide file tree
Showing 28 changed files with 676 additions and 245 deletions.
11 changes: 4 additions & 7 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@ on:
workflow_dispatch:
inputs:
branch:
description: "Source: cupy-release-tools"
type: choice
options:
- main
- v12
description: "Source Branch (e.g., v13)"
default: "main"
release:
description: "Release to Publish (draft/tag, e.g., v13.0.0a1)"
default: "v13.0.0rc9"
description: "Release to Publish (draft/tag, e.g., v14.0.0a1)"
default: "v14.0.0rc9"

jobs:
precheck:
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ Binary packages are also available for Linux and Windows on [Conda-Forge](https:
| --------------------- | --------------------------- | ------------------------------------------------------------- |
| CUDA | x86_64 / aarch64 / ppc64le | `conda install -c conda-forge cupy` |

If you need to use a particular CUDA version (say 11.8), you can do `conda install -c conda-forge cupy cuda-version=11.8`.
If you need a slim installation (without also getting CUDA dependencies installed), you can do `conda install -c conda-forge cupy-core`.

If you need to use a particular CUDA version (say 12.0), you can use the `cuda-version` metapackage to select the version, e.g. `conda install -c conda-forge cupy cuda-version=12.0`.

> [!NOTE]\
> If you encounter any problem with CuPy installed from `conda-forge`, please feel free to report to [cupy-feedstock](https://github.com/conda-forge/cupy-feedstock/issues), and we will help investigate if it is just a packaging issue in `conda-forge`'s recipe or a real issue in CuPy.
Expand Down
2 changes: 1 addition & 1 deletion cupy/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '13.0.0rc1'
__version__ = '13.0.0'
3 changes: 2 additions & 1 deletion cupy/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __getattr__(key):
# users use the class for type annotation.
if key == 'Generator':
# Lazy import libraries depending on cuRAND
from cupy.random._generator_api import Generator
import cupy.random._generator_api
Generator = cupy.random._generator_api.Generator
_cupy.random.Generator = Generator
return Generator
raise AttributeError(f"module '{__name__}' has no attribute '{key}'")
Expand Down
2 changes: 1 addition & 1 deletion cupyx/distributed/_nccl_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _get_stream(self, stream):
def _get_op(self, op, dtype):
if op not in _nccl_ops:
raise RuntimeError(f'Unknown op {op} for NCCL')
if dtype in 'FD' and op != nccl.NCCL_SUM:
if dtype in 'FD' and op != 'sum':
raise ValueError(
'Only nccl.SUM is supported for complex arrays')
return _nccl_ops[op]
Expand Down
3 changes: 3 additions & 0 deletions cupyx/signal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from cupyx.signal._acoustics import complex_cepstrum, real_cepstrum # NOQA
from cupyx.signal._acoustics import inverse_complex_cepstrum # NOQA
from cupyx.signal._acoustics import minimum_phase # NOQA
from cupyx.signal._convolution import convolve1d2o # NOQA
from cupyx.signal._convolution import convolve1d3o # NOQA
from cupyx.signal._radartools import pulse_compression, pulse_doppler, cfar_alpha # NOQA
from cupyx.signal._filtering import firfilter, firfilter2, firfilter_zi # NOQA
from cupyx.signal._filtering import freq_shift # NOQA
1 change: 1 addition & 0 deletions cupyx/signal/_convolution/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from cupyx.signal._convolution._convolve import convolve1d2o # NOQA
from cupyx.signal._convolution._convolve import convolve1d3o # NOQA
39 changes: 0 additions & 39 deletions cupyx/signal/_convolution/_convolution_utils.py

This file was deleted.

193 changes: 116 additions & 77 deletions cupyx/signal/_convolution/_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,97 +23,136 @@
DEALINGS IN THE SOFTWARE.
"""

import cupy as cp
from cupy._core._scalar import get_typename
from cupyx.signal._convolution import _convolution_utils


CONVOLVE1D3O_KERNEL = """
#include <cupy/complex.cuh>
///////////////////////////////////////////////////////////////////////////////
// CONVOLVE 1D3O //
///////////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ void _cupy_convolve1D3O( const T *__restrict__ inp,
const int inpW,
const T *__restrict__ kernel,
const int kerW,
const int kerH,
const int kerD,
const int mode,
T *__restrict__ out,
const int outW ) {
const int tx { static_cast<int>( blockIdx.x * blockDim.x + threadIdx.x ) };
const int stride { static_cast<int>( blockDim.x * gridDim.x ) };
for ( int tid = tx; tid < outW; tid += stride ) {
T temp {};
if ( mode == 0 ) { // Valid
if ( tid >= 0 && tid < inpW ) {
for ( int i = 0; i < kerW; i++ ) {
for ( int j = 0; j < kerH; j++ ) {
for ( int k = 0; k < kerD; k++ ) {
temp += inp[tid + kerW - i - 1] * inp[tid + kerH - j - 1] * inp[tid + kerD - k - 1] * kernel[ (kerH * i + j) * kerD + k ];
}
}
}
}
}
out[tid] = temp;
import cupy


_convolve1d2o_kernel = cupy.ElementwiseKernel(
'raw T in1, raw T in2, int32 W, int32 H', 'T out',
"""
T temp {};
for (int x = 0; x < W; x++) {
for (int y = 0; y < H; y++) {
temp += in1[i + W - x - 1] * in1[i + H - y - 1] * in2[H * x + y];
}
}
out = temp;
""",
"cupy_convolved2o",
)

}
""" # NOQA

CONVOLVE1D3O_MODULE = cp.RawModule(
code=CONVOLVE1D3O_KERNEL, options=('-std=c++11',),
name_expressions=[
'_cupy_convolve1D3O<float>',
'_cupy_convolve1D3O<double>',
'_cupy_convolve1D3O<complex<float>>',
'_cupy_convolve1D3O<complex<double>>',
])
def _convolve1d2o(in1, in2, mode):
assert mode == "valid"
out_dim = in1.shape[0] - max(in2.shape) + 1
dtype = cupy.result_type(in1, in2)
out = cupy.empty(out_dim, dtype=dtype)
_convolve1d2o_kernel(in1, in2, *in2.shape, out)
return out


def _convolve1d3o_gpu(inp, out, ker, mode):
def convolve1d2o(in1, in2, mode='valid', method='direct'):
"""
Convolve a 1-dimensional arrays with a 2nd order filter.
This results in a second order convolution.
kernel = CONVOLVE1D3O_MODULE.get_function(
f'_cupy_convolve1D3O<{get_typename(out.dtype)}>')
Convolve `in1` and `in2`, with the output size determined by the
`mode` argument.
threadsperblock = (out.shape[0] + 128 - 1) // 128,
blockspergrid = 128,
kernel_args = (
inp,
inp.shape[0],
ker,
*ker.shape,
mode,
out,
out.shape[0],
)
kernel(threadsperblock, blockspergrid, kernel_args)
Parameters
----------
in1 : array_like
First input.
in2 : array_like
Second input. Should have the same number of dimensions as `in1`.
mode : str {'full', 'valid', 'same'}, optional
A string indicating the size of the output:
``full``
The output is the full discrete linear convolution
of the inputs. (Default)
``valid``
The output consists only of those elements that do not
rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
must be at least as large as the other in every dimension.
``same``
The output is the same size as `in1`, centered
with respect to the 'full' output.
method : str {'auto', 'direct', 'fft'}, optional
A string indicating which method to use to calculate the convolution.
def _convolve1d3o(in1, in2, mode):
``direct``
The convolution is determined directly from sums, the definition of
convolution.
``fft``
The Fourier Transform is used to perform the convolution by calling
`fftconvolve`.
``auto``
Automatically chooses direct or Fourier method based on an estimate
of which is faster (default).
val = _convolution_utils._valfrommode(mode)
assert val == _convolution_utils.VALID
Returns
-------
out : ndarray
A 1-dimensional array containing a subset of the discrete linear
convolution of `in1` with `in2`.
See Also
--------
convolve
convolve1d2o
convolve1d3o
# Promote inputs
promType = cp.promote_types(in1.dtype, in2.dtype)
in1 = in1.astype(promType)
in2 = in2.astype(promType)
Examples
--------
Convolution of a 2nd order filter on a 1d signal
out_dim = in1.shape[0] - max(in2.shape) + 1
out = cp.empty(out_dim, dtype=in1.dtype)
>>> import cusignal as cs
>>> import numpy as np
>>> d = 50
>>> a = np.random.uniform(-1,1,(200))
>>> b = np.random.uniform(-1,1,(d,d))
>>> c = cs.convolve1d2o(a,b)
_convolve1d3o_gpu(in1, out, in2, val)
"""

if in1.ndim != 1:
raise ValueError('in1 should have one dimension')
if in2.ndim != 2:
raise ValueError('in2 should have three dimension')

if mode in ["same", "full"]:
raise NotImplementedError("Mode == {} not implemented".format(mode))

if method == "direct":
return _convolve1d2o(in1, in2, mode)
else:
raise NotImplementedError("Only Direct method implemented")


_convolve1d3o_kernel = cupy.ElementwiseKernel(
'raw T in1, raw T in2, int32 W, int32 H, int32 D', 'T out',
"""
T temp {};
for (int x = 0; x < W; x++) {
for (int y = 0; y < H; y++) {
for (int z = 0; z < D; z++) {
temp += in1[i + W - x - 1] * in1[i + H - y - 1] *
in1[i + D - z - 1] * in2[(H * x + y) * D + z];
}
}
}
out = temp;
""",
"cupy_convolved3o",
)


def _convolve1d3o(in1, in2, mode):
assert mode == "valid"
out_dim = in1.shape[0] - max(in2.shape) + 1
dtype = cupy.result_type(in1, in2)
out = cupy.empty(out_dim, dtype=dtype)
_convolve1d3o_kernel(in1, in2, *in2.shape, out)
return out


Expand Down
2 changes: 2 additions & 0 deletions cupyx/signal/_filtering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from cupyx.signal._filtering._filtering import firfilter, firfilter2, firfilter_zi # NOQA
from cupyx.signal._filtering._filtering import freq_shift # NOQA
Loading

0 comments on commit 2100b47

Please sign in to comment.