from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Sequence,
Tuple,
)
import numpy as np
import vtk
from openlifu.geo import (
cartesian_to_spherical,
cartesian_to_spherical_vectorized,
spherical_coordinate_basis,
spherical_to_cartesian,
spherical_to_cartesian_vectorized,
)
from openlifu.seg.skinseg import (
apply_affine_to_polydata,
compute_foreground_mask,
create_closed_surface_from_labelmap,
spherical_interpolator_from_mesh,
vtk_img_from_array_and_affine,
)
from openlifu.util.annotations import OpenLIFUFieldData
from openlifu.util.dict_conversion import DictMixin
from openlifu.util.units import getunitconversion
log = logging.getLogger("VirtualFit")
ras2asl_3x3 = np.array([[0,1,0],[0,0,1],[-1,0,0]], dtype=float) # ASL means Anterior-Superior-Left coordinates
asl2ras_3x3 = ras2asl_3x3.transpose()
[docs]
@dataclass
class VirtualFitOptions(DictMixin):
"""Parameters to configure the `virtual_fit` algorithm.
The terms 'pitch' and 'yaw' used here refer to the following target-centric angular coordinates in patient space:
pitch: The angle between the anterior axis through the target and the ray from the target to the projection of
a given point into the anterior-superior plane.
yaw: The angle between the anterior-superior plane through the target and the ray from the target to a given point.
Another way to describe them in terms of standard spherical coordinates centered at the target in ASL (anterior-superior-left) space:
pitch: The azimuthal spherical coordinate.
yaw: 90 degrees minus the polar spherical coordinate.
"""
units: Annotated[str, OpenLIFUFieldData("Length units", "The units of length used in the length attributes of this class")] = "mm"
"""The units of length used in the length attributes of this class"""
transducer_steering_center_distance: Annotated[float, OpenLIFUFieldData("Steering center distance", "Distance from the transducer origin axially to the center of the steering zone in the units `units`")] = 50.
"""Distance from the transducer origin axially to the center of the steering zone in the units `units`"""
steering_limits: Annotated[Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]],
OpenLIFUFieldData("Steering limits", "Steering bounds along each axis from the transducer origin, in the units `units`")] = ((-50, 50), (-50, 50), (-50, 50))
"""Distance from the transducer origin axially to the center of the steering zone in the units `units`"""
pitch_range: Annotated[Tuple[float, float], OpenLIFUFieldData("Pitch range (deg)", "Range of pitches to include in the transducer fitting search grid, in degrees")] = (-10, 150)
"""Range of pitches to include in the transducer fitting search grid, in degrees"""
pitch_step: Annotated[float, OpenLIFUFieldData("Pitch step size (deg)", "Pitch step size when forming the transducer fitting search grid, in degrees")] = 5
"""Pitch step size when forming the transducer fitting search grid, in degrees"""
yaw_range: Annotated[Tuple[float, float], OpenLIFUFieldData("Yaw range (deg)", "Range of yaws to include in the transducer fitting search grid, in degrees")] = (-65, 65)
"""Range of yaws to include in the transducer fitting search grid, in degrees"""
yaw_step: Annotated[float, OpenLIFUFieldData("Yaw step size (deg)", "Yaw step size when forming the transducer fitting search grid, in degrees")] = 5
"""Yaw step size when forming the transducer fitting search grid, in degrees"""
planefit_dyaw_extent: Annotated[float, OpenLIFUFieldData("Plane fit yaw extent", "Left and right extents of the point grid to be used for plane fitting along the local yaw axes, in units of `units`")] = 15
"""Left and right extents of the point grid to be used for plane fitting along the local yaw axes,
in units of `units`. The plane fitting point grid will be twice this size, since this is left
and right extents. (Note that this has units of length, not angle!)"""
planefit_dyaw_step: Annotated[float, OpenLIFUFieldData("Plane fit yaw step", "Local yaw axis step size to use when constructing plane fitting grids. In spatial units of `units`")] = 3
"""Local yaw axis step size to use when constructing plane fitting grids. In spatial units of `units`."""
planefit_dpitch_extent: Annotated[float, OpenLIFUFieldData("Plane fit pitch extent", "Left and right extents of the point grid to be used for plane fitting along the local pitch axes, in spatial units of `units`")] = 15
"""Left and right extents of the point grid to be used for plane fitting along the local pitch axes,
in spatial units of `units`. The plane fitting point grid will be twice this size, since this is left
and right extents."""
planefit_dpitch_step: Annotated[float, OpenLIFUFieldData("Plane fit pitch step", "Local pitch axis step size to use when constructing plane fitting grids. In spatial units of `units`")] = 3
"""Local pitch axis step size to use when constructing plane fitting grids. In spatial units of `units`."""
[docs]
def to_units(self, target_units: str) -> VirtualFitOptions:
"""Do unit conversion and return a version of this VirtualFitOptions that uses
`target_units` as the units for all attributes that have units of length."""
conversion_factor = getunitconversion(from_unit = self.units, to_unit=target_units)
return VirtualFitOptions(
units = target_units,
transducer_steering_center_distance = conversion_factor * self.transducer_steering_center_distance,
steering_limits = tuple(map(tuple,conversion_factor*np.array(self.steering_limits))),
pitch_range = self.pitch_range,
pitch_step = self.pitch_step,
yaw_range = self.yaw_range,
yaw_step = self.yaw_step,
planefit_dyaw_extent = conversion_factor * self.planefit_dyaw_extent,
planefit_dyaw_step = conversion_factor * self.planefit_dyaw_step,
planefit_dpitch_extent = conversion_factor * self.planefit_dpitch_extent,
planefit_dpitch_step = conversion_factor * self.planefit_dpitch_step,
)
[docs]
@staticmethod
def from_dict(parameter_dict: Dict[str,Any]) -> VirtualFitOptions: # Override DictMixin here
parameter_dict["pitch_range"] = tuple(parameter_dict["pitch_range"])
parameter_dict["yaw_range"] = tuple(parameter_dict["yaw_range"])
parameter_dict["steering_limits"] = tuple(map(tuple,parameter_dict["steering_limits"]))
return VirtualFitOptions(**parameter_dict)
[docs]
def compute_skin_mesh_from_volume(
volume_array : np.ndarray,
volume_affine_RAS : np.ndarray,
) -> vtk.vtkPolyData:
log.info("Computing foreground mask...")
foreground_mask_array = compute_foreground_mask(volume_array)
foreground_mask_vtk_image = vtk_img_from_array_and_affine(foreground_mask_array, volume_affine_RAS)
log.info("Creating closed surface from labelmap...")
skin_mesh = create_closed_surface_from_labelmap(foreground_mask_vtk_image)
return skin_mesh
[docs]
@dataclass
class VirtualFitDebugInfo:
"""Debugging information for the result of running `virtual_fit`."""
skin_mesh : vtk.vtkPolyData
"""The skin mesh that was used for virtual fitting"""
spherically_interpolated_mesh : vtk.vtkPolyData
"""A mesh representing the spherical interpolator that was used for virtual fitting"""
search_points : np.ndarray
"""Array of shape (N,3) containing the coordinates of the points that were tried for virtual fitting"""
plane_normals : np.ndarray
"""Array of shape (N,3) containing the normal vectors of the planes that were fitted at each of `search_points`"""
steering_dists : np.ndarray
"""Array of shape (N,) containing the computed steering distance for each point in `search_points`"""
in_bounds : np.ndarray
"""Boolean array of shape (N,) giving for each point in `search_points` whether the target was determined to be
in bounds for that candidate transducer placement."""
[docs]
def sphere_from_interpolator(
interpolator: Callable[[float, float], float],
theta_res:int = 50,
phi_res:int = 50,
) -> vtk.vtkPolyData:
"""Create a spherical mesh from a spherical interpolator, to help visualize how the interpolator works.
This is intended as a debugging utility."""
sphere_source = vtk.vtkSphereSource()
sphere_source.SetRadius(1.0)
sphere_source.SetThetaResolution(theta_res)
sphere_source.SetPhiResolution(phi_res)
sphere_source.Update()
sphere_polydata = sphere_source.GetOutput()
sphere_points = sphere_polydata.GetPoints()
for i in range(sphere_points.GetNumberOfPoints()):
point = np.array(sphere_points.GetPoint(i))
r, theta, phi = cartesian_to_spherical(*point)
r = interpolator(theta, phi)
sphere_points.SetPoint(i, r * point)
normals_filter = vtk.vtkPolyDataNormals()
normals_filter.SetInputData(sphere_polydata)
normals_filter.Update()
return normals_filter.GetOutput()
[docs]
def run_virtual_fit(
units: str,
target_RAS : Sequence[float],
standoff_transform : np.ndarray,
options : VirtualFitOptions,
volume_array : np.ndarray | None = None,
volume_affine_RAS : np.ndarray | None = None,
skin_mesh : vtk.vtkPolyData | None = None,
include_debug_info : bool = False,
progress_callback : Callable[[int,str],None] | None = None,
) -> List[np.ndarray] | Tuple[List[np.ndarray], VirtualFitDebugInfo]:
"""Run patient-specific "virtual fitting" algorithm, suggesting a series of candidate transducer
transforms for optimal sonicaiton of a given target.
Provide either a `volume_array` and `volume_affine_RAS`, or a `skin_mesh`.
Args:
units: The spatial units of the RAS space into which volume_affine_RAS maps
target_RAS: A 3D point, in the coordinates and units of `volume_affine_RAS` (the `units` argument)
standoff_transform: See the documentation of `create_standoff_transform` or
`Transducer.standoff_transform` for the meaning of this. Here it should be provided in the
units `units`. The method `Transducer.get_standoff_transform_in_units` is useful for getting this.
options : Virtual fitting algorithm configuration. See the `VirtualFitOptions` documentation.
volume_array: A 3D volume MRI
volume_affine_RAS: A 4x4 affine transform that maps `volume_array` into RAS space with certain units
skin_mesh: Optional pre-computed closed surface mesh. If provided, `volume_array` and
`volume_affine_RAS` can be omitted. The provided skin mesh should be in RAS space, with units
being the provided `units` arg. The function `compute_skin_mesh_from_volume` can be used to pre-compute
a skin mesh.
include_debug_info: Whether to include debugging info in the return value. Disabled by default because some of the debugging
info takes some time to compute.
progress_callback: An optional function that will be called to report progress. The function should accept two arguments:
an integer progress value from 0 to 100 followed by a string message describing the step currently being worked on.
Returns: A list of transducer transform candidates sorted starting from the best-scoring one. The transforms map transducer space
into LPS space, and they are in the same units as the RAS space of `volume_affine_RAS` (aka the `units` argument).
"""
if progress_callback is None:
def progress_callback(progress_percent : int, step_description : str): # noqa: ARG001
pass # Define it to be a no-op if no progress_callback was provided.
progress_callback(0, "Starting virtual fit")
# Express all virtual fit options in the units of volume_affine_RAS, i.e. the physical space of the volume
options = options.to_units(units)
pitch_range = options.pitch_range
pitch_step = options.pitch_step
yaw_range = options.yaw_range
yaw_step = options.yaw_step
transducer_steering_center_distance = options.transducer_steering_center_distance
steering_limits = options.steering_limits
planefit_dyaw_extent = options.planefit_dyaw_extent
planefit_dyaw_step = options.planefit_dyaw_step
planefit_dpitch_extent = options.planefit_dpitch_extent
planefit_dpitch_step = options.planefit_dpitch_step
if skin_mesh is None:
if volume_array is None or volume_affine_RAS is None:
raise ValueError("Both `volume_array` and `volume_affine_RAS` must be provided if `skin_mesh` is None.")
log.info("Computing skin mesh...")
progress_callback(0, "Computing skin mesh")
skin_mesh = compute_skin_mesh_from_volume(volume_array, volume_affine_RAS)
else:
log.info("Using provided skin mesh.")
log.info("Building skin interpolator...")
progress_callback(5, "Building skin interpolator")
skin_interpolator = spherical_interpolator_from_mesh(
surface_mesh = skin_mesh,
origin = target_RAS,
xyz_direction_columns = asl2ras_3x3, # surface mesh was in RAS, so here spherical coordinates are placed on ASL space
)
# Useful transforms to and from the skin_interpolator ASL space and between RAS and LPS
# Note that ASL is a left-handed coordinate system while RAS and LPS are right-handed.
interpolator2ras = np.eye(4)
interpolator2ras[:3,:3] = asl2ras_3x3
interpolator2ras[:3,3] = target_RAS
ras_lps_swap = np.diag([-1.,-1,1,1])
interpolator2lps = ras_lps_swap @ interpolator2ras
# Useful arrays for vectorized comparisons
steering_mins = np.array([sl[0] for sl in steering_limits], dtype=float) # shape (3,). It is the lat,ele,ax steering min
steering_maxs = np.array([sl[1] for sl in steering_limits], dtype=float) # shape (3,). It is the lat,ele,ax steering max
log.info("Searching through candidate transducer poses...")
progress_callback(50, "Searching through poses")
# Construct search grid
theta_sequence = np.arange(90 - yaw_range[-1], 90 - yaw_range[0], yaw_step)
phi_sequence = np.arange(pitch_range[0], pitch_range[-1], pitch_step)
theta_grid, phi_grid = np.meshgrid(theta_sequence, phi_sequence, indexing="ij") # each has shape (number of thetas, number of phis)
num_thetas, num_phis = theta_grid.shape
num_search_points = num_thetas*num_phis
thetas = theta_grid.reshape(num_search_points)
phis = phi_grid.reshape(num_search_points)
# Things that will be computed over the search grid
transducer_poses = np.empty((num_search_points,4,4), dtype=float)
in_bounds = np.zeros(shape=num_search_points, dtype=bool)
steering_dists = np.zeros(shape=num_search_points, dtype=float)
# Additional debugging info that will be computed over the search grid
points_asl = np.zeros((num_search_points,3), dtype=float) # search grid points in ASL coordinates
normals_asl = np.zeros((num_search_points,3), dtype=float) # normal vector of the plane that is fitted at each point, in ASL coordinates
for i in range(num_search_points):
theta_rad, phi_rad = thetas[i]*np.pi/180, phis[i]*np.pi/180
# Cartesian coordinate location of the point at which we are fitting a plane
point = np.array(spherical_to_cartesian(skin_interpolator(theta_rad, phi_rad), theta_rad, phi_rad))
# Build plane fitting grid in the spherical coordinate basis theta-phi plane, which we will later project back onto the skin surface
dtheta_sequence = np.arange(-planefit_dyaw_extent, planefit_dyaw_extent + planefit_dyaw_step, planefit_dyaw_step)
dphi_sequence = np.arange(-planefit_dpitch_extent, planefit_dpitch_extent + planefit_dpitch_step, planefit_dpitch_step)
dtheta_grid, dphi_grid = np.meshgrid(dtheta_sequence, dphi_sequence, indexing='ij')
r_hat, theta_hat, phi_hat = spherical_coordinate_basis(theta_rad,phi_rad)
planefit_points_unprojected_cartesian = (
point.reshape((1,1,3))
+ dtheta_grid[...,np.newaxis] * theta_hat.reshape(1,1,3) # shape (num dthetas, num dphis, 3)
+ dphi_grid[...,np.newaxis] * phi_hat.reshape(1,1,3) # shape (num dthetas, num dphis, 3)
) # shape (num dthetas, num dphis, 3)
planefit_points_unprojected_spherical = cartesian_to_spherical_vectorized(
planefit_points_unprojected_cartesian
) # shape (num dthetas, num dphis, 3)
skin_projected_r_values = skin_interpolator(planefit_points_unprojected_spherical[...,1:]) # shape (num dthetas, num dphis) # TODO adjust docstrings to demand a *vectorizable* spherical interpolator
planefit_points_cartesian = spherical_to_cartesian_vectorized( # Could instead renormalize planefit_points_unprojected_cartesian, not sure if it would give a speedup versus this
np.stack([
skin_projected_r_values, # New r values after projection to skin
planefit_points_unprojected_spherical[...,1], # Same old theta values
planefit_points_unprojected_spherical[...,2], # Same old phi values
], axis=-1)
)
# Fit the best plane to these points among the planes that pass through `point`. Here we find the normal vector to the plane.
plane_normal = np.linalg.svd(
planefit_points_cartesian.reshape(-1,3)-point.reshape(1,3),
full_matrices=False, # we don't need the left-singular vectors anyway, so this speeds things up
).Vh[-1] # The right-singular vector corresponding to the smallest singular value
# Transducer axial axis: Parallel to plane_normal, but points towards rather than away from the origin.
plane_normal_norm = np.linalg.norm(plane_normal)
if plane_normal_norm < 1e-10:
continue # Bad geometry at this location, so it's not a virtual fit candidate
transducer_z = - np.sign(np.dot(plane_normal,point)) * plane_normal / plane_normal_norm
# Transducer elevational axis: Phi-hat, but then with its component along transducer_z eliminated. This orients the transducer "up" if this were forehead, for example.
transducer_y = phi_hat - np.dot(phi_hat, transducer_z) * transducer_z
transducer_y_norm = np.linalg.norm(transducer_y)
if transducer_y_norm < 1e-10:
continue # Bad geometry at this location, so it's not a virtual fit candidate
transducer_y = transducer_y / transducer_y_norm
# Transducer lateral axis, here simply the only remaining choice to keep it a left handed coordinate system
# (ASL is left-handed, so the transducer axes must be left-handed to make for an orientation-preserving transducer transform)
transducer_x = np.cross(transducer_z, transducer_y)
transducer_transform = np.array(
[
[*transducer_x, 0],
[*transducer_y, 0],
[*transducer_z, 0],
[*point, 1],
],
dtype=float
).transpose()
# The transform moves the transducer into the ASL skin interpolator space.
# We want a transform that moves the transducer into LPS space, and we also want to apply the standoff transform
transducer_transform = interpolator2lps @ transducer_transform @ standoff_transform
# Target in transducer coordinates (lat, ele, ax)
target_XYZ = (np.linalg.inv(transducer_transform) @ ras_lps_swap @ np.array([*target_RAS,1.0]))[:3]
# Target in "steering space", where the origin is the center of the steering zone.
target_steering_space = target_XYZ - np.array([0.,0.,transducer_steering_center_distance])
steering_distance : float = float(np.linalg.norm(target_steering_space))
# Check whether the target is in the steering range
target_in_bounds : bool = bool(np.all((steering_mins < target_steering_space) & (target_steering_space < steering_maxs)))
# Finally, fill out the arrays we have been building in this loop
transducer_poses[i] = transducer_transform
steering_dists[i] = steering_distance
in_bounds[i] = target_in_bounds
points_asl[i] = point
normals_asl[i] = transducer_z
sorted_transforms = [
x[0] for x in sorted(zip(transducer_poses[in_bounds],steering_dists[in_bounds]), key = lambda x : x[1])
]
log.info("Virtual fitting complete.")
if include_debug_info:
log.info("Generating debug meshes...")
progress_callback(80, "Generating debug meshes")
interpolator_mesh : vtk.vtkPolyData = sphere_from_interpolator(skin_interpolator, theta_res=100, phi_res=100)
# A few things are in ASL coordinates, so we transform it to RAS space so that they are in the same coordinates as skin_mesh.
interpolator_mesh = apply_affine_to_polydata(interpolator2ras, interpolator_mesh)
points_asl_homogenized = np.concatenate([points_asl.T, np.ones((1,num_search_points))], axis=0) # shape (4,num_search_points)
points_ras = (interpolator2ras @ points_asl_homogenized)[:3].T # back to shape (num_search_points,3)
normals_ras = (asl2ras_3x3 @ (normals_asl.T)).T
# After transforming the interpolator_mesh, the normals can end up flipped, so we fix it just in case
normals_filter = vtk.vtkPolyDataNormals()
normals_filter.SetInputData(interpolator_mesh)
normals_filter.Update()
interpolator_mesh = normals_filter.GetOutput()
progress_callback(100, "Complete")
return (
sorted_transforms,
VirtualFitDebugInfo(
skin_mesh = skin_mesh,
spherically_interpolated_mesh = interpolator_mesh,
search_points = points_ras,
plane_normals = normals_ras,
steering_dists = steering_dists,
in_bounds = in_bounds,
),
)
progress_callback(100, "Complete")
return sorted_transforms