一、功能:
1.读取label标签,可视化关键点。
2. 筛选cocokepoints 中的行人,过滤其中一些 平躺、人体局部等图像。
二、源码实现:
import os
import shutil
import cv2
import numpy as np
import matplotlib.pyplot as plt
def visual_image(image_path, labels,save_path):
palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
[230, 230, 0], [255, 153, 255], [153, 204, 255],
[255, 102, 255], [255, 51, 255], [102, 178, 255],
[51, 153, 255], [255, 153, 153], [255, 102, 102],
[255, 51, 51], [153, 255, 153], [102, 255, 102],
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0],
[255, 255, 255]])
skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12],
[7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
[1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
pose_limb_color = palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
pose_kpt_color = palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
radius = 5
steps = 2
im = cv2.imread(image_path)
save_path = os.path.join(save_path,image_path.split("/")[-1])
person_label = [(x) for x in labels if x.strip("\n").split(" ")[0] =="0"]
for item in person_label:
label = [float(i) for i in item.split(" ") ]
keypoints = label[5:]
kpts = []
for i in range(len(keypoints)//3):
kpts.append(keypoints[3*i]*im.shape[1])
kpts.append(keypoints[3*i+1]*im.shape[0])
num_kpts = len(kpts) // steps
for kid in range(num_kpts):
r, g, b = pose_kpt_color[kid]
x_coord, y_coord = kpts[steps * kid], kpts[steps * kid + 1]
cv2.circle(im, (int(x_coord), int(y_coord)), radius, (int(r), int(g), int(b)), -1)
for sk_id, sk in enumerate(skeleton):
r, g, b = pose_limb_color[sk_id]
pos1 = (int(kpts[(sk[0]-1)*steps]), int(kpts[(sk[0]-1)*steps+1]))
pos2 = (int(kpts[(sk[1]-1)*steps]), int(kpts[(sk[1]-1)*steps+1]))
if pos1[0]%im.shape[1] == 0 or pos1[1]%im.shape[0]==0 or pos1[0]<0 or pos1[1]<0:
continue
if pos2[0] % im.shape[1] == 0 or pos2[1] % im.shape[0] == 0 or pos2[0]<0 or pos2[1]<0:
continue
cv2.line(im, pos1, pos2, (int(r), int(g), int(b)), thickness=2)
cv2.imwrite(save_path,im)
def check_label(label,image_path):
==
img = cv2.imread(image_path)
w,h = img.shape[1],img.shape[0]
object_w = float(label.split(" ")[3])
object_h =float(label.split(" ")[4])
if ((object_w * w) > (1.5* object_h * h)):
return False
key_points = label.split(" ")[5:]
head_points = key_points[:15]
shoulder_points = key_points[15:27]
medium_points = key_points[33:39]
bottom_points = key_points[39:45]
head_points_y = [float(x) for x in head_points[1:][::3] if float(x) !=0.0]
shoulder_points_y = [float(x) for x in shoulder_points[1:][::3] if float(x) !=0.0]
medium_points_y = [float(x) for x in medium_points[1:][::3] if float(x) !=0.0]
bottom_points_y = [float(x) for x in bottom_points[1:][::3] if float(x) !=0.0]
head_points_y_avg = 0.0
shoulder_points_y_avg = 0.0
medium_points_y_avg = 0.0
bottom_points_avg = 0.0
if len(head_points_y):
head_points_y_avg = sum(head_points_y)/len(head_points_y)
if len(shoulder_points_y):
shoulder_points_y_avg = sum(shoulder_points_y)/len(shoulder_points_y)
if len(medium_points_y):
medium_points_y_avg = sum(medium_points_y)/ len(medium_points_y)
if len(bottom_points_y):
bottom_points_avg = sum(bottom_points_y) / len(bottom_points_y)
avg_keypoints_y = [head_points_y_avg,shoulder_points_y_avg,medium_points_y_avg,bottom_points_avg]
if avg_keypoints_y[0] == 0.0 and avg_keypoints_y[1] == 0.0 and avg_keypoints_y[2] == 0.0:
return False
if avg_keypoints_y[0] == 0.0 and avg_keypoints_y[1] == 0.0 and avg_keypoints_y[3] == 0.0:
return False
if avg_keypoints_y[0] == 0.0 and avg_keypoints_y[2] == 0.0 and avg_keypoints_y[3] == 0.0:
return False
if avg_keypoints_y[1] == 0.0 and avg_keypoints_y[2] == 0.0 and avg_keypoints_y[3] == 0.0:
return False
if avg_keypoints_y[0] == 0.0:
if avg_keypoints_y[1]==0.0:
if avg_keypoints_y[3] - avg_keypoints_y[2] > 0.0 :
return True
else:
if (avg_keypoints_y[3] - avg_keypoints_y[1]) > 0.0 or (avg_keypoints_y[2]-avg_keypoints_y[1])>0.0:
return True
if avg_keypoints_y[0] != 0.0 and avg_keypoints_y[2] != 0.0 :
if avg_keypoints_y[3] - avg_keypoints_y[0] <= 0.0:
return False
if avg_keypoints_y[0] != 0.0 and avg_keypoints_y[3] != 0.0 :
if avg_keypoints_y[1] - avg_keypoints_y[0] <= 0.0:
return False
return True
def select_label (imagedir,labeldir,savedir):
label_files = os.listdir(labeldir)
image_files = os.listdir(imagedir)
sks_save_source = ''
sks_save_change = ''
if not os.path.exists(sks_save_source):
os.makedirs(sks_save_source)
if not os.path.exists(sks_save_change):
os.makedirs(sks_save_change)
for label_file in label_files:
temp_labels = []
lable_path = os.path.join(labeldir,label_file)
image_path = os.path.join(imagedir,label_file.replace('txt','jpg'))
save_path = os.path.join(savedir,label_file)
if not os.path.exists(savedir):
os.makedirs(savedir)
if not os.path.exists(image_path):
continue
with open(lable_path,"r") as fselect:
labels = fselect.readlines()
visual_image(image_path,labels,sks_save_source)
for label in labels:
if label.split(" ")[0] == '0':
flag = True
flag = check_label(label,image_path)
if flag:
temp_labels.append(label)
else:
temp_labels.append(label)
visual_image(image_path,temp_labels,sks_save_change)
if len(temp_labels):
with open(save_path,'w') as fsave:
for item in temp_labels:
fsave.write(item)
fsave.close()
fselect.close()
if __name__ == "__main__":
imagedir = 'your image_dir path'
labeldir = 'your label_dir path'
savedir = 'your save_dir path'
select_label(imagedir,labeldir,savedir)