skip-gram模型:使用中心詞來預測上下文單詞。
1. 涉及到的三個重要輸入參數:
- batch_size
- num_skip
- skip_window
2.輸入 參數之間的關係——兩個必要條件
- batch_size % num_skips == 0(整除關係)
- num_skips <= 2 * skip_window
3. 輸出參數
- batch: 由雙向隊列buffer的中心詞構成(即模型訓練時的輸入X)
- labels: 由中心詞的上下文單詞構成(即模型訓練時期望的輸出Y)
4. 三個輸入參數的意義
如圖所表示:
5. code & 實例
code參考來源:使用tensorflow實現word2vec中文詞向量的訓練
def generate_batch(batch_size, num_skips, skip_window):
# data_index有兩個重要的設置:
# *第一是設置爲全局變量,使得每次新生成一個batch時,buffer裏面的數據都不會重複
# *第二是更新方式不是簡單的+1, 而是使用data_index = (data_index + 1) % len(data)
# 目的是使得數據遍歷結束一遍之後又從頭開始遍歷。
global data_index
#--------------兩個非常重要的參數規則判斷----------------------
assert batch_size % num_skips == 0
#如果num_skips > 2 * skip_window,也即num_skips > span-1 則在下面的循環中,無法生成滿足條件的隨機數,陷入死循環中。
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size), dtype=np.int32)
print('初始化的batch', batch)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
print('初始化的labels', labels)
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
#雙向隊列,使用 deque(maxlen=N) 構造函數會新建一個固定大小的隊列。
#當新的元素加入並且這個隊列已滿的時候, 最老的元素會自動被移除掉。
buffer = collections.deque(maxlen=span)
#根據skip_windows的大小,把2 * skip_window + 1個數據放入緩存中
for _ in range(span):
buffer.append(data[data_index])
print('buffer: ', buffer)
data_index = (data_index + 1) % len(data)
print ('data_index', data_index)
print('--------------初始化的數據buffer準備完畢--------------------------')
#共有batch_size // num_skips個unique中心值
for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
# 用來標記不能選擇作爲label的詞(包括中心詞本身和已經選擇過一次的詞),
# 如果已經存在了,則需要再次隨機生成,直到沒有出現過
targets_to_avoid = [skip_window]
#針對每一箇中心詞,尋找上下文的輸出
for j in range(num_skips):
while target in targets_to_avoid:
print('當前', j, '輪的target值:', target, '已經存在,需要重新隨機生成')
target = random.randint(0, span - 1)#生成一個隨機數
print('-----新隨機生成的target值爲', target, '\n')
print('-------------隨機的target生成成功,開始準備數據------------\n')
targets_to_avoid.append(target)
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips + j, 0] = buffer[target]
#print(batch[i * num_skips + j])
#print(labels[i * num_skips + j, 0])
print('當前的batch\n', batch)
print('當前的labels\n', labels)
buffer.append(data[data_index])
print('當前的buffer\n', buffer)
data_index = (data_index + 1) % len(data)
print('最後的batch\n',batch)
print('最後的labels\n', labels)
return batch, labels
#--------------------test-----------------------------
global data_index
data_index = 0
global data
# 這裏的data實際上是真正的單詞索引(基於字典生成的)
data = [1,2,3,4,5,6,7,8,9,11,21,31,41,50,51,52,56,60, 61,62,66,77,88,99,111,112,123,234]
batch, labels = generate_batch(batch_size=9, num_skips=3, skip_window=5)
實例輸出:
初始化的batch [ 5 9 4 2 3 11 31 11 6]
初始化的labels [[7]
[7]
[7]
[8]
[8]
[8]
[9]
[9]
[9]]
buffer: deque([1], maxlen=11)
data_index 1
buffer: deque([1, 2], maxlen=11)
data_index 2
buffer: deque([1, 2, 3], maxlen=11)
data_index 3
buffer: deque([1, 2, 3, 4], maxlen=11)
data_index 4
buffer: deque([1, 2, 3, 4, 5], maxlen=11)
data_index 5
buffer: deque([1, 2, 3, 4, 5, 6], maxlen=11)
data_index 6
buffer: deque([1, 2, 3, 4, 5, 6, 7], maxlen=11)
data_index 7
buffer: deque([1, 2, 3, 4, 5, 6, 7, 8], maxlen=11)
data_index 8
buffer: deque([1, 2, 3, 4, 5, 6, 7, 8, 9], maxlen=11)
data_index 9
buffer: deque([1, 2, 3, 4, 5, 6, 7, 8, 9, 11], maxlen=11)
data_index 10
buffer: deque([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 21], maxlen=11)
data_index 11
--------------初始化的數據buffer準備完畢--------------------------
當前 0 輪的target值: 5 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 5
當前 0 輪的target值: 5 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 9
-------------隨機的target生成成功,開始準備數據------------
當前 1 輪的target值: 9 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 7
-------------隨機的target生成成功,開始準備數據------------
當前 2 輪的target值: 7 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 10
-------------隨機的target生成成功,開始準備數據------------
當前的batch
[ 6 6 6 2 3 11 31 11 6]
當前的labels
[[11]
[ 8]
[21]
[ 8]
[ 8]
[ 8]
[ 9]
[ 9]
[ 9]]
當前的buffer
deque([2, 3, 4, 5, 6, 7, 8, 9, 11, 21, 31], maxlen=11)
當前 0 輪的target值: 5 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 1
-------------隨機的target生成成功,開始準備數據------------
當前 1 輪的target值: 1 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 8
-------------隨機的target生成成功,開始準備數據------------
當前 2 輪的target值: 8 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 4
-------------隨機的target生成成功,開始準備數據------------
當前的batch
[ 6 6 6 7 7 7 31 11 6]
當前的labels
[[11]
[ 8]
[21]
[ 3]
[11]
[ 6]
[ 9]
[ 9]
[ 9]]
當前的buffer
deque([3, 4, 5, 6, 7, 8, 9, 11, 21, 31, 41], maxlen=11)
當前 0 輪的target值: 5 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 9
-------------隨機的target生成成功,開始準備數據------------
當前 1 輪的target值: 9 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 9
當前 1 輪的target值: 9 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 8
-------------隨機的target生成成功,開始準備數據------------
當前 2 輪的target值: 8 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 4
-------------隨機的target生成成功,開始準備數據------------
當前的batch
[6 6 6 7 7 7 8 8 8]
當前的labels
[[11]
[ 8]
[21]
[ 3]
[11]
[ 6]
[31]
[21]
[ 7]]
當前的buffer
deque([4, 5, 6, 7, 8, 9, 11, 21, 31, 41, 50], maxlen=11)
最後的batch
[6 6 6 7 7 7 8 8 8]
最後的labels
[[11]
[ 8]
[21]
[ 3]
[11]
[ 6]
[31]
[21]
[ 7]]