-
Sparse Conv Tensor: like hybird torch.sparse_coo_tensor but only have two difference: 1. SparseConvTensor only have one dense dim, 2. indice of SparseConvTensor is transposed. see torch doc for more details.
-
Sparse Convolution: equivalent to perform dense convolution when you convert SparseConvTensor to dense. Sparse Convolution only run calculation on valid data.
-
Submanifold Convolution (SubMConv): like Sparse Convolution but indices keeps same. imagine that you copy same spatial structure to output, then iterate them, get input coordinates by conv rule, finally apply convolution ONLY in these output coordinates.
-
features:
[N, num_channels]
tensor. -
indices:
[N, (batch_idx + x + y + z)]
coordinate tensor with batch axis. note that the coordinates xyz order MUST match spatial shape and conv params such as kernel size
import spconv.pytorch as spconv
features = # your features with shape [N, num_channels]
indices = # your indices/coordinates with shape [N, ndim + 1], batch index must be put in indices[:, 0]
spatial_shape = # spatial shape of your sparse tensor, spatial_shape[i] is shape of indices[:, 1 + i].
batch_size = # batch size of your sparse tensor.
x = spconv.SparseConvTensor(features, indices, spatial_shape, batch_size)
x_dense_NCHW = x.dense() # convert sparse tensor to dense NCHW tensor.
import spconv.pytorch as spconv
from torch import nn
class ExampleNet(nn.Module):
def __init__(self, shape):
super().__init__()
self.net = spconv.SparseSequential(
spconv.SparseConv3d(32, 64, 3), # just like nn.Conv3d but don't support group
nn.BatchNorm1d(64), # non-spatial layers can be used directly in SparseSequential.
nn.ReLU(),
spconv.SubMConv3d(64, 64, 3, indice_key="subm0"),
nn.BatchNorm1d(64),
nn.ReLU(),
# when use submanifold convolutions, their indices can be shared to save indices generation time.
spconv.SubMConv3d(64, 64, 3, indice_key="subm0"),
nn.BatchNorm1d(64),
nn.ReLU(),
spconv.SparseConvTranspose3d(64, 64, 3, 2),
nn.BatchNorm1d(64),
nn.ReLU(),
spconv.ToDense(), # convert spconv tensor to dense and convert it to NCHW format.
nn.Conv3d(64, 64, 3),
nn.BatchNorm1d(64),
nn.ReLU(),
)
self.shape = shape
def forward(self, features, coors, batch_size):
coors = coors.int() # unlike torch, this library only accept int coordinates.
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size)
return self.net(x)# .dense()
Inverse sparse convolution means "inv" of sparse convolution. the output of inverse convolution contains same indices as input of sparse convolution.
WARNING SparseInverseConv
isn't equivalent to SparseConvTranspose
. SparseConvTranspose is equivalent to ConvTranspose
in pytorch, but SparseInverseConv isn't.
Inverse convolution usually used in semantic segmentation.
class ExampleNet(nn.Module):
def __init__(self, shape):
super().__init__()
self.net = spconv.SparseSequential(
spconv.SparseConv3d(32, 64, 3, 2, indice_key="cp0"),
spconv.SparseInverseConv3d(64, 32, 3, indice_key="cp0"), # need provide kernel size to create weight
)
self.shape = shape
def forward(self, features, coors, batch_size):
coors = coors.int()
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size)
return self.net(x)
see example/mnist_sparse. we support torch.cuda.amp
.
- convert point cloud to voxel
voxel generator in spconv generate indices in ZYX order, the params format are XYZ.
generated indices don't include batch axis, you need to add it by yourself.
see examples/voxel_gen.py for examples.
from spconv.pytorch.utils import PointToVoxel, gather_features_by_pc_voxel_id
# this generator generate ZYX indices.
gen = PointToVoxel(
vsize_xyz=[0.1, 0.1, 0.1],
coors_range_xyz=[-80, -80, -2, 80, 80, 6],
num_point_features=3,
max_num_voxels=5000,
max_num_points_per_voxel=5)
pc = np.random.uniform(-10, 10, size=[1000, 3])
pc_th = torch.from_numpy(pc)
voxels, coords, num_points_per_voxel = gen(pc_th, empty_mean=True)
If you want to get label for every point of your pc, you need to use another function to get pc_voxel_id and gather features from sematic segmentation result:
voxels, coords, num_points_per_voxel, pc_voxel_id = gen.generate_voxel_with_id(pc_th, empty_mean=True)
seg_features = YourSegNet(...)
# if voxel id is invalid (point out of range, or no space left in a voxel)
# features will be zero.
point_features = gather_features_by_pc_voxel_id(seg_features, pc_voxel_id)