文章目錄
寫一下最近正在做的一個命名實體識別項目,還沒結束,這裏先放一段代碼
main.py
#系統包
import os
import tensorflow as tf
import pickle
#自定義包
import data_loader
import data_utils
import model_utils
flags = tf.app.flags
#訓練相關的
flags.DEFINE_boolean('train',True,'是否開始訓練')
flags.DEFINE_boolean('clean',True,'是否清理文件')
#配置相關
flags.DEFINE_integer('seg_dim',20,'seg embedding size')
flags.DEFINE_integer('word_dim',120,'word embedding')
flags.DEFINE_integer('lstm_dim',120,'Num of hiddem unis in lstm')
flags.DEFINE_string('tag_schema','BIOES','編碼方式')
##訓練相關的e
flags.DEFINE_float('clip',5,'Grandient clip')
flags.DEFINE_float('dropout',0.5,'Dropout rate')
flags.DEFINE_integer('batch_size',120,'batch_size')
flags.DEFINE_float('lr',0.001,'learning rate')
flags.DEFINE_string('optimizer','adam','優化器')
flags.DEFINE_boolean('pre_emb',True,'是否使用預訓練')
flags.DEFINE_integer('max_epoch',100,'最大輪訓次數')
flags.DEFINE_integer('steps_chech',100,'steps per checkpoint')
flags.DEFINE_string('ckpt_path',os.path.join('model','ckpt'),'保存模型的位置')
flags.DEFINE_string('log_file','train_log','訓練過程中日誌')
flags.DEFINE_string('map_file','maps.pkl','存放字典映射以及標籤映射')
flags.DEFINE_string('vocab_file','vocab.json','詞典')
flags.DEFINE_string('config_file','config_file','配置文件')
flags.DEFINE_string('train_file',os.path.join('data','ner.train'),'訓練數據路徑')
flags.DEFINE_string('dev_file',os.path.join('data','ner.dev'),'校驗數據路徑')
flags.DEFINE_string('test_file',os.path.join('data','ner.test'),'測試數據路徑')
FLAGS = tf.app.flags.FLAGS
assert FLAGS.clip < 5.1,'梯度裁剪不能過大'
assert 0 < FLAGS.dropout < 1, 'dropout必須在0和1之間'
assert FLAGS.lr >0,'lr 必須大於0'
assert FLAGS.optimizer in ['adam','sgd','adagrad'],'優化器必須在這三者之間'
def train():
#加載數據
train_sentences = data_loader.load_sentences(FLAGS.train_file)
dev_sentences = data_loader.load_sentences(FLAGS.dev_file)
test_sentences = data_loader.load_sentences(FLAGS.test_file)
#轉換編碼bio轉bioes
data_loader.update_tag_scheme(train_sentences,FLAGS.tag_schema)
data_loader.update_tag_scheme(test_sentences,FLAGS.tag_schema)
data_loader.update_tag_scheme(dev_sentences,FLAGS.tag_schema)
#創建單詞映射
if not os.path.isfile(FLAGS.map_file):
_,word_to_id,id_to_word=data_loader.word_mapping(train_sentences)
_,tag_to_id,id_to_tag = data_loader.tag_mapping(train_sentences)
with open(FLAGS.map_file,"wb") as f:#第一次會走這裏,會創建maps.pkl這個文件
pickle.dump([word_to_id,id_to_word,tag_to_id,id_to_tag],f)
else:
with open(FLAGS.map_file,'rb') as f:#第二次或者以後會走這裏
word_to_id,id_to_word,tag_to_id,id_to_tag = pickle.load(f)
train_data = data_loader.prepare_dataset(train_sentences,word_to_id,tag_to_id)
dev_data = data_loader.prepare_dataset(dev_sentences,word_to_id,tag_to_id)
test_data = data_loader.prepare_dataset(test_sentences,word_to_id,tag_to_id)
print("train_data_num%i,dev_data_num%i,test_data_num%i"%(len(train_data),len(dev_data),len(test_data)))
# config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)
model_utils.make_path(FLAGS)
if os.path.isfile(FLAGS.config_file):#查看config_file是否存在,存在,則load_cinfig
config = model_utils.load_config(FLAGS.config_file)
else:#如果confifile不存在,執行下面語句,先配置,在保存,下一次執行代碼時,就執行上面一句了,直接加載!
config = model_utils.config_model(FLAGS,word_to_id,tag_to_id)
model_utils.save_config(config,FLAGS.config_file)
log_path = os.path.join("log",FLAGS.log_file)
logger = model_utils.get_logger(log_path)
model_utils.print_config(config,logger)
print("hello")
def main(_):
if FLAGS.train:
train()
else:
pass
if __name__ =="__main__":
tf.app.run(main)
model_utils.py
from collections import OrderedDict
import os
import json
import logging
def get_logger(log_file):
"""
定義日誌方法(這個方法是通用的)
:param log_file:
:return:
"""
#創建一個logger的實例
logger = logging.getLogger(log_file)
#設置Logger的全局日誌級別爲DEBUG
logger.setLevel(logging.DEBUG)
#創建一個日誌文件的handler,並且設置日誌級別的DEBUG
fh = logging.FileHandler(log_file)
fh.setLevel(logging.DEBUG)
#創建一個控制檯的handler,並且設置日誌級別爲DEBUG
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
#設置日誌格式
formatter = logging.Formatter("%(asctime)s-%(name)s-%(levelname)s-%(message)s")
#add formatter to ch and fh
ch.setFormatter(formatter)
fh.setFormatter(formatter)
#add ch and fh to logger
logger.addHandler(ch)
logger.addHandler(fh)
return logger
#模型配置
def config_model(FLAGS,word_to_id,tag_to_id):
config = OrderedDict() #有序字典
config['num_words'] = len(word_to_id)
config['word_dim'] = FLAGS.word_dim
config['num_tags'] = len(tag_to_id)
config['seg_dim'] = FLAGS.seg_dim
config['list_dim'] = FLAGS.lstm_dim
config['batch_size'] = FLAGS.batch_size
config['clip'] = FLAGS.clip
config['dropout_keep'] = 1.0 - FLAGS.dropout
config['optimizer'] = FLAGS.optimizer
config['lr'] = FLAGS.lr
config['tag_schema'] = FLAGS.tag_schema
config['pre_emb'] = FLAGS.pre_emb
return config
def make_path(params):
"""
創建文件夾
:param params:
:return:
"""
if not os.path.isdir(params.ckpt_path):
os.makedirs(params.ckpt_path)
if not os.path.isdir('log'):#創建Log文件
os.makedirs('log')
def save_config(config,config_file):
"""
保存配置文件
:param config:
:param config_path:
:return:
"""
with open(config_file,'w',encoding='utf-8')as f:
json.dump(config,f,ensure_ascii=False,indent =4)
def load_config(config_file):
"""
加載配置文件
:param config_file:
:return:
"""
with open(config_file,encoding='utf-8') as f:
return json.load(f)
def print_config(config,logger):#通用方法,打印日誌
"""
打印模型參數
:param config:
:param logger:
:return:
"""
for k,v in config.items():
logger.info("{}:\t{}".format(k.ljust(15),v))
data_loader.py
import codecs
import data_utils
def load_sentences(path):
"""
加載數據集,每行包含一個漢字和一個標記
句子和句子之間是以空格進行分割的
最後返回句子集合
:param path:
:return:
"""
#存放數據集
sentences = []
#臨時存放每一個句子
sentence = []
for line in codecs.open(path,'r',encoding='utf-8'):#這個循環結束,將會將數據添加到sentences
#去掉兩邊空格
line = line.strip()
#首先判斷是不是空,如果是則表示句子和句子之間的分割點
if not line:#如果這行是空格
if len(sentence) > 0:#判斷這個句子裏面是否爲空(長度是否大於0), 在這裏下斷點,利用resume program這個鍵(跳到下個斷點),這裏相當於是從一個句子直接跳至這個句子結束
sentences.append(sentence)#如果大於0,說明句子裏面有東西,將其添加到句子集合裏面
#清空sentence,表示一句話完結
sentence = [] #由於下面還需要用到這個臨時的,故這裏將其清空
else: #如果這行不是空
if line[0] == " ":#判斷第一個是否爲空,如果爲空,表示這不是一個合法的
continue
else:#如果第一個不是空的
word = line.split()#使用空格對字符串進行切分
assert len(word) >= 2
sentence.append(word)#將上面的信息添加到句子裏面
#循環走完,要判斷一個,防止句子沒有進入到句子集合裏面
if len(sentence) > 0:
sentences.append(sentence)
return sentences
def update_tag_scheme(sentences,tag_scheme):
"""
更新爲指定編碼
:param sentences:
:param tag_scheme:
:return:
"""
for i,s in enumerate(sentences):
tags = [w[-1] for w in s ] #取出編碼
if not data_utils.chek_bio(tags):#如果不是bio編碼(做轉換之前校驗一下是否爲我們的BIO編碼)
s_str ="\n".join("".join(w) for w in s)
raise Exception("輸入的句子應爲BIO編碼,請檢查輸入句子%i:\n%s"%(i,s_str))
if tag_scheme == "BIO":
for word,new_tag in zip(s,tags):
word[-1] = new_tag
if tag_scheme == "BIOES":
new_tags =data_utils.bio_to_bioes(tags)
for word, new_tag in zip(s,new_tags):
word[-1] = new_tag
else:
raise Exception("非法目標編碼")
def word_mapping(sentences): #這個函數在NLP領域內經常用到!!比如分類
"""
構建字典
:param sentences:
:return:
"""
#這裏有個列表推導式,這是個重點!!!
word_list = [ [x[0] for x in s]for s in sentences]#將每個句子裏面的word提煉出來,eg:[['相', 'O'], ['比', 'O'], ['之', 'O'], ['下', 'O'], [',', 'O'], ['青', 'B-ORG']],將裏面的"相比之下"這些字提煉出來了
dico = data_utils.create_dico(word_list) #dico裏面存放的是每個單詞以及其對應的次數,
dico['<PAD>'] = 10000001
dico['<UNK>'] = 10000000
word_to_id,id_to_word =data_utils.create_mapping(dico)
return dico,word_to_id,id_to_word
def tag_mapping(sentences): #序列標註的時候會用到
"""
構建標籤字典
:param sentences:
:return:
"""
tag_list = [[ x[1] for x in s] for s in sentences]
dico = data_utils.create_dico(tag_list)
tag_to_id,id_to_tag =data_utils.create_mapping(dico)
return dico,tag_to_id,id_to_tag
def prepare_dataset(sentences,word_to_id,tag_to_id,train = True):
"""
數據預處理,返回list,其實包含:
-word_list
-word_id_list
-word char indexs
-tag_id_list
:param sentences:
:param word_to_id:
:param tag_to_id:
:param train:
:return:
"""
none_index = tag_to_id['O']#字母O
data =[]
for s in sentences:
word_list = [w[0] for w in s]
word_id_list = [ word_to_id [w if w in word_to_id else '<UNK>'] for w in word_list]#遍歷word_list,由於集合裏面不可能包含字典裏面的詞,判斷一下,如果在word_to_id裏面就取w,否則就取UNK(表示不在字典裏面)
segs = data_utils.get_seg_features("".join(word_list))
if train:
tag_id_list = [tag_to_id[w[-1]] for w in s]#tag_to_id[w[-1]]:將tag拿出來
else:
tag_id_list = [none_index for w in s]
data.append([word_list,word_id_list,segs,tag_id_list])
return data
if __name__ =="__main__":
path = "data/ner.dev"
sentences = load_sentences(path)
update_tag_scheme(sentences,"BIOES")
_,word_to_id,id_to_word= word_mapping(sentences)
_,tag_to_id,id_to_tag=tag_mapping(sentences) #_:表示默認值
dev_data = prepare_dataset(sentences,word_to_id,tag_to_id)
data_utils.BatchManager(dev_data,120)
data_utils.py
import jieba
import math
import random
def chek_bio(tags):
"""
檢測輸入的tags是否爲BIO編碼
如果不是bio編碼
那麼錯誤的類型
1)編碼不在BIO中
2)第一個編碼是I
3)當前編碼不是B,前一個編碼不是O
:param tags:
:return:
"""
for i,tag in enumerate(tags):
if tag =='O':#此時爲BIO
continue
tag_list = tag.split("-")
if len(tag_list) != 2 or tag_list[0] not in set(['B','I']):#此時爲非法編碼,分割之後的長度不是2,同時編碼不在B和I中;
return False
if tag_list[0] == 'B':#此時爲合法BIO編碼
continue
elif i == 0 or tags[i-1] == 'O':#如果第一個位置不是B,同時i等於0(I是第一個位置),上一個編碼等於O(字母)
tags[i] ='B' + tag[1:] #如果當前第一個位置不是B,或者當前編碼不是B並且前一個編碼0,則全部轉換成
elif tags[i-1][1:] ==tag[1:]:
#如果當前編碼的後面類型編碼與tags中的前一個編碼中的後面類型編碼相同,則跳過
continue
else:
#如果編碼類型不一致,則重新從B開始
tags[i] = 'B' + tag[1:]
return True
def bio_to_bioes(tags):
"""
把bio編碼轉換成bios
返回新的tags
:param tags:
:return:
"""
new_tags = []
for i ,tag in enumerate(tags):
if tag =='O':
#直接保留,不變化
new_tags.append(tag)
elif tag.split('-')[0] =='B':
#如果tag是以B開頭,那麼我們就要做下面的判斷:
#首先,如果當前tag不是最後一個,並且緊跟着的後一個是I,eg:B-ORG後面是I-ORG
if (i + 1) < len(tags) and tags[i +1].split('-')[0] =='I':#i + 1 < len(tags):如果不是最後一個tag
#直接保留
new_tags.append(tag)
else:#如果是最後一個或者後面一個不是I;eg:B-ORG後面一個也是B-ORG,那麼前面這個B-ORG就會變成S-ORG
#如果是最後一個或者緊跟着的後一個不是I,那麼表示,需要把B換成S表示單字
new_tags.append(tag.replace('B-','S-'))
elif tag.split('-')[0]=='I':
#如果tag是以I開頭,那麼我們需要進行下面的判斷
#首先,如果當前tag不是最後一個,並且緊跟着的一個是I,eg:I-ORG後面還是I-ORG,如果I-ORG後面是O(字母)就不行
if (i + 1)<len(tags) and tags[i+1].split('-')[0]=='I':
#直接保留
new_tags.append(tag)
else:
#如果是最後一個或者I-ORG後面一個不是以I開頭的(也不是以B開頭的),那麼就表示一個詞的結尾,就把I換成E表示一個詞的結尾
new_tags.append(tag.replace('I-','E-'))
else:
raise Exception('非法編碼')
return new_tags
def create_dico(item_list):
"""
對於item_list裏面,每個item,統計item_list中item在item_list的的次數
item:出現的次數
:param item_list:
:return:
"""
assert type(item_list) is list
dico = {}
for items in item_list:
for item in items:
if item not in dico:
dico[item] = 1
else:
dico[item] += 1
return dico
def create_mapping(dico):
"""
創建item to id,id to item
item的排序按照詞典中出現的次數
:param dico:
:return:
"""
sorted_items = sorted(dico.items(),key=lambda x:(-x[1],x[0])) #將dico(字典)裏面的key按照降序排
# sorted_items = sorted(dico.items(),key=lambda x:(x[1],x[0])) #將dico(字典)裏面的key按照升序排
id_to_item = {i:v[0] for i,v in enumerate(sorted_items)} #將字典裏面的每個字進行編號,出現的次數越多,排在越前面,編號越小;這裏的形式爲---》編號: 詞
item_to_id = {v:k for k,v in id_to_item.items()} #這裏的形式爲---》詞:編號
return item_to_id,id_to_item #注意智力的順序
def get_seg_features(words):
"""
利用jieba分詞
採用類似bioes的編碼,0表示單個字成詞,1表示一個詞的開始,2表示一個詞的中間,3表示一個詞的結尾
:param words:
:return:
"""
seg_features = []
word_list = list(jieba.cut(words))
for word in word_list:
if len(word) ==1: #表示單個成詞,就填0
seg_features.append(0)
else:
temp = [2]*len(word)
temp[0] = 1
temp[-1] = 3
seg_features.extend(temp)
return seg_features
class BatchManager(object):
def __init__(self,data,batch_size):
self.batch_data = self.sort_and_pad(data,batch_size)
self.len_data = len(self.batch_data)
def sort_and_pad(self,data,batch_size):
num_batch = int(math.ceil(len(data)/batch_size)) #計算有多少批次
sorted_data = sorted(data,key=lambda x :len(x[0])) #這裏是按照len的升序排
# sorted_data1 = sorted(data,key=lambda x :-len(x[0])) #這裏是按照len(長度)的降序排
batch_data = list()
for i in range(num_batch):
batch_data.append(self.pad_data(sorted_data[i*batch_size :(i+1)*batch_size]))
return batch_data
@staticmethod
def pad_data(data):#數據填充函數
word_list = []
word_id_list =[]
seg_list=[]
tag_id_list =[]
max_length =max( [len(sentence[0]) for sentence in data]) #一批有120(自己設置的)個句子(樣本),最長數據是17(這裏的demo是17,下一批最大是20,每一批的最大值不一樣的!以每批有120個樣本計算,這裏有20批)
for line in data:
words,word_ids,segs,tag_ids = line #單詞,單詞索引,分詞信息(分詞特徵信息),tag索引
padding = [0] *(max_length - len(words)) #需要填充的數據
word_list.append(words + padding)
word_id_list.append(word_ids + padding)
seg_list.append(segs + padding)
tag_id_list.append(tag_ids + padding)
return [word_list,word_id_list,seg_list,tag_id_list]
def iter_batch(self,shuffle=False):
if shuffle:
random.shuffle(self.batch_data)
for idx in range(self.len_data):
yield self.batch_data[idx]
後面代碼可能會做補充和拓展