1.載入模型與網絡文件,網絡預處理
string CfgPath="...\\name.cfg";
string WeightsPath="...\\name.weights";
network net = load_network((char*)CfgPath.c_str(), (char*)WeightsPath.c_str(), 0);///加載網絡,加載成功,終端會打印出網絡各層
set_batch_network(net, 1);///設置網絡batch, 預測時batch必須設置爲1
2. 輸入數據預處理
Mat frame=imread("testpath");
Mat resizeImg;
resize(frame, resizeImg, Size(net->w, net->h));///opencv API
size_t resizeSize = net->w*net->h * 3 * sizeof(float);
float* inputSizeImg = (float*)malloc(resizeSize);
imgConvert(resizeImg, inputSizeImg);
void imgConvert(const cv::Mat& img, float* dst)
{
uchar *data = img.data;
int h = img.rows;
int w = img.cols;
int c = img.channels();
for (int k = 0; k < c; ++k) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
dst[k*w*h + i * w + j] = data[(i*w + j)*c + k] / 255.;
}
}
}
}
3.網絡推理
float nms = 0.35;///置信度閾值
int classes = 2;///類型數目
network_predict(net, resizeImg);///網絡推理
int nboxes = 0;///預測框個數
detection *dets = get_network_boxes(net, frame.cols, frame.rows, confidenceThreshold, 0.5, 0, 1, &nboxes);
if (nms)
{
do_nms_sort(dets, nboxes, classes, nms);///非極大值抑制
}
4.計算輸出檢測框
vector<cv::Rect>boxes;
boxes.clear();
for (int i = 0; i < nboxes; i++)
{
bool flag = 0;
int className;
for (int j = 0; j < classes; j++)
{
if (dets[i].prob[j] > confidenceThreshold)
{
if (!flag)
{
flag = 1;
className = j;
}
}
}
if (flag)
{
int left = (dets[i].bbox.x - dets[i].bbox.w / 2.)*frame.cols;
int right = (dets[i].bbox.x + dets[i].bbox.w / 2.)*frame.cols;
int top = (dets[i].bbox.y - dets[i].bbox.h / 2.)*frame.rows;
int bot = (dets[i].bbox.y + dets[i].bbox.h / 2.)*frame.rows;
if (left < 0)
left = 0;
if (right > frame.cols - 1)
right = frame.cols - 1;
if (top < 0)
top = 0;
if (bot > frame.rows - 1)
bot = frame.rows - 1;
Rect box(left, top, fabs(left - right), fabs(top - bot));///opencv Rect類
boxes.push_back(box);
}
}
free_detections(dets, nboxes);
free(resizeImg);
if (boxes.size() < 1)
{
return -1;
}