cdopt.nn.utils.stateless#

stateless.functional_call#

cdopt.nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs=None)

Performs a functional call on the module by replacing the module parameters and buffers with the provided ones. This is the copy of the functional_call from PyTorch >=1.12.

Parameters:

  • module (torch.nn.Module) – the module to call

  • parameters_and_buffers (dict of str and Tensor) – the parameters that will be used in the module call.

  • args (tuple) – arguments to be passed to the module call

  • kwargs (dict) – keyword arguments to be passed to the module call

Returns:

  • the result of calling module.

stateless.get_quad_penalty_call#

cdopt.nn.utils.stateless.get_quad_penalty_call(module, parameters_and_buffers)

Call cdopt.nn.get_quad_penalty() to the module by replacing the module parameters and buffers with the provided ones.

Parameters:

  • module (torch.nn.Module) – the module to call

  • parameters_and_buffers (dict of str and Tensor) – the parameters that will be used in the module call.

Returns:

  • the result of calling cdopt.nn.get_quad_penalty() to the module.

stateless.functional_quad_penalty_call#

cdopt.nn.utils.stateless.functional_quad_penalty_call(module, parameters_and_buffers, args, kwargs=None)

Calling the function and cdopt.nn.get_quad_penalty() simultaneously on the module by replacing the module parameters and buffers with the provided ones.

Parameters:

  • module (torch.nn.Module) – the module to call

  • parameters_and_buffers (dict of str and Tensor) – the parameters that will be used in the module call.

  • args (tuple) – arguments to be passed to the module call

  • kwargs (dict) – keyword arguments to be passed to the module call

Returns:

  • the result of calling module and cdopt.nn.get_quad_penalty().