Note
Go to the end to download the full example code
LDDMM with normalized kernel
This notebook illustrates the interest of normalizing the cometric in the LDDMM model. We consider the registration of two spheres that differ by a translation first, and then a more complex registration task with a translation and a deformation.
Without normalization, the carpooling artifact occurs: the sphere is contracted, then translated and finally expanded. in this situation, event if the morphed shape matches the target, the intermediate shapes are not meaningful and the extrapolation is not reliable.
Normalizing the kernel adds regularization to the morphing, leanding to prevention of the carpooling artifact and improvement of the extrapolation to some extent.
Options for normalization are:
“rows”: normalize the rows of the kernel
“columns”: normalize the columns of the kernel
“both”: normalize both the rows and the columns of the kernel (for quare kernels, algorithm 5.7 in https://www.jeanfeydy.com/geometric_data_analysis.pdf)
Further explanation can be found in the p177 and onwards of https://www.jeanfeydy.com/geometric_data_analysis.pdf.
from time import time
import pyvista as pv
import torch
import skshapes as sks
# sphinx_gallery_thumbnail_number = 9
Load data
plot_kwargs = {
"smooth_shading": True,
"pbr": True,
"metallic": 0.7,
"roughness": 0.6,
}
cpos = [
(1.6256104086078755, -9.701422233882411, 1.3012755902068773),
(1.191160019984921, 0.01901107976782581, -0.0052552929581526076),
(0.006053690112347382, 0.13347614338229413, 0.9910335372649167),
]
source = sks.Sphere()
target = sks.Sphere()
decimation = sks.Decimation(n_points=200)
source = decimation.fit_transform(source)
target = decimation.transform(target)
target.points = target.points + torch.tensor([2, 0.0, 0.0])
plotter = pv.Plotter()
plotter.add_mesh(source.to_pyvista(), color="teal", opacity=0.8, **plot_kwargs)
plotter.add_mesh(target.to_pyvista(), color="red", opacity=0.8, **plot_kwargs)
plotter.camera_position = cpos
plotter.show()
data:image/s3,"s3://crabby-images/2c18f/2c18f0eca6b210a78cb05d04abfd0ce55c2594a3" alt="plot lddmm 1 normalization"
LDDM without normalization
model = sks.ExtrinsicDeformation(
n_steps=4,
kernel="gaussian",
scale=0.3,
)
loss = sks.L2Loss()
task = sks.Registration(
model=model,
loss=loss,
optimizer=sks.LBFGS(),
n_iter=3,
regularization_weight=1e-1,
verbose=True,
)
start = time()
task.fit(source=source, target=target)
print("Elapsed time: ", time() - start)
path = task.path_
plotter = pv.Plotter()
plotter.add_mesh(source.to_pyvista(), color="teal", opacity=0.2, **plot_kwargs)
plotter.add_mesh(target.to_pyvista(), color="red", opacity=0.2, **plot_kwargs)
for i in range(len(path)):
plotter.add_mesh(
path[i].to_pyvista(), color="tan", opacity=0.8, **plot_kwargs
)
plotter.camera_position = cpos
plotter.show()
data:image/s3,"s3://crabby-images/2c427/2c4276b46d9b1358737f8d3d2eb076bc9743976e" alt="plot lddmm 1 normalization"
Initial loss : 8.04e+02
= 8.04e+02 + 0.1 * 0.00e+00 (fidelity + regularization_weight * regularization)
Loss after 1 iteration(s) : 1.84e+00
= 1.42e+00 + 0.1 * 4.16e+00 (fidelity + regularization_weight * regularization)
Loss after 2 iteration(s) : 5.25e-01
= 1.21e-01 + 0.1 * 4.04e+00 (fidelity + regularization_weight * regularization)
Loss after 3 iteration(s) : 4.40e-01
= 3.99e-02 + 0.1 * 4.00e+00 (fidelity + regularization_weight * regularization)
Elapsed time: 69.92670607566833
Extrapolation
back = model.morph(
shape=source,
parameter=task.parameter_,
final_time=-1.0,
return_path=True,
).path
model.n_steps = 8
forward = model.morph(
shape=source,
parameter=task.parameter_,
final_time=2.0,
return_path=True,
).path
path = back[::-1] + forward[1:]
plotter = pv.Plotter()
plotter.open_gif("lddmm_no_normalization.gif", fps=4)
for i in range(len(path)):
plotter.clear_actors()
plotter.add_mesh(
source.to_pyvista(), color="teal", opacity=0.2, **plot_kwargs
)
plotter.add_mesh(
target.to_pyvista(), color="red", opacity=0.2, **plot_kwargs
)
plotter.add_mesh(
path[i].to_pyvista(), color="tan", opacity=0.8, **plot_kwargs
)
plotter.camera_position = cpos
plotter.write_frame()
plotter.close()
data:image/s3,"s3://crabby-images/4eec5/4eec5195db04da3b9606c068724ce62547b733a1" alt="plot lddmm 1 normalization"
Normalizing the rows of the kernel
model_norm = sks.ExtrinsicDeformation(
n_steps=4,
kernel="gaussian",
scale=0.3,
normalization="rows",
)
task_norm = sks.Registration(
model=model_norm,
loss=loss,
optimizer=sks.LBFGS(),
n_iter=1,
regularization_weight=0.0,
verbose=True,
)
start = time()
task_norm.fit(source=source, target=target)
print("Elapsed time: ", time() - start)
path = task_norm.path_
plotter = pv.Plotter()
plotter.add_mesh(source.to_pyvista(), color="teal", opacity=0.2, **plot_kwargs)
plotter.add_mesh(target.to_pyvista(), color="red", opacity=0.2, **plot_kwargs)
for i in range(len(path)):
plotter.add_mesh(
path[i].to_pyvista(), color="tan", opacity=0.8, **plot_kwargs
)
plotter.camera_position = cpos
plotter.show()
data:image/s3,"s3://crabby-images/c7211/c721183a638b1d633c900cf315d1470e06b91553" alt="plot lddmm 1 normalization"
Initial loss : 8.04e+02
= 8.04e+02 + 0 (fidelity + regularization_weight * regularization)
Loss after 1 iteration(s) : 1.67e-03
= 1.67e-03 + 0 (fidelity + regularization_weight * regularization)
Elapsed time: 41.609856843948364
Extrapolation
back = model_norm.morph(
shape=source,
parameter=task_norm.parameter_,
final_time=-1.0,
return_path=True,
).path
model_norm.n_steps = 8
forward = model_norm.morph(
shape=source,
parameter=task_norm.parameter_,
final_time=2.0,
return_path=True,
).path
path = back[::-1] + forward[1:]
plotter = pv.Plotter()
plotter.open_gif("lddmm_normalization.gif", fps=4)
for i in range(len(path)):
plotter.clear_actors()
plotter.add_mesh(
source.to_pyvista(), color="teal", opacity=0.2, **plot_kwargs
)
plotter.add_mesh(
target.to_pyvista(), color="red", opacity=0.2, **plot_kwargs
)
plotter.add_mesh(
path[i].to_pyvista(), color="tan", opacity=0.8, **plot_kwargs
)
plotter.camera_position = cpos
plotter.write_frame()
plotter.close()
data:image/s3,"s3://crabby-images/80223/8022325c8b96909c362b037a1d9fc327fa1c8d02" alt="plot lddmm 1 normalization"
Normalizing both rows and columns of the kernel
model_norm = sks.ExtrinsicDeformation(
n_steps=4,
kernel="gaussian",
scale=0.3,
normalization="both",
)
task_norm = sks.Registration(
model=model_norm,
loss=loss,
optimizer=sks.LBFGS(),
n_iter=1,
regularization_weight=0.0,
verbose=True,
)
start = time()
task_norm.fit(source=source, target=target)
elapsed_time = time() - start
path = task_norm.path_
plotter = pv.Plotter()
plotter.add_mesh(source.to_pyvista(), color="teal", opacity=0.2, **plot_kwargs)
plotter.add_mesh(target.to_pyvista(), color="red", opacity=0.2, **plot_kwargs)
for i in range(len(path)):
plotter.add_mesh(
path[i].to_pyvista(), color="tan", opacity=0.8, **plot_kwargs
)
plotter.camera_position = cpos
plotter.show()
data:image/s3,"s3://crabby-images/a1cbe/a1cbea48ba07bf6f2d368aeb0b65cfb7d51ec0b2" alt="plot lddmm 1 normalization"
Initial loss : 8.04e+02
= 8.04e+02 + 0 (fidelity + regularization_weight * regularization)
Loss after 1 iteration(s) : 4.98e-06
= 4.98e-06 + 0 (fidelity + regularization_weight * regularization)
Extrapolation
back = model_norm.morph(
shape=source,
parameter=task_norm.parameter_,
final_time=-1.0,
return_path=True,
).path
model_norm.n_steps = 8
forward = model_norm.morph(
shape=source,
parameter=task_norm.parameter_,
final_time=2.0,
return_path=True,
).path
path = back[::-1] + forward[1:]
plotter = pv.Plotter()
plotter.open_gif("lddmm_normalization.gif", fps=4)
for i in range(len(path)):
plotter.clear_actors()
plotter.add_mesh(
source.to_pyvista(), color="teal", opacity=0.2, **plot_kwargs
)
plotter.add_mesh(
target.to_pyvista(), color="red", opacity=0.2, **plot_kwargs
)
plotter.add_mesh(
path[i].to_pyvista(), color="tan", opacity=0.8, **plot_kwargs
)
plotter.camera_position = cpos
plotter.write_frame()
plotter.close()
data:image/s3,"s3://crabby-images/e35bf/e35bf9894b52c299ca799a8ca10e2de29268cfc8" alt="plot lddmm 1 normalization"
Example with a more complex shape
n_steps = 3
plot_kwargs = {
"smooth_shading": True,
"pbr": True,
"metallic": 0.7,
"roughness": 0.6,
}
cpos = [
(3.6401575998373183, -1.183408993703478, 1.0915912440258628),
(0.7463583722710609, 0.762569822371006, 0.48035204596817493),
(-0.1745415166347431, 0.04933887578777028, 0.9834129012306287),
]
# 5 - 8
source = sks.PolyData("../test_data/cactus/cactus3.ply")
target = sks.PolyData("../test_data/cactus/cactus11.ply")
target.points += torch.Tensor([0.5, 0.5, 0])
decimation = sks.Decimation(n_points=500)
source = decimation.fit_transform(source)
target = decimation.transform(target)
model = sks.ExtrinsicDeformation(
n_steps=n_steps,
kernel="gaussian",
scale=0.1,
normalization="both",
)
loss = sks.L2Loss()
Interpolation
task = sks.Registration(
model=model,
loss=loss,
optimizer=sks.LBFGS(),
n_iter=1,
verbose=True,
regularization_weight=0.001,
)
start = time()
task.fit(source=source, target=target)
print("Elapsed time: ", time() - start)
path = task.path_
plotter = pv.Plotter()
for frame in path:
plotter.add_mesh(
frame.to_pyvista(), color="tan", opacity=0.3, **plot_kwargs
)
plotter.add_mesh(source.to_pyvista(), color="teal", opacity=0.5, **plot_kwargs)
plotter.add_mesh(target.to_pyvista(), color="red", opacity=0.5, **plot_kwargs)
plotter.camera_position = cpos
plotter.show()
data:image/s3,"s3://crabby-images/c100a/c100a70fc5a230bcdd9fb2cd4e5b75fe3103c3dc" alt="plot lddmm 1 normalization"
Initial loss : 2.50e+02
= 2.50e+02 + 0.001 * 0.00e+00 (fidelity + regularization_weight * regularization)
Loss after 1 iteration(s) : 1.35e-01
= 9.99e-03 + 0.001 * 1.25e+02 (fidelity + regularization_weight * regularization)
Elapsed time: 15.22480583190918
Extrapolation
back = model.morph(
shape=source,
parameter=task.parameter_,
return_path=True,
return_regularization=True,
final_time=-1.0,
).path
model.n_steps = 2 * n_steps
forward = model.morph(
shape=source,
parameter=task.parameter_,
return_path=True,
return_regularization=True,
final_time=2.0,
).path
path = back[::-1] + forward[1:]
plotter = pv.Plotter()
for frame in path:
plotter.add_mesh(
frame.to_pyvista(), color="tan", opacity=0.3, **plot_kwargs
)
plotter.add_mesh(source.to_pyvista(), color="teal", opacity=0.5, **plot_kwargs)
plotter.add_mesh(target.to_pyvista(), color="red", opacity=0.5, **plot_kwargs)
plotter.camera_position = cpos
plotter.show()
print(plotter.camera_position)
data:image/s3,"s3://crabby-images/a3931/a3931023efb6a7b886b171485851970cb5d882f8" alt="plot lddmm 1 normalization"
[(3.6401575998373183, -1.183408993703478, 1.0915912440258628),
(0.7463583722710609, 0.762569822371006, 0.48035204596817493),
(-0.17454151663474313, 0.049338875787770284, 0.9834129012306289)]
Total running time of the script: (3 minutes 27.775 seconds)