skip-gram模型:使用中心词来预测上下文单词。
1. 涉及到的三个重要输入参数:
- batch_size
- num_skip
- skip_window
2.输入 参数之间的关系——两个必要条件
- batch_size % num_skips == 0(整除关系)
- num_skips <= 2 * skip_window
3. 输出参数
- batch: 由双向队列buffer的中心词构成(即模型训练时的输入X)
- labels: 由中心词的上下文单词构成(即模型训练时期望的输出Y)
4. 三个输入参数的意义
如图所表示:
5. code & 实例
code参考来源:使用tensorflow实现word2vec中文词向量的训练
def generate_batch(batch_size, num_skips, skip_window):
# data_index有两个重要的设置:
# *第一是设置为全局变量,使得每次新生成一个batch时,buffer里面的数据都不会重复
# *第二是更新方式不是简单的+1, 而是使用data_index = (data_index + 1) % len(data)
# 目的是使得数据遍历结束一遍之后又从头开始遍历。
global data_index
#--------------两个非常重要的参数规则判断----------------------
assert batch_size % num_skips == 0
#如果num_skips > 2 * skip_window,也即num_skips > span-1 则在下面的循环中,无法生成满足条件的随机数,陷入死循环中。
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size), dtype=np.int32)
print('初始化的batch', batch)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
print('初始化的labels', labels)
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
#双向队列,使用 deque(maxlen=N) 构造函数会新建一个固定大小的队列。
#当新的元素加入并且这个队列已满的时候, 最老的元素会自动被移除掉。
buffer = collections.deque(maxlen=span)
#根据skip_windows的大小,把2 * skip_window + 1个数据放入缓存中
for _ in range(span):
buffer.append(data[data_index])
print('buffer: ', buffer)
data_index = (data_index + 1) % len(data)
print ('data_index', data_index)
print('--------------初始化的数据buffer准备完毕--------------------------')
#共有batch_size // num_skips个unique中心值
for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
# 用来标记不能选择作为label的词(包括中心词本身和已经选择过一次的词),
# 如果已经存在了,则需要再次随机生成,直到没有出现过
targets_to_avoid = [skip_window]
#针对每一个中心词,寻找上下文的输出
for j in range(num_skips):
while target in targets_to_avoid:
print('当前', j, '轮的target值:', target, '已经存在,需要重新随机生成')
target = random.randint(0, span - 1)#生成一个随机数
print('-----新随机生成的target值为', target, '\n')
print('-------------随机的target生成成功,开始准备数据------------\n')
targets_to_avoid.append(target)
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips + j, 0] = buffer[target]
#print(batch[i * num_skips + j])
#print(labels[i * num_skips + j, 0])
print('当前的batch\n', batch)
print('当前的labels\n', labels)
buffer.append(data[data_index])
print('当前的buffer\n', buffer)
data_index = (data_index + 1) % len(data)
print('最后的batch\n',batch)
print('最后的labels\n', labels)
return batch, labels
#--------------------test-----------------------------
global data_index
data_index = 0
global data
# 这里的data实际上是真正的单词索引(基于字典生成的)
data = [1,2,3,4,5,6,7,8,9,11,21,31,41,50,51,52,56,60, 61,62,66,77,88,99,111,112,123,234]
batch, labels = generate_batch(batch_size=9, num_skips=3, skip_window=5)
实例输出:
初始化的batch [ 5 9 4 2 3 11 31 11 6]
初始化的labels [[7]
[7]
[7]
[8]
[8]
[8]
[9]
[9]
[9]]
buffer: deque([1], maxlen=11)
data_index 1
buffer: deque([1, 2], maxlen=11)
data_index 2
buffer: deque([1, 2, 3], maxlen=11)
data_index 3
buffer: deque([1, 2, 3, 4], maxlen=11)
data_index 4
buffer: deque([1, 2, 3, 4, 5], maxlen=11)
data_index 5
buffer: deque([1, 2, 3, 4, 5, 6], maxlen=11)
data_index 6
buffer: deque([1, 2, 3, 4, 5, 6, 7], maxlen=11)
data_index 7
buffer: deque([1, 2, 3, 4, 5, 6, 7, 8], maxlen=11)
data_index 8
buffer: deque([1, 2, 3, 4, 5, 6, 7, 8, 9], maxlen=11)
data_index 9
buffer: deque([1, 2, 3, 4, 5, 6, 7, 8, 9, 11], maxlen=11)
data_index 10
buffer: deque([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 21], maxlen=11)
data_index 11
--------------初始化的数据buffer准备完毕--------------------------
当前 0 轮的target值: 5 已经存在,需要重新随机生成
-----新随机生成的target值为 5
当前 0 轮的target值: 5 已经存在,需要重新随机生成
-----新随机生成的target值为 9
-------------随机的target生成成功,开始准备数据------------
当前 1 轮的target值: 9 已经存在,需要重新随机生成
-----新随机生成的target值为 7
-------------随机的target生成成功,开始准备数据------------
当前 2 轮的target值: 7 已经存在,需要重新随机生成
-----新随机生成的target值为 10
-------------随机的target生成成功,开始准备数据------------
当前的batch
[ 6 6 6 2 3 11 31 11 6]
当前的labels
[[11]
[ 8]
[21]
[ 8]
[ 8]
[ 8]
[ 9]
[ 9]
[ 9]]
当前的buffer
deque([2, 3, 4, 5, 6, 7, 8, 9, 11, 21, 31], maxlen=11)
当前 0 轮的target值: 5 已经存在,需要重新随机生成
-----新随机生成的target值为 1
-------------随机的target生成成功,开始准备数据------------
当前 1 轮的target值: 1 已经存在,需要重新随机生成
-----新随机生成的target值为 8
-------------随机的target生成成功,开始准备数据------------
当前 2 轮的target值: 8 已经存在,需要重新随机生成
-----新随机生成的target值为 4
-------------随机的target生成成功,开始准备数据------------
当前的batch
[ 6 6 6 7 7 7 31 11 6]
当前的labels
[[11]
[ 8]
[21]
[ 3]
[11]
[ 6]
[ 9]
[ 9]
[ 9]]
当前的buffer
deque([3, 4, 5, 6, 7, 8, 9, 11, 21, 31, 41], maxlen=11)
当前 0 轮的target值: 5 已经存在,需要重新随机生成
-----新随机生成的target值为 9
-------------随机的target生成成功,开始准备数据------------
当前 1 轮的target值: 9 已经存在,需要重新随机生成
-----新随机生成的target值为 9
当前 1 轮的target值: 9 已经存在,需要重新随机生成
-----新随机生成的target值为 8
-------------随机的target生成成功,开始准备数据------------
当前 2 轮的target值: 8 已经存在,需要重新随机生成
-----新随机生成的target值为 4
-------------随机的target生成成功,开始准备数据------------
当前的batch
[6 6 6 7 7 7 8 8 8]
当前的labels
[[11]
[ 8]
[21]
[ 3]
[11]
[ 6]
[31]
[21]
[ 7]]
当前的buffer
deque([4, 5, 6, 7, 8, 9, 11, 21, 31, 41, 50], maxlen=11)
最后的batch
[6 6 6 7 7 7 8 8 8]
最后的labels
[[11]
[ 8]
[21]
[ 3]
[11]
[ 6]
[31]
[21]
[ 7]]