Supervised Contrastive 损失函数详解

发布时间:2024年01月22日

在这里插入图片描述
有什么不对的及时指出,共同学习进步。(●’?’●)

"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode#设置对比的模式有one和all两种,代表对比一个channel还是所有,个人理解
        self.base_temperature = base_temperature #设置的温度

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')#设置设备
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:# batch_size, channel,H,W,平铺变成batch_size, channel, (H,W)
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:#只能存在一个
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:#如果两个都没有就是无监督对比损失,mask就是一个单位阵
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:#有标签,就把他变成mask
            labels = labels.contiguous().view(-1, 1)#contiguous深拷贝,与原来的labels没有关系,展开成一列,这样的话能够计算mask,否则labels一维的话labels.T是他本身捕获发生转置
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask =  torch.eq(labels, labels.T).float().to(device)#label和label的转置比较,感觉应该是广播机制,让label和label.T都扩充了然后进行比较,相同的是1,不同是0.
            #这里就是由label形成mask,mask(i,j)代表第i个数据和第j个数据的关系,如果两个类别相同就是1, 不同就是0
        else:
            mask = mask.float().to(device)#有mask就直接用mask,mask也是代表两个数据之间的关系

        contrast_count = features.shape[1]#对比数是channel的个数
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)#把feature按照第1维拆开,然后在第0维上cat,(batch_size*channel,h*w..)#后面就是展开的feature的维度
        #这个操作就和后面mask.repeat对上了,这个操作是第一个数据的第一维特征+第二个数据的第一维特征+第三个数据的第一维特征这样排列的与mask对应
        if self.contrast_mode == 'one':#如果mode=one,比较feature中第1维中的0号元素(batch, h*w)
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':#all就(batch*channel, h*w)
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),#两个相乘获得相似度矩阵,乘积值越大代表越相关
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)#计算其中最大值
        logits = anchor_dot_contrast - logits_max.detach()#减去最大值,都是负的了,指数就小于等于1

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)#repeat它就是把mask复制很多份
        # mask-out self-contrast cases
        logits_mask = torch.scatter(#生成一个mask形状的矩阵除了对角线上的元素是0,其他位置都是1, 不会对自身进行比较
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask#定义其中的相似度
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))#softmax

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan] 
        mask_pos_pairs = mask.sum(1)#mask的和
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)#满足返回1,不满足返回mask_pos_pairs.保证数值稳定
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos#类似蒸馏temperature温度越高,分布曲线越平滑不易陷入局部最优解,温度低,分布陡峭
        loss = loss.view(anchor_count, batch_size).mean()#计算平均

        return loss

使用的化就是下面这段:

loss = criterion(features, labels)
文章来源:https://blog.csdn.net/goodenough5/article/details/135757746
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。