static __device__ float box_iou(
float aleft, float atop, float aright, float abottom,
float bleft, float btop, float bright, float bbottom)
{
float cleft = max(aleft, bleft);
float ctop = max(atop, btop);
float cright = min(aright, bright);
float cbottom = min(abottom, bbottom);
float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f);
if (c_area == 0.0f)
return 0.0f;
float a_area = max(0.0f, aright - aleft) * max(0.0f, abottom - atop);
float b_area = max(0.0f, bright - bleft) * max(0.0f, bbottom - btop);
return c_area / (a_area + b_area - c_area);
}
static __global__ void nms_kernel(float *bboxes, int max_objects, float threshold, ObjectDetectType_e detect_type)
{
int position = (blockDim.x * blockIdx.x + threadIdx.x);
int count = min((int)*bboxes, max_objects);
if (position >= count)
return;
int COMMON_BOX_ELEMENT;
if (detect_type == DETECT_FACE ) {
COMMON_BOX_ELEMENT = FACE_BOX_ELEMENT;
}else if(detect_type == DETECT_PLATE){
COMMON_BOX_ELEMENT = PLATE_BOX_ELEMENT;
}else{
COMMON_BOX_ELEMENT = NUM_BOX_ELEMENT;
}
float *pcurrent = bboxes + 1 + position * COMMON_BOX_ELEMENT;
for (int i = 0; i < count; ++i)
{
float *pitem = bboxes + 1 + i * COMMON_BOX_ELEMENT;
if (i == position || pcurrent[5] != pitem[5])
continue;
if (pitem[4] >= pcurrent[4])
{
if (pitem[4] == pcurrent[4] && i < position)
continue;
float iou = box_iou(
pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3],
pitem[0], pitem[1], pitem[2], pitem[3]);
if (iou > threshold)
{
pcurrent[6] = 0;
return;
}
}
}
}