原文地址:
https://examples.itk.org/src/registration/common/mutualinformation/mutualinformation#Resample-the-moving-image
https://examples.itk.org/src/registration/common/mutualinformation/documentation
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import itk
dim = 2
ImageType = itk.Image[itk.F, dim]
FixedImageType = ImageType
MovingImageType = ImageType
baseDir = r'D:\learn\itk\itk5.3.0\ITK-5.3.0\Examples\Data'
fixed_img_path = baseDir + r"\BrainT1SliceBorder20.png"
moving_img_path = baseDir + r"\BrainProtonDensitySliceShifted13x17y.png"
fixed_img = itk.imread(fixed_img_path, itk.F)
moving_img = itk.imread(moving_img_path, itk.F)
fixed_normalized_image = itk.normalize_image_filter(fixed_img)
fixed_smoothed_image = itk.discrete_gaussian_image_filter(fixed_normalized_image, variance=2.0)
moving_normalized_image = itk.normalize_image_filter(moving_img)
moving_smoothed_image = itk.discrete_gaussian_image_filter(moving_normalized_image, variance=2.0)
plt.subplot(221), plt.imshow(itk.GetArrayFromImage(fixed_img), 'gray'), plt.title('fixed_img')
plt.subplot(222), plt.imshow(itk.GetArrayFromImage(moving_img), 'gray'), plt.title('moving_img')
plt.subplot(223), plt.imshow(itk.GetArrayFromImage(fixed_smoothed_image), 'gray'), plt.title('fixed_smoothed_image')
plt.subplot(224), plt.imshow(itk.GetArrayFromImage(moving_smoothed_image), 'gray'), plt.title('moving_smoothed_image')
plt.show()
TransformType = itk.TranslationTransform[itk.D, dim]
OptimizerType = itk.GradientDescentOptimizer
ExhaustiveOptimizerType = itk.ExhaustiveOptimizer
MetricType = itk.MutualInformationImageToImageMetric[ImageType, ImageType]
RegistrationType = itk.ImageRegistrationMethod[ImageType, ImageType]
InterpolatorType = itk.LinearInterpolateImageFunction[ImageType, itk.D]
# 显示-20到20的x方向和y方向step为0.2的每个可能的值, step 为 window_size/n_steps
# Plot_the_MutualInformationImageToImageMetric_surface
# Move at most 20 pixels away from the initial position
window_size = [20, 20]
# Collect 100 steps of data along each axis
n_steps = [100, 100]
transform = TransformType.New()
metric = MetricType.New()
optimizer = ExhaustiveOptimizerType.New()
registrar = RegistrationType.New()
interpolator = InterpolatorType.New()
metric.SetNumberOfSpatialSamples(100)
metric.SetFixedImageStandardDeviation(0.4)
metric.SetMovingImageStandardDeviation(0.4)
optimizer.SetNumberOfSteps(n_steps)
# Initialize scales and set back to optimizer
scales = optimizer.GetScales()
scales.SetSize(2)
scales.SetElement(0, window_size[0] / n_steps[0])
scales.SetElement(1, window_size[1] / n_steps[1])
optimizer.SetScales(scales)
registrar.SetFixedImage(fixed_smoothed_image)
registrar.SetMovingImage(moving_smoothed_image)
registrar.SetOptimizer(optimizer)
registrar.SetTransform(transform)
registrar.SetInterpolator(interpolator)
registrar.SetMetric(metric)
registrar.SetFixedImageRegion(fixed_img.GetBufferedRegion())
registrar.SetInitialTransformParameters(transform.GetParameters())
# Collect data describing the parametric surface with an observer
surface = dict()
def print_iteration():
surface[tuple(optimizer.GetCurrentPosition())] = optimizer.GetCurrentValue()
optimizer.AddObserver(itk.IterationEvent(), print_iteration)
registrar.Update()
max_position = list(optimizer.GetMaximumMetricValuePosition())
min_position = list(optimizer.GetMinimumMetricValuePosition())
max_val = optimizer.GetMaximumMetricValue()
min_val = optimizer.GetMinimumMetricValue()
print(max_position)
print(min_position)
print(max_val)
print(min_val)
# Set up values for the plot
x_vals = [list(set([x[i] for x in surface.keys()])) for i in range(0, 2)]
for i in range(0, 2):
x_vals[i].sort()
X, Y = np.meshgrid(x_vals[0], x_vals[1])
Z = np.array([[surface[(x0, x1)] for x1 in x_vals[0]] for x0 in x_vals[1]])
# Plot the surface as a 2D heat map
fig = plt.figure()
plt.gca().invert_yaxis()
ax = plt.gca()
surf = ax.scatter(X, Y, c=Z, cmap=cm.coolwarm)
ax.plot(max_position[0], max_position[1], "k^")
ax.plot(min_position[0], min_position[1], "kv")
plt.show()
# Plot the surface as a 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm)
plt.show()
n_iterations = 200
transform = TransformType.New()
metric = MetricType.New()
optimizer = OptimizerType.New()
registrar = RegistrationType.New()
interpolator = InterpolatorType.New()
registrar.SetFixedImage(fixed_smoothed_image)
registrar.SetMovingImage(moving_smoothed_image)
registrar.SetOptimizer(optimizer)
registrar.SetTransform(transform)
registrar.SetInterpolator(interpolator)
registrar.SetMetric(metric)
registrar.SetFixedImageRegion(fixed_img.GetBufferedRegion())
registrar.SetInitialTransformParameters(transform.GetParameters())
metric.SetNumberOfSpatialSamples(100)
metric.SetFixedImageStandardDeviation(0.4)
metric.SetMovingImageStandardDeviation(0.4)
optimizer.SetLearningRate(15)
optimizer.SetNumberOfIterations(n_iterations)
optimizer.MaximizeOn()
descent_data = dict()
descent_data[0] = (0, 0)
def log_iteration():
descent_data[optimizer.GetCurrentIteration() + 1] = tuple(optimizer.GetCurrentPosition())
optimizer.AddObserver(itk.IterationEvent(), log_iteration)
registrar.Update()
print(f"Its: {optimizer.GetCurrentIteration()}")
print(f"Final Value: {optimizer.GetValue()}")
print(f"Final Position: {list(registrar.GetLastTransformParameters())}")
x_vals = [descent_data[i][0] for i in range(0, n_iterations)]
y_vals = [descent_data[i][1] for i in range(0, n_iterations)]
fig = plt.figure()
# Note: We invert the y-axis to represent the image coordinate system
plt.gca().invert_yaxis()
ax = plt.gca()
surf = ax.scatter(X, Y, c=Z, cmap=cm.coolwarm)
for i in range(0, n_iterations - 1):
plt.plot(x_vals[i : i + 2], y_vals[i : i + 2], "wx-")
plt.plot(descent_data[0][0], descent_data[0][1], "bo")
plt.plot(descent_data[n_iterations - 1][0], descent_data[n_iterations - 1][1], "ro")
plt.plot(max_position[0], max_position[1], "k^")
plt.plot(min_position[0], min_position[1], "kv")
plt.show()
print(max_position)
print(min_position)
ResampleFilterType = itk.ResampleImageFilter[MovingImageType, FixedImageType]
resample = ResampleFilterType.New(
Transform=transform,
Input=moving_img,
Size=fixed_img.GetLargestPossibleRegion().GetSize(),
OutputOrigin=fixed_img.GetOrigin(),
OutputSpacing=fixed_img.GetSpacing(),
OutputDirection=fixed_img.GetDirection(),
DefaultPixelValue=100,
)
resample.Update()
plt.subplot(121), plt.imshow(itk.GetArrayFromImage(fixed_img), 'gray'), plt.title('fixed_img')
plt.subplot(122), plt.imshow(itk.GetArrayFromImage(resample.GetOutput()), 'gray'), plt.title('resample image')
plt.show()