Triplet Loss的动机

一个好的特征提取器,应该尽可能的做到同类别样本映射出来的特征会聚集在一起,而不同类别的样本映射出来的特征应该要相互远离。

为了达到这个目标,Triplet Loss显式的在Loss里面要求:不同类别之间的距离至少要超过同类别之间距离的某个阈值。如果能够做到这一点,那么类内距和类间距之间差就有一个明显的鸿沟,那么也可以达到上面提到的目标。

Triplet Loss的定义

Triplet Loss里面包含若干三元组:

  • 锚点 anchor
  • 正例 positive
  • 负例 negative

要求:锚点和正例是处于相同的类别,锚点和负例处于不同的类别。

a,p,n都不是原始样本,而是原始样本被神经网络做特征提取后的得到的特征向量。即:a=f(xa),b=f(xb),c=f(xc)a=f(x_a), b=f(x_b), c=f(x_c)a=f(xa​),b=f(xb​),c=f(xc​)。f(⋅)f(·)f(⋅)是神经网络特征提取器。

对于一个三元组triplet (a,p,n),它的triplet loss写作:
L=max(d(a,p)−d(a,n)+margin,0)L=max(d(a,p)- d(a,n)+margin, 0)L=max(d(a,p)−d(a,n)+margin,0),其中d(x,y)d(x,y)d(x,y) 是自定义的距离函数。

这个东西写的还是很直观的,它想表达的意思为:

  1. 如果 d(a,p)−d(a,n)+margin>0d(a,p)- d(a,n)+margin>0d(a,p)−d(a,n)+margin>0,那么loss就是 d(a,p)−d(a,n)+margind(a,p)- d(a,n)+margind(a,p)−d(a,n)+margin,否则就是0。0的时候就没有产生实际loss,就不会有梯度,意味着模型无需优化。
  2. 当 d(a,p)−d(a,n)+margin>0d(a,p)- d(a,n)+margin>0d(a,p)−d(a,n)+margin>0,有 d(a,n)−d(a,p)<margind(a,n) - d(a,p) < margind(a,n)−d(a,p)<margin,此时锚点和负例之间的距离和锚点与正例之间的距离之差还没有超过阈值,于是就要会产生LOSS。
  3. 又因为优化的目标是让loss越小越好,于是模型就会千方百计的优化fff,使得 d(a,p)−d(a,n)+margind(a,p)−d(a,n)+margind(a,p)−d(a,n)+margin 越小越好,直到d(a,p)−d(a,n)+margind(a,p)−d(a,n)+margind(a,p)−d(a,n)+margin 小于等于0,就不优化了。

Triplet Loss在pytorch里面的实现

__author__ = 'dk'
import torch
import torch as th
from torch.nn import functional as F
from torch import nnclass TripletLoss(nn.Module):"""Triplet loss with hard positive/negative mining.Reference:Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.Args:margin (float, optional): margin for triplet. Default is 0.3."""def __init__(self, margin=0.3, batch_size=128, view_num=3):super(TripletLoss, self).__init__()self.margin = marginself.ranking_loss = nn.MarginRankingLoss(margin=margin)def forward(self, inputs,targets = None):"""Args:inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).targets (torch.LongTensor): ground truth labels with shape (num_classes)."""if targets == None:targets = self.targetsn = inputs.size(0)# Compute pairwise distance, replace by the official when mergeddist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)dist = dist + dist.t()dist.addmm_(1, -2, inputs, inputs.t())dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability# For each anchor, find the hardest positive and negativemask = targets.expand(n, n).eq(targets.expand(n, n).t())dist_ap, dist_an = [], []for i in range(n):dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))dist_ap = torch.cat(dist_ap)dist_an = torch.cat(dist_an)# Compute ranking hinge lossy = torch.ones_like(dist_an)return self.ranking_loss(dist_an, dist_ap, y)

这是我在网上随便找到的衣服triplet loss实现,这个是基于l2范数实现的,写的比较隐蔽。
这个triplet loss是接受n个样本的特征作为inputs,然后返回最难优化的hard example。

我们解析这几句关键的,为了方便起见,记Inputs为VVV,里面的第i样本为 viv_ivi​。

        # Compute pairwise distance, replace by the official when mergeddist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)

计算输入的各个样本自己的l2范数。

        dist = dist + dist.t()

将每个样本自己的l2范数加上别人的l2范数。此时dist是个nxn的矩阵,dist[i,j]表示第i样本与第j样本的l2范数之和,也就是:dist[i,j]=vi2+vj2dist[i,j]=v^2_i + v^2_jdist[i,j]=vi2​+vj2​

        dist.addmm_(1, -2, inputs, inputs.t())

这个式子展开就是:A=dist−2×VVTA=dist -2 \times VV^TA=dist−2×VVT
于是A[i,j]=vi2+vj2−2vivj=(vi−vj)2A[i,j]=v^2_i+v^2_j-2v_iv_j=(v_i-v_j)^2A[i,j]=vi2​+vj2​−2vi​vj​=(vi​−vj​)2,妙啊。这就是等价于想把每个样本相互减,然后计算差值向量的l2范数。

        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability

上面这句话就是在开方,这没有啥好说的。

        # For each anchor, find the hardest positive and negativemask = targets.expand(n, n).eq(targets.expand(n, n).t())

上面这句话的作用是:mask[i,j]mask[i,j]mask[i,j]表示i个样本和第j个样本的label是否相同,相同为True,不同为false, 注意这是一个bool的tensor矩阵。

        dist_ap, dist_an = [], []for i in range(n):dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))

dist[i][mask[i]] 表示把与第i样本具有相同标签那些样本的dist拿出来,max()是取最大的那个。
dist[i][mask[i] == 0].min()表示把与第i个样本不同标签的样本的dist拿出来,取距离最小的那个。

于是dist_ap保存了那些距离各自锚点最远的同类标签。
dist_an保存距离各自锚点最近的异类标签。这类样本都是难分的。这其实是个hardest triplet loss。

        # Compute ranking hinge lossy = torch.ones_like(dist_an)return self.ranking_loss(dist_an, dist_ap, y)

这个ranking_loss的公式为 loss(x1,x2,y)=max⁡(0,−y∗(x1−x2)+margin)\text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})loss(x1,x2,y)=max(0,−y∗(x1−x2)+margin)
代入进去就是:loss(dist_an,dist_ap,1)=max⁡(0,dist_ap−dist_an+margin)\text{loss}(dist\_{an}, dist\_{ap}, \bold{1}) = \max(0, dist\_ap-dist\_an + \text{margin})loss(dist_an,dist_ap,1)=max(0,dist_ap−dist_an+margin)

更通用的实现

那如果我们想实现一个可以选择LpL_pLp​ 范数也不是硬编码为2范数的triplet loss如何实现呢?

__author__ = 'dk'
import torch
import torch as th
from torch.nn import functional as F
from torch import nnclass TripletLoss(nn.Module):"""Triplet loss with hard positive/negative mining.Reference:Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.Args:margin (float, optional): margin for triplet. Default is 0.3."""def __init__(self, margin=0.3, batch_size=128, view_num=3, p=2):super(TripletLoss, self).__init__()self.margin = marginself.p = pself.ranking_loss = nn.MarginRankingLoss(margin=margin)self.targets =  torch.cat([torch.arange(batch_size) for i in range(view_num)], dim=0)def forward(self, inputs,targets = None):"""Args:inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).targets (torch.LongTensor): ground truth labels with shape (num_classes)."""if targets == None:targets = self.targetsn = inputs.size(0)# Compute pairwise distance, replace by the official when mergeddist = []for i in range(n):dist.append(inputs[i] - inputs)dist = torch.stack(dist)dist = torch.linalg.norm(dist,ord=self.p,dim=2)# For each anchor, find the hardest positive and negativemask = targets.expand(n, n).eq(targets.expand(n, n).t())dist_ap, dist_an = [], []for i in range(n):dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))dist_ap = torch.cat(dist_ap)dist_an = torch.cat(dist_an)# Compute ranking hinge lossy = torch.ones_like(dist_an)return self.ranking_loss(dist_an, dist_ap, y)

两种方法用时测试:


if __name__ == '__main__':import time,tqdminputs = torch.randn(128,256)step = 1000tripletloss1 = TripletLoss(batch_size=64,view_num=2)loss1 = 0from SSL.triplet_loss import TripletLoss as TripletLoss2tripletloss2 = TripletLoss2(batch_size=64,view_num=2)loss2 = 0s = time.time()for i in tqdm.trange(step):loss1 += tripletloss1(inputs)e1=time.time()for i in tqdm.trange(step):loss2 += tripletloss2(inputs)e2 =time.time()print('1: {0}s, result:{1}'.format(e1-s, loss1/ step))print('1: {0}s, result:{1}'.format(e2-e1, loss2/ step))

tripletloss1就是网上的方法,tripletloss2是我们实现的。结果:

方法1: 52.5317645072937s, result:2.367180824279785
方法2: 15.61170482635498s, result:2.367180824279785

可以发现我们的方法更快,究其原因,方法1里面矩阵的乘法及其耗时。

Triplet Loss原理及实现相关推荐

  1. 三元组损失 Triplet Loss及其梯度

    Triplet Loss及其梯度 Triplet Loss及其梯度_jcjx0315的博客-CSDN博客 Triplet Loss简介 我这里将Triplet Loss翻译为三元组损失,其中的三元也就 ...

  2. triplet loss损失函数

    知识点来源于知乎链接 triplet loss的原理 损失函数的公式L=max(d(a,p)-d(a,n)+margin,0) a: anchor,p: positive, 与a是同一类别的样本:n: ...

  3. 一文理解Ranking Loss/Margin Loss/Triplet Loss

    点击蓝字  关注我们 作者丨土豆@知乎 来源丨https://zhuanlan.zhihu.com/p/158853633 本文已获授权,未经作者许可,不得二次转载. 前言 Ranking loss在 ...

  4. Person Re-Identification by Multi-Channel Parts-Based CNN with Improved Triplet Loss Function

    作者:西安交大的De Cheng, Yihong Gong, Sanping Zhou, Jinjun Wang, Nanning Zheng 主要贡献: 贡献有两个,一个是改进的网络结构,一个是改进 ...

  5. 机器学习笔记:triplet loss

    1 Triplet loss Triplet Loss,即三元组损失,其中的三元是Anchor.Negative.Positive. 通过Triplet Loss的学习后使得Positive元和Anc ...

  6. CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss、Center Loss)简介、使用方法之详细攻略

    CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss.Center Loss)简介.使用方法之详细攻略 目录 T1.Triplet Loss 1.英文原文解释 ...

  7. triplet loss 在深度学习中主要应用在什么地方?有什么明显的优势?

    作者:罗浩.ZJU 链接:https://www.zhihu.com/question/62486208/answer/199117070 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权,非 ...

  8. 车辆搜索 -使用triplet loss 训练车辆识别模型

    最近读了LEARNING A REPRESSION NETWORK FOR PRECISE VEHICLE SEARCH 论文,将从中所了解的信息在此记录下来. 背景及模型介绍 此论文主要是讨论车辆的 ...

  9. 三元组损失(Triplet loss)

    来源:Coursera吴恩达深度学习课程 在人脸识别中,我们希望学习"输入两张人脸图片,然后输出相似度"的函数d,然后Siamese 网络(Siamese network)实现了这 ...

最新文章

  1. 关于eclpse java项目与tomcat jdk版本不一致的解决方法
  2. linux的grup文件,Linux /boot/grub/grub.conf(GRUB配置文件)内容详解
  3. docker yum php mysql_Centos下 使用Docker, 配置PHP+Nginx+Mysql(多PHP版本)
  4. Codeforces Round #220 (Div. 2)
  5. 计算机图形学方向投稿国外期刊
  6. 37 SD配置-销售凭证设置-分配项目类别
  7. 啥?不用安装Jre,SpringBoot项目也可以打包exe应用程序运行!
  8. Django 实现第三方账号登录网站
  9. 解决krpano全景视频在QQ浏览器、安卓不能正常播放的问题
  10. 32位oracle10,『三思笔记』-- Solaris10下安装32位Oracle10g -- Solaris 10下安装ORACLE10G
  11. 数据清洗有哪些方法?
  12. Ubuntu虚拟机如何与主机复制粘贴?
  13. c语言对数组取反,C语言中按逆取反是什么意思
  14. vue 中的el表达式_解释el页面数据表达式
  15. 加拿大Introspect I3C 协议分析仪(Analyzer)及训练器(Exerciser)
  16. Vue前端模板框架--vue-admin-template
  17. Velodyne 32E pcap包GPS时间戳解析
  18. Bootstrap 下拉菜单(Dropdown)插件
  19. 了解自动驾驶 从ADAS开始
  20. 在线图片格式转换为 psd png

热门文章

  1. 【后端】10进制与进制转换以及斐波那契数列第N位的JAVA小练习
  2. ORA-38706: Cannot turn on FLASHBACK DATABASE logging.ORA-38709: Recovery Area is not enabled.
  3. php中while的用法,PHP丨PHP基础知识之流程控制WHILE循环「理论篇」
  4. Google Sync手机在线同步工具使用指南
  5. java assembly_Java技术--maven的assembly插件打包(依赖包归档)
  6. tiktok如何运营
  7. python里clear和copy_python (集合和深浅拷贝)
  8. 使用灵曜内网穿透 免费实现外网访问内网Vue
  9. 剩余时间,倒计时毫秒时间戳转换为时间格式HH:mm:ss时间差计算
  10. JavaScript 矩形碰撞检测