使用mmdetection训练HTC

配置文件

配置解读

测试

直接使用会报这个错:

TypeError: type object got multiple values for keyword argument 'groups'

说是字典里的对应值多了一个,查配置文件,发现有两个groups,值是一样的,看别的配置文件只有上面的一个,注释掉下面的试了一下,就可以了
在这里插入图片描述
测试结果:
在这里插入图片描述


训练

1.根据自己的数据集的类改动配置文件参考这个:配置改动

注意:一定要看好学习率的设置,16GPU是0.02,单卡或双卡设置0.0025左右,一定要小,一开始我双卡设的0.02,loss从个位数飙到五位数

2.因为我只需要检测bbox和mask,所以在配置文件的model字典里将semantic相关的都注释了
在这里插入图片描述
并且在train_pineline里设置with_seg=False
在这里插入图片描述

只设置with_seg是没有用的,最后会计算seg的loss,然而没有label,就报错,所以要加前面的注释

3.修改评估内容:
在这里插入图片描述

4.在这里取消TensorboardLoggerHook的注释就能在命令行看到loss
在这里插入图片描述
但是我的程序会弹出很多警告,做如下操作就能不显示警告:

import warnings
warnings.filterwarnings('ignore')

调试后再次运行出现:

Address alredy in use

这说明刚才停止的程序的内存没有释放,执行下面命令:

 ps aux|grep user_name|grep python

然后用kill -9 ID来释放内存。

这样就能顺利训练了:
在这里插入图片描述

使用预训练权重

参考

def main():
    #gen coco pretrained weight
    import torch
    num_classes = 11
    model_coco = torch.load("../checkpoints/cascade_rcnn_r50_fpn_1x_20190501-3b6211ab.pth") # weight
    for key, value in model_coco["state_dict"].items():
        print(key)

    model_coco["state_dict"]["bbox_head.0.fc_cls.weight"] = \
    model_coco["state_dict"]["bbox_head.0.fc_cls.weight"][:num_classes, :]

    model_coco["state_dict"]["bbox_head.1.fc_cls.weight"] = \
    model_coco["state_dict"]["bbox_head.1.fc_cls.weight"][:num_classes, :]

    model_coco["state_dict"]["bbox_head.2.fc_cls.weight"] = \
    model_coco["state_dict"]["bbox_head.2.fc_cls.weight"][:num_classes, :]

    model_coco["state_dict"]["bbox_head.0.fc_cls.bias"] = \
    model_coco["state_dict"]["bbox_head.0.fc_cls.bias"][:num_classes]

    model_coco["state_dict"]["bbox_head.1.fc_cls.bias"] = \
    model_coco["state_dict"]["bbox_head.1.fc_cls.bias"][:num_classes]

    model_coco["state_dict"]["bbox_head.2.fc_cls.bias"] = \
    model_coco["state_dict"]["bbox_head.2.fc_cls.bias"][:num_classes]
    # save new model
    torch.save(model_coco, "cascade_rcnn_r50_fpn_1x_coco_pretrained_weights_classes_%d.pth" % num_classes)
if __name__ == "__main__":
    main()

如果是其他的模型,可以使用这个工具,可以看到网络的参数,比如要从80类变成1个类,找后面的参数个数是81的,然后记住名字,用上面的代码修改

注意事项

这里的resize一定要注意!
在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章