skshapes.tasks.registration.Registration
- class skshapes.tasks.registration.Registration(*, model, loss, optimizer=None, regularization_weight=1, n_iter=10, verbose=0, gpu=True, debug=False)
Bases:
object
Registration class.
This class implements the registration between two shapes. It must be initialized with a model, a loss and an optimizer. The registration is performed by calling the fit method with the source and target shapes as arguments. The transform method can then be used to transform a new shape using the learned registration’s parameter.
It must be initialized with a model, a loss and an optimizer. The registration is performed by calling the fit method with the source and target shapes as arguments. The transform method can then be used to transform the source shape using the learned registration’s parameter.
The optimization criterion is the sum of the fidelity and the regularization term, weighted by the regularization_weight parameter:
$$ text{loss}(theta) = text{fid}(text{Morph}(theta, text{source}), text{target}) + text{regularization_weight} times text{reg}(theta)$$
The fidelity term \(\text{fid}\) is given by the loss object, the regularization term \(\text{reg}\) and the morphing \(\text{Morph}\) are given by the model.
- Parameters:
model (
BaseModel
) – a model object (from skshapes.morphing)loss (
BaseLoss
) – a loss object (from skshapes.loss)optimizer (
LBFGS
|Adam
|Adagrad
|SGD
|None
) – an optimizer object (from skshapes.optimization)regularization_weight (
int
|float
) – the regularization_weight parameter for the criterion : fidelity + regularization_weight * regularization.n_iter (
int
) – number of iteration for optimization loop.verbose (
int
) – positive to print the losses after each optimization loop iterationgpu (
bool
) – do intensive numerical computations on a nvidia gpu with a cuda backend if available.debug – if True, information will be stored during the optimization process
Examples
model = sks.RigidMotion() loss = sks.OptimalTransportLoss() optimizer = sks.SGD(lr=0.1) registration = sks.Registration(model=model, loss=loss, optimizer=optimizer) registration.fit(source=source, target=target) transformed_source = registration.transform(source=source) # Access the parameter parameter = registration.parameter_ # Access the loss loss = registration.loss_ # Access the fidelity term fidelity = registration.fidelity_ # Access the regularization term regularization = registration.regularization_
More examples can be found in the [gallery](../../../generated/gallery/#registration).
- __init__(*, model, loss, optimizer=None, regularization_weight=1, n_iter=10, verbose=0, gpu=True, debug=False)
Methods
__init__
(*, model, loss[, optimizer, ...])fit
(*, source, target[, initial_parameter])Fit the registration between the source and target shapes.
fit_transform
(*, source, target[, ...])Fit the registration and apply it to the source shape.
transform
(*, source)Apply the registration to a new shape.
- fit(*, source, target, initial_parameter=None)
Fit the registration between the source and target shapes.
After calling this method, the registration’s parameter can be accessed with the
parameter_
attribute, the transformed shape with thetransformed_shape_
attribute and the list of successives shapes during the registration process with thepath_
attribute.- Parameters:
source (shape_object) – a shape object (from skshapes.shapes)
target (
polydata_type
|image_type
) – a shape object (from skshapes.shapes)initial_parameter (
Float32[Tensor, '*_']
|None
) – an initial parameter tensor for the optimization process. If None, the parameter is initialized with zeros. Defaults to None.
- Raises:
DeviceError – if the source and target shapes are not on the same device.
- Returns:
self
- Return type:
- fit_transform(*, source, target, initial_parameter=None)
Fit the registration and apply it to the source shape.
- Return type:
- transform(*, source)
Apply the registration to a new shape.
- Parameters:
source (
polydata_type
|image_type
) – the shape to transform.- Returns:
the transformed shape.
- Return type:
shape_type