Task3-5 Seq2Seq

import torch
import torch.nn as tnn
from mt.load import *
import torch.utils.data as td

english, chinese = load_data()
sen_c, (id2w_c, w2id_c) = split_c(chinese)
sen_e, (id2w_e, w2id_e) = split_e(english)

input_c, pad_c = word2id(sen_c, w2id_c)
output_e, pad_e = word2id(sen_e, w2id_e)


class MyDataset(td.Dataset):
    def __init__(self, *tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(self.tensors[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].shape[0]


class EncoderRNN(tnn.Module):
    def __init__(self, voc_size, hidden_size):
        super(EncoderRNN,self).__init__()
        self.voc_size = voc_size
        self.hidden_size = hidden_size
        self.embedding = tnn.Embedding(voc_size, hidden_size)
        self.gru = tnn.GRU(hidden_size, hidden_size)

    def forward(self,X):
        X = self.embedding(X)
        out,state = self.gru(X)
        return out,state


class DecoderRNN(tnn.Module):
    def __init__(self,voc_size,hidden_size):
        super(DecoderRNN,self).__init__()
        self.voc_size = voc_size
        self.hidden_size = hidden_size
        self.embedding = tnn.Embedding(voc_size,hidden_size)
        self.gru = tnn.GRU(hidden_size,hidden_size)
        self.dense = tnn.Linear(hidden_size,voc_size)

    def forward(self,X,state_0):
        X = self.embedding(X)
        out,state = self.gru(X,state_0)
        out = self.dense(out)
        return out,state


class Encoder2Decoder(tnn.Module):
    def __init__(self,encoder,decoder):
        super(Encoder2Decoder,self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self,X,y_0):
        _,state = self.encoder(X)
        out,state = self.decoder(y_0,state)
        return out


def mask(X,l):
    maxlen = X.shape[1]
    mask = torch.arange(maxlen)[None, :] < l[:, None]
    X[~mask] = 0
    return X


class MaskedEntropy(tnn.CrossEntropyLoss):
    def forward(self,pre,label):
        weights = torch.ones_like(label)
        weights = mask(weights,15)
        output = super(MaskedEntropy,self).forward(pre,label)
        return output*weights


train = MyDataset(input_c, pad_c, output_e, pad_e)
train_iter = td.DataLoader(train, 400, shuffle=True)

import re
import jieba
import collections
import nltk
import torch
import random

def load_data():
    with open('cmn.txt', encoding='utf8') as file:
        read_in = file.read().split('\n')
        english, chinese = [], []
        for sentence in read_in:
            iter = re.compile(u'[\u4e00-\u9fa5]').search(sentence)
            if iter is None:
                continue
            first = iter.start()
            last = re.compile('C').search(sentence[first:-1]).start()
            english.append(sentence[0:first].replace(' +', ' ').replace('\s', ' ').lower())
            chinese.append(sentence[first:first + last].replace(' ', '').replace('\t',''))

        return english, chinese


def split_c(data):
    cut_data = []
    for sen in data:
        line = []
        for i in jieba.cut(sen):
            line.append(i)
        cut_data.append(line)
    return cut_data,make_dic([i for j in cut_data for i in j])


def split_e(data):
    cut_data = []
    for sen in data:
        cut_data.append(nltk.word_tokenize(sen))
    return cut_data,make_dic([i for j in cut_data for i in j])


def make_dic(data):
    w2c = collections.Counter(data)
    w2frequent = {i:w2c[i] for i in w2c if w2c[i] > 1}
    id2w = [i for i in w2frequent]
    unk = 'UNK'
    id2w.append(unk)
    id2w.append('PAD')
    w2id = {i:index for index,i in enumerate(id2w)}
    return id2w,w2id


def pad(data,w2id):
    l = len(data)
    if l<=15:
        return data+[w2id['PAD']]*(15-l),l
    return data[0:15],15


def word2id(data,w2id):
    id = []
    pad_num = []
    for sen in data:
        line = []
        for i in sen:
            if i in w2id.keys():
                line.append(w2id[i])
            else:
                line.append(w2id['UNK'])
        line,num = pad(line,w2id)
        id.append(torch.tensor(line).view(1,15))
        pad_num.append(num)
    return torch.cat(id,dim=0),torch.tensor(pad_num)

還沒寫完,來不及了,先交一下,等會接着寫

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章