Dictionary Learning Accelerated by JIT#

In this part, we use this simple example to illustrate that the optimization in CDOpt can be accelerated by the JIT compilation from JAX package.

Problem Description#

Given data \(\{y_i\}_{i = 1,...,m}\) generated as \(y_i = Q z_i\), where \(Q\) is a fixed unknown orthogonal matrix and each \(x_i\) folllows iid Bernoulli-Gaussian distributation with parameter \(\theta \in (0,1)\). The goal is to recover \(Z\) and \(Q\) from the given data \(Y = [y_1, ..., y_m]^\top \in \mathrm{R}^{m\times n}\).

Based on the \(\ell_4\)-norm maximization model proposed in [1,2], we can consider the following optimization problem,

\[\begin{split} \begin{aligned} \min_{X = [x_1,...x_n] \in \mathbb{R}^{n\times n}} \quad & f(X) := - \sum_{1\leq i\leq m, 1\leq j\leq n} (y_i^\top x_j)^4\\ \text{s. t.} \quad & X^TX = I_n. \end{aligned} \end{split}\]

This problem is nonconvex due to the nonconvex constraints. The constraints define the Stiefel manifold, hence this problem can be regarded as the smooth optimization problem over the Stiefel manifold.

Importing modules#

We first import all the necessary modules for this optimization problem.

import cdopt 
import numpy as np
import scipy as sp
from scipy.stats import norm
from scipy.sparse import csr_matrix
import time
import jax
import jax.numpy as jnp
import jax.random as random
from jax.config import config
config.update("jax_enable_x64", True)

Generating datas#

We then specify torch device, and generate data

We set the torch device as the GPU for this problem as default setting. If no cuda device available, we switch the device as the CPU.

n = 10        # dimension of the problem
m = 10*n**2   # sample complexity
theta = 0.3   # sparsity level

Y = jnp.asarray(norm.ppf(np.random.rand(m,n)) * (norm.ppf(np.random.rand(m,n)) <= theta))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Set functions and problems#

Then we set the objective function and the Stiefel manifold.

def obj_fun(X):
    return - jnp.sum( (Y @ X) **4 )

M = cdopt.manifold_jax.stiefel_jax((n,n))   # The Stiefel manifold.

Describe the optimization problem#

The optimization problem can be described only by the manifold and the objective function. All the other components are automatically computed by the automatic differentiation algorithms provided in torch.autograd.

problem_test = cdopt.core.problem(M, obj_fun, beta = 500, enable_jit= True)  # describe the optimization problem and set the penalty parameter \beta.
problem_nojit = cdopt.core.problem(M, obj_fun, beta = 500, enable_jit= False)

We first compare the computation time of the gradient. It can be observed that JIT greatly accelerates the computation of the gradient.

X_test = M.Init_point()
%timeit -n 100 -r 3 problem_test.cdf_grad(X_test)
114 µs ± 37.9 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)
%timeit -n 100 -r 3 problem_nojit.cdf_grad(X_test)
265 µs ± 67.1 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)

Apply optimization solvers#

After describe the optimization problem, we can directly function value, gradient and Hessian-vector product from the cdopt.core.Problem class.

# the vectorized function value, gradient and Hessian-vector product of the constraint dissolving function. Their inputs are numpy 1D array, and their outputs are float or numpy 1D array.
cdf_fun_np = problem_test.cdf_fun_vec_np   
cdf_grad_np = problem_test.cdf_grad_vec_np 
cdf_hvp_np = problem_test.cdf_hvp_vec_np


## Apply limit memory BFGS solver from scipy.minimize 
from scipy.optimize import fmin_bfgs, fmin_cg, fmin_l_bfgs_b, fmin_ncg
Xinit = problem_test.Xinit_vec_np  # set initial point
# optimize by L-BFGS method
t_start = time.time()
out_msg = sp.optimize.minimize(cdf_fun_np, Xinit,method='L-BFGS-B',jac = cdf_grad_np, options={'disp': None, 'maxcor': 10, 'ftol': 0, 'gtol': 1e-06, 'eps': 0e-08,})
t_end = time.time() - t_start

# Statistics
feas = M.Feas_eval(M.v2m(M.array2tensor(out_msg.x)))   # Feasibility
stationarity = np.linalg.norm(out_msg['jac'],2)   # stationarity

result_lbfgs = [out_msg['fun'], out_msg['nit'], out_msg['nfev'],stationarity,feas, t_end]

# print results
print('Solver   fval         iter   f_eval   stationarity   feaibility     CPU time')
print('& L-BFGS & {:.2e}  & {:}  & {:}    & {:.2e}     & {:.2e}     & {:.2f} \\\\'.format(*result_lbfgs))
Solver   fval         iter   f_eval   stationarity   feaibility     CPU time
& L-BFGS & -1.88e+04  & 63  & 70    & 2.61e-04     & 1.67e-08     & 0.05 \\

Reference#

  1. Zhai Y, Yang Z, Liao Z, et al. Complete Dictionary Learning via L4-Norm Maximization over the Orthogonal Group[J]. J. Mach. Learn. Res., 2020, 21(165): 1-68.

  2. Hu X, Liu X. An efficient orthonormalization-free approach for sparse dictionary learning and dual principal component pursuit[J]. Sensors, 2020, 20(11): 3041.