

class cdopt.linen.Dense_cdopt(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None, manifold_class = euclidean_jax, weight_var_transfer = <function>, manifold_args )

A linear transformation applied over the last dimension of the input.


  • features: the number of output features.

  • use_bias: whether to add a bias to the output (default: True).

  • dtype: the dtype of the computation (default: infer from input and params).

  • param_dtype: the dtype passed to parameter initializers (default: float32).

  • precision: numerical precision of the computation see jax.lax.Precision for details.

  • kernel_init: initializer function for the weight matrix.

  • bias_init: initializer function for the bias.

  • manifold_class – The manifold class for the weight matrix. Default: cdopt.manifold_torch.euclidean_torch

  • manifold_args - The additional key-word arguments that helps to define the manifold constraints.

  • weight_var_transfer (callable) – The function that transfer the weights (3D-tensor) to the shape of the variables of the manifold.


class cdopt.linen.Conv_cdopt(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None, manifold_class = euclidean_jax, weight_var_transfer = <function>, manifold_args )

Convolution Module wrapping lax.conv_general_dilated. This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last.


  • features: the number of output features.

  • manifold_class – The manifold class for the weight matrix. Default: cdopt.manifold_torch.euclidean_torch

  • manifold_args - The additional key-word arguments that helps to define the manifold constraints.

  • weight_var_transfer (callable) – The function that transfer the weights (3D-tensor) to the shape of the variables of the manifold.


class cdopt.linen.ConvTranspose_cdopt(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None, manifold_class = euclidean_jax, weight_var_transfer = <function>, manifold_args )

Convolution Module wrapping lax.conv_transpose.


  • features: the number of output features.

  • manifold_class – The manifold class for the weight matrix. Default: cdopt.manifold_torch.euclidean_torch

  • manifold_args - The additional key-word arguments that helps to define the manifold constraints.

  • weight_var_transfer (callable) – The function that transfer the weights (3D-tensor) to the shape of the variables of the manifold.