A Simple Convolutional U-Net

In this tutorial, you will be guided on how to build a simple sparse convolutional neural network using fVDB. If you were using MinkowskiEngine to tackle sparse 3D data previously, we will also guide you step-by-step to help you smoothly transfer from it and enjoy speed-ups and memory-savings.

In our simplistic U-Net case, we want to build a Res-UNet with four layers, and each layer contains several blocks. First, we import basic fvdb libraries:

import fvdb
import fvdb.nn as fvnn
from fvdb.nn import VDBTensor
import torch

Here fvdb.nn is a namespace similar to torch.nn, containing a broad definition of different neural layers. VDBTensor is a very thin wrapper around a grid (with type GridBatch) and its corresponding feature (with type JaggedTensor), and internally makes sure that the two members align. It also overloads a bunch of operators such as arithmetic computations. Please refer to our API docs to learn more.

We could then build a basic block as follows:

class BasicBlock(torch.nn.Module):
    expansion = 1

    def __init__(self, in_channels: int, out_channels: int, downsample=None, bn_momentum: float = 0.1):
        super().__init__()
        self.conv1 = fvnn.SparseConv3d(in_channels, out_channels, kernel_size=3, stride=1)
        self.norm1 = fvnn.BatchNorm(out_channels, momentum=bn_momentum)
        self.conv2 = fvnn.SparseConv3d(out_channels, out_channels, kernel_size=3, stride=1)
        self.norm2 = fvnn.BatchNorm(out_channels, momentum=bn_momentum)
        self.relu = fvnn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x: VDBTensor):
        residual = x

        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.norm2(out)

        if self.downsample is not None:
          residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

This defines a similar block as MinkowskiEngine:

import MinkowskiEngine as ME


class BasicBlock(torch.nn.Module):
    expansion = 1

    def __init__(self, in_channels: int, out_channels: int, downsample=None, bn_momentum: float = 0.1):
        super().__init__()
        self.conv1 = ME.MinkowskiConvolution(
            in_channels, out_channels, kernel_size=3, stride=1, dilation=1, dimension=3)
        self.norm1 = ME.MinkowskiBatchNorm(out_channels, momentum=bn_momentum)
        self.conv2 = ME.MinkowskiConvolution(
            out_channels, out_channels, kernel_size=3, stride=1, dilation=1, dimension=3)
        self.norm2 = ME.MinkowskiBatchNorm(out_channels, momentum=bn_momentum)
        self.relu = ME.MinkowskiReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.norm2(out)

        if self.downsample is not None:
          residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

All the network layers are fully compatible with torch.nn. The only difference is that they take VDBTensor as input and return a VDBTensor. A full network definition could then be built as:

class FVDBUNetBase(torch.nn.Module):
    LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
    CHANNELS = (32, 64, 128, 256, 256, 128, 96, 96)
    INIT_DIM = 32
    OUT_TENSOR_STRIDE = 1

    def __init__(self, in_channels, out_channels, D=3):
        super().__init__()

        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv0p1s1 = fvnn.SparseConv3d(in_channels, self.inplanes, kernel_size=5, stride=1, bias=False)
        self.bn0 = fvnn.BatchNorm(self.inplanes)

        self.conv1p1s2 = fvnn.SparseConv3d(self.inplanes, self.inplanes, kernel_size=2, stride=2, bias=False)
        self.bn1 = fvnn.BatchNorm(self.inplanes)

        self.block1 = self._make_layer(BasicBlock, self.CHANNELS[0], self.LAYERS[0])

        self.conv2p2s2 = fvnn.SparseConv3d(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, bias=False)
        self.bn2 = fvnn.BatchNorm(self.inplanes)

        self.block2 = self._make_layer(BasicBlock, self.CHANNELS[1], self.LAYERS[1])

        self.conv3p4s2 = fvnn.SparseConv3d(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, bias=False)

        self.bn3 = fvnn.BatchNorm(self.inplanes)
        self.block3 = self._make_layer(BasicBlock, self.CHANNELS[2], self.LAYERS[2])

        self.conv4p8s2 = fvnn.SparseConv3d(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, bias=False)
        self.bn4 = fvnn.BatchNorm(self.inplanes)
        self.block4 = self._make_layer(BasicBlock, self.CHANNELS[3], self.LAYERS[3])

        self.convtr4p16s2 = fvnn.SparseConv3d(
            self.inplanes, self.CHANNELS[4], kernel_size=2, stride=2, transposed=True, bias=False)
        self.bntr4 = fvnn.BatchNorm(self.CHANNELS[4])

        self.inplanes = self.CHANNELS[4] + self.CHANNELS[2]
        self.block5 = self._make_layer(BasicBlock, self.CHANNELS[4], self.LAYERS[4])
        self.convtr5p8s2 = fvnn.SparseConv3d(
            self.inplanes, self.CHANNELS[5], kernel_size=2, stride=2, transposed=True, bias=False)
        self.bntr5 = fvnn.BatchNorm(self.CHANNELS[5])

        self.inplanes = self.CHANNELS[5] + self.CHANNELS[1]
        self.block6 = self._make_layer(BasicBlock, self.CHANNELS[5], self.LAYERS[5])
        self.convtr6p4s2 = fvnn.SparseConv3d(
            self.inplanes, self.CHANNELS[6], kernel_size=2, stride=2, transposed=True, bias=False)
        self.bntr6 = fvnn.BatchNorm(self.CHANNELS[6])

        self.inplanes = self.CHANNELS[6] + self.CHANNELS[0]
        self.block7 = self._make_layer(BasicBlock, self.CHANNELS[6], self.LAYERS[6])
        self.convtr7p2s2 = fvnn.SparseConv3d(
            self.inplanes, self.CHANNELS[7], kernel_size=2, stride=2, transposed=True, bias=False)
        self.bntr7 = fvnn.BatchNorm(self.CHANNELS[7])

        self.inplanes = self.CHANNELS[7] + self.INIT_DIM
        self.block8 = self._make_layer(BasicBlock, self.CHANNELS[7], self.LAYERS[7])

        self.final = fvnn.SparseConv3d(self.CHANNELS[7], out_channels, kernel_size=1)
        self.relu = fvnn.ReLU(inplace=True)

    def _make_layer(self, block, planes, blocks):
        downsample = None
        if self.inplanes != planes * block.expansion:
            downsample = torch.nn.Sequential(
                fvnn.SparseConv3d(
                    self.inplanes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=1
                ),
                fvnn.BatchNorm(planes * block.expansion),
            )
        layers = []
        layers.append(
            BasicBlock(
                self.inplanes, planes,
                downsample=downsample
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))

        return torch.nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = self.relu(out)
        grid1 = out_p1.grid

        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = self.relu(out)
        out_b1p2 = self.block1(out)
        grid2 = out_b1p2.grid

        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = self.relu(out)
        out_b2p4 = self.block2(out)
        grid4 = out_b2p4.grid

        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = self.relu(out)
        out_b3p8 = self.block3(out)
        grid8 = out_b3p8.grid

        # tensor_stride=16
        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = self.relu(out)
        out = self.block4(out)

        # tensor_stride=8
        out = self.convtr4p16s2(out, out_grid=grid8)
        out = self.bntr4(out)
        out = self.relu(out)

        out = fvdb.jcat([out, out_b3p8], dim=1)
        out = self.block5(out)

        # tensor_stride=4
        out = self.convtr5p8s2(out, out_grid=grid4)
        out = self.bntr5(out)
        out = self.relu(out)

        out = fvdb.jcat([out, out_b2p4], dim=1)
        out = self.block6(out)

        # tensor_stride=2
        out = self.convtr6p4s2(out, out_grid=grid2)
        out = self.bntr6(out)
        out = self.relu(out)

        out = fvdb.jcat([out, out_b1p2], dim=1)
        out = self.block7(out)

        # tensor_stride=1
        out = self.convtr7p2s2(out, out_grid=grid1)
        out = self.bntr7(out)
        out = self.relu(out)

        out = fvdb.jcat([out, out_p1], dim=1)
        out = self.block8(out)

        return self.final(out)

Please note that here, when we apply transposed convolution layers, we additionally introduce the out_grid keyword arguments. This is needed to guide the output domain of the network, because for perception networks, the output grid topology should align with the input topology. Note that fVDB will NOT cache the grids to maintain maximum flexibility.

To perform inference with the network, you could simply create a VDBTensor and feed it into the model:

coords = fvdb.JaggedTensor([
    (torch.randn(10_000, 3, device='cuda')),
    (torch.randn(11_000, 3, device='cuda')),
])

grid = fvdb.gridbatch_from_points(coords)
features = grid.jagged_like(torch.randn(grid.total_voxels, 32, device='cuda'))
sinput = fvnn.VDBTensor(grid, features)

model = FVDBUNetBase(32, 1).to('cuda')
soutput = model(sinput)

The output soutput will carry gradients during training, and you could train the sparse network accordingly. Please find a fully working example at examples/perception_example.py. The same network is implemented using MinkowskiEngine for reference.