Tensorflow YOLO代碼解析(4)

下面介紹訓練和測試代碼,訓練代碼主要graph構建,加載預訓練模型,訓練中的數據讀取和保存相關日誌和模型文件等內容,測試代碼主要部分是模型預測結果格式的轉換。

其他相關的部分請見:
YOLO代碼解析(1) 代碼總覽與使用
YOLO代碼解析(2) 數據處理
YOLO代碼解析(3) 模型和損失函數
YOLO代碼解析(4) 訓練和測試代碼

訓練相關代碼:yolo_solver.py

def _train(self):
    """訓練模型
    創建優化器,最小化Loss
    Args:
      total_loss: Total loss from net.loss()
      global_step: Integer Variable counting the number of training steps
      processed
    Returns:
      train_op: op for training
    """
    # 使用Momentum優化算法
    opt = tf.train.MomentumOptimizer(self.learning_rate, self.moment)
    grads = opt.compute_gradients(self.total_loss)

    apply_gradient_op = opt.apply_gradients(grads, global_step=self.global_step)

    # 這裏也可以直接寫成
    # tf.train.MomentumOptimizer(self.learning_rate,self.moment).minimize(self.total_loss,global_step=self.global_step)

    return apply_gradient_op

  def construct_graph(self):
    # 構建graph
    self.global_step = tf.Variable(0, trainable=False)
    # (1)訓練時網絡的輸入
    self.images = tf.placeholder(tf.float32, (self.batch_size, self.height, self.width, 3))
    self.labels = tf.placeholder(tf.float32, (self.batch_size, self.max_objects, 5))
    self.objects_num = tf.placeholder(tf.int32, (self.batch_size))

    # (2)inference部分,輸入是一張圖片,輸出是一個(N,cell_size,cell_size,class_num+box_num*5)的tensor
    self.predicts = self.net.inference(self.images)

    # (3)loss 部分
    self.total_loss = self.net.loss(self.predicts, self.labels, self.objects_num)
    
    tf.summary.scalar('loss', self.total_loss)
    self.train_op = self._train()

  def solve(self):
    saver1 = tf.train.Saver(self.net.pretrained_collection, write_version=1)
    saver2 = tf.train.Saver(self.net.trainable_collection, write_version=1)

    # 變量初始化
    init =  tf.global_variables_initializer()

    summary_op = tf.summary.merge_all()

    sess = tf.Session()
    sess.run(init)

    # 加載預訓練模型
    saver1.restore(sess, self.pretrain_path)

    # 創建 event file writer
    summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph)

    for step in range(self.max_iterators):
      start_time = time.time()
      # 獲取一個batch的訓練數據
      np_images, np_labels, np_objects_num = self.dataset.batch()

      _, loss_value = sess.run([self.train_op, self.total_loss], feed_dict={self.images: np_images, self.labels: np_labels, self.objects_num: np_objects_num})


      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = self.dataset.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,examples_per_sec, sec_per_batch))

        sys.stdout.flush()
      if step % 100 == 0: # 保存event file
        summary_str = sess.run(summary_op, feed_dict={self.images: np_images, self.labels: np_labels, self.objects_num: np_objects_num})
        summary_writer.add_summary(summary_str, step)
      if step % 5000 == 0: # 保存checkpoint
        saver2.save(sess, self.train_dir + '/model.ckpt', global_step=step)
    sess.close()

測試相關代碼:demo.py

# 對網絡給出的預測結果做處理
def process_predicts(predicts):
    # predicts 的shape是 (N,grid_size,grid_size,30), 30=(4+1)*2+20
    p_classes = predicts[0, :, :, 0:20] # 類別的概率
    C = predicts[0, :, :, 20:22]        # Bbox中有物體的概率
    coordinate = predicts[0, :, :, 22:] # 預測的Bbox座標
    print(predicts.shape)

    p_classes = np.reshape(p_classes, (7, 7, 1, 20))
    C = np.reshape(C, (7, 7, 2, 1))

    # P = 有物體的概率 * 類別的概率
    P = C * p_classes
    print(P.shape)

    # 找到有最大的概率P的Bbox
    index = np.argmax(P)
    index = np.unravel_index(index, P.shape)

    class_num = index[3]

    coordinate = np.reshape(coordinate, (7, 7, 2, 4))

    max_coordinate = coordinate[index[0], index[1], index[2], :]

    # 對網絡輸出的座標值進行處理
    # 網絡輸出的Bbox的中心座標是相對於格子左上角的座標,並且用格子的寬度進行歸一化(偏移+歸一化),這裏需要處理成在原圖中的座標
    # 網絡輸出的Bbox的寬高是相對於圖片大小歸一化的,這裏也要恢復成原始大小
    xcenter = max_coordinate[0]
    ycenter = max_coordinate[1]
    w = max_coordinate[2]
    h = max_coordinate[3]

    # ‘恢復’中心座標:反偏移,反歸一化
    xcenter = (index[1] + xcenter) * (448/7.0)
    ycenter = (index[0] + ycenter) * (448/7.0)
    # ‘恢復’寬高到原始像素大小
    w = w * 448
    h = h * 448

    xmin = xcenter - w/2.0
    ymin = ycenter - h/2.0

    xmax = xmin + w
    ymax = ymin + h

    # 這裏檢測部分寫的比較‘簡單’,直接取了物體概率*類別概率最大的那個Bbox和class的結果
    # 實際上應該對每一個類分別進行檢測,並用NMS去除多餘的候選框
    return xmin, ymin, xmax, ymax, class_num


common_params = {'image_size': 448, 'num_classes': 20, 'batch_size':1}
net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay': 0.0005}

# network,input place holder and output tensor
net = YoloTinyNet(common_params, net_params, test=True)
image = tf.placeholder(tf.float32, (1, 448, 448, 3))
predicts = net.inference(image)

sess = tf.Session()

# 讀入圖片
np_img = cv2.imread('cat.jpg')
height, width, channels = np_img.shape
print(height, width, channels)


# 對圖片作處理,尺寸縮放,值映射到[-1,1]
resized_img = cv2.resize(np_img, (448, 448))
np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)
np_img = np_img.astype(np.float32)
np_img = np_img / 255.0 * 2 - 1
np_img = np.reshape(np_img, (1, 448, 448, 3))

# 加載模型,並做前向傳播得到檢測結果
saver = tf.train.Saver()
saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt')
np_predict = sess.run(predicts, feed_dict={image: np_img})

xmin, ymin, xmax, ymax, class_num = process_predicts(np_predict)
class_name = classes_name[class_num]
cv2.rectangle(resized_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255))
cv2.putText(resized_img, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255))
cv2.imwrite('cat_out.jpg', resized_img)
sess.close()

其他沒有提到的部分代碼請見完整代碼
另外對代碼中涉及到的一些TensorFlow的函數的使用做了一個簡單的整理,詳見tensorflow函數

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章