圖片檢測(Mask RCNN)

# -*- coding: UTF-8 -*-

## https://github.com/matterport/Mask_RCNN
## https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5

from mrcnn.config import Config
from mrcnn import model as modellib
from mrcnn import visualize
import numpy as np
import colorsys
import imutils
import random
import cv2
import os

model_path = "mask_rcnn_coco.h5"
labels_path = "coco_labels.txt"
image_path = "image.jpg"

# 讀取標籤
CLASS_NAMES = open(labels_path, encoding="utf-8").read().strip().split("\n")
print(CLASS_NAMES)
# 標籤設置顏色
hsv = [(i / len(CLASS_NAMES), 1, 1.0) for i in range(len(CLASS_NAMES))]
COLORS = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
random.seed(42)
random.shuffle(COLORS)

# 配置
class SimpleConfig(Config):
    NAME = "coco_inference"
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    NUM_CLASSES = len(CLASS_NAMES)
    
config = SimpleConfig()
 
# 讀取  Mask R-CNN model
print("[INFO] 讀取  Mask R-CNN model...")
model = modellib.MaskRCNN(mode="inference", config=config,
    model_dir=os.getcwd())
model.load_weights(model_path, by_name=True)

# 讀圖片
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = imutils.resize(image, width=512)
 
# 預測
print("[INFO] 使用Mask R-CNN進行預測...")
r = model.detect([image], verbose=1)[0]

# 提取邊框,masks
for i in range(0, r["rois"].shape[0]):
    classID = r["class_ids"][i]
    mask = r["masks"][:, :, i]
    color = COLORS[classID][::-1]
 
    image = visualize.apply_mask(image, mask, color, alpha=0.5)

# 將圖像轉換回BGR
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 
# 展示
for i in range(0, len(r["scores"])):
    (startY, startX, endY, endX) = r["rois"][i]
    classID = r["class_ids"][i]
    label = CLASS_NAMES[classID]
    score = r["scores"][i]
    color = [int(c) for c in np.array(COLORS[classID]) * 255]
 
    # 畫
    cv2.rectangle(image, (startX, startY), (endX, endY), color, 2)
    text = "{}: {:.3f}".format(label, score)
    print(text)
    y = startY - 10 if startY - 10 > 10 else startY + 10
    cv2.putText(image, text, (startX, y), cv2.FONT_HERSHEY_SIMPLEX,
        0.6, color, 2)

cv2.imshow("Output", image)
cv2.waitKey()

labels: 

BG
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章