這個代碼其實是別人寫的pytorch的實現:GitHub
code2seq復現
數據
test|reset test,Nm0|MarkerExpr|Mth|Void1,void test,Nm0|MarkerExpr|Mth|Nm2,METHOD_NAME void,Void1|Mth|Nm2,METHOD_NAME
數據按行存,通過空格分隔開。其中,第一項test|reset
是方法名,用豎線|
分隔爲subtoken
,其餘的項是AST PATH
。
AST PATH
由三個部分組成,通過逗號,
分隔開。
第一項和第三項是AST PATH
開始的token
和結束的token
,通過豎線|
分隔爲subtoken
。
第二項是AST PATH
中的結點,也通過豎線|
分隔。
with open(path, 'r') as f:
target, *syntax_path = f.readline().split(' ')
target = target.split('|')
for path in syntax_path
terminal1, ast_path, terminal2 = path.split(',')
terminal1 = terminal1.split('|')
ast_path = ast_path.split('|')
terminal2 = terminal2.split('|')
其他注意,空格,換行符之類的
Data Loader
相比於NMT
的數據(如英譯中),code2seq
的數據項要多一些,
每次需要取AST nodes
,兩類subtoken
,以及方法名target
爲了性能,方便GPU加速訓練做了很多其他的事情
- 記錄
AST
結點數據原始長度 - 填充數據到一樣的長度
- 按原始長度從大到小排序
- 填充數據的壓縮與解壓
最後給出的數據是
(B*k, l)
大小的subtoken
(l, B*k)
大小的nodes
其中,B
是batch_size
,k
是每個樣本有多少個AST PATH
,l
是最大的(subtoken or nodes)
長度
注意:B*k
僅僅是估算,實際上應該是
路徑表示
AST路徑是由結點組成的,每個結點由矩陣表示。
用雙向LSTM對結點進行編碼,dropout爲0.5。
# 定義層 total_nodes是有多少種AST結點
embedding_node = nn.Embedding(total_nodes, 128)
lstm = nn.LSTM(128, 128, bidirectional=True, dropout=0.5)
# 使用層 輸入的結點是batch_N (l, B*k) 其原始長度爲lengths_N (B)
# (l, B*k) -> (l, B*k, 128) -> ??? -> (B*k, 256)
# output: ()
# hidden: (num_layers*2, B*k, 128)
encode_N = embedding_node(batch_N)
packed_N = pack_padded_sequence(encode_N, lengths_N)
output, (hidden, _) = lstm(packed_N)
# (2, B*k, 128) -> (2, B*k, 128) -> (B*k, 2, 128) -> (B*k, 1, 256)
hidden = hidden[-2:,:,:]
hidden = hideen.transpose(0, 1)
hidden = hidden.contiguous().view(B*k, 1, -1)
encode_N = hidden.squeeze(1)
#
公式中有將雙向的隱藏層連結起來,pytorch的lstm實現中,是默認連着的
token表示
每個AST路徑加開始和結束的終止符,這些終止符也被當做token。token會被進一步分割爲subtoken,類似NMT
中的byte-pair encoding
。最後,把所有子token的矩陣表示加起來,作爲token的表示。
# 定義層 total_subtoken表示有多少種subtoken
embedding_subtoken = nn.Embedding(total_subtoken, 128)
# 使用層 輸入的subtoken爲batch_S, batch_E
# (B*k, l) -> (B*k, l, 128) -> (B*k, 128)
encode_S = embedding_subtoken(batch_S)
encode_S = encode_S.sum(1)
encode_E = embedding_subtoken(batch_E)
encode_E = encode_E.sum(1)
聯合表示
把路徑表示和token表示拼接,送進FC中
其中,是終止結點到其關聯值的映射;是一個的矩陣。
# 定義層
fc = nn.Linear(128 * (2 + 1 + 1), 128)
# 使用層
# (B*k, 256) (B*k, 128) -> (B*k, 512) -> (B*k, 128)
encode_SNE = torch.cat([encode_N, encode_S, encode_E], dim=1)
encode_SNE = fc(encode_SNE)
encode_SNE = torch.tanh(encode_SNE)
decoder初始狀態
對聯合表示取平均
# 拆 -> tuple B * (k, 128) -> list B * (1, 128) -> (1, B, 128)
# encode_SNE (B*k, 512)
# lenghts_k (B) 每個樣本的長度
output_bag = torch.split(encode_SNE, lengths_k, dim=0)
hidden_0 = [ob.mean(0).unsqueeze(dim=0) for ob in output_bag]
hidden_0 = torch.cat(hidden_0, dim=0).unsqueeze(dim=0)
注意力(tf源碼用的LuongAttention
)
# 公式是樣本級的
# encoder_output_bag: tuple batch_size * (k, hidden_size)
# hidden: (1, batch_size, hidden_size)
# lengths_k: list batch_size
# tuple batch_size * (k, hidden_size) -> (batch_size * k, hidden_size)
e_out = torch.cat(encoder_output_bag, dim=0)
# W_a 是權重矩陣: (hidden_size, hidden_size)
# (batch_size * k, hidden_size) (hidden_size, hidden_size)
# -> (batch_size * k, hidden_size)
ha = torch.einsum('ij,jk->ik', e_out, W_a) # 即 W_a * z
# (batch_size * k, hidden_size) -> tuple batch_size * (k, hidden_size)
ha = torch.split(ha, lenths_k, dim=0)
# (1, batch_size, hidden_size)->(batch_size, 1, hidden_size)
# -> tuple batch_size * (1, hidden_size)
hd = hidden.transpose(0, 1)
hd = torch.unbind(hd, dim=0)
# list batch_size * (k) 可以看作矩陣乘向量,行相加,k=1
at = [F.softmax(torch.einsum('ij,kj->i', _ha, _hd), dim=0)
for _ha, _hd in zip(ha, hd)] # a = h * (W_a * z)
# list batch_size * (1, hidden_size)
ct = [torch.einsum('i,ij->j', a, e).unsqueeze(0)
for a, e in zip(at, encoder_output_bag)] # c = sum_i a_i * z_i
# (1, batch_size, hidden_size)
ct = torch.cat(ct, dim=0).unsqueeze(0)
注:
- 輸入側重複的字母意味着這個維度的數據相乘,乘積組成輸出
- 輸出側省略的字母意味着這個維度的數據求和
最後我在服務器上跑的結果:
code2seq Dictionaries loaded.
code2seq vocab_size_subtoken: 73908
code2seq vocab_size_nodes: 325
code2seq vocab_size_target: 11320
code2seq num_examples : 691974
code2seq Epoch 1: train_loss: 13.64 train_f1: 0.4324 valid_loss: 17.03 valid_f1: 0.3187
code2seq Epoch 2: train_loss: 10.86 train_f1: 0.5313 valid_loss: 16.98 valid_f1: 0.3368
code2seq Epoch 3: train_loss: 10.08 train_f1: 0.5573 valid_loss: 17.33 valid_f1: 0.3428
code2seq Epoch 4: train_loss: 9.62 train_f1: 0.5725 valid_loss: 17.43 valid_f1: 0.3440
code2seq Epoch 5: train_loss: 9.30 train_f1: 0.5826 valid_loss: 17.97 valid_f1: 0.3592
code2seq Epoch 6: train_loss: 9.05 train_f1: 0.5910 valid_loss: 18.15 valid_f1: 0.3516
code2seq Epoch 7: train_loss: 8.85 train_f1: 0.5979 valid_loss: 18.49 valid_f1: 0.3622
code2seq Epoch 8: train_loss: 8.69 train_f1: 0.6033 valid_loss: 19.15 valid_f1: 0.3612
code2seq Epoch 9: train_loss: 8.54 train_f1: 0.6082 valid_loss: 19.42 valid_f1: 0.3553
code2seq Epoch 10: train_loss: 8.43 train_f1: 0.6125 valid_loss: 19.81 valid_f1: 0.3605
code2seq Epoch 11: train_loss: 8.31 train_f1: 0.6165 valid_loss: 20.03 valid_f1: 0.3608
code2seq Epoch 12: train_loss: 8.19 train_f1: 0.6201 valid_loss: 20.18 valid_f1: 0.3603
code2seq Epoch 13: train_loss: 8.12 train_f1: 0.6232 valid_loss: 20.17 valid_f1: 0.3587
code2seq Epoch 14: train_loss: 8.02 train_f1: 0.6264 valid_loss: 20.48 valid_f1: 0.3634
code2seq Epoch 15: train_loss: 7.95 train_f1: 0.6292 valid_loss: 20.67 valid_f1: 0.3558
code2seq Epoch 16: train_loss: 7.89 train_f1: 0.6316 valid_loss: 20.51 valid_f1: 0.3623
code2seq Epoch 17: train_loss: 7.83 train_f1: 0.6338 valid_loss: 20.69 valid_f1: 0.3562
code2seq Epoch 18: train_loss: 7.79 train_f1: 0.6359 valid_loss: 20.85 valid_f1: 0.3582
code2seq Epoch 19: train_loss: 7.72 train_f1: 0.6376 valid_loss: 20.75 valid_f1: 0.3596
code2seq Epoch 20: train_loss: 7.68 train_f1: 0.6394 valid_loss: 20.87 valid_f1: 0.3612
最高37和論文裏的42還是有不小的差距,當然,某些細節和參數不一樣,也會影響到結果