雙向聯想記憶網絡的原理與python簡單實現

原理

聯想記憶功能分爲自聯想和異聯想,異聯想也稱爲雙向聯想記憶,簡寫爲BAM。BAM存儲器可存儲兩組矢量,若有如下N維矢量A與P維矢量B:
A=[a0,a1,,aN1]{1,1}NA=[a_0,a_1,\dots,a_{N-1}]\in \{-1,1\}^N
B=[b0,b1,,bP1]{1,1}PB=[b_0,b_1,\dots,b_{P-1}]\in \{-1,1\}^P
構成M對矢量(As,Bs)(A_s, B_s), s = 0,1,…, M-1,將它們存入BAM存儲器即可進行由A到B或由B到A的雙向聯想,即給定A(或B)可經聯想得到對應的標準樣本B(或A),當有噪聲或缺損時,聯想功能可使樣本對復原。

人工神經網絡力圖實現這種功能。Kosko的BAM網絡就是其中的一種。如圖所示,與矢量A相應的一層有N個節點,另一層對應矢量B,有P個節點,兩層間雙向連接。假定B到A的傳輸爲正向,正向的權矩陣爲W,反之,A到B爲反向傳輸,權矩陣爲WTW^T

如果輸入矢量由上層加入,且相應於網絡中B的穩定狀態,則經W之作用產生A穩定狀態。同理,如果輸入矢量在下層,且相應於網絡中A的穩定狀態,經WTW^T之作用產生B穩定狀態。

當任意矢量輸入時,網絡要經若干次迭代計算演變至穩態,過程爲
WB(t)A(t+1)WTA(t+1)B(t+2)WB(t+2)A(t+3)WB(t)\rightarrow A(t+1)\\ W^TA(t+1)\rightarrow B(t+2)\\ WB(t+2)\rightarrow A(t+3)\\ \cdots
直到A,B爲穩態,演變過程結束。
網絡學習遵從Hebb規則,若給定M個雙極性矢量對
(A0,B0),(A1,B1),(AM1,BM1)(A_0,B_0),(A_1,B_1)…,(A_{M-1},B_{M-1})
則正、反向權矩陣爲
W=s=0M1AsBsTW=\sum_{s=0}^{M-1}A_sB_s^T
根據轉置的性質:
WT=s=0M1BsAsTW^T=\sum_{s=0}^{M-1}B_sA_s^T
如果BAM網絡神經元函數閡值爲0,則稱爲齊次BAMⅥ網終,其能量函
數爲
E(A,B)=12ATWB12BTWTA=ATWBE(A, B)=-\frac{1}{2}A^TWB--\frac{1}{2}B^TW^TA=-A^TWB
若神經元非線性函數爲f,則述齊次BAM動態特性的差分方程爲

  • 正向聯想(B→A)
    ai(t+1)=f[j=1Pwijbj(t)]a_i(t+1)=f\left[\sum_{j=1}^Pw_{ij}b_j(t)\right]
  • 反向聯想(A→B)
    bj(t+1)=f[i=1Nwijai(t+1)]b_j(t+1)=f\left[\sum_{i=1}^Nw_{ij}a_i(t+1)\right]

處理步驟

所處理的數據爲:
A1=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]T\mathbf{A}_{1}=[1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1]^{T}
A2=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]T\mathbf{A}_{2}=[1,1,-1,-1,1,1,-1,-1,1,1,-1,-1,1,1,-1]^{T}
A3=[1,1,1,1,1,1,1,1,1,1,1,1,1]T\mathbf{A}_{3}=[1,1,1,-1,-1,-1,1,1,-1,-1,1,1,1]^{T}
A4=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]T\mathbf{A}_{4}=[1,1,1,1,-1,-1,-1,-1,1,1,1,1,-1,-1,-1]^{T}
B1=[1,1,1,1,1,1,1,1,1,1]TB2=[1,1,1,1,1,1,1,1,1,1]TB3=[1,1,1,1,1,1,1,1,1,1]TB4=[1,1,1,1,1,1,1,1,1,1]T\begin{aligned} \mathbf{B}_{1} &=[1,1,1,1,-1,-1,-1,-1,1,1]^{T} \\ \mathbf{B}_{2} &=[1,1,1,-1,-1,-1,1,1,1,-1]^{T} \\ \mathbf{B}_{3} &=[1,1,-1,-1,1,1,-1,-1,1,1]^{T} \\ \mathbf{B}_{4} &=[1,-1,1,-1,1,-1,1,-1,1,-1]^{T} \end{aligned}

  1. 選擇齊次 Koske型BAM網終編制程序實現聯想記憶。設神經元非線性函數f爲硬限幅函數,即f(x)=sgn(x);當x=0時,神經元的輸出維持不變。
  2. 根據Hebb規則算出網終的連接權矩陣:
    W=[422202024020042020222000202422240020202202224202002000202422022202420420402420224222020240022202420402220202002400202022240020202202220202000222024204]W=\begin{bmatrix} 4 & 2 & 2 & -2 & 0 & -2 & 0 & -2 & 4 & 0 \\ 2 & 0 & 0 & -4 & 2 & 0 & 2 & 0 & 2 & -2 \\ 2 & 0 & 0 & 0 & 2 & 0 & -2 & -4 & 2 & 2 \\ -2 & -4 & 0 & 0 & 2 & 0 & 2 & 0 & -2 & -2 \\ 0 & 2 & 2 & 2 & -4 & -2 & 0 & 2 & 0 & 0 \\ -2 & 0 & 0 & 0 & -2 & 0 & 2 & 4 & -2 & -2 \\ 0 & 2 & -2 & 2 & 0 & 2 & -4 & -2 & 0 & 4 \\ -2 & 0 & -4 & 0 & 2 & 4 & -2 & 0 & -2 & 2 \\ 4 & 2 & 2 & -2 & 0 & -2 & 0 & -2 & 4 & 0 \\ 0 & -2 & 2 & -2 & 0 & -2 & 4 & 2 & 0 & -4 \\ 0 & -2 & 2 & 2 & 0 & -2 & 0 & -2 & 0 & 0 \\ -2 & -4 & 0 & 0 & 2 & 0 & 2 & 0 & -2 & -2 \\ 2 & 4 & 0 & 0 & -2 & 0 & -2 & 0 & 2 & 2 \\ 0 & 2 & -2 & -2 & 0 & 2 & 0 & 2 & 0 & 0 \\ 0 & 2 & -2 & 2 & 0 & 2 & -4 & -2 & 0 & 4 \end{bmatrix}
    四對矢量所對應的穩定狀態的能量值:[-158 -142 -158 -146]
  3. 驗證網絡的聯想能力。
    選擇標準樣本AiA_i輸入網終進行迭代運算直至網絡穩定,觀察上下兩層的狀態是否爲(Ai,Bi)(A_i,B_i)
    經過權值和符號函數的非線性作用,進行一次迭代之後,迭代成功沒有偏差。
    Ai=sgn(WBi),i=1,2,3,4Bi=sgn(WTAi),i=1,2,3,4A_i=sgn(W*B_i),i=1,2,3,4\\ B_i=sgn(W^T*A_i),i=1,2,3,4
  4. 驗證網絡的抗噪能力。
    隨機選取標準矢量A1A_1若干位(如2位)取反形成畸變矢量
    A1=[111111111111111]TA_1'=\begin{bmatrix}1&-1&1&-1&1&-1&-1&-1&1&1&1&-1&1&-1&1\end{bmatrix}^T
    將其輸入網絡迭代至穩態,觀察到結果:
    A1=A1,B1=B1A_1'=A_1,B_1'=B_1
    經過兩次迭代期間能量變化爲:
    [-118, -158]
    網絡能量E在迭代過程中是減小的。
  5. 噪聲大小對聯想能力的影響。
    統計個A矢量分別在1~3位取反時的聯想正確率。記錄所得結果產觀察聯想正確率和取反位數之間的關係
噪聲對網絡聯想正確率的影響:
取反位數 1 2 3
A1 100% 93.4% 81.6%
A2 100% 96.4% 80.4%
A3 100% 90.5% 76.2%
A4 100% 90.15% 75.1%
噪聲對網絡聯想次數的影響:
取反位數 1 2 3
A1 2 1.988 2.091
A2 2 2.0585 2.26
A3 2 2.0055 2.035
A4 2 1.9905 2.0655
  1. 比較正、反向聯想的抗噪能力。統計4個B矢量分別在1-3位取反時的正確率
噪聲對網絡聯想正確率的影響:
取反位數 1 2 3
B1 100% 70.2% 31.4%
B2 80.8% 61.1% 27.1%
B3 90.25% 69.9% 23.55%
B4 78.8% 56.5% 26.35%
噪聲對網絡聯想次數的影響:
取反位數 1 2 3
B1 2. 2.1605 2.219
B2 1.808 2.0555 2.252
B3 1.9025 1.9975 2.2105
B4 1.788 1.9355 2.165
  1. 僞穩定狀態。
    除以上四對矢量對應的穩定狀態外,尋找2個僞穩定狀態並分別計算所對應的能量值。
    第一個樣本所對應的僞穩定性:
    A1=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],B1=[1,1,1,1,1,1,1,1,1,1]A2=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],B2=[1,1,1,0,1,1,1,1,1,1]A_1'=[1,-1,1,-1,1,-1,1,1,1,-1,-1,-1,1,1,1],B_1'=[1,1,-1,1,-1,1,-1,-1,1,1]\\ A_2'=[1,1,1,-1,-1,-1,1,-1,1,-1,1,-1,1,-1,1],B_2'=[1,1,1,0,1,-1,-1,-1,1,1]
    此時對應的能量值分別爲:
    E1=146E2=130E_1=-146\\ E_2=-130
    而第一個樣本所對應的能量值爲-158。

思考

  1. 在步驟4中觀察網絡能量E是如何變化的?
    從步驟4中可以看出,網絡能量隨着迭代次數的增加而減小。這是因爲雙向聯想網絡是Hopfield網絡的一種,所以也具有Hopfeld網絡特性。而對於標準Hopfield網絡其能量函數E對時間的導數爲負值。
    dEdt=iCi(dVidt)2dfi1(Vi)dVi0\frac{\mathrm{d}E}{\mathrm{d}t}=-\sum_i C_i \left(\frac{\mathrm{d}V_i}{\mathrm{d}t}\right)^2 \frac{\mathrm{d}f_i^{-1}(V_i)}{\mathrm{d}V_i}\leq 0
    也就是說網絡息是朝着能量函數E減小的方問運動,且達到穩態時取極小值。

  2. 如果我們想要“擦除”存儲矢量對中的某對(Aa,B3},應如何調整網
    絡?

    雙向聯想記憶網絡特性主要體現在權值W,如果想要擦除存儲矢量中的某對(Ai,Bi)(A_i,B_i),只需對權值進行調整。
    因爲權值爲
    W=s=0M1AsBsTW=\sum_{s=0}^{M-1}A_sB_s^T
    所以爲了儘量減小運算量,更新權值爲:
    W=WAiBiTW'=W-A_iB_i^T

  3. 通過總結笫5步和第6步結果,能得岀什麼結論?

  • 隨着網絡噪聲的增大,網絡聯想正確率下降,說明網絡恢復的能力是有限的。
  • 網絡對不同樣本組的記憶力大體相同。
  • 隨着噪聲的增大,網絡達到穩態所需的迭代次數略有增加。

源代碼

所使用語言爲python:

1. 計算網絡

import numpy as np
import random;

A= np.array([[1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1],[1,1,-1,-1,1,1,-1,-1,1,1,-1,-1,1,1,-1],[1,1,1,-1,-1,-1,1,1,1,-1,-1,-1,1,1,1],
                       [1,1,1,1,-1,-1,-1,-1,1,1,1,1,-1,-1,-1]] )
B=np.array([[1,1,1,1,-1,-1,-1,-1,1,1],
[1,1,1,-1,-1,-1,1,1,1,-1],
[1,1,-1,-1,1,1,-1,-1,1,1],
[1,-1,1,-1,1,-1,1,-1,1,-1]])
A=A.T
B=B.T
# 計算權值矩陣
w=np.dot(A,B.T)
# 計算能量函數
e=np.zeros(4)
for i in range(4):
    e[i]=-np.dot(np.dot(A[:,i].T,w),B[:,i])
print(e)

2. 驗證網絡能力:

a0=A[:,0]

r=random.sample(range(len(A)),2)
a=a0.copy()
a[r]=-a[r]
print(a0)
print(a)
temp=np.zeros((15,1),dtype=int)
# temp=0
# print(temp)
tempB=np.sign(np.dot(w.T,a))

# temp=np.array( (15,4),dtype=int )

timeNum=0
errArr=[]
while not((temp==a).all()):
    temp=a.copy()
    b=np.sign(np.dot(w.T,a))
    b[np.where(b==0)]=tempB[np.where(b==0)]
    tempB=b.copy()
    ee=-np.dot(np.dot(a.T,w),b)
    a=np.sign(np.dot(w,b))
    a[np.where(a==0)]=temp[np.where(a==0)]
    errArr.append(ee)
    timeNum+=1
#     print(temp==A)
#     print(tempB==B)
print(timeNum)
print(a)
print(errArr)
  1. 驗證網絡抗噪能力
    驗證對A樣本的抗噪能力:
timeAll=np.zeros((4,3),dtype=int)
correctNum=np.zeros((4,3),dtype=int)
for i in range(4):
    for j in range(1,4):
        a0=A[:,i]
        
        for k in range(2000):
            r=random.sample(range(len(A)),j)
            a=a0.copy()
            a[r]=-a[r]
            temp=np.zeros((15,1),dtype=int)
            # temp=0
            # print(temp)
            tempB=np.sign(np.dot(w.T,a))
            timeNum=0
            errArr=[]
            while not((temp==a).all()):
                temp=a.copy()
                b=np.sign(np.dot(w.T,a))
                b[np.where(b==0)]=tempB[np.where(b==0)]
                tempB=b.copy()
                ee=-np.dot(np.dot(a.T,w),b)
                a=np.sign(np.dot(w,b))
                a[np.where(a==0)]=temp[np.where(a==0)]
                errArr.append(ee)
                timeNum+=1
            #     print(temp==A)
            #     print(tempB==B)
            timeAll[i][j-1]+=timeNum
            if (a0==a).all():
                correctNum[i][j-1]+=1
            #else:
                #print(i,a,b,ee)
print(timeAll/2000.0,correctNum/2000.0*100)

驗證對B樣本的抗噪能力:

timeAll=np.zeros((4,3),dtype=int)
correctNum=np.zeros((4,3),dtype=int)
for i in range(4):
    for j in range(1,4):
        b0=B[:,i]
        for k in range(2000):
            r=random.sample(range(len(B)),j)
            b=b0.copy()
            b[r]=-b[r]
            temp=np.zeros((10,1),dtype=int)
            # temp=0
            # print(temp)
            tempA=np.sign(np.dot(w,b))
#             tempA[np.where(tempB==0)]=1
            timeNum=0
            errArr=[]
            while not((temp==b).all()):
                temp=b.copy()
                a=np.sign(np.dot(w,b))
                a[np.where(a==0)]=tempA[np.where(a==0)]
                tempA=a.copy()
                ee=-np.dot(np.dot(a.T,w),b)
                b=np.sign(np.dot(w.T,a))
                b[np.where(b==0)]=temp[np.where(b==0)]
                errArr.append(ee)
                timeNum+=1
            #     print(temp==A)
            #     print(tempB==B)
            timeAll[i][j-1]+=timeNum
            if (b0==b).all():
                correctNum[i][j-1]+=1
print(timeAll/2000.0,correctNum/2000.0*100)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章