PyTorch TripletMarginLoss(三元损失)
文章目录
- triplet loss
- triplet hard loss
triplet loss
官方文档:
torch.nn — PyTorch master documentation
关于三元损失,出自论文:
FaceNet: A Unified Embedding for Face Recognition and Clustering
FaceNet: A Unified Embedding for Face Recognition and Clustering(论文阅读笔记)
三元损失的介绍很多,本站上搜一下就可以找到,比如:
Triplet Loss 和 Center Loss详解和pytorch实现
Triplet-Loss原理及其实现、应用
看下图:
- 训练集中随机选取一个样本:Anchor(a)
- 再随机选取一个和Anchor属于同一类的样本:Positive(p)
- 再随机选取一个和Anchor属于不同类的样本:Negative(n)
这样<a, p, n>就构成了一个三元组。
学习目标是让Positive和Anchor之间的距离 D ( a , p ) D(a,p) D(a,p) 尽可能的小,Negative和Anchor之间的距离 D ( a , n ) D(a,n) D(a,n) 尽可能的大:
∥ f ( x i a ) − f ( x i p ) ∥ 2 2 + α < ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 (1) \left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}+\alpha<\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2} \tag{1} ∥f(xia)−f(xip)∥22+α<∥f(xia)−f(xin)∥22(1)
∀ ( f ( x i a ) , f ( x i p ) , f ( x i n ) ) ∈ T \forall\left(f\left(x_{i}^{a}\right), f\left(x_{i}^{p}\right), f\left(x_{i}^{n}\right)\right) \in \mathcal{T} ∀(f(xia),f(xip),f(xin))∈T
优化目标:
L = ∑ i N [ ∥ f ( x i a ) − f ( x i p ) ∥ 2 2 − ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 + α ] + (2) L = \sum_{i}^{N}\left[\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}-\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2}+\alpha\right]_{+} \tag{2} L=i∑N[∥f(xia)−f(xip)∥22−∥f(xia)−f(xin)∥22+α]+(2)
距离用欧式距离度量, + + +表示[ ∗ ∗ ∗ *** ∗∗∗ ]内的值大于零的时候,取该值为损失值,而[ ∗ ∗ ∗ *** ∗∗∗ ]内的值小于零的时候,损失值则为零。也可以这么表示:
L = max ( D ( a , p ) − D ( a , n ) + α , 0 ) (3) L=\max (D(a, p)-D(a, n)+\alpha, 0) \tag{3} L=max(D(a,p)−D(a,n)+α,0)(3)
其中 α 迫使positive pairs (a, p) 和 negative pairs (a, n) 之间有一个margin(α)。 T \mathcal{T} T是训练集中所有可能的三元组的集合。
关于三元组,可以分为:
easy triplets
: L = 0 L = 0 L=0 的情况(不产生loss), D ( a , p ) + α < D ( a , n ) D(a, p)+\alpha<D(a, n) D(a,p)+α<D(a,n),类内距离小,类间距离大,显然无需优化。hard triplets
: D ( a , n ) < D ( a , p ) D(a, n)<D(a, p) D(a,n)<D(a,p),类间距离比类内距离还要小,较难优化,是重点照顾对象。semi-hard triplets
: D ( a , p ) < D ( a , n ) < D ( a , p ) + α D(a, p)<D(a, n)<D(a, p) + \alpha D(a,p)<D(a,n)<D(a,p)+α,类内距离和类间距离很接近,但是存在一个margin
(α),比较容易优化。
更多内容可以看这儿Triplet-Loss原理及其实现、应用
PyTorch中的Triplet-Loss接口:
CLASS torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean')
参数:
margin
(float) – 默认为1p
(int) – norm degree,默认为2swap
(bool) – The distance swap is described in detail in the paperLearning shallow convolutional feature descriptors with triplet losses
by V. Balntas, E. Riba et al. 默认为Falsesize_average
(bool) – Deprecatedreduce
(bool) – Deprecatedreduction
(string) – 指定返回各损失值(none),批损失均值(mean),批损失和(sum),默认返回批损失均值(mean)
使用示例:
输入tensor的尺寸:(N, D),N为批量大小,D为张量维度
输出:为标量, 如果reduction
为 ‘none’,则shape为(N),即N个标量;否则为1个标量
anchor = torch.randn(20, 20, requires_grad=True)
positive = torch.randn(20, 20, requires_grad=True)
negative = torch.randn(20, 20, requires_grad=True)torch.nn.functional.triplet_margin_loss(anchor, positive, negative,reduction='none')
>>>
tensor([1.0158, 0.0975, 2.1613, 1.4658, 0.7332, 1.5604, 1.0034, 0.3777, 0.1616,0.7618, 0.9989, 0.0000, 3.4407, 1.0938, 0.3333, 0.0000, 0.0000, 0.4422,1.1857, 1.7083], grad_fn=<ClampMinBackward>)torch.nn.functional.triplet_margin_loss(anchor, positive, negative,reduction='mean')
>>>
tensor(0.9271, grad_fn=<MeanBackward0>)
# 官方例子
triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2)
anchor = torch.randn(20, 20, requires_grad=True)
positive = torch.randn(20, 20, requires_grad=True)
negative = torch.randn(20, 20, requires_grad=True)output = triplet_loss(anchor, positive, negative)
output.backward()
triplet hard loss
我们再回过头来看(1)式的优化函数:
∥ f ( x i a ) − f ( x i p ) ∥ 2 2 + α < ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 \left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}+\alpha<\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2} ∥f(xia)−f(xip)∥22+α<∥f(xia)−f(xin)∥22
∀ ( f ( x i a ) , f ( x i p ) , f ( x i n ) ) ∈ T \forall\left(f\left(x_{i}^{a}\right), f\left(x_{i}^{p}\right), f\left(x_{i}^{n}\right)\right) \in \mathcal{T} ∀(f(xia),f(xip),f(xin))∈T
这个约束条件需要在所有的三元组上面都成立,但是如果严格按照这个约束,那么三元组集合 T \mathcal{T} T可能会相当大,需要穷举所有的三元组:
【深度学习论文笔记】FaceNet: A Unified Embedding for Face Recognition and Clustering
在1000个人,每人有20张图片的情况下, T = 1000 ∗ 20 ∗ 20 ∗ 999 \mathcal{T} = 1000*20*20*999 T=1000∗20∗20∗999,也即 O ( T ) = N 2 O(T) = N^2 O(T)=N2,显然穷举不太现实,所以常用的办法就是选取部分进行训练,也就是选取困难样本对(hard triplets
)进行训练。
(可以这么想, T \mathcal{T} T包含许多easy triplets
(满足(1)式的约束),这些easy triplets
对训练helpless,而且会使收敛更慢,因为它们仍然需要前向计算。所以需要选择hard triplets
)
那么hard triplets怎么选?
给定一张人脸图片(Anchor):
- 挑选一个hard positive:另外19张图像中,跟它最不相似的图片
argmax x i p ∥ f ( x i a ) − f ( x i p ) ∥ 2 2 \operatorname{argmax}_{x_{i}^{p}}\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2} argmaxxip∥f(xia)−f(xip)∥22 - 挑选一个hard negative:另外20*999张图像中,跟它最为相似的图片
argmin x i n ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 \operatorname{argmin}_{x_{i}^{n}}\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2} argminxin∥f(xia)−f(xin)∥22
而挑选方法也有两种:offline和online
这儿介绍实际采用的online方法:通过在一个mini-batch中选择hard positive/negative 样本来实现。具体的解释可以参照论文以及参考文档。
下面贴一个PyTorch的triplet hard loss
实现():
关于代码的解析可以看pytorch triphard代码理解
class TripletLoss(nn.Module):"""Triplet loss with hard positive/negative mining.Reference:Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.Args:margin (float, optional): margin for triplet. Default is 0.3."""def __init__(self, margin=0.3,global_feat, labels):super(TripletLoss, self).__init__()self.margin = marginself.ranking_loss = nn.MarginRankingLoss(margin=margin)def forward(self, inputs, targets):"""Args:inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).targets (torch.LongTensor): ground truth labels with shape (num_classes)."""n = inputs.size(0) # batch_size# Compute pairwise distance, replace by the official when mergeddist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)dist = dist + dist.t()dist.addmm_(1, -2, inputs, inputs.t())dist = dist.clamp(min=1e-12).sqrt() # for numerical stability# For each anchor, find the hardest positive and negativemask = targets.expand(n, n).eq(targets.expand(n, n).t())dist_ap, dist_an = [], []for i in range(n):dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))dist_ap = torch.cat(dist_ap)dist_an = torch.cat(dist_an)# Compute ranking hinge lossy = torch.ones_like(dist_an)loss = self.ranking_loss(dist_an, dist_ap, y)return loss
PyTorch TripletMarginLoss(三元损失)相关推荐
- 【Pytorch神经网络理论篇】 35 GaitSet模型:步态识别思路+水平金字塔池化+三元损失
代码: [Pytorch神经网络实战案例]28 GitSet模型进行步态与身份识别(CASIA-B数据集)_LiBiGor的博客-CSDN博客1 CASIA-B数据集本例使用的是预处理后的CASIA- ...
- 三元损失“In Defense of the Triplet Loss for Person Re-Identification”
更全面的阅读记录可以参考这篇博客:https://blog.csdn.net/xuluohongshang/article/details/78965580 背景描述 提出了一个三元损失的变形用于行人 ...
- 不在pytorch中的损失的函数
https://zhuanlan.zhihu.com/p/60747096 https://blog.csdn.net/fuwenyan/article/details/79657738
- Pytorch 加权BCE损失
bce_loss = nn.BCELoss(reduction='none') # 默认选项是mean, 设置为none后会返回一个和target一样尺寸的tensor, 每个位置的数字对应网络输出和 ...
- 深度度量学习(DML)中pair-based方法中的loss
文章目录 前言 一.Constrative loss[1] 二.Triplet loss[2] Offline and online triplet mining 参考 三.Lifted Struct ...
- pytorch深度学习和入门实战(四)神经网络的构建和训练
目录 1.前言 2.神经网络概述 2.1 核心组件包括: 2.2 核心过程 3.构建神经网络模型 3.1构建网络层(Layer ➨ Model) 3.2 torch.nn.Sequential的3大使 ...
- 【Pytorch神经网络实战案例】29 【代码汇总】GitSet模型进行步态与身份识别(CASIA-B数据集)
1 GaitSet_DataLoader.py import numpy as np # 引入基础库 import os import torch.utils.data as tordata from ...
- 【Pytorch神经网络实战案例】28 GitSet模型进行步态与身份识别(CASIA-B数据集)
1 CASIA-B数据集 本例使用的是预处理后的CASIA-B数据集, 数据集下载网址如下. http://www.cbsr.ia.ac.cn/china/Gait%20Databases%20cH. ...
- 使用PyTorch进行手写数字识别,在20 k参数中获得99.5%的精度。
In this article we'll build a simple convolutional neural network in PyTorch and train it to recogni ...
- 深度篇——人脸识别(一) ArcFace 论文 翻译
返回主目录 返回 人脸识别 目录 下一章:深度篇--人脸识别(二) 人脸识别代码 insight_face_pro 项目讲解 目录内容: 深度篇--人脸识别(一) ArcFace 论文 翻译 深度篇- ...
最新文章
- C++/C++11中用于定义类型别名的两种方法:typedef和using
- eclipse编写wordcount提交spark运行
- toj 4613 Number of Battlefields
- 3D MRI brain tumor segmentation using autoencoder regularization
- pyspider—爬取下载图片
- Transformer在图像复原领域的降维打击!ETH提出SwinIR:各项任务全面领先
- 如何防止社工钓鱼——软件伪造
- 慎用 JSON.stringify
- 电脑wifi距离测试软件,wifi测速工具
- linux下解压bin文件怎么打开方式,安卓手机如何打开.bin文件?
- 最好的网盘--主流网盘大比拼
- ccy测试影响因子版270ms
- JAVA练习——集合练习题(HashSet,TreeSet)产生随机数不能重复,去掉重复元素,将集合中重复元素去掉,字符串倒序输出,倒序输出整数,倒序排列对象
- 晚安西南-----地破实验
- 无网络环境安装docker
- java商城加入购物车接口实现_商城系统购物车功能分析实现
- java毕业设计坝上长尾鸡养殖管理系统Mybatis+系统+数据库+调试部署
- Sentinel-高可用流量管理框架
- 拓嘉恒业:拼多多开店条件分享
- Java数据结构之中缀表达式转后缀表达式
热门文章
- 怎样使用车载信息服务器,云平台在车载信息上的应用
- Qt pro 文件中路径设置 生成可执行文件路径
- 强烈推荐:一文洞悉Python必备50种算法
- 爬虫入门实战(标价400的单子-1)
- 电子科技大学成都学院 计算机考试,2018年电子科技大学成都学院高职单招考试数学真题及答案...
- Vue中computed多个值互相计算,Vue多个字段之间互相计算,vue多个值互相引用计算
- 售卖机控制板开发,轻松实现线下售卖和线上运营
- 关于前端应用表现层抽象--学习笔记
- airplay协议简述
- 山科c语言简单数值计算方法,山科大c语言编程.doc