word2vec——圖解生成batch數據的code(for skip-gram模型)

skip-gram模型:使用中心詞來預測上下文單詞。

1. 涉及到的三個重要輸入參數:

  • batch_size
  • num_skip
  • skip_window

2.輸入 參數之間的關係——兩個必要條件

  • batch_size % num_skips == 0(整除關係)
  • num_skips <= 2 * skip_window

3. 輸出參數

  • batch:  由雙向隊列buffer的中心詞構成(即模型訓練時的輸入X)
  • labels:  由中心詞的上下文單詞構成(即模型訓練時期望的輸出Y)

4. 三個輸入參數的意義

如圖所表示:

5. code & 實例

code參考來源:使用tensorflow實現word2vec中文詞向量的訓練

def generate_batch(batch_size, num_skips, skip_window):
    
    # data_index有兩個重要的設置:
    #    *第一是設置爲全局變量,使得每次新生成一個batch時,buffer裏面的數據都不會重複
    #    *第二是更新方式不是簡單的+1, 而是使用data_index = (data_index + 1) % len(data)
    #     目的是使得數據遍歷結束一遍之後又從頭開始遍歷。
    global data_index
      
    #--------------兩個非常重要的參數規則判斷----------------------
    assert batch_size % num_skips == 0
    #如果num_skips > 2 * skip_window,也即num_skips > span-1 則在下面的循環中,無法生成滿足條件的隨機數,陷入死循環中。
    assert num_skips <= 2 * skip_window
       
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)
    print('初始化的batch', batch)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
    print('初始化的labels', labels)
    span = 2 * skip_window + 1  # [ skip_window target skip_window ]
    #雙向隊列,使用 deque(maxlen=N) 構造函數會新建一個固定大小的隊列。
    #當新的元素加入並且這個隊列已滿的時候, 最老的元素會自動被移除掉。
    buffer = collections.deque(maxlen=span)
    
    #根據skip_windows的大小,把2 * skip_window + 1個數據放入緩存中
    for _ in range(span):
        buffer.append(data[data_index])
        print('buffer: ', buffer)
        data_index = (data_index + 1) % len(data)
        print ('data_index', data_index)
        
    print('--------------初始化的數據buffer準備完畢--------------------------')
        
    #共有batch_size // num_skips個unique中心值
    for i in range(batch_size // num_skips):
        target = skip_window  # target label at the center of the buffer
        # 用來標記不能選擇作爲label的詞(包括中心詞本身和已經選擇過一次的詞),
        # 如果已經存在了,則需要再次隨機生成,直到沒有出現過
        targets_to_avoid = [skip_window] 
        #針對每一箇中心詞,尋找上下文的輸出
        for j in range(num_skips):
            while target in targets_to_avoid:
                print('當前', j, '輪的target值:', target, '已經存在,需要重新隨機生成')
                target = random.randint(0, span - 1)#生成一個隨機數
                print('-----新隨機生成的target值爲', target, '\n')
            print('-------------隨機的target生成成功,開始準備數據------------\n')
            targets_to_avoid.append(target)
            batch[i * num_skips + j] = buffer[skip_window]
            labels[i * num_skips + j, 0] = buffer[target]
            #print(batch[i * num_skips + j])
            #print(labels[i * num_skips + j, 0])
        print('當前的batch\n', batch)
        print('當前的labels\n', labels)
        buffer.append(data[data_index])
        print('當前的buffer\n', buffer)
        data_index = (data_index + 1) % len(data)
    print('最後的batch\n',batch)
    print('最後的labels\n', labels)
    return batch, labels

#--------------------test-----------------------------

global data_index
data_index = 0
global data
# 這裏的data實際上是真正的單詞索引(基於字典生成的)
data = [1,2,3,4,5,6,7,8,9,11,21,31,41,50,51,52,56,60, 61,62,66,77,88,99,111,112,123,234]

batch, labels = generate_batch(batch_size=9, num_skips=3, skip_window=5)

實例輸出:

初始化的batch [ 5  9  4  2  3 11 31 11  6]
初始化的labels [[7]
 [7]
 [7]
 [8]
 [8]
 [8]
 [9]
 [9]
 [9]]
buffer:  deque([1], maxlen=11)
data_index 1
buffer:  deque([1, 2], maxlen=11)
data_index 2
buffer:  deque([1, 2, 3], maxlen=11)
data_index 3
buffer:  deque([1, 2, 3, 4], maxlen=11)
data_index 4
buffer:  deque([1, 2, 3, 4, 5], maxlen=11)
data_index 5
buffer:  deque([1, 2, 3, 4, 5, 6], maxlen=11)
data_index 6
buffer:  deque([1, 2, 3, 4, 5, 6, 7], maxlen=11)
data_index 7
buffer:  deque([1, 2, 3, 4, 5, 6, 7, 8], maxlen=11)
data_index 8
buffer:  deque([1, 2, 3, 4, 5, 6, 7, 8, 9], maxlen=11)
data_index 9
buffer:  deque([1, 2, 3, 4, 5, 6, 7, 8, 9, 11], maxlen=11)
data_index 10
buffer:  deque([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 21], maxlen=11)
data_index 11
--------------初始化的數據buffer準備完畢--------------------------
當前 0 輪的target值: 5 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 5 

當前 0 輪的target值: 5 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 9 

-------------隨機的target生成成功,開始準備數據------------

當前 1 輪的target值: 9 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 7 

-------------隨機的target生成成功,開始準備數據------------

當前 2 輪的target值: 7 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 10 

-------------隨機的target生成成功,開始準備數據------------

當前的batch
 [ 6  6  6  2  3 11 31 11  6]
當前的labels
 [[11]
 [ 8]
 [21]
 [ 8]
 [ 8]
 [ 8]
 [ 9]
 [ 9]
 [ 9]]
當前的buffer
 deque([2, 3, 4, 5, 6, 7, 8, 9, 11, 21, 31], maxlen=11)
當前 0 輪的target值: 5 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 1 

-------------隨機的target生成成功,開始準備數據------------

當前 1 輪的target值: 1 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 8 

-------------隨機的target生成成功,開始準備數據------------

當前 2 輪的target值: 8 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 4 

-------------隨機的target生成成功,開始準備數據------------

當前的batch
 [ 6  6  6  7  7  7 31 11  6]
當前的labels
 [[11]
 [ 8]
 [21]
 [ 3]
 [11]
 [ 6]
 [ 9]
 [ 9]
 [ 9]]
當前的buffer
 deque([3, 4, 5, 6, 7, 8, 9, 11, 21, 31, 41], maxlen=11)
當前 0 輪的target值: 5 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 9 

-------------隨機的target生成成功,開始準備數據------------

當前 1 輪的target值: 9 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 9 

當前 1 輪的target值: 9 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 8 

-------------隨機的target生成成功,開始準備數據------------

當前 2 輪的target值: 8 已經存在,需要重新隨機生成
-----新隨機生成的target值爲 4 

-------------隨機的target生成成功,開始準備數據------------

當前的batch
 [6 6 6 7 7 7 8 8 8]
當前的labels
 [[11]
 [ 8]
 [21]
 [ 3]
 [11]
 [ 6]
 [31]
 [21]
 [ 7]]
當前的buffer
 deque([4, 5, 6, 7, 8, 9, 11, 21, 31, 41, 50], maxlen=11)
最後的batch
 [6 6 6 7 7 7 8 8 8]
最後的labels
 [[11]
 [ 8]
 [21]
 [ 3]
 [11]
 [ 6]
 [31]
 [21]
 [ 7]]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章