Writing a deformation model
A deformation model specifies the type of deformation to apply to a shape. Examples include
This document is intended to explain how deformation models must be implemented to be integrated into scikit-shapes pipelines.
Structure
A deformation model must inherits from sks.morphing.BaseModel
and following methods are required :
__init__
: class constructor used to define some hyperparametersmorph
: method to apply the deformation to the shape given a parameter. It returns aMorphingOutput
objects.parameter_shape(shape=...)
: method that output a tuple corresponding to the shape of the parameter that must be passed tomorph
initial_parameter(shape=...)
: method that returns a default parameter. For now it is a zerotorch.tensor
defined on the right device and with the right shape.
Usage
Deformation objects are accessed by two ways : by the end-user of scikit-shapes and internally by a task objects (such as Registration
).
End-user interacts with deformation
:
at initialization, by setting hyperparameters
inderectly by accessing MorphingOuput after fitting
task
Task interact with deformation
:
by calling
initial_parameter
by calling
morph(shape, parameter)
by copying the content of
MorphingOutput
Tasks relies on optimization schemes and morph
must be compatible with autograd
the following test must tun without error (it is actually in the test suite)
source, target = ... # Any shapes
loss = ... # Any loss
# Initialize the deformation model
model = deformation_model()
# Get an initial parameter
p = model.inital_parameter(shape=source)
p.requires_grad_(True)
morphed_shape = model.morph(shape=source, parameter=p).morphed_shape
L = loss(morphed_shape, target)
L.backward()
assert p.grad is not None
if torch.cuda.is_available():
p = p.cuda()
try:
model.morph(shape=source, parameter=p).morphed_shape
except ValueError as e:
pass
else:
raise RuntimeError(
"Expected ValueError as parameter and shape not on the same device"
)
Access to shape properties
Some deformation models require some attributes to be available for morphed shape. An example is shape.control_points
in ExtrinsicDeformation
. These attributes should not be added to the arguments of morph
, but instead accessed inside morph
with a clear error message if the attribute cannot be reached and if no default behavior can be defined in this case.
Custom attributes
The output of morph
is a [MorphingOutput
]. It is simply a class with no method and containing by default the following attributes :
morphed_shape
: the shape morphed by the modelpath
: a list of shapes, fromsource
tomorphed_shape
of lengthn_steps + 1
regularization
: the regularization corresponding to the model and the parameter
Other attributes can be written inside .morph()
, for example for RigidMotion
, you can access rotation_matrix
or translation_vector
After fitting, all the attributes of the fitted deformation model are accessible in the registration with the same name as in MorphingOutput
plus an extra “_”:
model = RigidMotion()
r = Registration(model=model, los=..., ...)
r.fit(source=source, target=target)
rotation_matrix = r.rotation_matrix_
translation = r.translation_vector_
output = r.morphed_shape_