下面介紹訓練和測試
代碼,訓練代碼主要graph構建,加載預訓練模型,訓練中的數據讀取和保存相關日誌和模型文件等內容,測試代碼主要部分是模型預測結果格式的轉換。
其他相關的部分請見:
YOLO代碼解析(1) 代碼總覽與使用
YOLO代碼解析(2) 數據處理
YOLO代碼解析(3) 模型和損失函數
YOLO代碼解析(4) 訓練和測試代碼
訓練相關代碼:yolo_solver.py
def _train(self):
"""訓練模型
創建優化器,最小化Loss
Args:
total_loss: Total loss from net.loss()
global_step: Integer Variable counting the number of training steps
processed
Returns:
train_op: op for training
"""
# 使用Momentum優化算法
opt = tf.train.MomentumOptimizer(self.learning_rate, self.moment)
grads = opt.compute_gradients(self.total_loss)
apply_gradient_op = opt.apply_gradients(grads, global_step=self.global_step)
# 這裏也可以直接寫成
# tf.train.MomentumOptimizer(self.learning_rate,self.moment).minimize(self.total_loss,global_step=self.global_step)
return apply_gradient_op
def construct_graph(self):
# 構建graph
self.global_step = tf.Variable(0, trainable=False)
# (1)訓練時網絡的輸入
self.images = tf.placeholder(tf.float32, (self.batch_size, self.height, self.width, 3))
self.labels = tf.placeholder(tf.float32, (self.batch_size, self.max_objects, 5))
self.objects_num = tf.placeholder(tf.int32, (self.batch_size))
# (2)inference部分,輸入是一張圖片,輸出是一個(N,cell_size,cell_size,class_num+box_num*5)的tensor
self.predicts = self.net.inference(self.images)
# (3)loss 部分
self.total_loss = self.net.loss(self.predicts, self.labels, self.objects_num)
tf.summary.scalar('loss', self.total_loss)
self.train_op = self._train()
def solve(self):
saver1 = tf.train.Saver(self.net.pretrained_collection, write_version=1)
saver2 = tf.train.Saver(self.net.trainable_collection, write_version=1)
# 變量初始化
init = tf.global_variables_initializer()
summary_op = tf.summary.merge_all()
sess = tf.Session()
sess.run(init)
# 加載預訓練模型
saver1.restore(sess, self.pretrain_path)
# 創建 event file writer
summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph)
for step in range(self.max_iterators):
start_time = time.time()
# 獲取一個batch的訓練數據
np_images, np_labels, np_objects_num = self.dataset.batch()
_, loss_value = sess.run([self.train_op, self.total_loss], feed_dict={self.images: np_images, self.labels: np_labels, self.objects_num: np_objects_num})
duration = time.time() - start_time
assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
if step % 10 == 0:
num_examples_per_step = self.dataset.batch_size
examples_per_sec = num_examples_per_step / duration
sec_per_batch = float(duration)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)')
print (format_str % (datetime.now(), step, loss_value,examples_per_sec, sec_per_batch))
sys.stdout.flush()
if step % 100 == 0: # 保存event file
summary_str = sess.run(summary_op, feed_dict={self.images: np_images, self.labels: np_labels, self.objects_num: np_objects_num})
summary_writer.add_summary(summary_str, step)
if step % 5000 == 0: # 保存checkpoint
saver2.save(sess, self.train_dir + '/model.ckpt', global_step=step)
sess.close()
測試相關代碼:demo.py
# 對網絡給出的預測結果做處理
def process_predicts(predicts):
# predicts 的shape是 (N,grid_size,grid_size,30), 30=(4+1)*2+20
p_classes = predicts[0, :, :, 0:20] # 類別的概率
C = predicts[0, :, :, 20:22] # Bbox中有物體的概率
coordinate = predicts[0, :, :, 22:] # 預測的Bbox座標
print(predicts.shape)
p_classes = np.reshape(p_classes, (7, 7, 1, 20))
C = np.reshape(C, (7, 7, 2, 1))
# P = 有物體的概率 * 類別的概率
P = C * p_classes
print(P.shape)
# 找到有最大的概率P的Bbox
index = np.argmax(P)
index = np.unravel_index(index, P.shape)
class_num = index[3]
coordinate = np.reshape(coordinate, (7, 7, 2, 4))
max_coordinate = coordinate[index[0], index[1], index[2], :]
# 對網絡輸出的座標值進行處理
# 網絡輸出的Bbox的中心座標是相對於格子左上角的座標,並且用格子的寬度進行歸一化(偏移+歸一化),這裏需要處理成在原圖中的座標
# 網絡輸出的Bbox的寬高是相對於圖片大小歸一化的,這裏也要恢復成原始大小
xcenter = max_coordinate[0]
ycenter = max_coordinate[1]
w = max_coordinate[2]
h = max_coordinate[3]
# ‘恢復’中心座標:反偏移,反歸一化
xcenter = (index[1] + xcenter) * (448/7.0)
ycenter = (index[0] + ycenter) * (448/7.0)
# ‘恢復’寬高到原始像素大小
w = w * 448
h = h * 448
xmin = xcenter - w/2.0
ymin = ycenter - h/2.0
xmax = xmin + w
ymax = ymin + h
# 這裏檢測部分寫的比較‘簡單’,直接取了物體概率*類別概率最大的那個Bbox和class的結果
# 實際上應該對每一個類分別進行檢測,並用NMS去除多餘的候選框
return xmin, ymin, xmax, ymax, class_num
common_params = {'image_size': 448, 'num_classes': 20, 'batch_size':1}
net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay': 0.0005}
# network,input place holder and output tensor
net = YoloTinyNet(common_params, net_params, test=True)
image = tf.placeholder(tf.float32, (1, 448, 448, 3))
predicts = net.inference(image)
sess = tf.Session()
# 讀入圖片
np_img = cv2.imread('cat.jpg')
height, width, channels = np_img.shape
print(height, width, channels)
# 對圖片作處理,尺寸縮放,值映射到[-1,1]
resized_img = cv2.resize(np_img, (448, 448))
np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)
np_img = np_img.astype(np.float32)
np_img = np_img / 255.0 * 2 - 1
np_img = np.reshape(np_img, (1, 448, 448, 3))
# 加載模型,並做前向傳播得到檢測結果
saver = tf.train.Saver()
saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt')
np_predict = sess.run(predicts, feed_dict={image: np_img})
xmin, ymin, xmax, ymax, class_num = process_predicts(np_predict)
class_name = classes_name[class_num]
cv2.rectangle(resized_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255))
cv2.putText(resized_img, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255))
cv2.imwrite('cat_out.jpg', resized_img)
sess.close()
其他沒有提到的部分代碼請見完整代碼;
另外對代碼中涉及到的一些TensorFlow的函數的使用做了一個簡單的整理,詳見tensorflow函數。