pointnet train函數第二十七句 for epoch in range(MAX_EPOCH):

        for epoch in range(MAX_EPOCH):
            log_string('**** EPOCH %03d ****' % (epoch))
            sys.stdout.flush()
             
            train_one_epoch(sess, ops, train_writer)
            eval_one_epoch(sess, ops, test_writer)
            
            # Save the variables to disk.
            if epoch % 10 == 0:
                save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
                log_string("Model saved in file: %s" % save_path)

當前工程cls的MAX_EPOCH設置爲250,當前的epoch是爲了增加樣本數量,因爲樣本有限,所以需要每個epoch打亂一次訓練樣本,以此來增加訓練樣本的總數

train_one_epoch(sess, ops, train_writer)

這句則是具體train過程函數,具體實現如下

def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    
    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        print(TRAIN_FILES[train_file_idxs[fn]])
        current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]])
        print("current_data shape")
        print(current_data.shape)
        print("current_label shape")
        print(current_label.shape)
        current_data = current_data[:,0:NUM_POINT,:]
        current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))            
        current_label = np.squeeze(current_label)
        
        file_size = current_data.shape[0]
        print("current_data.shape[0]:")
        print(current_data.shape[0])
        num_batches = file_size // BATCH_SIZE
        print("num_batches,BATCH_SIZE")
        print(num_batches)
        print(BATCH_SIZE)
        total_correct = 0
        total_seen = 0
        loss_sum = 0
       
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE
            
            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            feed_dict = {ops['pointclouds_pl']: jittered_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
        
        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))

 

這裏第一句is_training是用來設置is_training_pl tensor

第二句,TRAIN_FILES可以看看定義裏面,是從train_files.txt裏面讀取point數據的路徑,內容如下

data/modelnet40_ply_hdf5_2048/ply_data_train0.h5
data/modelnet40_ply_hdf5_2048/ply_data_train1.h5
data/modelnet40_ply_hdf5_2048/ply_data_train2.h5
data/modelnet40_ply_hdf5_2048/ply_data_train3.h5
data/modelnet40_ply_hdf5_2048/ply_data_train4.h5

 調用了

provider.getDataFiles,產生一個array裏面放的是上面h5文件名,第二句則根據array的長度產生對應的一個index的array

第三句是將這個array順序打亂,從而在每次epoch有不同的索引序列,在第四句for循環中讀取data的時候point樣本不同,達到增加訓練數據的目的

然後看for循環內部,第一步是讀取point data以及label data,參考pointnet provider.loadDataFile讀取之後shape爲三維batchsize,pointnum,xyz,的current_data因爲每個h5文件中的point的pointnum不一定都是1024因此,需要對pointnum這一維進行處理,然後繼續對點雲順序打亂順序繼續增加樣本多樣性

current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))

參考pointnet shuffle_data(data, labels)

其中"_"爲index np array點雲處理過程中,這裏的index特別重要,所以要及時記錄

current_label = np.squeeze(current_label)

繼續把current_label 中維度爲1的數據去除掉

file_size = current_data.shape[0]

        num_batches = file_size // BATCH_SIZE

首先獲取h5文件中的file_size即一個文件中有多少個point模型。因爲我們的網絡處理單位是batchsize*pointnum爲一次,因此需要計算一個文件需要多少次bacthsize計算才能訓練完畢一個文件

即num_batches

total_correct = 0
        total_seen = 0
        loss_sum = 0

分別爲總的分類準確的個數

總的訓練模型個數

總的損失值

獲取到數據之後根據num_batches進行loop每個loop一個batchsize的模型數進入下面的循環

for batch_idx in range(num_batches):

因爲current_data裏面第一維是filesize,是h5文件中總的point點雲模型的array,每次讀取一個batchsie個,當前設置的是32個,因此需要每次循環更新起始index以及終止index,取出一個batchsize的point模型數據。

start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE

取出數據之後perovider進行處理,

rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)

參考pointnet provider.rotate_point_cloud provider.jitter_point_cloud,作用是對point cloud進行旋轉平移,旋轉角度以及平移距離爲隨機的,產生更多的訓練樣本

feed_dict = {ops['pointclouds_pl']: jittered_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}

此處的代碼是將我們之前的placeholder產生的rensor進行feedict才能在gragh中進行運算

summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)

這一句即是運行我們前二十六句建立起來的由各種tensor組成的op鏈接起來的gragh圖,運行完即訓練完了一次,返回summary可視化數據,step訓練次數,_ 點雲索引,loss數據,pred正向數據

pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))

 

訓練完一個文件內的數據,計算一次平均loss以及精確度

train完一次,接着進行test數據進行預測並且得到預測準確率參考pointnet def eval_one_epoch(sess, ops, test_writer)

至此分類訓練代碼解讀完畢後續會持續更新更正,然後用pytorch+open3d實現一遍。敬請期待

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章