文章目录

  • 前言
  • 一、Constrative loss[1]
  • 二、Triplet loss[2]
    • Offline and online triplet mining
    • 参考
  • 三、Lifted Structure Loss
  • 四、N-pairs loss [4]
  • 五、Multi-similarity (MS) loss
  • 参考文献

前言

一系列DML(deep metric learning)方法被称为基于对(pair-based)的方法,其目标可以根据小批量内的成对相似性来定义。例如Constrative loss、Triplet loss、Lifted Structure loss、N-pairs loss、Multi-similarity loss等等。


提示:以下是本篇文章正文内容

一、Constrative loss[1]

文章提出了一种从数据中训练相似性度量的方法。这种方法适用于识别和验证任务,其中任务特点:
(1)数据所属的类别特别多
(2)有些类别在训练的时候是未知的
(3)并且每个类别的训练样本特别少。

孪生神经网络一般采用Contrastive Loss处理成对的数据,对于positive pair,输出特征向量距离要尽量小;对于negative pair,输出特征距离要尽量大,但若Ew>m则不处理这种easy negative pair。

根据论文[1]的推导得到最终表达式:

其中W表示是网络权重,Y是成对标签,如果X1,X2这对样本属于同一个类,则Y=0,属于不同类则Y=1。EW为定义的一个标量“能量方程(energy function),Gw定义为映射后的特征值。LG为相似对损失,LI为不相似对损失。


整套机制应满足如上条件,其中m为margin。

代码表示如下:

# 自定义ContrastiveLoss
class ContrastiveLoss(torch.nn.Module):"""Contrastive loss function.Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf"""def __init__(self, margin=2.0):super(ContrastiveLoss, self).__init__()self.margin = margindef forward(self, output1, output2, label):euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))return loss_contrastive

二、Triplet loss[2]

Triplet Loss即三元组损失,定义为:最小化Anchor和Positive之间的距离,最大化Anchor和不同身份的Negative之间的距离。

我们期望下式成立:

其中α为margin,T为就是样本容量为N的数据集的各种三元组。然后根据上式,Triplet Loss可以写成:

生成所有可能的三元组将导致许多容易满足的三元组(即满足等式(1)中的约束)。这些三胞胎不会对训练做出贡献,并导致较慢的融合,因为它们仍然会通过网络传递。关键是要选择hard triplets,它们是活跃的,因此有助于改进模型。

Offline and online triplet mining

  1. 每n步离线生成triplet,使用最新的网络检查点并计算数据子集上的argmin和argmax。
  2. 在线生成triplet。这可以通过从mini-batch中选择hard positive/negative样本来实现。

下面是在线生成的triplet loss代码:

class TripletLoss(nn.Module):"""Triplet loss with hard positive/negative mining.Reference:Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.Args:margin (float, optional): margin for triplet. Default is 0.3."""def __init__(self, margin=0.3,global_feat, labels):super(TripletLoss, self).__init__()self.margin = margin# https://pytorch.org/docs/1.2.0/nn.html?highlight=marginrankingloss#torch.nn.MarginRankingLoss# 计算两个张量之间的相似度,两张量之间的距离>margin,loss 为正,否则loss 为 0self.ranking_loss = nn.MarginRankingLoss(margin=margin)def forward(self, inputs, targets):"""Args:inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).targets (torch.LongTensor): ground truth labels with shape (num_classes)."""n = inputs.size(0)    # batch_size# Compute pairwise distance, replace by the official when mergeddist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)dist = dist + dist.t()dist.addmm_(1, -2, inputs, inputs.t())dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability# For each anchor, find the hardest positive and negativemask = targets.expand(n, n).eq(targets.expand(n, n).t())dist_ap, dist_an = [], []for i in range(n):dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))dist_ap = torch.cat(dist_ap)dist_an = torch.cat(dist_an)# Compute ranking hinge lossy = torch.ones_like(dist_an)loss = self.ranking_loss(dist_an, dist_ap, y)return loss

参考

PyTorch TripletMarginLoss(三元损失)


三、Lifted Structure Loss

Lifted Structure loss的思想是对于一对正样本对而言,不去区分这个样本对中谁是anchor,谁是positive,而是让这个正样本对中的每个样本与其他所有负样本的距离都大于给定的阈值。此方法能够充分的利用mini-batch中的所有样本,挖掘出所有的样本对。

每个batch的loss定义为:

代码:

class LiftedStructureLoss(GenericPairLoss):def __init__(self, neg_margin=1, pos_margin=0, **kwargs):super().__init__(mat_based_loss=False, **kwargs)self.neg_margin = neg_marginself.pos_margin = pos_marginself.add_to_recordable_attributes(list_of_names=["pos_margin", "neg_margin"], is_stat=False)def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple):a1, p, a2, _ = indices_tupledtype = pos_pairs.dtypeif len(a1) > 0 and len(a2) > 0:pos_pairs = pos_pairs.unsqueeze(1)n_per_p = c_f.to_dtype((a2.unsqueeze(0) == a1.unsqueeze(1))| (a2.unsqueeze(0) == p.unsqueeze(1)),dtype=dtype,)neg_pairs = neg_pairs * n_per_pkeep_mask = ~(n_per_p == 0)remaining_pos_margin = self.distance.margin(pos_pairs, self.pos_margin)remaining_neg_margin = self.distance.margin(self.neg_margin, neg_pairs)neg_pairs_loss = lmu.logsumexp(remaining_neg_margin, keep_mask=keep_mask, add_one=False, dim=1)loss_per_pos_pair = neg_pairs_loss + remaining_pos_marginloss_per_pos_pair = torch.relu(loss_per_pos_pair) ** 2loss_per_pos_pair /= (2  # divide by 2 since each positive pair will be counted twice)return {"loss": {"losses": loss_per_pos_pair,"indices": (a1, p),"reduction_type": "pos_pair",}}return self.zero_losses()class GeneralizedLiftedStructureLoss(GenericPairLoss):# The 'generalized' lifted structure loss shown on page 4# of the "in defense of triplet loss" paper# https://arxiv.org/pdf/1703.07737.pdfdef __init__(self, neg_margin=1, pos_margin=0, **kwargs):super().__init__(mat_based_loss=True, **kwargs)self.neg_margin = neg_marginself.pos_margin = pos_marginself.add_to_recordable_attributes(list_of_names=["pos_margin", "neg_margin"], is_stat=False)def _compute_loss(self, mat, pos_mask, neg_mask):remaining_pos_margin = self.distance.margin(mat, self.pos_margin)remaining_neg_margin = self.distance.margin(self.neg_margin, mat)pos_loss = lmu.logsumexp(remaining_pos_margin, keep_mask=pos_mask.bool(), add_one=False)neg_loss = lmu.logsumexp(remaining_neg_margin, keep_mask=neg_mask.bool(), add_one=False)return {"loss": {"losses": torch.relu(pos_loss + neg_loss),"indices": c_f.torch_arange_from_size(mat),"reduction_type": "element",}}

四、N-pairs loss [4]

Triplet loss同时拉近一对正样本和一对负样本,这就导致在选取样本对的时候,当前样本对只能够关注一对负样本对,而缺失了对其他类别样本的区分能力。

为了改善这种情况,N-pair loss[4]就选取了多个负样本对,即一对正样本对,选取其他所有不同类别的样本作为负样本与其组合得到负样本对。如果数据集中有 N个类别,则每个正样本对Yii都对应了N-1个负样本对。N+1元组一般不会提前构建好,而是在训练的过程中,从同一个mini batch中构建出来。

相关代码:

def cross_entropy(logits, target, size_average=True):if size_average:return torch.mean(torch.sum(- target * F.log_softmax(logits, -1), -1))else:return torch.sum(torch.sum(- target * F.log_softmax(logits, -1), -1))class NpairLoss(nn.Module):"""the multi-class n-pair loss"""def __init__(self, l2_reg=0.02):super(NpairLoss, self).__init__()self.l2_reg = l2_regdef forward(self, anchor, positive, target):batch_size = anchor.size(0)target = target.view(target.size(0), 1)target = (target == torch.transpose(target, 0, 1)).float()target = target / torch.sum(target, dim=1, keepdim=True).float()logit = torch.matmul(anchor, torch.transpose(positive, 0, 1))loss_ce = cross_entropy(logit, target)l2_loss = torch.sum(anchor**2) / batch_size + torch.sum(positive**2) / batch_sizeloss = loss_ce + self.l2_reg*l2_loss*0.25return

五、Multi-similarity (MS) loss


S:Self-similarity:从自身对计算而来,是最重要的相似性。一个反例对有一个更大的余弦相似对意味着从不同的类别中区分两对样例是更困难的。这样的对被视为硬反例对(hard negative pairs),他们有更多的信息并且更有意义去学习一个可区分的特征。Contrastive loss和Binomial Deviance Loss就是基于这个准则,如图case-1,当反例样例变得更近的时候,三个反例对的权重是被增加的。

N: Negative relative similarity:通过考虑附近反例对的关系计算而来的,如图case-2,即使自相似度(self-similarity)不变,相对相似度也减少。这是因为附近的反例样例变得更近,增加了这些对的自相似度(self-similarity),所以减少了相对相似度。Lifted Structure Loss就是基于这个的。

P:Positive relative similarity:相对相似度也考虑其他的正例对的关系,如果case-3,当这些正例样例变得和anchor更近的时候,当前对的相对相似度就变小了,因此该对的权重也变小。Triplet loss就是基于这个相似度。

主要分为两步:1. 首先通过Similarity-P来将信息丰富的对采样;2. 然后使用Similarity-S和Similarity-N一起给选择的对加权。

相关代码:

class MultiSimilarityLoss(nn.Module):def __init__(self, cfg):super(MultiSimilarityLoss, self).__init__()self.thresh = 0.5self.margin = 0.1self.scale_pos = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POSself.scale_neg = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEGdef forward(self, feats, labels):# feats = features extracted from backbone model for images# labels = ground truth classes corresponding to imagesbatch_size = feats.size(0)sim_mat = torch.matmul(feats, torch.t(feats))         # since feats are l2 normalized vectors, taking
its dot product with transpose of itself will yield a similarity matrix whose i,j (row and column) will correspond to similarity between i'th embedding and j'th embedding of the batch, dim of sim mat = batch_size * batch_size. zeroth row of this matrix correspond to similarity between zeroth embedding of the batch with all other embeddings in the batch.epsilon = 1e-5loss = list()for i in range(batch_size): # i'th embedding is the anchorpos_pair_ = sim_mat[i][labels == labels[i]] # get all positive pair simply by matching ground truth labels of those embedding which share the same label with anchorpos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] # remove the pair which calculates similarity of anchor with itself i.e the pair with similarity one.neg_pair_ = sim_mat[i][labels != labels[i]] # get all negative embeddings which doesn't share the same ground truth label with the anchorneg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]  # mine hard negatives using the method described in the blog, a margin of 0.1 is added to the neg pair similarity to fetch negatives which are just lying on the brink of boundary for hard negative which would have been missed if this term was not present.pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]# mine hard positives using the method described in the blog with a margin of 0.1.if len(neg_pair) < 1 or len(pos_pair) < 1:continue# continue calculating the loss only if both hard pos and hard neg are present.# weighting steppos_loss = 1.0 / self.scale_pos * torch.log(1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))neg_loss = 1.0 / self.scale_neg * torch.log(1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))# losses as described in the equationloss.append(pos_loss + neg_loss)if len(loss) == 0:return torch.zeros([], requires_grad=True)loss = sum(loss) / batch_sizereturn loss

参考文献

[1]: S. Chopra, R. Hadsell and Y. LeCun, “Learning a similarity metric discriminatively, with application to face verification,” 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’05), 2005, pp. 539-546 vol. 1, doi: 10.1109/CVPR.2005.202.

[2]: Schroff, Florian et al. “FaceNet: A unified embedding for face recognition and clustering.” 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2015): 815-823.

[3]: Hyun Oh Song, Yu Xiang, Stefanie Jegelka, and Silvio Savarese. Deep metric learning via lifted structured feature embedding. In CVPR, 2016.

[4]: Kihyuk Sohn. Improved deep metric learning with multi-class n-pair loss objective. In NeurIPS. 2016.

[5]: Xun Wang, Xintong Han, Weilin Huang, Dengke Dong, and Matthew R Scott. Multi-similarity loss with general pair weighting for deep metric learning. In CVPR, 2019.

深度度量学习(DML)中pair-based方法中的loss相关推荐

  1. tensorflow中同时两个损失函数_深度度量学习中的损失函数

    度量学习(metric learning)研究如何在一个特定的任务上学习一个距离函数,使得该距离函数能够帮助基于近邻的算法(kNN.k-means等)取得较好的性能.深度度量学习(deep metri ...

  2. 深度度量学习-论文简评

    缘起: 有人说起深度度量学习,就会觉得这里面水文多,或者觉得鱼龙混杂,参见曾经上了知乎热榜的一个问题: 如何看待研究人员声称近13年来在 deep metric learning 领域的进展实际并不存 ...

  3. Facebook爆锤深度度量学习:该领域13年来并无进展!网友:沧海横流,方显英雄本色...

    来源:AI科技评论 近日,Facebook AI和Cornell Tech的研究人员近期发表研究论文预览文稿,声称近十三年深度度量学习(deep metric learning) 领域的目前研究进展和 ...

  4. 度量学习DML之Circle Loss

    度量学习DML之Contrastive Loss及其变种_程大海的博客-CSDN博客 度量学习DML之Triplet Loss_程大海的博客-CSDN博客 度量学习DML之Lifted Structu ...

  5. 度量学习 度量函数 metric learning deep metric learning 深度度量学习

    曼哈顿距离(CityBlockSimilarity) 同欧式距离相似,都是用于多维数据空间距离的测度. 欧式距离(Euclidean Distance) 用于衡量多维空间中各个点之间的绝对距离.欧式距 ...

  6. 深度度量学习 (metric learning deep metric learning )度量函数总结

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_16234613/article/ ...

  7. 深度度量学习的这十三年,难道是错付了吗?

    机器之心报道 机器之心编辑部 或许对于每一个领域来说,停下脚步去思考,与低头赶路一样重要. 「度量学习(Metric Learning)」即学习一个度量空间,在该空间中的学习异常高效,这种方法用于小样 ...

  8. 【CVPR2022】语言引导与基于视觉的深度度量学习的集成

    来源:专知 本文为论文,建议阅读5分钟我们提出了一种视觉相似度学习的语言指导目标. 深度度量学习(Deep Metric Learning, DML)提出学习度量空间,将语义相似性编码为嵌入空间距离. ...

  9. 度量学习DML之MoCO

    度量学习DML之Contrastive Loss及其变种_程大海的博客-CSDN博客 度量学习DML之Triplet Loss_程大海的博客-CSDN博客 度量学习DML之Lifted Structu ...

  10. 度量学习DML之Lifted Structure Loss

    度量学习DML之Contrastive Loss及其变种_程大海的博客-CSDN博客 度量学习DML之Triplet Loss_程大海的博客-CSDN博客 度量学习DML之Lifted Structu ...

最新文章

  1. 佐治亚理工学院计算科学与工程系博士生招生!
  2. ORACLE 数据字典
  3. Google工程师带你学算法
  4. python123第九周测验答案2020_运用python123平台助力编程课教学
  5. 在Delphi程序中应用IE浏览器控件
  6. C语言实现二分法检索binary search(附完整源码)
  7. 我想solo自己一个人!
  8. Hadoop集群(一) Zookeeper搭建
  9. 波兰表达式(前序表达式)的计算(栈)
  10. Windows2003四大必知版本
  11. php中abs,php中的abs函数怎么用
  12. [课程相关]homework-03
  13. 【VM】—VM安装包
  14. 半夜偷看“不良网站”,删除历史记录也没用,“坏影响”已悄然发生
  15. 编译原理 --- 递归下降分析器
  16. 千月影视全新改版影视app系统-支持投屏-二开美化版
  17. 程序设计基础流程图以及进制的介绍
  18. 管理后台界面基本框架设计
  19. USB PD快充协议详解(待续)
  20. 处理告警“ warning #69-D integer conversion resulted in truncation”的方法

热门文章

  1. 汽车系统升级更新,诺威达k2201升级包,解决系统卡顿问题
  2. 推荐一些实用的电脑应用
  3. 给ThinkPad插上无线的翅膀 安装EM7430全网通无线上网卡(图文)
  4. 如何转载CNSD博客
  5. 键盘分类键盘适合编程_最佳编程键盘
  6. java 投影转换算法_影像快速投影变换算法
  7. 速来围观:大佬们学习Spring的方式
  8. xbrl 数据比较分析_思考XML,使用XBRL分析财务报告
  9. 团队项目-第16周汇报-工程过程
  10. linux centos7 rhel7 虚拟机中怎么挂在卸载 光盘镜像 U盘