Augmentation

This example shows how to apply data augmentation in conjunction with the pymia.data package. Besides transformations from the pymia.data.augmentation module, transformations from the Python packages batchgenerators and TorchIO are integrated.

Tip

This example is available as Jupyter notebook at ./examples/augmentation/basic.ipynb and Python script at ./examples/augmentation/basic.py.

Note

To be able to run this example:

Import the required modules.

[1]:
import batchgenerators.transforms as bg_tfm
import matplotlib.pyplot as plt
import numpy as np
import torchio as tio

import pymia.data.transformation as tfm
import pymia.data.augmentation as augm
import pymia.data.definition as defs
import pymia.data.extraction as extr
If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits

We create the the access to the .h5 dataset by defining: (i) the indexing strategy (indexing_strategy) that defines the chunks of data to be retrieved, and (ii) the information to be extracted (extractor).

[2]:
hdf_file = '../example-data/example-dataset.h5'
indexing_strategy = extr.SliceIndexing()
extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS))
dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor)

For reproducibility, set the seed and define a sample index for plotting.

[3]:
seed = 1
np.random.seed(seed)
sample_idx = 55

We can now define the transformations to apply. For reference, we first do not apply any data augmentation.

[4]:
transforms_augmentation = []
transforms_before_augmentation = [tfm.Permute(permutation=(2, 0, 1)), ]  # to have the channel-dimension first
transforms_after_augmentation = [tfm.Squeeze(entries=(defs.KEY_LABELS,)), ]  # get rid of the channel-dimension for the labels
train_transforms = tfm.ComposeTransform(transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
dataset.set_transform(train_transforms)
sample = dataset[sample_idx]

pymia augmentation

Let’s us now use pymia to apply a random 90° rotation and a random mirroring.

[5]:
transforms_augmentation = [augm.RandomRotation90(axes=(-2, -1)), augm.RandomMirror()]
train_transforms = tfm.ComposeTransform(
    transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
dataset.set_transform(train_transforms)
sample_pymia = dataset[sample_idx]
/home/fbalsiger/PycharmProjects/pymia/pymia/data/augmentation.py:231: RuntimeWarning: entry "images" has unequal in-plane dimensions (217, 181). Random 90 degree rotation might produce undesired results. Verify the output!
  warnings.warn(f'entry "{entry}" has unequal in-plane dimensions ({sample[entry].shape[self.axes[0]]}, '
/home/fbalsiger/PycharmProjects/pymia/pymia/data/augmentation.py:231: RuntimeWarning: entry "labels" has unequal in-plane dimensions (217, 181). Random 90 degree rotation might produce undesired results. Verify the output!
  warnings.warn(f'entry "{entry}" has unequal in-plane dimensions ({sample[entry].shape[self.axes[0]]}, '

batchgenerators augmentation

Let’s us now use batchgenerators to apply a random 90° rotation and a random mirroring. To use batchgenerators, we create wrapper classes for simple integration into pymia.

[6]:
class BatchgeneratorsTransform(tfm.Transform):
    """Example wrapper for `batchgenerators <https://github.com/MIC-DKFZ/batchgenerators>`_ transformations."""

    def __init__(self, transforms, entries=(defs.KEY_IMAGES, defs.KEY_LABELS)) -> None:
        super().__init__()
        self.transforms = transforms
        self.entries = entries

    def __call__(self, sample: dict) -> dict:
        # unsqueeze samples to add a batch dimensions, as required by batchgenerators
        for entry in self.entries:
            if entry not in sample:
                if tfm.raise_error_if_entry_not_extracted:
                    raise ValueError(tfm.ENTRY_NOT_EXTRACTED_ERR_MSG.format(entry))
                continue

            np_entry = tfm.check_and_return(sample[entry], np.ndarray)
            sample[entry] = np.expand_dims(np_entry, 0)

        # apply batchgenerators transforms
        for t in self.transforms:
            sample = t(**sample)

        # squeeze samples back to original format
        for entry in self.entries:
            np_entry = tfm.check_and_return(sample[entry], np.ndarray)
            sample[entry] = np_entry.squeeze(0)

        return sample

transforms_augmentation = [BatchgeneratorsTransform([
    bg_tfm.spatial_transforms.MirrorTransform(axes=(0, 1), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS),
    bg_tfm.noise_transforms.GaussianBlurTransform(blur_sigma=(0.2, 1.0), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS),
])]
train_transforms = tfm.ComposeTransform(
    transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
dataset.set_transform(train_transforms)
sample_batchgenerators = dataset[sample_idx]

TorchIO augmentation

Let’s us now use TorchIO to apply a random flip and a random affine transformation. To use TorchIO, we create wrapper classes for simple integration into pymia.

[7]:
class TorchIOTransform(tfm.Transform):
    """Example wrapper for `TorchIO <https://github.com/fepegar/torchio>`_ transformations."""

    def __init__(self, transforms: list, entries=(defs.KEY_IMAGES, defs.KEY_LABELS)) -> None:
        super().__init__()
        self.transforms = transforms
        self.entries = entries

    def __call__(self, sample: dict) -> dict:
        # unsqueeze samples to be 4-D tensors, as required by TorchIO
        for entry in self.entries:
            if entry not in sample:
                if tfm.raise_error_if_entry_not_extracted:
                    raise ValueError(tfm.ENTRY_NOT_EXTRACTED_ERR_MSG.format(entry))
                continue

            np_entry = tfm.check_and_return(sample[entry], np.ndarray)
            sample[entry] = np.expand_dims(np_entry, -1)

        # apply TorchIO transforms
        for t in self.transforms:
            sample = t(sample)

        # squeeze samples back to original format
        for entry in self.entries:
            np_entry = tfm.check_and_return(sample[entry].numpy(), np.ndarray)
            sample[entry] = np_entry.squeeze(-1)

        return sample

transforms_augmentation = [TorchIOTransform(
    [tio.RandomFlip(axes=('LR'), flip_probability=1.0, keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed),
     tio.RandomAffine(scales=(0.9, 1.2), degrees=(10), isotropic=False, default_pad_value='otsu',
                      image_interpolation='NEAREST', keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed),
     ])]
train_transforms = tfm.ComposeTransform(
    transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
dataset.set_transform(train_transforms)
sample_torchio = dataset[sample_idx]
[8]:
# prepare and format the plot
fig, axs = plt.subplots(4, 3, figsize=(9, 12))
axs[0, 0].set_title('T1-weighted')
axs[0, 1].set_title('T2-weighted')
axs[0, 2].set_title('Label')
axs[0, 0].set_ylabel('None')
axs[1, 0].set_ylabel('pymia')
axs[2, 0].set_ylabel('batchgenerators')
axs[3, 0].set_ylabel('TorchIO')
plt.setp(axs, xticks=[], yticks=[])

axs[0, 0].imshow(sample[defs.KEY_IMAGES][0], cmap='gray')
axs[0, 1].imshow(sample[defs.KEY_IMAGES][1], cmap='gray')
axs[0, 2].imshow(sample[defs.KEY_LABELS], cmap='viridis')
axs[1, 0].imshow(sample_pymia[defs.KEY_IMAGES][0], cmap='gray')
axs[1, 1].imshow(sample_pymia[defs.KEY_IMAGES][1], cmap='gray')
axs[1, 2].imshow(sample_pymia[defs.KEY_LABELS], cmap='viridis')
axs[2, 0].imshow(sample_batchgenerators[defs.KEY_IMAGES][0], cmap='gray')
axs[2, 1].imshow(sample_batchgenerators[defs.KEY_IMAGES][1], cmap='gray')
axs[2, 2].imshow(sample_batchgenerators[defs.KEY_LABELS], cmap='viridis')
axs[3, 0].imshow(sample_torchio[defs.KEY_IMAGES][0], cmap='gray')
axs[3, 1].imshow(sample_torchio[defs.KEY_IMAGES][1], cmap='gray')
axs[3, 2].imshow(sample_torchio[defs.KEY_LABELS], cmap='viridis')
[8]:
<matplotlib.image.AxesImage at 0x7f7d44430f40>
_images/examples.augmentation.basic_16_1.png

Visually, we can clearly see the difference between the non-transformed and transformed images using different transformations and Python packages.