學習筆記CB010:遞歸神經網絡、LSTM、自動抓取字幕

遞歸神經網絡可存儲記憶神經網絡,LSTM是其中一種,在NLP領域應用效果不錯。

遞歸神經網絡(RNN),時間遞歸神經網絡(recurrent neural network),結構遞歸神經網絡(recursive neural network)。時間遞歸神經網絡神經元間連接構成有向圖,結構遞歸神經網絡利用相似神經網絡結構遞歸構造更復雜深度網絡。兩者訓練屬同一算法變體。

時間遞歸神經網絡。傳統神經網絡FNN(Feed-Forward Neural Networks),前向反饋神經網絡。RNN引入定向循環,神經元爲節點組成有向環,可表達前後關聯關係。隱藏層節點間構成全連接,一個隱藏層節點輸出可作另一個隱藏層節點或自己的輸入。U、V、W是變換概率矩陣,x是輸入,o是輸出。RNN關鍵是隱藏層,隱藏層捕捉序列信息,記憶能力。RNN中U、V、W參數共享,每一步都在做相同事情,輸入不同,降低參數個數和計算量。RNN在NLP應用較多,語言模型在已知已出現詞情況下預測下一個詞概率,是時序模型,下一個詞出現取決於前幾個詞,對應RNN隱藏層間內部連接。

RNN的訓練方法。用BP誤差反向傳播算法更新訓練參數。從輸入到輸出經過步驟不確定,利用時序方式做前向計算,假設x表示輸入值,s表示輸入x經過U矩陣變換後值,h表示隱藏層激活值,o表示輸出層值, f表示隱藏層激活函數,g表示輸出層激活函數。當t=0時,輸入爲x0, 隱藏層爲h0。當t=1時,輸入爲x1, s1 = Ux1+Wh0, h1 = f(s1), o1 = g(Vh1)。當t=2時,s2 = Ux2+Wh1, h2 = f(s2), o2 = g(Vh2)。st = Uxt + Wh(t-1), ht = f(st), ot = g(Vht)。h=f(現有的輸入+過去記憶總結),對RNN記憶能力全然體現。
UVW變換概率矩陣,x輸入,s xU矩陣變換後值,f隱藏層激活函數,h隱藏層激活值,g輸出層激活函數,o輸出。時間、輸入、變換(輸入、前隱藏)、隱藏(變換)、輸出(隱藏)。輸出(隱藏(變換(時間、輸入、前隱藏)))。反向修正參數,每一步輸出o和實際o值誤差,用誤差反向推導,鏈式求導求每層梯度,更新參數。

LSTM(Long Short Tem Momery networks)。RNN存在長序列依賴(Long-Term Dependencies)問題。下一個詞出現概率和非常久遠之前詞有關,考慮到計算量,限制依賴長度。http://colah.github.io/posts/2015-08-Understanding-LSTMs 。傳統RNN示意圖,只包含一個隱藏層,tanh爲激發函數,“記憶”體現在t滑動窗口,有多少個t就有多少記憶。

LSTM設計,神經網絡層(權重係數和激活函數,σ表示sigmoid激活函數,tanh表示tanh激活函數),矩陣運算(矩陣乘或矩陣加)。歷史信息傳遞和記憶,調大小閥門(乘以一個0到1之間係數),第一個sigmoid層計算輸出0到1之間係數,作用到×門,這個操作表達上一階段傳遞過來的記憶保留多少,忘掉多少。忘掉記憶多少取決上一隱藏層輸出h{t-1}和本層的輸入x{t}。上一層輸出h{t-1}和本層的輸入x{t}得出新信息,存到記憶。計算輸出值Ct部分tanh神經元和計算比例係數sigmoid神經元(sigmoid取值範圍是[0,1]作比例係數,tanh取值範圍[-1,1]作一個輸出值)。隱藏層輸出h計算,考慮當前全部信息(上一時序隱藏層輸出、本層輸入x和當前整體記憶信息),本單元狀態部分C通過tanh激活並做一個過濾(上一時序輸出值和當前輸入值通過sigmoid激活係數)。一句話詞是不同時序輸入x,在某一時間t出現詞A概率可LSTM計算,詞A出現概率取決前面出現過詞,取決前面多少個詞不確定,LSTM存儲記憶信息C,得出較接近概率。

聊天機器人是範問答系統。

語料庫獲取。範問答系統,一般從互聯網收集語料信息,比如百度、谷歌,構建問答對組成語料庫。語料庫分成多訓練集、開發集、測試集。問答系統訓練在一堆答案裏找一個正確答案模型。訓練過程不把所有答案都放到一個向量空間,做分組,在語料庫裏採集樣本,收集每一個問題對應500個答案集合,500個裏面有正向樣本,隨機選些負向樣本,突出正向樣本作用。

基於CNN系統設計,sparse interaction(稀疏交互),parameter sharing(參數共享),equivalent respresentation(等價表示),適合自動問答系統答案選擇模型訓練。

通用訓練方法。訓練時獲取問題詞向量Vq(詞向量可用google word2vec訓練,和一個正向答案詞向量Va+,和一個負向答案詞向量Va-, 比較問題和兩個答案相似度,兩個相似度差值大於一個閾值m更新模型參數,在候選池裏選答案,小於m不更新模型。參數更新,梯度下降、鏈式求導。測試數據,計算問題和候選答案cos距離,相似度最大是正確答案預測。

神經網絡結構設計。HL hide layer隱藏層,激活函數z = tanh(Wx+B),CNN 卷積層,P 池化層,池化步長 1,T tanh層,P+T輸出是向量表示,最終輸出兩個向量cos相似度。HL或CNN連起來表示共享相同權重。CNN輸出維數取決做多少卷積特徵。論文《Applying Deep Learning To Answer Selection- A Study And An Open Task》。

深度學習運用到聊天機器人中,1. 神經網絡結構選擇、組合、優化。2. 自然語言處理,機器識別詞向量。3. 相似或匹配關係考慮相似度計算,典型方法 cos距離。4. 文本序列全局信息用CNN或LSTM。5. 精度不高可加層。6. 計算量過大,參數共享和池化。

聊天機器人學習,需要海量聊天語料庫。美劇字幕。外文電影或電視劇字幕文件是天然聊天語料,對話比較多美劇最佳。字幕庫網站www.zimuku.net。

自動抓取字幕。抓取器代碼(https://github.com/warmheartli/ChatBotCourse)。在subtitle下創建目錄result,scrapy.Request
方法調用時增加傳參 dont_filter=True:

# coding:utf-8

import sys
import importlib
importlib.reload(sys)

import scrapy
from subtitle_crawler.items import SubtitleCrawlerItem

class SubTitleSpider(scrapy.Spider):
    name = "subtitle"
    allowed_domains = ["zimuku.net"]
    start_urls = [
            "http://www.zimuku.net/search?q=&t=onlyst&ad=1&p=20",
            "http://www.zimuku.net/search?q=&t=onlyst&ad=1&p=21",
            "http://www.zimuku.net/search?q=&t=onlyst&ad=1&p=22",
    ]

    def parse(self, response):
        hrefs = response.selector.xpath('//div[contains(@class, "persub")]/h1/a/@href').extract()
        for href in hrefs:
            url = response.urljoin(href)
            request = scrapy.Request(url, callback=self.parse_detail, dont_filter=True)
            yield request

    def parse_detail(self, response):
        url = response.selector.xpath('//li[contains(@class, "dlsub")]/div/a/@href').extract()[0]
        print("processing: ", url)
        request = scrapy.Request(url, callback=self.parse_file, dont_filter=True)
        yield request

    def parse_file(self, response):
        body = response.body
        item = SubtitleCrawlerItem()
        item['url'] = response.url
        item['body'] = body
        return item

# -*- coding: utf-8 -*-

class SubtitleCrawlerPipeline(object):
    def process_item(self, item, spider):
        url = item['url']
        file_name = url.replace('/','_').replace(':','_')+'.rar'
        fp = open('result/'+file_name, 'wb+')
        fp.write(item['body'])
        fp.close()
        return item

ls result/|head -1 , ls result/|wc -l , du -hs result/ 。

字幕文件解壓,linux直接執行unzip file.zip。linux解壓rar文件,http://www.rarlab.com/download.htm 。wget http://www.rarlab.com/rar/rarlinux-x64-5.4.0.tar.gz 。tar zxvf rarlinux-x64-5.4.0.tar.gz
./rar/unrar 。解壓命令,unrar x file.rar 。linux解壓7z文件,http://downloads.sourceforge.net/project/p7zip 下載源文件,解壓執行make編譯 bin/7za可用,用法 bin/7za x file.7z。

程序和腳本在https://github.com/warmheartli/ChatBotCourse 。第一步:爬取影視劇字幕。第二步:壓縮格式分類。文件多無法ls、文件名帶特殊字符、文件名重名誤覆蓋、擴展名千奇百怪,python腳本mv_zip.py:

import glob
import os
import fnmatch
import shutil
import sys

def iterfindfiles(path, fnexp):
    for root, dirs, files in os.walk(path):
        for filename in fnmatch.filter(files, fnexp):
            yield os.path.join(root, filename)

i=0
for filename in iterfindfiles(r"./input/", "*.ZIP"):
    i=i+1
    newfilename = "zip/" + str(i) + "_" + os.path.basename(filename)
    print(filename + " <===> " + newfilename)
    shutil.move(filename, newfilename)
    #sys.exit(-1)

擴展名根據壓縮文件修改.rar、.RAR、.zip、.ZIP。第三步:解壓。根據操作系統下載不同解壓工具,建議unrar和unzip,腳本來實現批量解壓:

i=0; for file in `ls`; do mkdir output/${i}; echo "unzip $file -d output/${i}";unzip -P abc $file -d output/${i} > /dev/null; ((i++)); done
i=0; for file in `ls`; do mkdir output/${i}; echo "${i} unrar x $file output/${i}";unrar x $file output/${i} > /dev/null; ((i++)); done

第四步:srt、ass、ssa字幕文件分類整理。字幕文件類型srt、lrc、ass、ssa、sup、idx、str、vtt。第五步:清理目錄。自動清理空目錄腳本clear_empty_dir.py :

import glob
import os
import fnmatch
import shutil
import sys

def iterfindfiles(path, fnexp):
    for root, dirs, files in os.walk(path):
        if 0 == len(files) and len(dirs) == 0:
            print(root)
            os.rmdir(root)

iterfindfiles(r"./input/", "*.srt")

第六步:清理非字幕文件。批量刪除腳本del_file.py :

import glob
import os
import fnmatch
import shutil
import sys

def iterfindfiles(path, fnexp):
    for root, dirs, files in os.walk(path):
        for filename in fnmatch.filter(files, fnexp):
            yield os.path.join(root, filename)

for suffix in ("*.mp4", "*.txt", "*.JPG", "*.htm", "*.doc", "*.docx", "*.nfo", "*.sub", "*.idx"):
    for filename in iterfindfiles(r"./input/", suffix):
        print(filename)
        os.remove(filename)

第七步:多層解壓縮。第八步:捨棄剩餘少量文件。無擴展名、特殊擴展名、少量壓縮文件,總體不超過50M。第九步:編碼識別與轉碼。utf-8、utf-16、gbk、unicode、iso8859,統一utf-8,get_charset_and_conv.py :

import chardet
import sys
import os

if __name__ == '__main__':
    if len(sys.argv) == 2:
        for root, dirs, files in os.walk(sys.argv[1]):
            for file in files:
                file_path = root + "/" + file
                f = open(file_path,'r')
                data = f.read()
                f.close()
                encoding = chardet.detect(data)["encoding"]
                if encoding not in ("UTF-8-SIG", "UTF-16LE", "utf-8", "ascii"):
                    try:
                        gb_content = data.decode("gb18030")
                        gb_content.encode('utf-8')
                        f = open(file_path, 'w')
                        f.write(gb_content.encode('utf-8'))
                        f.close()
                    except:
                        print("except:", file_path)

第十步:篩選中文。extract_sentence_srt.py :

# coding:utf-8
import chardet
import os
import re

cn=ur"([u4e00-u9fa5]+)"
pattern_cn = re.compile(cn)
jp1=ur"([u3040-u309F]+)"
pattern_jp1 = re.compile(jp1)
jp2=ur"([u30A0-u30FF]+)"
pattern_jp2 = re.compile(jp2)

for root, dirs, files in os.walk("./srt"):
    file_count = len(files)
    if file_count > 0:
        for index, file in enumerate(files):
            f = open(root + "/" + file, "r")
            content = f.read()
            f.close()
            encoding = chardet.detect(content)["encoding"]
            try:
                for sentence in content.decode(encoding).split('n'):
                    if len(sentence) > 0:
                        match_cn =  pattern_cn.findall(sentence)
                        match_jp1 =  pattern_jp1.findall(sentence)
                        match_jp2 =  pattern_jp2.findall(sentence)
                        sentence = sentence.strip()
                        if len(match_cn)>0 and len(match_jp1)==0 and len(match_jp2) == 0 and len(sentence)>1 and len(sentence.split(' ')) < 10:
                            print(sentence.encode('utf-8'))
            except:
                continue

第十一步:字幕中句子提取。

# coding:utf-8
import chardet
import os
import re

cn=ur"([u4e00-u9fa5]+)"
pattern_cn = re.compile(cn)
jp1=ur"([u3040-u309F]+)"
pattern_jp1 = re.compile(jp1)
jp2=ur"([u30A0-u30FF]+)"
pattern_jp2 = re.compile(jp2)

for root, dirs, files in os.walk("./ssa"):
    file_count = len(files)
    if file_count > 0:
        for index, file in enumerate(files):
            f = open(root + "/" + file, "r")
            content = f.read()
            f.close()
            encoding = chardet.detect(content)["encoding"]
            try:
                for line in content.decode(encoding).split('n'):
                    if line.find('Dialogue') == 0 and len(line) < 500:
                        fields = line.split(',')
                        sentence = fields[len(fields)-1]
                        tag_fields = sentence.split('}')
                        if len(tag_fields) > 1:
                            sentence = tag_fields[len(tag_fields)-1]
                        match_cn =  pattern_cn.findall(sentence)
                        match_jp1 =  pattern_jp1.findall(sentence)
                        match_jp2 =  pattern_jp2.findall(sentence)
                        sentence = sentence.strip()
                        if len(match_cn)>0 and len(match_jp1)==0 and len(match_jp2) == 0 and len(sentence)>1 and len(sentence.split(' ')) < 10:
                            sentence = sentence.replace('N', '')
                            print(sentence.encode('utf-8'))
            except:
                continue

第十二步:內容過濾。過濾特殊unicode字符、關鍵詞、去除字幕樣式標籤、html標籤、連續特殊字符、轉義字符、劇集信息:

# coding:utf-8
import sys
import re
import chardet

if __name__ == '__main__':
    #illegal=ur"([u2000-u2010]+)"
    illegal=ur"([u0000-u2010]+)"
    pattern_illegals = [re.compile(ur"([u2000-u2010]+)"), re.compile(ur"([u0090-u0099]+)")]
    filters = ["字幕", "時間軸:", "校對:", "翻譯:", "後期:", "監製:"]
    filters.append("時間軸:")
    filters.append("校對:")
    filters.append("翻譯:")
    filters.append("後期:")
    filters.append("監製:")
    filters.append("禁止用作任何商業盈利行爲")
    filters.append("http")
    htmltagregex = re.compile(r'<[^>]+>',re.S)
    brace_regex = re.compile(r'{.*}',re.S)
    slash_regex = re.compile(r'\w',re.S)
    repeat_regex = re.compile(r'[-=]{10}',re.S)
    f = open("./corpus/all.out", "r")
    count=0
    while True:
        line = f.readline()
        if line:
            line = line.strip()

            # 編碼識別,不是utf-8就過濾
            gb_content = ''
            try:
                gb_content = line.decode("utf-8")
            except Exception as e:
                sys.stderr.write("decode error:  ", line)
                continue

            # 中文識別,不是中文就過濾
            need_continue = False
            for pattern_illegal in pattern_illegals:
                match_illegal = pattern_illegal.findall(gb_content)
                if len(match_illegal) > 0:
                    sys.stderr.write("match_illegal error: %sn" % line)
                    need_continue = True
                    break
            if need_continue:
                continue

            # 關鍵詞過濾
            need_continue = False
            for filter in filters:
                try:
                    line.index(filter)
                    sys.stderr.write("filter keyword of %s %sn" % (filter, line))
                    need_continue = True
                    break
                except:
                    pass
            if need_continue:
                continue

            # 去掉劇集信息
            if re.match('.*第.*季.*', line):
                sys.stderr.write("filter copora %sn" % line)
                continue
            if re.match('.*第.*集.*', line):
                sys.stderr.write("filter copora %sn" % line)
                continue
            if re.match('.*第.*幀.*', line):
                sys.stderr.write("filter copora %sn" % line)
                continue

            # 去html標籤
            line = htmltagregex.sub('',line)

            # 去花括號修飾
            line = brace_regex.sub('', line)

            # 去轉義
            line = slash_regex.sub('', line)

            # 去重複
            new_line = repeat_regex.sub('', line)
            if len(new_line) != len(line):
                continue

            # 去特殊字符
            line = line.replace('-', '').strip()

            if len(line) > 0:
                sys.stdout.write("%sn" % line)
            count+=1
        else:
            break
    f.close()
    pass

參考資料:

《Python 自然語言處理》

http://www.shareditor.com/blogshow?blogId=103

http://www.shareditor.com/blogshow?blogId=104

http://www.shareditor.com/blogshow?blogId=105

http://www.shareditor.com/blogshow?blogId=112

歡迎推薦上海機器學習工作機會,我的微信:qingxingfengzi

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章