深度學習:深度壓縮感知-從ISTA到LISTA及其pytorch實現方法

摘要:傳統的壓縮感知方法在重構時的速度通常比較慢。通過將深度學習和壓縮感知結合,可以大大提高重構速度。Learned Iterative Shrinkage and Thresholding Algorithm (LISTA)應該是用深度學習方法求解壓縮感知最早的方法,本文簡單總結一下LISTA,並給出ISTA和LISTA的具體實現,分別使用Python和Pytorch,並對仿真的稀疏信號進行重構。

參考文獻

【1】A Fast Iterative Shrinkage-Thresholding Algorithm

【2】Learning Fast Approximations of Sparse Coding

【3】近端梯度下降方法及三個簡單例子,軟閾值,硬閾值和ReLU。

目錄

  1. Learning Fast Approximations of Sparse Coding 論文閱讀
  2. ISTA算法Python實現
  3. LISTA算法Pytorch實現
  4. 一個稀疏重構的算例

1. Learning Fast Approximations of Sparse Coding 論文閱讀

摘要:稀疏編碼中,輸入向量通過稀疏基向量的線性組合進行重構。稀疏編碼已經成爲從數據中獲取特徵的流行方法。對於給定的輸入向量,稀疏編碼最小化二次重構誤差以及編碼的一範數約束。這個過程在實際應用中通常很慢,例如實時的模式識別。本文我們給出兩種能給出稀疏編碼的快速算法,可以用於特徵提取或用於初始化特定的迭代算法。主要思想是訓練一個具有特定結構和固定深度的非線性前向估計器,用於估計稀疏編碼的最佳近似。這裏只關注LISTA,不關注另一種方法。

方法:ISTA算法僞代碼如下,其中XX是測量值,WdW_d是字典矩陣,ZZ是稀疏編碼,α\alpha是稀疏相關係數,LL是Lipschitz常量。

算法推導可以通過近端梯度下降方法【3】得到。將ISTA加入動量,很容易就能擴展爲FISTA方法【1】。
在這裏插入圖片描述
ISTA和LISTA的算法框圖如下。其中,上圖是ISTA的框圖,其中的符號和算法僞代碼中一致;下圖爲LISTA,全部是全連接結構,激活函數採用ISTA的shrinkage function,WWSS都是通過訓練學習到的。
在這裏插入圖片描述
下面通過代碼進一步理解ISTA和LISTA。

2. ISTA算法python實現

使用的符號和上面的僞代碼一致。

定義shrinkage function和ISTA的迭代,迭代結果返回稀疏解以及重構誤差。算例和LISTA一起在最後給出。

import numpy as np

def shrinkage(x, theta):
    return np.multiply(np.sign(x), np.maximum(np.abs(x) - theta, 0))

def ista(X, W_d, a, L, max_iter, eps):

    eig, eig_vector = np.linalg.eig(W_d.T * W_d)
    assert L > np.max(eig)
    del eig, eig_vector
    
    W_e = W_d.T / L

    recon_errors = []
    Z_old = np.zeros((W_d.shape[1], 1))
    for i in range(max_iter):
        temp = W_d * Z_old - X
        Z_new = shrinkage(Z_old - W_e * temp, a / L)
        if np.sum(np.abs(Z_new - Z_old)) <= eps: break
        Z_old = Z_new
        recon_error = np.linalg.norm(X - W_d * Z_new, 2) ** 2
        recon_errors.append(recon_error)
        
    return Z_new, recon_errors

3. LISTA算法Pytorch實現

按照上面的算法框圖進行,網絡只有兩個學習參數W和S,但需要在一步優化中重複迭代多次,有代碼中的max_iter確定。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class LISTA(nn.Module):
    def __init__(self, n, m, W_e, max_iter, L, theta):
        """
        # Arguments
            n: int, dimensions of the measurement
            m: int, dimensions of the sparse signal
            W_e: array, dictionary
            max_iter:int, max number of internal iteration
            L: Lipschitz const 
            theta: Thresholding
        """
        
        super(LISTA, self).__init__()
        self._W = nn.Linear(in_features=n, out_features=m, bias=False)
        self._S = nn.Linear(in_features=m, out_features=m,
                            bias=False)
        self.shrinkage = nn.Softshrink(theta)
        self.theta = theta
        self.max_iter = max_iter
        self.A = W_e
        self.L = L
        
    # weights initialization based on the dictionary
    def weights_init(self):
        A = self.A.cpu().numpy()
        L = self.L
        S = torch.from_numpy(np.eye(A.shape[1]) - (1/L)*np.matmul(A.T, A))
        S = S.float().to(device)
        W = torch.from_numpy((1/L)*A.T)
        W = W.float().to(device)
        
        self._S.weight = nn.Parameter(S)
        self._W.weight = nn.Parameter(W)


    def forward(self, y):
        x = self.shrinkage(self._W(y))

        if self.max_iter == 1 :
            return x

        for iter in range(self.max_iter):
            x = self.shrinkage(self._W(y) + self._S(x))

        return x

def train_lista(Y, dictionary, a, L, max_iter=30):
    
    n, m = dictionary.shape
    n_samples = Y.shape[0]
    batch_size = 128
    steps_per_epoch = n_samples // batch_size
    
    # convert the data into tensors
    Y = torch.from_numpy(Y)
    Y = Y.float().to(device)
    
    W_d = torch.from_numpy(dictionary)
    W_d = W_d.float().to(device)

    net = LISTA(n, m, W_d, max_iter=30, L=L, theta=a/L)
    net = net.float().to(device)
    net.weights_init()

    # build the optimizer and criterion
    learning_rate = 1e-2
    criterion1 = nn.MSELoss()
    criterion2 = nn.L1Loss()
    all_zeros = torch.zeros(batch_size, m).to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate,  momentum=0.9)

    loss_list = []
    for epoch in range(100):
        index_samples = np.random.choice(a=n_samples, size=n_samples, replace=False, p=None)
        Y_shuffle = Y[index_samples]
        for step in range(steps_per_epoch):
            Y_batch = Y_shuffle[step*batch_size:(step+1)*batch_size]
            optimizer.zero_grad()
    
            # get the outputs
            X_h = net(Y_batch)
            Y_h = torch.mm(X_h, W_d.T)
    
            # compute the losss
            loss1 = criterion1(Y_batch.float(), Y_h.float()) 
            loss2 = a * criterion2(X_h.float(), all_zeros.float())
            loss = loss1 + loss2
            
            loss.backward()
            optimizer.step()  
    
            with torch.no_grad():
                loss_list.append(loss.detach().data)     
            
    return net, loss_list

4. 一個稀疏重構的算例

對於LISTA,需要額外多一個訓練的步驟。

下面先進行LISTA的訓練。

稀疏信號維度爲1000,測量信號維度爲256,信號稀疏度爲5,使用5000個訓練樣本。

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import orth

# dimensions of the sparse signal, measurement and sparsity
m, n, k = 1000, 256, 5
# number of test examples
N = 5000

# generate dictionary
Psi = np.eye(m)
Phi = np.random.randn(n, m)
Phi = np.transpose(orth(np.transpose(Phi)))
W_d = np.dot(Phi, Psi)
print(W_d.shape)

# generate sparse signal Z and measurement X
Z = np.zeros((N, m))
X = np.zeros((N, n))
for i in range(N):
    index_k = np.random.choice(a=m, size=k, replace=False, p=None)
    Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1,])
    X[i] = np.dot(W_d, Z[i, :])

print(X.shape)
print(X[0].shape)

# computing average reconstruction-SNR
net, err_list = train_lista(X, W_d, 0.1, 2)

通過訓練,得到用於重構稀疏解的網絡,將其和ISTA算法的測試結果進行對比。

# Test stage
# generate sparse signal Z and measurement X
Z = np.zeros((1, m))
X = np.zeros((1, n))
for i in range(1):
    index_k = np.random.choice(a=m, size=k, replace=False, p=None)
    Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1,])
    X[i] = np.dot(W_d, Z[i, :])

Z_recon = net(torch.from_numpy(X).float().to(device))
Z_recon = Z_recon.detach().cpu().numpy()
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(X[0])
plt.subplot(2,1,2)
plt.plot(Z[0], label='real')
plt.subplot(2,1,2)
plt.plot(Z_recon[0], '.-', label='LISTA')

# ISTA
Z_recon, recon_errors = ista(np.mat(X).T, np.mat(W_d), 0.1, 2, 1000, 0.00001)
plt.subplot(2, 1, 2)
plt.plot(Z_recon, '--', label='ISTA')
plt.legend()

結果如下圖所示,這裏兩種方法都很好的算出稀疏解。上子圖是觀測信號,下子圖爲稀疏信號。
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章