文章目录

  • triplet loss
  • triplet hard loss

triplet loss

官方文档:
torch.nn — PyTorch master documentation

关于三元损失,出自论文:

FaceNet: A Unified Embedding for Face Recognition and Clustering
FaceNet: A Unified Embedding for Face Recognition and Clustering(论文阅读笔记)

三元损失的介绍很多,本站上搜一下就可以找到,比如:

Triplet Loss 和 Center Loss详解和pytorch实现
Triplet-Loss原理及其实现、应用

看下图:

  • 训练集中随机选取一个样本:Anchor(a)
  • 再随机选取一个和Anchor属于同一类的样本:Positive(p)
  • 再随机选取一个和Anchor属于不同类的样本:Negative(n)

这样<a, p, n>就构成了一个三元组。

学习目标是让Positive和Anchor之间的距离 D ( a , p ) D(a,p) D(a,p) 尽可能的小,Negative和Anchor之间的距离 D ( a , n ) D(a,n) D(a,n) 尽可能的大:

∥ f ( x i a ) − f ( x i p ) ∥ 2 2 + α < ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 (1) \left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}+\alpha<\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2} \tag{1} f(xia)f(xip)22+α<f(xia)f(xin)22(1)

∀ ( f ( x i a ) , f ( x i p ) , f ( x i n ) ) ∈ T \forall\left(f\left(x_{i}^{a}\right), f\left(x_{i}^{p}\right), f\left(x_{i}^{n}\right)\right) \in \mathcal{T} (f(xia),f(xip),f(xin))T

优化目标:

L = ∑ i N [ ∥ f ( x i a ) − f ( x i p ) ∥ 2 2 − ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 + α ] + (2) L = \sum_{i}^{N}\left[\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}-\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2}+\alpha\right]_{+} \tag{2} L=iN[f(xia)f(xip)22f(xia)f(xin)22+α]+(2)

距离用欧式距离度量, + + +表示[ ∗ ∗ ∗ *** ]内的值大于零的时候,取该值为损失值,而[ ∗ ∗ ∗ *** ]内的值小于零的时候,损失值则为零。也可以这么表示:

L = max ⁡ ( D ( a , p ) − D ( a , n ) + α , 0 ) (3) L=\max (D(a, p)-D(a, n)+\alpha, 0) \tag{3} L=max(D(a,p)D(a,n)+α,0)(3)

其中 α 迫使positive pairs (a, p) 和 negative pairs (a, n) 之间有一个margin(α)。 T \mathcal{T} T是训练集中所有可能的三元组的集合。

关于三元组,可以分为:

  • easy tripletsL = 0 L = 0 L=0 的情况(不产生loss), D ( a , p ) + α < D ( a , n ) D(a, p)+\alpha<D(a, n) D(a,p)+α<D(a,n),类内距离小,类间距离大,显然无需优化。
  • hard tripletsD ( a , n ) < D ( a , p ) D(a, n)<D(a, p) D(a,n)<D(a,p),类间距离比类内距离还要小,较难优化,是重点照顾对象。
  • semi-hard tripletsD ( a , p ) < D ( a , n ) < D ( a , p ) + α D(a, p)<D(a, n)<D(a, p) + \alpha D(a,p)<D(a,n)<D(a,p)+α,类内距离和类间距离很接近,但是存在一个margin(α),比较容易优化。


更多内容可以看这儿Triplet-Loss原理及其实现、应用

PyTorch中的Triplet-Loss接口:

CLASS torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean')

参数:

  • margin (float) – 默认为1
  • p (int) – norm degree,默认为2
  • swap (bool) – The distance swap is described in detail in the paper Learning shallow convolutional feature descriptors with triplet losses by V. Balntas, E. Riba et al. 默认为False
  • size_average (bool) – Deprecated
  • reduce (bool) – Deprecated
  • reduction (string) – 指定返回各损失值(none),批损失均值(mean),批损失和(sum),默认返回批损失均值(mean)

使用示例:
输入tensor的尺寸:(N, D),N为批量大小,D为张量维度
输出:为标量, 如果reduction为 ‘none’,则shape为(N),即N个标量;否则为1个标量

anchor = torch.randn(20, 20, requires_grad=True)
positive = torch.randn(20, 20, requires_grad=True)
negative = torch.randn(20, 20, requires_grad=True)torch.nn.functional.triplet_margin_loss(anchor, positive, negative,reduction='none')
>>>
tensor([1.0158, 0.0975, 2.1613, 1.4658, 0.7332, 1.5604, 1.0034, 0.3777, 0.1616,0.7618, 0.9989, 0.0000, 3.4407, 1.0938, 0.3333, 0.0000, 0.0000, 0.4422,1.1857, 1.7083], grad_fn=<ClampMinBackward>)torch.nn.functional.triplet_margin_loss(anchor, positive, negative,reduction='mean')
>>>
tensor(0.9271, grad_fn=<MeanBackward0>)
# 官方例子
triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2)
anchor = torch.randn(20, 20, requires_grad=True)
positive = torch.randn(20, 20, requires_grad=True)
negative = torch.randn(20, 20, requires_grad=True)output = triplet_loss(anchor, positive, negative)
output.backward()

triplet hard loss

我们再回过头来看(1)式的优化函数:
∥ f ( x i a ) − f ( x i p ) ∥ 2 2 + α < ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 \left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}+\alpha<\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2} f(xia)f(xip)22+α<f(xia)f(xin)22

∀ ( f ( x i a ) , f ( x i p ) , f ( x i n ) ) ∈ T \forall\left(f\left(x_{i}^{a}\right), f\left(x_{i}^{p}\right), f\left(x_{i}^{n}\right)\right) \in \mathcal{T} (f(xia),f(xip),f(xin))T
这个约束条件需要在所有的三元组上面都成立,但是如果严格按照这个约束,那么三元组集合 T \mathcal{T} T可能会相当大,需要穷举所有的三元组:

【深度学习论文笔记】FaceNet: A Unified Embedding for Face Recognition and Clustering
在1000个人,每人有20张图片的情况下, T = 1000 ∗ 20 ∗ 20 ∗ 999 \mathcal{T} = 1000*20*20*999 T=10002020999,也即 O ( T ) = N 2 O(T) = N^2 O(T)=N2,显然穷举不太现实,所以常用的办法就是选取部分进行训练,也就是选取困难样本对hard triplets)进行训练。

(可以这么想, T \mathcal{T} T包含许多easy triplets(满足(1)式的约束),这些easy triplets对训练helpless,而且会使收敛更慢,因为它们仍然需要前向计算。所以需要选择hard triplets

那么hard triplets怎么选?

给定一张人脸图片(Anchor):

  • 挑选一个hard positive:另外19张图像中,跟它最不相似的图片
    argmax ⁡ x i p ∥ f ( x i a ) − f ( x i p ) ∥ 2 2 \operatorname{argmax}_{x_{i}^{p}}\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2} argmaxxipf(xia)f(xip)22
  • 挑选一个hard negative:另外20*999张图像中,跟它最为相似的图片
    argmin ⁡ x i n ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 \operatorname{argmin}_{x_{i}^{n}}\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2} argminxinf(xia)f(xin)22

而挑选方法也有两种:offlineonline

这儿介绍实际采用的online方法:通过在一个mini-batch中选择hard positive/negative 样本来实现。具体的解释可以参照论文以及参考文档。

下面贴一个PyTorch的triplet hard loss实现():
关于代码的解析可以看pytorch triphard代码理解


class TripletLoss(nn.Module):"""Triplet loss with hard positive/negative mining.Reference:Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.Args:margin (float, optional): margin for triplet. Default is 0.3."""def __init__(self, margin=0.3,global_feat, labels):super(TripletLoss, self).__init__()self.margin = marginself.ranking_loss = nn.MarginRankingLoss(margin=margin)def forward(self, inputs, targets):"""Args:inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).targets (torch.LongTensor): ground truth labels with shape (num_classes)."""n = inputs.size(0)  # batch_size# Compute pairwise distance, replace by the official when mergeddist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)dist = dist + dist.t()dist.addmm_(1, -2, inputs, inputs.t())dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability# For each anchor, find the hardest positive and negativemask = targets.expand(n, n).eq(targets.expand(n, n).t())dist_ap, dist_an = [], []for i in range(n):dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))dist_ap = torch.cat(dist_ap)dist_an = torch.cat(dist_an)# Compute ranking hinge lossy = torch.ones_like(dist_an)loss = self.ranking_loss(dist_an, dist_ap, y)return loss

PyTorch TripletMarginLoss(三元损失)相关推荐

  1. 【Pytorch神经网络理论篇】 35 GaitSet模型:步态识别思路+水平金字塔池化+三元损失

    代码: [Pytorch神经网络实战案例]28 GitSet模型进行步态与身份识别(CASIA-B数据集)_LiBiGor的博客-CSDN博客1 CASIA-B数据集本例使用的是预处理后的CASIA- ...

  2. 三元损失“In Defense of the Triplet Loss for Person Re-Identification”

    更全面的阅读记录可以参考这篇博客:https://blog.csdn.net/xuluohongshang/article/details/78965580 背景描述 提出了一个三元损失的变形用于行人 ...

  3. 不在pytorch中的损失的函数

    https://zhuanlan.zhihu.com/p/60747096 https://blog.csdn.net/fuwenyan/article/details/79657738

  4. Pytorch 加权BCE损失

    bce_loss = nn.BCELoss(reduction='none') # 默认选项是mean, 设置为none后会返回一个和target一样尺寸的tensor, 每个位置的数字对应网络输出和 ...

  5. 深度度量学习(DML)中pair-based方法中的loss

    文章目录 前言 一.Constrative loss[1] 二.Triplet loss[2] Offline and online triplet mining 参考 三.Lifted Struct ...

  6. pytorch深度学习和入门实战(四)神经网络的构建和训练

    目录 1.前言 2.神经网络概述 2.1 核心组件包括: 2.2 核心过程 3.构建神经网络模型 3.1构建网络层(Layer ➨ Model) 3.2 torch.nn.Sequential的3大使 ...

  7. 【Pytorch神经网络实战案例】29 【代码汇总】GitSet模型进行步态与身份识别(CASIA-B数据集)

    1 GaitSet_DataLoader.py import numpy as np # 引入基础库 import os import torch.utils.data as tordata from ...

  8. 【Pytorch神经网络实战案例】28 GitSet模型进行步态与身份识别(CASIA-B数据集)

    1 CASIA-B数据集 本例使用的是预处理后的CASIA-B数据集, 数据集下载网址如下. http://www.cbsr.ia.ac.cn/china/Gait%20Databases%20cH. ...

  9. 使用PyTorch进行手写数字识别,在20 k参数中获得99.5%的精度。

    In this article we'll build a simple convolutional neural network in PyTorch and train it to recogni ...

  10. 深度篇——人脸识别(一)  ArcFace 论文 翻译

    返回主目录 返回 人脸识别 目录 下一章:深度篇--人脸识别(二) 人脸识别代码 insight_face_pro 项目讲解 目录内容: 深度篇--人脸识别(一) ArcFace 论文 翻译 深度篇- ...

最新文章

  1. C++/C++11中用于定义类型别名的两种方法:typedef和using
  2. eclipse编写wordcount提交spark运行
  3. toj 4613 Number of Battlefields
  4. 3D MRI brain tumor segmentation using autoencoder regularization
  5. pyspider—爬取下载图片
  6. Transformer在图像复原领域的降维打击!ETH提出SwinIR:各项任务全面领先
  7. 如何防止社工钓鱼——软件伪造
  8. 慎用 JSON.stringify
  9. 电脑wifi距离测试软件,wifi测速工具
  10. linux下解压bin文件怎么打开方式,安卓手机如何打开.bin文件?
  11. 最好的网盘--主流网盘大比拼
  12. ccy测试影响因子版270ms
  13. JAVA练习——集合练习题(HashSet,TreeSet)产生随机数不能重复,去掉重复元素,将集合中重复元素去掉,字符串倒序输出,倒序输出整数,倒序排列对象
  14. 晚安西南-----地破实验
  15. 无网络环境安装docker
  16. java商城加入购物车接口实现_商城系统购物车功能分析实现
  17. java毕业设计坝上长尾鸡养殖管理系统Mybatis+系统+数据库+调试部署
  18. Sentinel-高可用流量管理框架
  19. 拓嘉恒业:拼多多开店条件分享
  20. Java数据结构之中缀表达式转后缀表达式

热门文章

  1. 怎样使用车载信息服务器,云平台在车载信息上的应用
  2. Qt pro 文件中路径设置 生成可执行文件路径
  3. 强烈推荐:一文洞悉Python必备50种算法
  4. 爬虫入门实战(标价400的单子-1)
  5. 电子科技大学成都学院 计算机考试,2018年电子科技大学成都学院高职单招考试数学真题及答案...
  6. Vue中computed多个值互相计算,Vue多个字段之间互相计算,vue多个值互相引用计算
  7. 售卖机控制板开发,轻松实现线下售卖和线上运营
  8. 关于前端应用表现层抽象--学习笔记
  9. airplay协议简述
  10. 山科c语言简单数值计算方法,山科大c语言编程.doc