参考代码:https://github.com/Adlik/yolov5
https://cloud.tencent.com/developer/article/2160509
yolov5间的模型蒸馏,相同结构的。
配置参数
parser.add_argument('--t_weights', type=str, default='./weights/yolov5s.pt',
help='initial teacher model weights path')
parser.add_argument('--t_cfg', type=str, default='models/yolov5s.yaml', help='teacher model.yaml path')
parser.add_argument('--d_output', action='store_true', default=False,
help='if true, only distill outputs')
parser.add_argument('--d_feature', action='store_true', default=False,
help='if true, distill both feature and output layers')
加载教师模型
check_suffix(weights, '.pt') # check weights
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(csd, strict=False) # load
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
# 这里添加加载教师模型
# Teacher model
LOGGER.info(f'Loaded teacher model {t_cfg}') # report
t_ckpt = torch.load(t_weights, map_location=device) # load checkpoint
t_model = Model(t_cfg or t_ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
exclude = ['anchor'] if (t_cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = t_ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, t_model.state_dict(), exclude=exclude) # intersect
t_model.load_state_dict(csd, strict=False) # load
损失函数:
s_loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
d_outputs_loss = compute_distillation_output_loss(pred, t_pred, model, d_weight=10)
if opt.d_feature:
d_feature_loss = compute_distillation_feature_loss(s_f, t_f, model, f_weight=0.1)
loss = d_outputs_loss + s_loss + d_feature_loss
else:
loss = d_outputs_loss + s_loss