"""Tools to get a skin surface from an MRI.
Example workflow, starting from
- an MRI volume `vol_array`, as a numpy array
- an associated 4x4 `affine` transform, as a numpy array
- a desired spherical coordinate system `origin` location inside the volume (e.g. the sonication target)
foreground_mask_array = compute_foreground_mask(vol_array)
foreground_mask_vtk_image = vtk_img_from_array_and_affine(foreground_mask_array, affine)
skin_mesh = create_closed_surface_from_labelmap(foreground_mask_vtk_image)
skin_interpolator = spherical_interpolator_from_mesh(skin_mesh, origin)
"""
from __future__ import annotations
from typing import Callable, List, Tuple
import numpy as np
import skimage.filters
import skimage.measure
import vtk
from packaging.version import parse
from scipy.interpolate import LinearNDInterpolator
from scipy.ndimage import distance_transform_edt
from vtk.util.numpy_support import numpy_to_vtk
from openlifu.geo import cartesian_to_spherical
[docs]
def apply_affine_to_polydata(affine:np.ndarray, polydata:vtk.vtkPolyData) -> vtk.vtkPolyData:
"""Apply an affine transform to a vtkPolyData."""
affine_vtkmat = vtk.vtkMatrix4x4()
for i in range(4):
for j in range(4):
affine_vtkmat.SetElement(i, j, affine[i, j])
affine_vtktransform = vtk.vtkTransform()
affine_vtktransform.SetMatrix(affine_vtkmat)
transform_filter = vtk.vtkTransformPolyDataFilter()
transform_filter.SetTransform(affine_vtktransform)
transform_filter.SetInputData(polydata)
transform_filter.Update()
return transform_filter.GetOutput()
[docs]
def take_largest_connected_component(mask: np.ndarray) -> np.ndarray:
"""Given a boolean image array (or any integer numpy array), return a mask of the largest connected component."""
mask_labeled = skimage.measure.label(mask)
connected_component_info = skimage.measure.regionprops(mask_labeled)
largest_connected_componet_label = connected_component_info[np.argmax([rp.area for rp in connected_component_info])].label
return (mask_labeled == largest_connected_componet_label)
[docs]
def compute_foreground_mask(
vol_array : np.ndarray,
closing_radius : float = 9.,
lower_quantile_for_otsu_threshold : float = 0.02,
upper_quantile_for_otsu_threshold : float = 0.99,
) -> np.ndarray:
"""Given a 3D image array, return a boolean mask representing the "foreground."
Args:
vol_array: a 3D image array of shape (H,W,D)
closing_readius: the radius of the ball used in the morphological closing operation
lower_quantile_for_otsu_threshold: a number from 0 to 1. Before otsu thresholding,
values below this quantile are omitted from the histogram as outliers.
upper_quantile_for_otsu_threshold: a number from 0 to 1. Before otsu thresholding,
values above this quantile are omitted from the histogram as outliers.
Returns: a boolean array of shape (H,W,D) representing a foreground mask
This is essentially a port of the BRAINSTools automated foreground masking algorithm.
- Original algorithm documentation: https://slicer.readthedocs.io/en/latest/user_guide/modules/brainsroiauto.html
- Original algorithm code: https://github.com/BRAINSia/BRAINSTools/tree/7c37d9e8c238f66f8a83f997d9c9bb659c494c90/BRAINSROIAuto
The algorithm roughly works as follows:
- step 1: otsu thresholding
- step 2: keep only the largest connected component
- step 3: morphological closing
- step 4: hole filling
The default values of the parameters have been observed to work well for mm-spaced brain MRIs.
"""
# step 1: otsu-threshold the image to create an initial foreground mask.
threshold_lower, threshold_upper = np.quantile(
vol_array,
[lower_quantile_for_otsu_threshold,upper_quantile_for_otsu_threshold]
)
threshold_foreground = skimage.filters.threshold_otsu(
vol_array[(vol_array >= threshold_lower) & (vol_array <= threshold_upper)]
)
foreground_mask = vol_array >= threshold_foreground
# step 2: keep only the largest connected component to throw out spurious bits.
foreground_mask = take_largest_connected_component(foreground_mask)
# step 3: do a morphological closing.
# while this does fill some holes, that's not the main point since step 4 already fills holes.
# the point of this step is rather to clean up and smooth out the skin surface of small cavities.
pad_width = int(closing_radius+2) # pad to avoid the situation where dilation hits the boundary
foreground_mask_padded = np.pad(foreground_mask, pad_width, mode='constant')
background_edt = distance_transform_edt(~foreground_mask_padded)
foreground_dilated = background_edt <= closing_radius
foreground_dilated_edt = distance_transform_edt(foreground_dilated)
foreground_closed = foreground_dilated_edt >= closing_radius
# crop to undo the padding above
h,w,d = foreground_mask.shape
p = pad_width
foreground_mask = foreground_closed[p:p+h,p:p+w,p:p+d]
# step 4: take the complement of the largest connected component of the current background.
# the background mask at this point contains the "actual background" and possibly also some
# holes that are inside the foreground region. the largest connected component of this background
# mask is considered to be the "actual background."" this step therefore serves to fill
# any remaining holes in the foreground mask.
# this step is analogous to the seeded flood fill in the original algorithm:
# https://github.com/BRAINSia/BRAINSTools/blob/7c37d9e8c238f66f8a83f997d9c9bb659c494c90/BRAINSCommonLib/itkLargestForegroundFilledMaskImageFilter.hxx#L255-L302
foreground_mask = ~take_largest_connected_component(~foreground_mask)
return foreground_mask
[docs]
def vtk_img_from_array_and_affine(vol_array:np.ndarray, affine:np.ndarray) -> vtk.vtkImageData:
""" Convert a numpy (array, affine) pair into a vtkImageData.
Args:
vol_array: a 3D image numpy array with float type data
affine: a numpy array of shape (4,4) representing the affine matrix of the 3D image.
Returns: vtkImageData with a copy of vol_array as the underlying image data,
and with origin, spacing, and direction matrix set according to the affine matrix.
Since a vtkImageData is intended to represent image data on a structured grid with *orthogonal* axes,
the upper-left 3x3 submatrix of the affine matrix should be an orthogonal matrix. There will be no error
if it isn't, since the "direction matrix" of a vtkImageData can be set to be non-orthogonal -- it just isn't
the intended usage of vtkImageData and could be misinterpreted by downstream vtk filters.
Maintain a reference to vol_array, so that it is not garbage collected
(which could leave the vtkImageData pointing to invalid memory -- see vtk.util.numpy_support.numpy_to_vtk documentation).
"""
matrix_3x3 = affine[:3, :3]
origin = affine[:3, 3]
spacing = np.linalg.norm(matrix_3x3, axis=0)
direction_matrix = matrix_3x3 / spacing[np.newaxis,:]
vtk_img = vtk.vtkImageData()
vtk_img.SetDimensions(vol_array.shape)
vtk_img.SetOrigin(origin.tolist())
vtk_img.SetSpacing(spacing.tolist())
direction_matrix_vtk = vtk.vtkMatrix3x3()
for i in range(3):
for j in range(3):
direction_matrix_vtk.SetElement(i, j, direction_matrix[i, j])
vtk_img.SetDirectionMatrix(direction_matrix_vtk)
vol_array_flat = vol_array.transpose((2,1,0)).ravel(order='C')
vol_array_vtk = numpy_to_vtk(num_array=vol_array_flat, deep=True, array_type=vtk.VTK_FLOAT)
vtk_img.GetPointData().SetScalars(vol_array_vtk)
return vtk_img
[docs]
def affine_from_vtk_image_data(vtk_img:vtk.vtkImageData) -> np.ndarray:
"""Get a 4x4 affine matrix out of a vtkImageData, a partial reverse to `vtk_img_from_array_and_affine`"""
origin = np.array(vtk_img.GetOrigin())
spacing = np.array(vtk_img.GetSpacing())
direction_vtk = vtk_img.GetDirectionMatrix()
direction = np.eye(3)
for i in range(3):
for j in range(3):
direction[i, j] = direction_vtk.GetElement(i, j)
affine = np.eye(4, dtype=float)
affine[:3, :3] = direction @ np.diag(spacing)
affine[:3, 3] = origin
return affine
[docs]
def create_closed_surface_from_labelmap(
binary_labelmap:vtk.vtkImageData,
decimation_factor:float=0.,
smoothing_factor:float=0.5
) -> vtk.vtkPolyData:
""" Create a surface mesh vtkPolyData from a binary labelmap vtkImageData.
Args:
binary_labelmap: input vtkImageData binary labelmap
decimation_factor: 0.0 for no decimation, 1.0 for maximum reduction.
smoothing_factor: 0.0 for no smoothing, 1.0 for maximum smoothing.
Returns:
vtkPolyData: the resulting surface mesh
The algorithm here is based on the labelmap-to-closed-surface algorithm in 3D Slicer:
https://github.com/Slicer/Slicer/blob/677932127c73a6c78654d4afd9458a655a4eef63/Libs/vtkSegmentationCore/vtkBinaryLabelmapToClosedSurfaceConversionRule.cxx#L246-L476
"""
affine = None # Only needed if vtk version is less than 9.3.0
if parse(vtk.__version__) < parse("9.3.0"):
# In these older versions of vtk, the labelmap would not work.
affine = affine_from_vtk_image_data(binary_labelmap)
binary_labelmap.SetOrigin([0,0,0])
binary_labelmap.SetSpacing([1,1,1])
direction_matrix_vtk = vtk.vtkMatrix3x3()
direction_matrix_vtk.Identity()
binary_labelmap.SetDirectionMatrix(direction_matrix_vtk)
# step 1: pad by 1 pixel all around with 0s, to ensure that the surface is still closed
# even if the labelmap runs up against the image boundary.
padder = vtk.vtkImageConstantPad()
padder.SetInputData(binary_labelmap)
extent = binary_labelmap.GetExtent()
padder.SetOutputWholeExtent(
extent[0] - 1, extent[1] + 1,
extent[2] - 1, extent[3] + 1,
extent[4] - 1, extent[5] + 1,
)
padder.Update()
padded_labelmap = padder.GetOutput()
# step 1: extract surface
flying_edges = vtk.vtkDiscreteFlyingEdges3D()
flying_edges.SetInputData(padded_labelmap)
flying_edges.ComputeGradientsOff()
flying_edges.ComputeNormalsOff()
flying_edges.Update()
surface_mesh = flying_edges.GetOutput()
# step 2: decimation
if decimation_factor > 0.0:
decimator = vtk.vtkDecimatePro()
decimator.SetInputData(surface_mesh)
decimator.SetFeatureAngle(60)
decimator.SplittingOff()
decimator.PreserveTopologyOn()
decimator.SetMaximumError(1)
decimator.SetTargetReduction(decimation_factor)
decimator.Update()
surface_mesh = decimator.GetOutput()
# step 3: smoothing
if smoothing_factor > 0.0:
smoother = vtk.vtkWindowedSincPolyDataFilter()
smoother.SetInputData(surface_mesh)
# map smoothing factor to passband and iterations, copying the approach taken by Slicer
passband = pow(10.0, -4.0 * smoothing_factor)
num_iterations = 20 + int(smoothing_factor * 40)
smoother.SetNumberOfIterations(num_iterations)
smoother.SetPassBand(passband)
smoother.BoundarySmoothingOff()
smoother.FeatureEdgeSmoothingOff()
smoother.NonManifoldSmoothingOn()
smoother.NormalizeCoordinatesOn()
smoother.Update()
surface_mesh = smoother.GetOutput()
# step 4: compute normals
normals = vtk.vtkPolyDataNormals()
normals.SetInputData(surface_mesh)
normals.ConsistencyOn()
normals.SplittingOff()
normals.Update()
surface_mesh = normals.GetOutput()
if parse(vtk.__version__) < parse("9.3.0"):
# In these older versions of vtk, the labelmap internal affine transform is not used correctly,
# so we manually apply the transform after the fact
surface_mesh = apply_affine_to_polydata(
affine,
surface_mesh,
)
# Some scalars can get tacked on by the above processing for some reason, so remove those in case they are present
surface_mesh.GetPointData().SetScalars(None)
return surface_mesh
[docs]
def spherical_interpolator_from_mesh(
surface_mesh: vtk.vtkPolyData,
origin: Tuple[float, float, float] = (0.,0.,0.),
xyz_direction_columns: np.ndarray | None = None,
dist_tolerance: float = 0.0001
) -> Callable[[float, float], float]:
"""Create a spherical interpolator from a vtkPolyData.
Here a "spherical interpolator" is a function that maps angles from a spherical coordinate system
to r values (radial spherical coordinate values) by interpolating over a set of known values.
It's essentially a "spherical plotter."
Args:
surface_mesh: The mesh containing the points to be interpolated over
origin: The origin of the spherical coordinate system
xyz_direction_columns: A matrix of shape (3,3) the columns of which are unit vectors that describe
the cartesian x,y,z axis directions on which to base the spherical coordinate system. For example
the spherical azimuthal angle is the polar angle of the projection of the point into the x-y-plane, etc.
See the documentation on `spherical_to_cartesian` and `cartesian_to_spherical` for a complete description
of how the spherical angles relate to the x, y, and z axes. If not provided, the xyz_direction_columns will
be an identity matrix, which means that the coordinates in which surface_mesh is given will directly be
interpreted as the x,y,z upon which a spherical coordinate system will be based.
dist_tolerance: A vertex of the surface_mesh will only be included if it is the furthest point from the origin
that is on the mesh along the ray emanating from the origin and passing through the vertex. The
dist_tolerance is the threshold for determining whether an intersection of the ray with the mesh
counts as being a distinct further out point from the vertex.
Returns:
A spherical interpolator, which is a callable that maps (theta,phi) pairs of spherical coordinates (phi being azimuthal)
to r values (radial spherical coordinate values). The angles are in radians.
Summary of the algorithm:
- Transform the input mesh based on the desired origin and orientation of the spherical coordinate system.
- We will gather some points into a set $S$. For each point $P$ on the mesh consider the ray $\\vec{OP}$
from the origin through $P$ and look at all the intersections of this ray $\\vec{OP}$ with the mesh.
If none of those intersections are further out from the origin than $P$ is, then we put $P$ into our set $S$.
- Using the spherical coordinates of the points in $S$, build a `scipy.interpolate.LinearNDInterpolator`
that interpolates spherical $r$ values from the spherical $(\\theta,\\phi)$ values.
- Problem: All the gathered $(\\theta,\\phi)$ values are likely strictly inside the square $[0,\\pi]\\times[-\\pi,\\pi]$,
and `LinearNDInterpolator` does not _extrapolate_, and so angles close to the "seams" of the spherical coordinate
system (the boundaries of that square) generate NaNs through the interpolator. The solution used here is to first
clone the gathered points with appropriate angular shifts so as to cover those seams, and then give that larger set
of points to the interpolator.
- Return the interpolator.
"""
if xyz_direction_columns is None:
xyz_direction_columns = np.eye(3, dtype=float)
xyz_affine = np.eye(4)
xyz_affine[:3,:3] = xyz_direction_columns
xyz_affine[:3,3] = origin
# Now xyz_affine is a coordinate transformation matrix that transforms from the xyz system to the coord system of the vtkPolyData
# We want to apply the inverse to the vtkPolyData
xyz_affine_inverse = np.linalg.inv(xyz_affine)
xyz_affine_inverse_vtkmat = vtk.vtkMatrix4x4()
xyz_affine_inverse_vtkmat.DeepCopy(xyz_affine_inverse.ravel())
xyz_inverse_transform = vtk.vtkTransform()
xyz_inverse_transform.SetMatrix(xyz_affine_inverse_vtkmat)
transform_filter = vtk.vtkTransformPolyDataFilter()
transform_filter.SetTransform(xyz_inverse_transform)
transform_filter.SetInputData(surface_mesh)
transform_filter.Update()
surface_mesh_transformed = transform_filter.GetOutput()
spherical_coords_on_mesh : List[Tuple[float,float,float]] = []
points = surface_mesh_transformed.GetPoints()
# The farthest point from the origin is this far out:
r_max = np.max([np.sqrt(np.sum(np.array(points.GetPoint(i))**2)) for i in range(points.GetNumberOfPoints())])
sqdist_tolerance = dist_tolerance**2
locator = vtk.vtkCellLocator() # Tried vtkOBBTree and it seems vtkCellLocator is much faster for this application
locator.SetDataSet(surface_mesh_transformed)
locator.BuildLocator()
for i in range(points.GetNumberOfPoints()):
point = np.array(points.GetPoint(i))
point_r_squared = np.sum(point**2)
# A point that is distance 2*r_max from the origin along the same ray as `point`
# We will check for intersections along the line segment from `point` to `distant_point_along_same_ray_as_point`
# The distance 2*r_max is chosen just to ensure that the line segment captures any possible intersection in the infinite
# ray emanating from `point` outward
distant_point_along_same_ray_as_point = (2*r_max/np.sqrt(point_r_squared)) * point
intersection_points = vtk.vtkPoints()
cell_ids = vtk.vtkIdList()
locator.IntersectWithLine(
point, # p1
distant_point_along_same_ray_as_point, # p2
0., # tol
intersection_points, # points
cell_ids, # cellIds
)
point_is_the_furthest_out = True
for j in range(intersection_points.GetNumberOfPoints()):
intersection_point = np.array(intersection_points.GetPoint(j))
sqdist = np.sum((point-intersection_point)**2) # squared distance from point to intersection point
if sqdist > sqdist_tolerance:
point_is_the_furthest_out = False
break
if point_is_the_furthest_out:
spherical_coords_on_mesh.append(cartesian_to_spherical(*point)) # append the (r, theta, phi) triple
spherical_coords_on_mesh = np.array(spherical_coords_on_mesh)
# We clone the points with a +/- 2pi translation in the phi (azimuthal) coordinate, creating 3 times as many points
# This will help the LinearNDInterpolator to better handle phi values as they wrap around
spherical_coords_on_mesh = np.concatenate(
[
spherical_coords_on_mesh,
spherical_coords_on_mesh + np.array([0.,0.,2*np.pi]).reshape((1,3)), # add 2pi to phi coordinate
spherical_coords_on_mesh - np.array([0.,0.,2*np.pi]).reshape((1,3)), # subtract 2pi from phi coordinate
],
axis=0
)
# We clone the points with a pi translation in the phi (azimuthal) coordinate and suitable flips in theta,
# creating another 3 times as many points. This will help the LinearNDInterpolator to better
# handle theta values close to the poles (theta=0 and theta=pi).
spherical_coords_on_mesh = np.concatenate(
[
spherical_coords_on_mesh,
# theta |--> -theta, phi |--> phi+pi, introduces negative theta values
(spherical_coords_on_mesh * np.array([1.,-1.,1.]).reshape((1,3))) + np.array([0.,0.,np.pi]).reshape((1,3)),
# theta |--> 2pi-theta, phi |--> phi+pi, introduces theta values greater than pi
(spherical_coords_on_mesh * np.array([1.,-1.,1.]).reshape((1,3))) + np.array([0.,2*np.pi,np.pi]).reshape((1,3)),
],
axis=0
)
interpolator = LinearNDInterpolator(
points = spherical_coords_on_mesh[:,[1,2]], # The (theta, phi) spherical coordinates
values = spherical_coords_on_mesh[:,0], # The r spherical coordinates
)
return interpolator