Source code for pymia.data.extraction.indexing

import abc
import typing

import numpy as np

import pymia.data.indexexpression as expr


[docs]class IndexingStrategy(abc.ABC): """Interface for indexing strategies that can be applied to images. .. automethod:: __call__ .. automethod:: __repr__ """
[docs] @abc.abstractmethod def __call__(self, shape: tuple) -> typing.List[expr.IndexExpression]: """Calculate the indexes for a given shape Args: shape (tuple): The shape to determine the indexes for. Returns: list: The list of :class:`.IndexExpression` instances defining the indexes for an image shape. """ pass
[docs] def __repr__(self) -> str: """ Returns: str: Representation of the strategy. Should include attributes such that it uniquely defines the strategy. """ return self.__class__.__name__
[docs]class EmptyIndexing(IndexingStrategy): """An empty indexing strategy. This is useful when a strategy is required but entire images should be extracted.""" def __call__(self, shape) -> typing.List[expr.IndexExpression]: return [expr.IndexExpression()]
[docs]class SliceIndexing(IndexingStrategy): def __init__(self, slice_axis: typing.Union[int, tuple] = 0) -> None: """Strategy to generate a slice-wise indexing. Args: slice_axis (int, tuple): The axis to be sliced. Multi-axis slicing can be achieved by providing a tuple of axes. """ if isinstance(slice_axis, int): slice_axis = (slice_axis, ) self.slice_axis = slice_axis def __call__(self, shape) -> typing.List[expr.IndexExpression]: indexing = [] for axis in self.slice_axis: indexing.extend(expr.IndexExpression(i, axis) for i in range(shape[axis])) return indexing def __repr__(self) -> str: return '{} ({})'.format(self.__class__.__name__, self.slice_axis)
[docs]class VoxelWiseIndexing(IndexingStrategy): def __init__(self, image_dimension: int = 3): """Strategy to generate indices for every voxel of an image. Args: image_dimension (int): The image dimension without the dimension of the voxels itself. """ self.shape = None self.indexing = None self.image_dimension = image_dimension def __call__(self, shape) -> typing.List[expr.IndexExpression]: if self.shape == shape: return self.indexing self.shape = shape # save for later comparison to avoid calculating indices if the shape is equal shape_without_voxel = shape[0:self.image_dimension] indices = np.indices(shape_without_voxel) indices = indices.reshape((indices.shape[0], np.prod(indices.shape[1:]))) indices = indices.transpose() self.indexing = [expr.IndexExpression(idx.tolist()) for idx in indices] return self.indexing
[docs]class PatchWiseIndexing(IndexingStrategy): def __init__(self, patch_shape: tuple, ignore_incomplete=True) -> None: """Strategy to generate indices for patches (sub-volumes) of an image. Args: patch_shape (tuple): The patch shape. ignore_incomplete (bool): If even division of image by patch shape ignore incomplete patch on True. Boundary condition. """ super().__init__() self.patch_shape = patch_shape self.image_dimension = len(patch_shape) self.ignore_incomplete = ignore_incomplete self.prev_shape = None self.prev_indexing = None def __call__(self, shape) -> typing.List[expr.IndexExpression]: if shape == self.prev_shape: return self.prev_indexing shape_without_voxel = shape[:self.image_dimension] index_count = np.divide(shape_without_voxel, self.patch_shape) index_count = np.floor(index_count) if self.ignore_incomplete else np.ceil(index_count) index_count = index_count.astype('int') indices = np.indices(index_count).reshape(index_count.size, -1).T index_ranges = np.stack([indices, indices + 1], axis=-1) index_ranges *= np.asarray(self.patch_shape)[np.newaxis, :, np.newaxis] indexing = [expr.IndexExpression(idx.tolist()) for idx in index_ranges] self.prev_indexing = indexing self.prev_shape = shape return indexing def __repr__(self) -> str: return '{} (patch shape={}, ignore incomplete={})'.format(self.__class__.__name__, self.patch_shape, self.ignore_incomplete)