cdopt.nn.utils.set_constraints#
set_constraints.set_constraint_dissolving#
set_constraint_dissolving(module, attr_name, manifold_class = euclidean_torch, weight_var_transfer = None, manifold_args = {}, penalty_param = 0)
Set the manifold constraints to the attribute attr_name
to the Module module
.
Parameters:
module (torch.nn.Module) – the module to be set the manifold constraints
attr_name (str) – the name of the attribute to be set the manifold constraints
manifold_class (cdopt.manifold) – the manifold classes from
cdopt.manifold_torch
.weight_var_transfer (Callable) – the function that determines how the variables from
manifold_class
are transformed to the attributeattr_name
of themodule
.manifold_args (dict) – arguments to be passed when instantiating the manifold class from
manifold_class
penalty_param (float) – the value of the penalty parameter.
Returns:
The module with manifold constraints on its attribute
attr_name
.
Note
The set_constraint_dissolving()
function is developed based on the torch.nn.utils.parametrize
functions. Therefore, the module returned belongs to the ParametrizationList
module in PyTorch, which is different from the predefined neural layers from cdopt.nn
.
The list of parametrizations on the tensor weight will be accessible under module.parametrizations
. And the variables of the manifold will be accessible under module.parametrizations[attr_name].original
. Moreover, the constraint dissolving mapping \(\mathcal{A}\) can be accessed at module.parametrizations[attr_name].A
, the constriants can be accessed at module.parametrizations[attr_name].C
, the quadratic penalty term can be accessed at module.parametrizations[attr_name].quad_penalty()
, and the penalty parameter can be accessed at module.parametrizations[attr_name].penalty_param
.