文章目录

  • 1. 数据集处理
  • 2. 网络与损失函数
  • 3. 代码如下:

本来是想做检测图像的相似度的,偶然见到这篇文章。于是写下了这篇博文。
本文参考于: github and PyTorch 中文网人脸相似度对比

关于Siamese网络 请查看,或者查看 。就是两个共享参数的CNN。每次的输入是一对图像+1个label,共3个值。注意label=0或1(又称正负样本),表示输入的两张图片match(匹配、同一个人)或no-match(不匹配、非同一人)。 下图是Siamese基本结构

1. 数据集处理

数据采用的是AT&T人脸数据。共40个人,每个人有10张脸。
数据集下载:https://files.cnblogs.com/files/king-lps/att_faces.zip
解压后文件夹下共40个文件夹,每个文件夹里有10张pgm图片。

2. 网络与损失函数

显然该网络前向传播是两张图同时输入进行。
损失函数公式:m为容忍度,DwD_wDw​ 为两张图之间的欧式距离。
loss=(1−Y)12Dw2+(Y)12{max(0,m−Dw)}2\rm loss = (1-Y) \frac{1}{2}D_w^2 + (Y) \frac{1}{2}\{max(0,m-D_w)\}^2 loss=(1−Y)21​Dw2​+(Y)21​{max(0,m−Dw​)}2

Dw={Gw(X1)−Gw(X2)}2\rm D_w = \sqrt{\{G_w(X_1) - G_w(X_2)\}^2}Dw​={Gw​(X1​)−Gw​(X2​)}2​

3. 代码如下:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import PIL.ImageOps
import matplotlib.pyplot as pltclass Config():root = 'E:/pytorch_AI_learning/att_faces'txt_root = 'E:/pytorch_AI_learning/att_faces/train.txt'train_batch_size = 32train_number_epochs = 30# Helper functions
def imshow(img, text=None, should_save=False):npimg = img.numpy()plt.axis("off")if text:plt.text(75, 8, text, style='italic', fontweight='bold',bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10})plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()def show_plot(iteration, loss):plt.plot(iteration, loss)plt.show()def convert(train=True):if (train):f = open(Config.txt_root, 'w')data_path = Config.root + '/'if (not os.path.exists(data_path)):os.makedirs(data_path)for i in range(40):for j in range(10):img_path = data_path + 's' + str(i + 1) + '/' + str(j + 1) + '.pgm'f.write(img_path + ' ' + str(i) + '\n')f.close()convert(True)  # 生成train.txt文件# ready the dataset, Not use ImageFolder as the author did
class MyDataset(Dataset):def __init__(self, txt, transform=None, target_transform=None, should_invert=False):self.transform = transformself.target_transform = target_transformself.should_invert = should_invertself.txt = txtdef __getitem__(self, index):line = linecache.getline(self.txt, random.randint(1, self.__len__()))line.strip('\n')img0_list = line.split()should_get_same_class = random.randint(0, 1)if should_get_same_class:while True:img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()if img0_list[1] == img1_list[1]:breakelse:img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()img0 = Image.open(img0_list[0])img1 = Image.open(img1_list[0])img0 = img0.convert("L")img1 = img1.convert("L")if self.should_invert:img0 = PIL.ImageOps.invert(img0)img1 = PIL.ImageOps.invert(img1)if self.transform is not None:img0 = self.transform(img0)img1 = self.transform(img1)return img0, img1, torch.from_numpy(np.array([int(img1_list[1] != img0_list[1])], dtype=np.float32))def __len__(self):fh = open(self.txt, 'r')num = len(fh.readlines())fh.close()return num""" =======Visualising some of the data==========
train_data = MyDataset(txt = Config.txt_root, transform=transforms.ToTensor(),target_transform=transforms.Compose([transforms.Resize((100,100)),transforms.ToTensor()],))
train_loader = DataLoader(dataset=train_data, batch_size=8, shuffle=True)
it = iter(train_loader)
p1, p2, label = it.next()
example_batch = it.next()
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy())
"""# Neural Net Definition, Standard CNNs
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass SiameseNetwork(nn.Module):def __init__(self):super(SiameseNetwork, self).__init__()self.cnn1 = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(1, 4, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(4),nn.Dropout2d(p=.2),nn.ReflectionPad2d(1),nn.Conv2d(4, 8, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(8),nn.Dropout2d(p=.2),nn.ReflectionPad2d(1),nn.Conv2d(8, 8, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(8),nn.Dropout2d(p=.2),)self.fc1 = nn.Sequential(nn.Linear(8 * 100 * 100, 500),nn.ReLU(inplace=True),nn.Linear(500, 500),nn.ReLU(inplace=True),nn.Linear(500, 5))def forward_once(self, x):output = self.cnn1(x)output = output.view(output.size()[0], -1)output = self.fc1(output)return outputdef forward(self, input1, input2):output1 = self.forward_once(input1)output2 = self.forward_once(input2)return output1, output2# Custom Contrastive Loss
class ContrastiveLoss(torch.nn.Module):"""Contrastive loss function.Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf"""def __init__(self, margin=2.0):super(ContrastiveLoss, self).__init__()self.margin = margindef forward(self, output1, output2, label):euclidean_distance = F.pairwise_distance(output1, output2)loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))return loss_contrastive# Training
train_data = MyDataset(txt=Config.txt_root, transform=transforms.Compose([transforms.Resize((100, 100)), transforms.ToTensor()]), should_invert=False)
train_dataloader = DataLoader(dataset=train_data, shuffle=True, num_workers=2, batch_size=Config.train_batch_size)net = SiameseNetwork().cuda()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)counter = []
loss_history = []
iteration_number = 0if __name__ == '__main__':for epoch in range(0, Config.train_number_epochs):for i, data in enumerate(train_dataloader, 0):img0, img1, label = dataimg0, img1, label = Variable(img0).cuda(), Variable(img1).cuda(), Variable(label).cuda()output1, output2 = net(img0, img1)optimizer.zero_grad()loss_contrastive = criterion(output1, output2, label)loss_contrastive.backward()optimizer.step()if i % 10 == 0:print("Epoch:{},  Current loss {}\n".format(epoch, loss_contrastive.data.item()))iteration_number += 10counter.append(iteration_number)loss_history.append(loss_contrastive.data.item())show_plot(counter, loss_history)torch.save(net.state_dict(), './SiameseNet.pth')

PyTorch之—Siamese网络相关推荐

  1. pytorch空间变换网络

    pytorch空间变换网络 本文将学习如何使用称为空间变换器网络的视觉注意机制来扩充网络.可以在DeepMind paper 有关空间变换器网络的内容. 空间变换器网络是对任何空间变换的差异化关注的概 ...

  2. MPASNET:用于视频场景中无监督深度人群分割的运动先验感知SIAMESE网络

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 小白导读 论文是学术研究的精华和未来发展的明灯.小白决心每天为大家 ...

  3. siamese网络_CVPR 2019手写签名认证的逆鉴别网络

    点击我爱计算机视觉标星,更快获取CVML新技术 本文简要介绍CVPR2019论文"Inverse Discriminative Networks for Handwritten Signat ...

  4. 4.3 Siamese 网络-深度学习第四课《卷积神经网络》-Stanford吴恩达教授

    ←上一篇 ↓↑ 下一篇→ 4.2 One-Shot 学习 回到目录 4.4 Triplet 损失 Siamese 网络 (Siamese Network) 上个视频中你学到的函数 ddd 的作用就是输 ...

  5. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  6. pytorch贝叶斯网络_贝叶斯神经网络:2个在TensorFlow和Pytorch中完全连接

    pytorch贝叶斯网络 贝叶斯神经网络 (Bayesian Neural Net) This chapter continues the series on Bayesian deep learni ...

  7. Siamese网络(孪生神经网络)详解

    SiameseFC Siamese网络(孪生神经网络) 本文参考文章: Siamese背景 Siamese网络解决的问题 要解决什么问题? 用了什么方法解决? 应用的场景: Siamese的创新 Si ...

  8. 智慧交通day04-特定目标车辆追踪02:Siamese网络+单样本学习

    1.Siamese网络 Siamese network就是"连体的神经网络",神经网络的"连体"是通过共享权值来实现的,如下图所示.共享权值意味着两边的网络权重 ...

  9. 深度学习笔记(43) Siamese网络

    深度学习笔记(43) Siamese网络 1. Siamese网络 2. 建立人脸识别系统 3. 训练网络 1. Siamese网络 深度学习笔记(42) 人脸识别 提到的函数ddd的作用: 输入两张 ...

最新文章

  1. 使用UPnP来穿透NAT使内网接口对外网可见
  2. 软件架构设计——解释器模式
  3. 6. A Deeper Understanding of Deep Learning
  4. wordpress常用插件打包 百度搜索推送插件+sitemap生成等
  5. winscp连接windows_winscp登陆云主机,winscp登陆云主机如何登陆,教程详情
  6. xftp无法链接Linux
  7. 【UVA129】Krypton Factor(回溯+在回溯法的基础上判断一个字符串是否有相邻的重复子串(后缀))
  8. iNavFlight之MSP DJI协议分析
  9. 计算机英语固定词组搭配,英语短语搭配,英语中穿的五种用法及搭配
  10. OpenCV 对比度增强
  11. 零基础CSS入门教程(28)–CSS导航栏实例
  12. 中科院自动化所 模式识别国家重点实验室(NLPR)
  13. android 广播的插件化
  14. Ae 入门系列之十:效果和动画预设
  15. 领导的艺术:工作里怎么样做,才是包容
  16. Error: @vitejs/plugin-vue requires vue (>=3.2.13) or @vue/compiler-sfc to be present in the dependen
  17. Raft算法详细介绍
  18. DSP/ARM+FPGA运动控制器定制 精雕机数据机床
  19. Java加密算法—对称加密(DES、AES)
  20. 基于Qt的音乐播放器(三)通过酷狗音乐的api接口,返回json格式歌曲信息(播放地址,歌词,图片)

热门文章

  1. MySQL如何删除表中一行数据
  2. mysql基础10(SQL逻辑查询语句执行顺序)
  3. linux添加压缩文件tar,在linux中使用tar创建与解压文件
  4. html5跨域通信之postMessage
  5. 【其它】Nook HD刷机
  6. 海思开发:海思上对 relu6、hswish、h-sigmoid 移植的探索
  7. java项目开发实战--使用ssm框架开发众筹网站(IDEA版)
  8. C语言顺序栈简单实现
  9. LXC的安装与配置使用
  10. 深度分析智能垃圾回收站 智慧社区|智能垃圾分类|智慧环卫