三元组损失(Triplet Loss)是一种用于学习深度神经网络嵌入的损失函数,它的主要目标是确保在我们的嵌入空间中,来自相同类别的样本更接近彼此,而不同类别的样本更远离彼此。三元组损失(Triplet Loss)常在人脸识别、图像检索等需要计算相似度的任务中使用
三元组损失需要三个样本来计算损失,这三个样本被称为锚(Anchor)、正(Positive)和负(Negative)样本。其中,锚样本是我们关注的样本,正样本与锚样本具有相同的类别标签,负样本与锚样本具有不同的类别标签。
假设我们已经通过神经网络得到了这三个样本在嵌入空间的位置,分别是 A(锚样本),P(正样本)和 N(负样本)。则三元组损失函数的形式为:
L = max(d(A, P) - d(A, N) + margin, 0)
其中,d(A, P) 和 d(A, N) 分别是锚样本与正样本,锚样本与负样本在嵌入空间的距离,"margin"是一个预设定的阈值,用于控制正样本与负样本之间的差异,我们希望锚样本比与负样本的距离比至少比与正样本的距离大。
例如:
我们有三个样本锚样本A, 正样本P, 负样本N。它们分别被一个神经网络映射到一个三维空间,得到的嵌入向量是:
A = [1, 1, 1]P = [1.1, 1.1, 1.1]
N = [2, 2, 2]
我们可以看到,正样本P比锚样本A更接近,而负样本N则比正样本P和锚样本A更远,这就是我们希望的结果。但如果网络没有很好的训练,可能会得到违背这一原则的嵌入,例如负样本N离锚样本A更近,那么这就需要三元组损失来调整网络的权重,使得同类样本更接近,不同类样本更远离。
Triplet Loss三元组损失函数如下:
def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
if norm_feat:
dist_mat = cosine_dist(embedding, embedding)
else:
dist_mat = euclidean_dist(embedding, embedding)
# For distributed training, gather all features from different process.
# if comm.get_world_size() > 1:
# all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
# all_targets = concat_all_gather(targets)
# else:
# all_embedding = embedding
# all_targets = targets
# 获取相似度矩阵dist_mat的行数,即样本数量
N = dist_mat.size(0)
# 创建两个相同大小的矩阵is_pos和is_neg,分别存储样本之间是否属于相同类别(正样本对)及不同类别(负样本对)
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
if hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if margin > 0:
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
else:
loss = F.soft_margin_loss(dist_an - dist_ap, y)
# fmt: off
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
# fmt: on
return loss
对上面代码进行解析:
def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
定义了一个名为triplet_loss的函数,输入参数为embedding(嵌入特征)、targets(目标标签)、margin(用于增加正负样本之间间距的值)、norm_feat(决定是否对特征进行归一化)以及hard_mining(决定是否启动困难样本挖掘)。
if norm_feat:
dist_mat = cosine_dist(embedding, embedding)
else:
dist_mat = euclidean_dist(embedding, embedding)
判断是否对特征进行归一化,若决定归一化,就用余弦距离度量相似度;若不归一化,则用欧氏距离度量相似度。
cosine_dist(embedding, embedding)是将embedding中的每一个向量与embedding中的每一个向量都计算一遍余弦距离。
假设你的embedding是一个(3, 2)的张量,内容如下:
[[a1, a2],
[b1, b2],
[c1, c2]]
其中,[a1, a2],[b1, b2]和[c1, c2]是这个embedding中的3个向量。
当你执行cosine_dist(embedding, embedding)时,实际上计算的是:
[[cosine_dist([a1, a2], [a1, a2]), cosine_dist([a1, a2], [b1, b2]), cosine_dist([a1, a2], [c1, c2])],
[cosine_dist([b1, b2], [a1, a2]), cosine_dist([b1, b2], [b1, b2]), cosine_dist([b1, b2], [c1, c2])],
[cosine_dist([c1, c2], [a1, a2]), cosine_dist([c1, c2], [b1, b2]), cosine_dist([c1, c2], [c1, c2])]]
这个结果是一个(3, 3)的矩阵,表示embedding中的每一个向量与embedding中的每一个向量之间的余弦距离。
当if norm_feat:这个条件语句为真时,即当我们想对embedding进行归一化处理时,就会使用这种方法计算embedding中所有向量之间的余弦距离。
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
创建两个相同大小的矩阵is_pos和is_neg,分别存储样本之间是否属于相同类别(正样本对)及不同类别(负样本对)。
targets.view(N, 1).expand(N, N),得到的结果是:
1 1 1 1
2 2 2 2
1 1 1 1
2 2 2 2
执行targets.view(N, 1).expand(N, N).t(),得到的结果是:
1 2 1 2
1 2 1 2
1 2 1 2
1 2 1 2
当我们用eq()去判断两个矩阵对应位置是否相等时,得到的结果(is_pos)是:
1 0 1 0
0 1 0 1
1 0 1 0
0 1 0 1
对应位置用ne()去判断是否不相等,得到的结果(is_neg)是:
0 1 0 1
1 0 1 0
0 1 0 1
1 0 1 0
if hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
根据是否进行困难样本挖掘,采用不同的挖掘方法获取到每个样本对的距离。
# 对于每个锚点样本,找到最难正样本(最远的具有相同类别标签的样本)和最难负样本(最近的具有不同类别标签的样本)。
def hard_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pair wise distance between samples, shape [N, M]
is_pos: positive index with shape [N, M]
is_neg: negative index with shape [N, M]
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.size()) == 2
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N]
# dist_ap表示锚点样本与正样本之间的距离。通过在距离矩阵和正样本矩阵做逐元素相乘后,取每行(每个锚点)的最大值。
dist_ap, _ = torch.max(dist_mat * is_pos, dim=1)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N]
# dist_an表示锚点样本与负样本之间的距离。首先,通过在距离矩阵和负样本矩阵做逐元素相乘后,再将正样本矩阵与大数(1e9)相乘并加到上述结果上,旨在将负样本对里的正样本对的距离设置地非常大。之后取每行的最小值,找出与锚点样本最近且类别不同的样本。
dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)
return dist_ap, dist_an
def weighted_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the weighted positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
is_pos:
is_neg:
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
"""
assert len(dist_mat.size()) == 2
is_pos = is_pos
is_neg = is_neg
# 对于每个锚点样本,找到正样本和负样本的加权距离
dist_ap = dist_mat * is_pos
dist_an = dist_mat * is_neg
# 分别通过softmax函数计算正样本和负样本的权重,注意负样本在计算权重之前要取负数。
weights_ap = softmax_weights(dist_ap, is_pos)
weights_an = softmax_weights(-dist_an, is_neg)
# 计算的是加权距离,将距离与对应的权重相乘,然后对结果进行累加求和,得到最后的加权距离。
dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
dist_an = torch.sum(dist_an * weights_an, dim=1)
return dist_ap, dist_an
y = dist_an.new().resize_as_(dist_an).fill_(1)
创建一个和dist_an相同大小并内容全部为1的向量。
if margin > 0:
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
else:
loss = F.soft_margin_loss(dist_an - dist_ap, y)
# fmt: off
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
# fmt: on
计算最终的三元组损失:
F.margin_ranking_loss函数是用来实现三元组损失的一个实用方法,它接受两组数据和一个目标向量作为输入来计算定制的秩序损失。
dist_ap代表锚点和正样本之间的距离,dist_an代表锚点和负样本之间的距离。y是目标向量,经常被设置为1,表示我们希望dist_an(锚点和负样本之间的距离)比dist_ap(锚点和正样本之间的距离)大。margin是我们希望两者之间的最小差距。