dropout实现过程

1、dropout可以用来防止过拟合
pytorch中实现如下:

m = nn.Dropout(p=0.2)
input = torch.randn(2, 5)
print()
output = m(input)
print(input)
print(output)

输出如下
在这里插入图片描述
实际上,dropout不只mask掉某个位置的数,而且还将保留的数进行缩放,缩放比例为
p1p{\frac{p}{1-p}}
这里计算出来缩放比例为1.25。-0.1673*1.25=0.209125
dropout的作用在于对某个字(seq维度上)一共有384个特征,随机剔除几个特征,并放大剩余特征。

2、Dropout(x)的后续作用
Dropout: [batch_size, seq, hidden_dim]
W: [hidden_dim, 10]
在这里插入图片描述
上图中Dropout(x)维度为[1, 2, 4],对图中Dropout(x)的第一行来说,W的第二行数据失效,第二行类似。因为seq维度不一定是独立的。
因为每个batch中的数据两两之间可以认为是独立的,对Dropout(x)每一行来说,都相当于独立训练一个分类器。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章