引入Redis
在集成Netty之後,爲了提高效率,我打算將消息存儲在Redis緩存系統中,本節將介紹Redis在項目中的引入,以及前端界面的開發。
引入Redis後,完整代碼鏈接。
想要直接得到訓練了13000步的聊天機器人可以直接下載鏈接中
這三個文件,以及詞彙表文件
然後直接運行連接中的py腳本進行測試即可。
最終實現效果如下:
在Netty中引入Redis
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.time.LocalDateTime;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import redis.clients.jedis.Jedis;
public class ChatHandler
extends SimpleChannelInboundHandler<TextWebSocketFrame>{
private static ChannelGroup clients=
new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
System.out.println("channelRead0...");
//連接redis
Jedis jedis=new Jedis("localhost");
System.out.println("連接成功...");
System.out.println("服務正在運行:"+jedis.ping());
//得到用戶輸入的消息,需要寫入文件/緩存中,讓AI進行讀取
String content=msg.text();
if(content==null||content=="") {
System.out.println("content 爲null");
return ;
}
System.out.println("接收到的消息:"+content);
//寫入緩存中
jedis.set("user_say", content+":user");
Thread.sleep(1000);
//讀取AI返回的內容
String AIsay=null;
while(AIsay=="no"||AIsay==null) {
//從緩存中讀取AI回覆的內容
AIsay=jedis.get("ai_say");
String [] arr=AIsay.split(":");
AIsay=arr[0];
}
//讀取後馬上向緩存中寫入
jedis.set("ai_say", "no");
//沒有說,或者還沒說
if(AIsay==null||AIsay=="") {
System.out.println("AIsay==null||AIsay==\"\"");
return;
}
System.out.println("AI說:"+AIsay);
clients.writeAndFlush(
new TextWebSocketFrame(
"AI_PigPig在"+LocalDateTime.now()
+"說:"+AIsay));
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
System.out.println("add...");
clients.add(ctx.channel());
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
System.out.println("客戶端斷開,channel對應的長id爲:"
+ctx.channel().id().asLongText());
System.out.println("客戶端斷開,channel對應的短id爲:"
+ctx.channel().id().asShortText());
}
}
在Python中引入Redis
with tf.Session() as sess:#打開作爲一次會話
# 恢復前一次訓練
ckpt = tf.train.get_checkpoint_state('.')#從檢查點文件中返回一個狀態(ckpt)
#如果ckpt存在,輸出模型路徑
if ckpt != None:
print(ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)#儲存模型參數
else:
print("沒找到模型")
r.set('user_say','no')
#測試該模型的能力
while True:
line='no'
#從緩存中進行讀取
while line=='no':
line=r.get('user_say').decode()
#print(line)
list1=line.split(':')
if len(list1)==1:
input_string='no'
else:
input_string=list1[0]
r.set('user_say','no')
# 退出
if input_string == 'quit':
exit()
if input_string != 'no':
input_string_vec = []#輸入字符串向量化
for words in input_string.strip():
input_string_vec.append(vocab_en.get(words, UNK_ID))#get()函數:如果words在詞表中,返回索引號;否則,返回UNK_ID
bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])#保留最小的大於輸入的bucket的id
encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)
#get_batch(A,B):兩個參數,A爲大小爲len(buckets)的元組,返回了指定bucket_id的encoder_inputs,decoder_inputs,target_weights
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
#得到其輸出
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]#求得最大的預測範圍列表
if EOS_ID in outputs:#如果EOS_ID在輸出內部,則輸出列表爲[,,,,:End]
outputs = outputs[:outputs.index(EOS_ID)]
response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])#轉爲解碼詞彙分別添加到回覆中
print('AI-PigPig > ' + response)#輸出回覆
#向緩存中進行寫入
r.set('ai_say',response+':AI')
下一節將講述通信規則的制定,以規範應用程序。