Pairwise Image Registration#

This tutorial is an introduction to using the deepali library for the spatial alignment of two images.

Let’s begin with the common imports used throughout this tutorial.

import os
from typing import Any, Callable, Optional, Tuple, Type, cast

from IPython.utils import io
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
from torch import Tensor, optim
try:
    from deepali.core.environ import cuda_visible_devices
except ImportError:
    if not os.getenv("COLAB_RELEASE_TAG"):
        raise
    !git clone https://github.com/BioMedIA/deepali.git && pip install ./deepali
    from deepali.core.environ import cuda_visible_devices

Choose CUDA device to use if available. By default, this tutorial runs on the CPU.

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Use first device specified in CUDA_VISIBLE_DEVICES if CUDA is available
device = torch.device("cuda:0" if torch.cuda.is_available() and cuda_visible_devices() else "cpu")

Dataset#

The images used in this tutorial are from the public MNIST dataset available through torchvision.

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

with io.capture_output() as captured:  # type: ignore[reportPrivateImportUsage]
    mnist = MNIST(root="data", download=True, transform=ToTensor())

Next we define a utility function named imshow for displaying the example images used in this tutorial.

def imshow(
    image: Tensor,
    label: Optional[str] = None,
    ax: Optional[plt.Axes] = None,
    **kwargs,
) -> None:
    r"""Render image data in last two tensor dimensions using matplotlib.pyplot.imshow().

    Args:
        image: Image tensor of shape ``(..., H, W)``.
        ax: Figure axes to render the image in. If ``None``, a new figure is created.
        label: Image label to display in the axes title.
        kwargs: Keyword arguments to pass on to ``matplotlib.pyplot.imshow()``.
            When ``ax`` is ``None``, can contain ``figsize`` to specify the size of
            the figure created for displaying the image.

    """
    if ax is None:
        figsize = kwargs.pop("figsize", (4, 4))
        _, ax = plt.subplots(figsize=figsize)
    kwargs["cmap"] = kwargs.get("cmap", "gray")
    ax.imshow(image.reshape((-1,) + image.shape[-2:])[0].cpu().numpy(), **kwargs)
    if label:
        ax.set_title(label, fontsize=16, y=1.04)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

We can now display a random selection of the MNIST dataset to select an example pair of images to register.

digits = tuple(range(10))

n_rows = len(digits)
n_cols = 10

rng = torch.Generator().manual_seed(0)
perm = torch.randperm(len(mnist), generator=rng)

samples: dict[int, list[Tensor]] = {digit: [] for digit in digits}
for i in perm:
    image, digit = mnist[i]
    samples[digit].append(image)
    if all(len(samples[digit]) >= n_cols for digit in digits):
        break
samples = {digit: images[:n_cols] for digit, images in samples.items()}

_, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows), tight_layout=True)
for row, digit in zip(axes, digits):
    for ax, image in zip(row, samples[digit]):
        imshow(image, ax=ax)
../_images/0a0d85655b87086a09305170ff216db6bf13aba6ed8b6bf12d555ee16c6a1964.png
target = samples[9][0].float()
source = samples[9][2].float()

_, axes = plt.subplots(1, 2, figsize=(8, 4), tight_layout=True)

imshow(target, "target", ax=axes[0])
imshow(source, "source", ax=axes[1])
../_images/7f9d8ccb77fb40d8c3084077ac6a8fc950c3a470ac8750c418f73397ab35400e.png

Let’s now try and align the source image of our chosen digit with the target image.

First, we define a spatial transformation model which determines the type of spatial transformation that we want to apply. As we have chosen two examples with a different rotation, we select an EulerRotation from the deepali.spatial library. This transformation has a single parameter in 2D, namely the rotation angle. Each spatial transformation is defined with respect to normalized coordinates as used by torch.nn.functional.grid_sample(). Every spatial transform with base type SpatialTransform maps points given with respect to these coordinates to normalized coordinates defined with respect to the same domain, i.e., the domain and codomain of the normalized coordinate map is the same. As these normalized coordinates range are in [-1, 1] for the extent of each image side, the origin of this coordinate system is in the image center. The rotation is thus also with respect to the image center.

For images sampled on a regular grid, we use the Grid class defined by deepali.core to convert between different coordinate systems. An instance of Grid is also used by the spatial transform to define its domain and codomain, respectively. Commonly, this is the sampling grid of the fixed target image, though it need not be. It could also be a subimage region or a reference space if both images were symmetrically mapped to this common space. For medical images, the Grid defines the mapping between image element indices, world coordinates, and the normalized coordinates. The position of the image within the world is defined by the Grid.center() (or Grid.origin()), its orientation by the Grid.direction() cosines matrix, and the extent of each image element in world units by Grid.spacing() (e.g., millimeters).

In case of MNIST, we simply use a world coordinate system which is identical to the image space.

from deepali.core import Grid

grid = Grid(shape=target.shape[1:])

print("size:     ", list(grid.size()))
print("origin:   ", grid.origin().tolist())
print("center:   ", grid.center().tolist())
print("spacing:  ", grid.spacing().tolist())
print("direction:", grid.direction().tolist())
size:      [28, 28]
origin:    [-13.5, -13.5]
center:    [0.0, 0.0]
spacing:   [1.0, 1.0]
direction: [[1.0, 0.0], [0.0, 1.0]]
Images in world space

The objective of image registration is traditionally with respect to their alignment within the world coordinate system. This is especially important when aligning images from different sources or with point clouds (e.g., from depth sensors). In deep learning applications, the alignment is often carried out with respect to the normalized coordinates and the image to world map is ignored. However, when exporting a spatial transform for visualization or to use with classic registration tools such as ANTs, ITK, MIRTK, NiftyReg and others, we need to recover the correct world information and also convert the spatial transform to a world coordinate map. For this purpose, the Grid and more specifically FlowField tensor type of the deepali.data library will come in handy, though we will cover this in another tutorial.


Grid center vs. origin

Notice that grid stores the position of the image center as attribute, from which the image origin, i.e., the world coordinates of the sampling point with all zero indices, is computed using the grid spacing and direction. This simplifies grid resizing operations, including increasing (e.g., Grid.upsample()) and decreasing (e.g., Grid.downsample()) the number of image sampling points along each spatial grid dimension. We will go into more details on the different coordinate systems and relation between them in a separate tutorial when working with medical image volumes.


Grid size vs. shape

The Grid properties are defined with respect to world coordinate axes in the order x, y, etc. The Grid.shape property, however, is in reverse order to match the ordering of the spatial dimensions in the corresponding image data tensor which has shape (C, ..., Y, X), where C is the number of image channels, and ..., Y, X is the size of the image and thus sampling grid along the respective spatial dimension. The Grid.size() and size argument of the Grid init function specify the size of each spatial dimension in the original order, i.e., x, y, etc. The torch.Tensor.size() uses the ordering of Grid.shape. In order to not confuse the different ordering of spatial dimensions, it is adviced to prefer torch.Tensor.shape over torch.Tensor.size(), which is also consistent with the use of numpy.ndarray.shape and Grid.shape.

Gradient descent#

While for this toy example with a single parameter, grid search is a viable option, the number of objective function evalutions grows exponentially with the number of parameters of our spatial transform, and moreover grid search requires a discretization of our parameter space. We can use the gradient descent optimization normally used in PyTorch in fitting a neural network to a training set also to perform a gradient descent to optimize the image alignment.

target_batch = target.unsqueeze(0)  # (N, C, Y, X)
source_batch = source.unsqueeze(0)  # (N, C, Y, X)

rotation = spatial.EulerRotation(grid)
transformer = spatial.ImageTransformer(rotation)
optimizer = optim.Adam(transformer.parameters(), lr=1e-2)

iterations = 100

bar_format = "{l_bar}{bar}{rate_fmt}{postfix}"
for _ in (pbar := tqdm(range(iterations), bar_format=bar_format)):
    warped_batch = transformer(source_batch)
    loss = sim_loss(warped_batch, target_batch)
    angle_deg = rotation.angles().rad2deg().item()
    pbar.set_postfix(dict(loss=loss.item(), angle=angle_deg))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

with torch.inference_mode():
    warped: Tensor = transformer(source_batch)[0]

fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)

imshow(target, "target", ax=axes[0])
imshow(warped, "warped", ax=axes[1])
imshow(source, "source", ax=axes[2])
../_images/88011296b880f8bcb215e4e9af0c2ae45f3e9bf9bdc3d3e21defbb197fb3cb2c.png

Unfortunately, for this given pair of images and the chosen optimizer and hyperparameters such as learning rate and momentum, the registration resulted in a suboptimal solution compared to our previous grid search result. What happened?

Gradient descent is a local optimization technique and sensitive to the initialization. Furthermore, the binary nature of the images with a lot of homogeneous regions (background mostly) and the chosen similarity measure result in an insufficient signal (capture range) from our objective in which way to adjust our rotation at each gradient step in order to minimize the loss and thereby maximize alignment. Increasing and decreasing the angle by a little bit changes our objective value similarly, so it is difficult to judge which is the better direction to follow.

Multi-scale registration#

The way this is overcome in traditional optimization based image registration is via a multi-scale or multi-resolution optimization scheme, respectively. For this, we first downsample the images a number of times, find the optimal rotation between those lower resolution images with a wider spatial extent of each image element, and continue with the next higher resolution with the solution found at the lower resolution. To generate a multi-scale representation of our images, Gaussian blurring is commonly applied. This is also to obey the Nyquist-Shannon sampling theorem when subsampling the image signal. The resulting pyramid of differently sized images is referred to as Gaussian image pyramid. In deep learning based image registration, a multi-scale representation of the input images is usually obtained by the employed neural network.

Sampling grid pyramid

Notice that the Grid.pyramid() function can be used to construct the sampling grids for the different scales of the image pyramid. The corresponding function of the Image tensor type of the deepali.data library generates both, the data tensors of the different pyramid scales and assigns the respective sampling grid with the images at each scale.


Multi-scale vs. multi-resolution

Note that multi-scale optimization in case of a rigid transformation such as a rotation would already by achieved by simply using a multi-scale representation of the images by blurring the images with consecutively larger Gaussian kernels (or consecutively with the same kernel) without also reducing the image size itself. In non-rigid registration, the images downsampled mainly for two reasons: a) to reduce the resolution of the image deformation field, and b) to reduce the computational cost at lower scales. In deepali.spatial the non-rigid transformations such as DisplacementFieldTransform and FreeFormDeformation have a stride parameter which can be used to reduce the size of the non-rigid transformation without changing the image sampling grid. Here, we are using a standard multi-resolution image pyramid with downsampling.


from deepali.data import Image

levels = 3

target_pyramid = Image(target, grid).pyramid(levels)
source_pyramid = Image(source, grid).pyramid(levels)

fig, axes = plt.subplots(2, levels, figsize=(4 * levels, 8), tight_layout=True)

for (level, tgt), src in zip(target_pyramid.items(), source_pyramid.values()):
    imshow(tgt, f"target, level {level}", ax=axes[0, level])
    imshow(src, f"source, level {level}", ax=axes[1, level])
../_images/9647aab184a200b177da16c1c24d725837d0c71924708c03edd98c19a4864331.png

With these image pyramids, we can implement a sequential multi-resolution registration as follows.

ImagePyramid = dict[int, Image]
LossFunction = Callable[[Tensor, Tensor, spatial.SpatialTransform], Tensor | dict[str, Tensor]]
TransformCls = str | Type[spatial.SpatialTransform]
TransformArg = TransformCls | Tuple[TransformCls, dict[str, Any]]
OptimizerCls = str | Type[optim.Optimizer]
OptimizerArg = OptimizerCls | Tuple[OptimizerCls, dict[str, Any]]


def image_pyramid(
    image: Tensor | Image | ImagePyramid,
    levels: int,
    grid: Optional[Grid] = None,
    device: Optional[torch.device] = None,
) -> ImagePyramid:
    r"""Consruct image pyramid from image tensor."""
    if isinstance(image, dict):
        pyramid = {}
        for level, im in image.items():
            if type(level) is not int:
                raise TypeError("Image pyramid key values must be int")
            if level >= levels:
                break
            if type(im) is Tensor:
                im = Image(im, grid)
            if not isinstance(im, Image):
                raise TypeError("Image pyramid key values must be deepali.data.Image or torch.Tensor")
            im = cast(Image, im.float().to(device))
            pyramid[level] = im
        if len(pyramid) < levels:
            raise ValueError(f"Expected image pyramid with {levels} levels, but only got {len(pyramid)} levels")
    else:
        if not isinstance(image, Image):
            image = Image(image, grid)
        image = cast(Image, image.float().to(device))
        pyramid = image.pyramid(levels)
    return pyramid


def init_transform(transform: TransformArg, grid: Grid, device: Optional[torch.device] = None) -> spatial.SpatialTransform:
    r"""Auxiliary functiont to create spatial transform."""
    if isinstance(transform, tuple):
        cls, args = transform
    else:
        cls = transform
        args = {}
    if isinstance(cls, str):
        spatial_transform = spatial.new_spatial_transform(cls, grid, **args)
    else:
        spatial_transform = cls(grid, **args)
    return spatial_transform.to(device).train()


def init_optimizer(optimizer: OptimizerArg, transform: spatial.SpatialTransform) -> optim.Optimizer:
    r"""Auxiliary function to initialize optimizer."""
    if isinstance(optimizer, tuple):
        cls, args = optimizer
    else:
        cls = optimizer
        args = {}
    if isinstance(cls, str):
        cls = getattr(optim, cls)
    if not issubclass(cls, optim.Optimizer):
        raise TypeError("'optimizer' must be a torch.optim.Optimizer")
    return cls(transform.parameters(), **args)


def multi_resolution_registration(
    target: Tensor | Image | ImagePyramid,
    source: Tensor | Image | ImagePyramid,
    loss_fn: LossFunction,
    transform: TransformArg,
    optimizer: OptimizerArg,
    iterations: int | list[int] = 100,
    levels: int = 3,
    device: Optional[str | int | torch.device] = None,
) -> spatial.SpatialTransform:
    r"""Multi-resolution pairwise image registration."""
    if device is None:
        if isinstance(target, dict):
            device = next(iter(target.values())).device
        else:
            device = target.device
    device = torch.device(f"cuda:{device}" if type(device) is int else device)
    target = image_pyramid(target, levels=levels, device=device)
    levels = len(target)
    source = image_pyramid(source, levels=levels, device=device)
    model = init_transform(transform, target[levels - 1].grid(), device=device)
    bar_format = "{l_bar}{bar}{rate_fmt}{postfix}"
    if isinstance(iterations, int):
        iterations = [iterations]
    iterations = list(iterations)
    iterations += [iterations[-1]] * (levels - len(iterations))
    for level, steps in zip(reversed(range(levels)), iterations):
        model.grid_(target[level].grid())
        target_batch = target[level].batch().tensor()
        source_batch = source[level].batch().tensor()
        transformer = spatial.ImageTransformer(model)
        optim = init_optimizer(optimizer, model)
        for _ in (pbar := tqdm(range(steps), bar_format=bar_format)):
            warped_batch: Tensor = transformer(source_batch)
            loss = loss_fn(warped_batch, target_batch, model)
            if isinstance(loss, Tensor):
                loss = dict(loss=loss)
            pbar.set_description(f"Level {level}")
            pbar.set_postfix({k: v.item() for k, v in loss.items()})
            optim.zero_grad()
            loss["loss"].backward()
            optim.step()
    return model.eval()

Let’s see what we now obtain for the rigid registration of our two digits using this registration function:

transform = multi_resolution_registration(
    target=target,
    source=source,
    transform=spatial.EulerRotation,
    optimizer=(optim.Adam, {"lr": 1e-2}),
    loss_fn=lambda a, b, _: sim_loss(a, b),
    device=device,
)
transform = transform.cpu()

with torch.inference_mode():
    transformer = spatial.ImageTransformer(transform)
    warped: Tensor = transformer(source)

fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)

imshow(target, "target", ax=axes[0])
imshow(warped, "warped", ax=axes[1])
imshow(source, "source", ax=axes[2])
../_images/fc695a062361720ecbbdfacc7f957ac709dce0a6305917fea4603fc0416f2a9a.png

Great! The multi-resolution gradient descent resulted in a similar solution as the previous exhaustive grid search.

Non-rigid registration#

Multi-resolution optimization becomes even more important when the spatial transform is non-rigid. Apart from linearly interpolated dense vector fields, a free-form deformation based on a cubic B-spline parameterization is commonly employed in medical image registration. This non-rigid transformation model has the advantage that first and second order derivatives can be computed exactly. Using our previously defined multi_resolution_registration() function, we can alternatively optimize a non-rigid deformation such as a spatial.FreeFormDeformation (FFD).

def loss_fn(
    w_curvature: float = 0,
    w_diffusion: float = 0,
    w_bending: float = 0,
) -> Callable[[Tensor, Tensor, spatial.SpatialTransform], dict[str, Tensor]]:
    r"""Construct loss function for free-form deformation (FFD) based image registration.

    Args:
        w_curvature: Weight of curvature, i.e., sum of unmixed first order derivatives.
            When the spatial transform is parameterized by velocities, the curvature of
            the velocity vector field is computed.
        w_bending: Weight of bending energy, i.e., sum of second order derivatives.

    Returns:
        Loss function which takes as input a registered image pair, and the spatial transform
        used to register the images. The loss function evaluates the alignment of the images
        based on a similarity term and optional regularization terms (transform penalties).

    """

    def loss(
        warped: Tensor,
        target: Tensor,
        transform: spatial.SpatialTransform,
    ) -> dict[str, Tensor]:
        terms: dict[str, Tensor] = {}
        # Similarity term
        sim = sim_loss(warped, target)
        terms["sim"] = sim
        loss = sim
        # Regularization terms
        # v_or_u: dense velocity or displacement vector field, respectively.
        v_or_u = getattr(transform, "v", getattr(transform, "u", None))
        assert v_or_u is not None
        if w_curvature > 0:
            curvature = L.curvature_loss(v_or_u)
            loss = curvature.mul(w_curvature).add(loss)
            terms["curv"] = curvature
        if w_diffusion > 0:
            diffusion = L.diffusion_loss(v_or_u)
            loss = diffusion.mul(w_diffusion).add(loss)
            terms["diff"] = diffusion
        if w_bending > 0:
            if isinstance(transform, spatial.BSplineTransform):
                params = transform.params
                assert isinstance(params, Tensor)
                bending = L.bspline_bending_loss(params)
            else:
                bending = L.bending_loss(v_or_u)
            loss = bending.mul(w_bending).add(loss)
            terms["be"] = bending
        return {"loss": loss, **terms}

    return loss
transform = multi_resolution_registration(
    target=target_pyramid,
    source=source_pyramid,
    transform=("FFD", {"stride": 2}),
    optimizer=("Adam", {"lr": 1e-2}),
    loss_fn=loss_fn(w_bending=1e-5),
    device=device,
)
transform = transform.cpu()

with torch.inference_mode():
    transformer = spatial.ImageTransformer(transform)
    warped: Tensor = transformer(source)

fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)

imshow(target, "target", ax=axes[0])
imshow(warped, "warped", ax=axes[1])
imshow(source, "source", ax=axes[2])
../_images/54cd545c6af35a00305fc2b4e4f1d47b58fe9e7f934528cde5773ba8f2400f88.png

In addition to the choice of parameterization of the spatial transform (i.e., vector field given by a cubic B-spline with control points at every second image grid point), we added a loss term (bspline_bending_loss()) which penalizes bending of the spline. By setting the weight of this term w_bending to higher values, we can make the spatial transform stiffer and allow for less severe deformation. In order to assess the quality of the computed deformation, we can visualize the deformation of the image grid. Because of the small size of the MNIST images, we generate a higher resolution grid image for this visualization, which we then deform using an ImageTransformer with an input source and output target grid matching this higher resolution.

from deepali.core import functional as U


grid_highres = grid.resize(512)
grid_image = U.grid_image(grid_highres, num=1, stride=8, inverted=True)
grid_transformer = spatial.ImageTransformer(transform, grid_highres, padding="zeros")

with torch.inference_mode():
    warped_grid: Tensor = grid_transformer(grid_image)

fig, axes = plt.subplots(1, 2, figsize=(8, 4), tight_layout=True)

imshow(grid_image, "source grid", ax=axes[0])
imshow(warped_grid, "warped grid", ax=axes[1])
../_images/8a657cfc3fa0e4bdf36870b1fe6528b7d9ce0ae9d1c7be0d0936901c7a31e0b9.png

Although the deformed source image matches well our target image, we realize that our deformation is not well behaved. This can be remedied by increasing the weight of the bending loss (w_bending).

transform = multi_resolution_registration(
    target=target_pyramid,
    source=source_pyramid,
    transform=(spatial.FreeFormDeformation, {"stride": 2}),
    optimizer=(optim.Adam, {"lr": 1e-2}),
    loss_fn=loss_fn(w_bending=1e-3),
    device=device,
)
transform = transform.cpu()

with torch.inference_mode():
    image_transformer = spatial.ImageTransformer(transform)
    grid_transformer = spatial.ImageTransformer(transform, grid_highres, padding="zeros")
    warped_grid: Tensor = grid_transformer(grid_image)
    warped: Tensor = image_transformer(source)

fig, axes = plt.subplots(1, 4, figsize=(16, 4), tight_layout=True)

imshow(target, "target", ax=axes[0])
imshow(warped, "warped", ax=axes[1])
imshow(source, "source", ax=axes[2])
imshow(warped_grid, "deformation", ax=axes[3])
../_images/6e7a6410c3833a02a80800a8ea744ce3ac2af713b3b413c9db4659aa6236ddaf.png

Diffeomorphic registration#

In the previous non-rigid registration example, we used a free form deformation model as our spatial transformation. Though with a good choice of the control point spacing (stride) and bending regularization weight (w_bending) we ended up with a smooth and invertible spatial transform, sometimes it is desireable to have the property of being able to invert a spatial transform with both the forward and backward map being differentiable, baked into the chosen transformation model. Typically, a diffeomorphic coordinate map is parameterized either by a time-varying or stationary velocity vector field (SVF). The latter is more commonly found in the recent registration literature because of its computational efficiency. We can use the scaling and squaring algorithm to compute the exponential map of this vector field to compute the dense displacement field generated by this velocity field.

The spatial.StationaryVelocityFieldTransform (SVF) and spatial.StationaryVelocityFreeFormDeformation (SVFFD) are diffeomorphic transformation models based on such stationary velocity field. We can use the SpatialTransform.inverse() funtion to obtain the inverse deformation. With this, we can define the following auxiliary function to plot the result of a diffeomorphic registration, where the source image is deformed by the computed forward transform and the target image is deformed with its inverse, respectively.

def invertible_registration_figure(
    target: Tensor,
    source: Tensor,
    transform: spatial.SpatialTransform,
) -> Figure:
    r"""Create figure visualizing result of diffeomorphic registration.

    Args:
        target: Fixed target image.
        source: Moving source image.
        transform: Invertible spatial transform, i.e., must implement ``SpatialTransform.inverse()``.

    Returns:
        Instance of ``matplotlib.pyplot.Figure``.

    """
    device = transform.device

    highres_grid = transform.grid().resize(512)
    grid_image = U.grid_image(highres_grid, num=1, stride=8, inverted=True, device=device)

    with torch.inference_mode():
        inverse = transform.inverse()

        source_transformer = spatial.ImageTransformer(transform)
        target_transformer = spatial.ImageTransformer(inverse)

        source_grid_transformer = spatial.ImageTransformer(transform, highres_grid, padding="zeros")
        target_grid_transformer = spatial.ImageTransformer(inverse, highres_grid, padding="zeros")

        warped_source: Tensor = source_transformer(source.to(device))
        warped_target: Tensor = target_transformer(target.to(device))

        warped_source_grid: Tensor = source_grid_transformer(grid_image)
        warped_target_grid: Tensor = target_grid_transformer(grid_image)

    fig, axes = plt.subplots(2, 3, figsize=(12, 8), tight_layout=True)

    imshow(target, "target", ax=axes[0, 0])
    imshow(warped_source, "warped source", ax=axes[0, 1])
    imshow(warped_source_grid, "forward deformation", ax=axes[0, 2])

    imshow(source, "source", ax=axes[1, 0])
    imshow(warped_target, "warped target", ax=axes[1, 1])
    imshow(warped_target_grid, "inverse deformation", ax=axes[1, 2])

    return fig

We use the previously implemented multi_resolution_registration helper to optimize a diffeomorphic transformation. Here, we selected the SVFFD with a control point at every other image grid point. Note that the stride is in number of grid points. This way, the resolution of the control point grid is defined by the resolution of the target image at the current level of the mult-resolution optimization.

transform = multi_resolution_registration(
    target=target_pyramid,
    source=source_pyramid,
    transform=(spatial.StationaryVelocityFreeFormDeformation, {"stride": 2}),
    optimizer=(optim.Adam, {"lr": 1e-2}),
    loss_fn=loss_fn(w_bending=1e-3),
    device=device,
)

The visualization shows both good alignment between the deformed source image and target, as well as good alignment between the target deformed by the inverse and the input source image.

_ = invertible_registration_figure(target, source, transform)
../_images/81728af18a74d5bec8a8d1576cbcbb4a9e52c8bf483b7115da1caa8cc955cace.png

Circle to C deformation#

A classic example for demonstrating a diffeomorphic registration is the “circle to C” problem, where an image of a circle is deformed to align it with the image of the letter “C”. This example has been used in the DARTEL paper by John Ashburner, one of the seminal works which first introduced this parameterization for use in medical image registration. We can generate a similar example, though at present with a more open and thus more challenging target “C” shape, using functions cshape_image() and circle_image() of the deepali.core library. Gaussian blurring is applied to smooth the edges of these synthetic binary images.

fig, axes = plt.subplots(1, 2, figsize=(8, 4))

circle_c_grid = Grid((128, 128))
target_cshape = U.cshape_image(circle_c_grid, radius=50, sigma=2, dtype=torch.float)
source_circle = U.circle_image(circle_c_grid, radius=50, sigma=2, dtype=torch.float)

imshow(target_cshape, "c-shape", ax=axes[0])
imshow(source_circle, "circle", ax=axes[1])
../_images/c1cf5001fe9fbeaca985ff7f38543a96607002382ea7c6c5c0f043fbd0c5a998.png
circle_to_c_transform = multi_resolution_registration(
    target=target_cshape[0],
    source=source_circle[0],
    loss_fn=loss_fn(w_diffusion=1e-3),
    transform=spatial.StationaryVelocityFreeFormDeformation,
    optimizer=(optim.Adam, {"lr": 1e-2}),
    iterations=[200, 200, 300, 100],
    levels=4,
    device=device,
)
_ = invertible_registration_figure(target_cshape, source_circle, circle_to_c_transform)
../_images/ed754d85e0633f6e68ca2e9f373e19d77efdda005c34dedd05d07ae71fafe8d0.png