訓練並保存模型
def train_savemodel():
model = Word2Vec(PathLineSentences(directory), size=400, window=5, min_count=5, workers=multiprocessing.cpu_count(),
sg=1, # 使用 skip-gram算法
hs=1, # 使用分層softmax
negative=0) # 不使用負採樣和噪音
model.save(modelpath)
model.wv.save_word2vec_format(wv_path, binary=True)
加載模型
def loadmodel():
en_wiki_word2vec_model = Word2Vec.load(modelpath)
return en_wiki_word2vec_model
加載wv,如果不需要再次訓練模型,那麼只需要恢復wv就可以了,wv是去除了model中的權重參數和損失等細節的,可以直接用於查詢。
def loadwv():
word_vectors = KeyedVectors.load_word2vec_format(wv_path, binary=True)
return word_vectors
打印每一個詞以及他的相近的詞
most_similars_pre = {word : md.wv.most_similar(word) for word in md.wv.index2word}
for i, (key,word) in enumerate(most_similars_pre.items()):
if i == 10:
break
print(key,word)
使用TSNE降維(這個操作很費時間),打印關係圖:
def reduce_dimensions(model):
num_dimensions=2
vectors =[]
labels = []
for word in model.wv.vocab:
vectors.append(model.wv[word])
labels.append(word)
vectors=np.asarray(vectors)
labels =np.asarray(labels)
#使用t-sne降低維度
vectors=np.asarray(vectors)
tsne =TSNE(n_components=num_dimensions, random_state=0)
vectors = tsne.fit_transform(vectors)
x_vals =[v[0] for v in vectors]
y_vals =[v[1] for v in vectors]
return x_vals,y_vals,labels
def plot_with_plotly(x_vals, y_vals, labels, plot_in_notebook=True):
from plotly.offline import init_notebook_mode,iplot,plot
import plotly.graph_objs as go
trace =go.Scatter(x=x_vals,y=y_vals,mode='text',text=labels)
data = [trace]
if plot_in_notebook:
init_notebook_mode(connected=True)
iplot(data,filename='word-embedding-plot')
else:
plot(data, filename='word-embedding-plot.html')
def plot_with_matplotlib(x_vals, y_vals, labels):
import matplotlib.pyplot as plt
import random
random.seed(0)
plt.figure(figsize=(12,12))
plt.scatter(x_vals,y_vals)
indices =list(range(len(labels)))
selected_indices = random.sample(indices,25)
for i in selected_indices:
plt.annotate(labels[i], (x_vals[i],y_vals[i]))
plt.show()
x_vals, y_vals, labels =reduce_dimensions(md)
try:
get_ipython()
except Exception:
plot_function = plot_with_matplotlib
else:
plot_function = plot_with_plotly
plot_function(x_vals,y_vals,labels)
跑了半個小時的一張圖