如何將MMrotate的識別結果轉換爲dota和fair1m格式

問題來源,在使用mmrotate的過程中,需要能夠對識別的結果進行推斷,結果發現缺乏相關功能:
From the demo i know show_result_pyplot can plot the inferred results, I would like to ask how to convert inferred results to DOTA format, is there a related function? Or do you need to handle the result directly?
thanks!
Hi @jsxyhelu, we did not provide the corresponding script to the user. You need to convert the results to DOTA format by yourself. Welcome to submit your script to help more people.
 
那麼就自己來設計實現相關功能:
 
一、數據格式:
mmrotate的輸出格式爲:

分別爲: x, y, w, h, theta, score.

目標格式爲Dota採用 txt 文件存放,

其中一個標註框對應爲:  x1、 y1、   x2、 y2、 x3、 y3、 x4、 y4、 classname、diffcult 。注意這裏沒有歸一化處理 

二、批量處理和保存

首先將result保存下來

import os
import numpy as np
 
src_label_root = '/root/mmrotate/demo/ssdd_tiny/images/'
dst_label_root = '/root/mmrotate/demo/ssdd_tiny/dst/'
!mkdir '/root/mmrotate/demo/ssdd_tiny/dst/'
 
 
model.cfg = cfg
for i, src_label_name in enumerate(os.listdir(src_label_root)):
    src_label_path = os.path.join(src_label_root,src_label_name) #輸入地址
    dst_label_path = os.path.join(dst_label_root,os.path.splitext(src_label_name)[0]+".txt")
    img = mmcv.imread(src_label_path)
    result = inference_detector(model, img)
    np.savetxt(dst_label_path, result[0], delimiter=',')
    print(dst_label_path)

而後進行格式轉換,對於單通道圖片來說爲:

def rota( x, y, w, h, a):  # 旋轉中心點,旋轉中心點,框的w,h,旋轉角
    center_x1 = x
    center_y1 = y
    x1, y1 = x - w / 2, y - h / 2  # 旋轉前左上
    x2, y2 = x + w / 2, y - h / 2  # 旋轉前右上
    x3, y3 = x + w / 2, y + h / 2  # 旋轉前右下
    x4, y4 = x - w / 2, y + h / 2  # 旋轉前左下
    px1 = (x1 - center_x1) * math.cos(a) - (y1 - center_y1) * math.sin(a) + center_x1  # 旋轉後左上
    py1 = (x1 - center_x1) * math.sin(a) + (y1 - center_y1) * math.cos(a) + center_y1
    px2 = (x2 - center_x1) * math.cos(a) - (y2 - center_y1) * math.sin(a) + center_x1  # 旋轉後右上
    py2 = (x2 - center_x1) * math.sin(a) + (y2 - center_y1) * math.cos(a) + center_y1
    px3 = (x3 - center_x1) * math.cos(a) - (y3 - center_y1) * math.sin(a) + center_x1  # 旋轉後右下
    py3 = (x3 - center_x1) * math.sin(a) + (y3 - center_y1) * math.cos(a) + center_y1
    px4 = (x4 - center_x1) * math.cos(a) - (y4 - center_y1) * math.sin(a) + center_x1  # 旋轉後左下
    py4 = (x4 - center_x1) * math.sin(a) + (y4 - center_y1) * math.cos(a) + center_y1
 
    return px1, py1, px2, py2, px3, py3, px4, py4  # 旋轉後的四個點,左上,右上,右下,左下

def mmrotate2dota(src_img_root, src_label_root, dst_label_root,class_map,score_thr=0.3):
    not_have_img = []
    if not os.path.exists(dst_label_root):
        os.makedirs(dst_label_root)
    # 遍歷所有txt文件
    for i, src_label_name in enumerate(os.listdir(src_label_root)):
        src_label_path = os.path.join(src_label_root,src_label_name) #輸入地址
        dst_label_path = os.path.join(dst_label_root,src_label_name) #輸出地址
        dst_label_list = []          ## 空列表
        with open(src_label_path, 'r') as fr:
            txtlines = fr.readlines()   #原始數據
        for line in txtlines:
            oneline = line.strip().split(",")    
            x = float(oneline[0])
            y = float(oneline[1])
            w = float(oneline[2])
            h = float(oneline[3])
            a = float(oneline[4])
            score = float(oneline[5])
            px1, py1, px2, py2, px3, py3, px4, py4 = rota(x,y,w,h,a)
            #目標格式爲  x1、y1、x2、y2、x3、y3、x4、y4、 classname、diffcult
            dstline = str(px1)+" "+ str(py1)+" "+ str(px2)+" "+ str(py2)+" "+ str(px3)+" "+ str(py3)+" "+ str(px4)+" "+ str(py4)+" "+ str(class_map['0'])+ "1"
            if(score >= score_thr):
                dst_label_list.append(dstline)
        with open(dst_label_path,'w') as fw:
            fw.writelines([line+'\n' for line in dst_label_list]) #添加換行
        print(dst_label_path)
    print('convert done')

得到初步的對比結果,目視是正確的

 

使用Dota自己的工具進行標繪(Dota_devKit)

 

具體

查看  https://www.kaggle.com/code/jsxyhelu2019/ddd-mmrotate-result2dota

三、獲得批量處理結果

當前的結果處理的只是一種類型的,在處理批量數據的時候是有不同的。

而且轉換的過程中存在錯誤,需要進行修正。

通過模仿現有的例子,能夠獲得讀取現有pt,執行推斷的結果。

它的內容是這樣來組織的:

 

一共37個array,每一個都是推測出來的位置。

這樣的話在寫下來的過程中,就需要編碼了。

而且在推斷的過程中,就是需要使用

from mmrotate.apis import inference_detector_by_patches
img = 'demo/dota_demo.jpg'
result = inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
 
def inference_detector_by_patches(model,
                                  img,
                                  sizes,
                                  steps,
                                  ratios,
                                  merge_iou_thr,
                                  bs=1):
    """inference patches with the detector.
    Split huge image(s) into patches and inference them with the detector.
    Finally, merge patch results on one huge image by nms.
    Args:
        model (nn.Module): The loaded detector.
        img (str | ndarray or): Either an image file or loaded image.
        sizes (list): The sizes of patches.
        steps (list): The steps between two patches.
        ratios (list): Image resizing ratios for multi-scale detecting.
        merge_iou_thr (float): IoU threshold for merging results.
        bs (int): Batch size, must greater than or equal to 1.
    Returns:
        list[np.ndarray]: Detection results.
    "
""

所以最後,單個寫:

# Use the detector to do inference
dst = []
from mmrotate.apis import inference_detector_by_patches
img = '/home/helu/workstation/Fair1m/fair1M_jpg_train_split_1280_200/images/1__1__0___0.jpg'
result =  inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
for index,typeresult in enumerate(result):
    if(typeresult.size!=0):
        for lineresult in typeresult:
            lineresult = np.append(lineresult,  np.float32(index))
            dst.append(lineresult)
            #print(index)
print(dst)
#show_result_pyplot(model, img, result, score_thr=0.3)

批量處理,獲得Dota的結果

 

 

test_image_root = '/home/helu/workstation/Fair1m/fair1M_jpg_test_tiny/images/'
test_result_root = '/home/helu/workstation/Fair1m/fair1M_jpg_test_tiny/labelTxt/'
dst = []
from mmrotate.apis import inference_detector_by_patches
for i, test_image_name in enumerate(os.listdir(test_image_root)):
    dst = []
    test_image_path = os.path.join(test_image_root,test_image_name) #輸入地址
    dst_label_path = os.path.join(test_result_root,os.path.splitext(test_image_name)[0]+".txt")
    img = mmcv.imread(test_image_path)
    result =  inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
    for index,typeresult in enumerate(result):
        if(typeresult.size!=0):
            for lineresult in typeresult:
                lineresult = np.append(lineresult,  np.float32(index))
                dst.append(lineresult)
        np.savetxt(dst_label_path, dst, delimiter=',')

 

 

全部代碼爲 https://files.cnblogs.com/files/blogs/758212/MMRotat_infer.rar?t=1682895717&download=true

 

 

 

需要進行進一步的修改,或者數據轉換也可以。

 

轉換爲Fair1m數據格式並上分,30 epoch獲得這個值

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章