摘要:傳統的壓縮感知方法在重構時的速度通常比較慢。通過將深度學習和壓縮感知結合,可以大大提高重構速度。Learned Iterative Shrinkage and Thresholding Algorithm (LISTA)應該是用深度學習方法求解壓縮感知最早的方法,本文簡單總結一下LISTA,並給出ISTA和LISTA的具體實現,分別使用Python和Pytorch,並對仿真的稀疏信號進行重構。
參考文獻
【1】A Fast Iterative Shrinkage-Thresholding Algorithm
【2】Learning Fast Approximations of Sparse Coding
【3】近端梯度下降方法及三個簡單例子,軟閾值,硬閾值和ReLU。
目錄
- Learning Fast Approximations of Sparse Coding 論文閱讀
- ISTA算法Python實現
- LISTA算法Pytorch實現
- 一個稀疏重構的算例
1. Learning Fast Approximations of Sparse Coding 論文閱讀
摘要:稀疏編碼中,輸入向量通過稀疏基向量的線性組合進行重構。稀疏編碼已經成爲從數據中獲取特徵的流行方法。對於給定的輸入向量,稀疏編碼最小化二次重構誤差以及編碼的一範數約束。這個過程在實際應用中通常很慢,例如實時的模式識別。本文我們給出兩種能給出稀疏編碼的快速算法,可以用於特徵提取或用於初始化特定的迭代算法。主要思想是訓練一個具有特定結構和固定深度的非線性前向估計器,用於估計稀疏編碼的最佳近似。這裏只關注LISTA,不關注另一種方法。
方法:ISTA算法僞代碼如下,其中是測量值,是字典矩陣,是稀疏編碼,是稀疏相關係數,是Lipschitz常量。
算法推導可以通過近端梯度下降方法【3】得到。將ISTA加入動量,很容易就能擴展爲FISTA方法【1】。
ISTA和LISTA的算法框圖如下。其中,上圖是ISTA的框圖,其中的符號和算法僞代碼中一致;下圖爲LISTA,全部是全連接結構,激活函數採用ISTA的shrinkage function,和都是通過訓練學習到的。
下面通過代碼進一步理解ISTA和LISTA。
2. ISTA算法python實現
使用的符號和上面的僞代碼一致。
定義shrinkage function和ISTA的迭代,迭代結果返回稀疏解以及重構誤差。算例和LISTA一起在最後給出。
import numpy as np
def shrinkage(x, theta):
return np.multiply(np.sign(x), np.maximum(np.abs(x) - theta, 0))
def ista(X, W_d, a, L, max_iter, eps):
eig, eig_vector = np.linalg.eig(W_d.T * W_d)
assert L > np.max(eig)
del eig, eig_vector
W_e = W_d.T / L
recon_errors = []
Z_old = np.zeros((W_d.shape[1], 1))
for i in range(max_iter):
temp = W_d * Z_old - X
Z_new = shrinkage(Z_old - W_e * temp, a / L)
if np.sum(np.abs(Z_new - Z_old)) <= eps: break
Z_old = Z_new
recon_error = np.linalg.norm(X - W_d * Z_new, 2) ** 2
recon_errors.append(recon_error)
return Z_new, recon_errors
3. LISTA算法Pytorch實現
按照上面的算法框圖進行,網絡只有兩個學習參數W和S,但需要在一步優化中重複迭代多次,有代碼中的max_iter確定。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class LISTA(nn.Module):
def __init__(self, n, m, W_e, max_iter, L, theta):
"""
# Arguments
n: int, dimensions of the measurement
m: int, dimensions of the sparse signal
W_e: array, dictionary
max_iter:int, max number of internal iteration
L: Lipschitz const
theta: Thresholding
"""
super(LISTA, self).__init__()
self._W = nn.Linear(in_features=n, out_features=m, bias=False)
self._S = nn.Linear(in_features=m, out_features=m,
bias=False)
self.shrinkage = nn.Softshrink(theta)
self.theta = theta
self.max_iter = max_iter
self.A = W_e
self.L = L
# weights initialization based on the dictionary
def weights_init(self):
A = self.A.cpu().numpy()
L = self.L
S = torch.from_numpy(np.eye(A.shape[1]) - (1/L)*np.matmul(A.T, A))
S = S.float().to(device)
W = torch.from_numpy((1/L)*A.T)
W = W.float().to(device)
self._S.weight = nn.Parameter(S)
self._W.weight = nn.Parameter(W)
def forward(self, y):
x = self.shrinkage(self._W(y))
if self.max_iter == 1 :
return x
for iter in range(self.max_iter):
x = self.shrinkage(self._W(y) + self._S(x))
return x
def train_lista(Y, dictionary, a, L, max_iter=30):
n, m = dictionary.shape
n_samples = Y.shape[0]
batch_size = 128
steps_per_epoch = n_samples // batch_size
# convert the data into tensors
Y = torch.from_numpy(Y)
Y = Y.float().to(device)
W_d = torch.from_numpy(dictionary)
W_d = W_d.float().to(device)
net = LISTA(n, m, W_d, max_iter=30, L=L, theta=a/L)
net = net.float().to(device)
net.weights_init()
# build the optimizer and criterion
learning_rate = 1e-2
criterion1 = nn.MSELoss()
criterion2 = nn.L1Loss()
all_zeros = torch.zeros(batch_size, m).to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
loss_list = []
for epoch in range(100):
index_samples = np.random.choice(a=n_samples, size=n_samples, replace=False, p=None)
Y_shuffle = Y[index_samples]
for step in range(steps_per_epoch):
Y_batch = Y_shuffle[step*batch_size:(step+1)*batch_size]
optimizer.zero_grad()
# get the outputs
X_h = net(Y_batch)
Y_h = torch.mm(X_h, W_d.T)
# compute the losss
loss1 = criterion1(Y_batch.float(), Y_h.float())
loss2 = a * criterion2(X_h.float(), all_zeros.float())
loss = loss1 + loss2
loss.backward()
optimizer.step()
with torch.no_grad():
loss_list.append(loss.detach().data)
return net, loss_list
4. 一個稀疏重構的算例
對於LISTA,需要額外多一個訓練的步驟。
下面先進行LISTA的訓練。
稀疏信號維度爲1000,測量信號維度爲256,信號稀疏度爲5,使用5000個訓練樣本。
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import orth
# dimensions of the sparse signal, measurement and sparsity
m, n, k = 1000, 256, 5
# number of test examples
N = 5000
# generate dictionary
Psi = np.eye(m)
Phi = np.random.randn(n, m)
Phi = np.transpose(orth(np.transpose(Phi)))
W_d = np.dot(Phi, Psi)
print(W_d.shape)
# generate sparse signal Z and measurement X
Z = np.zeros((N, m))
X = np.zeros((N, n))
for i in range(N):
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1,])
X[i] = np.dot(W_d, Z[i, :])
print(X.shape)
print(X[0].shape)
# computing average reconstruction-SNR
net, err_list = train_lista(X, W_d, 0.1, 2)
通過訓練,得到用於重構稀疏解的網絡,將其和ISTA算法的測試結果進行對比。
# Test stage
# generate sparse signal Z and measurement X
Z = np.zeros((1, m))
X = np.zeros((1, n))
for i in range(1):
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1,])
X[i] = np.dot(W_d, Z[i, :])
Z_recon = net(torch.from_numpy(X).float().to(device))
Z_recon = Z_recon.detach().cpu().numpy()
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(X[0])
plt.subplot(2,1,2)
plt.plot(Z[0], label='real')
plt.subplot(2,1,2)
plt.plot(Z_recon[0], '.-', label='LISTA')
# ISTA
Z_recon, recon_errors = ista(np.mat(X).T, np.mat(W_d), 0.1, 2, 1000, 0.00001)
plt.subplot(2, 1, 2)
plt.plot(Z_recon, '--', label='ISTA')
plt.legend()
結果如下圖所示,這裏兩種方法都很好的算出稀疏解。上子圖是觀測信號,下子圖爲稀疏信號。