目录

  • 1 前言
  • 2 原理
  • 3 coding
    • 3.1 加载数据集
    • 3.2 搭建网络
    • 3.3 训练及测试
    • 3.4 对比
    • 3.5 可视化特征分布
  • 4 总结
  • 附录

1 前言

  • 好几天前我就想做关于域自适应的文章。可能你觉得我不务正业【手动狗头】,但是别急,这个还真挺有意思。还能加深自己对网络的理解,所以这几天,我一直在肝这方面的内容。
  • 好的,什么叫做域自适应?举个例子哈,有些人微信头像喜欢换成动漫的自己,当然了,作为好朋友,你肯定认得出动漫的他。这是因为我们视觉系统比较强大,偷偷做了域自适应。梦回正文,单纯的神经网络却认不出来。你训练网络时使用的数据全是现实生活中的朋友照片,那网络就只能认出现实生活中的他,因为现实生活的照片服从同一种分布。
  • 这时你发现,训练好的网络识别现实生活中的人脸杠杠的,同一个人的动漫人脸却无能为力了,换句话,准确率很低。域自适应就是来解决这样一种问题,即源域(训练集)数据与目标域(测试集)数据分布不一致。尽管如此,网络还是能正确分类。
  • 看到这里,是不是觉得还挺有意思的。关键这还挺有应用价值的,要不然训练出个人工智障,那不要贻笑大方。

2 原理

这里放张图,马上我就要实现它(激动.jpg)。

上图是个什么鬼东西(刚接触的领域,折磨我千百遍,以示尊敬),我也懒得解释它,这里采用知乎大佬的反击。

背景

想必大家对GAN都不陌生,GAN是基于对抗的生成网络,主要目标是生成与训练集分布一致的数据。而在迁移学习领域,对抗也是一种常用的方式,如Ganin[1]的论文,使用的网络结构如上图,由三部分组成:特征映射网络 Gf(x;θf)G_f(x;\theta_f)Gf(x;θf)标签分类网络 Gf(z;θy)G_f(z;\theta_y)Gf(z;θy)和域判别网络 Gd(z;θd)G_d(z;\theta_d)Gd(z;θd)

其中,source domain的数据是有标签的,target domain的数据是无标签的。GfG_fGf 将source和target domain的数据都映射到一个特征空间 ZZZ上, GyG_yGy 预测标签 yyyGdG_dGd 预测数据来自于target还是source domain。所以流入 GyG_yGy 的是带标签的source数据,流入 GdG_dGd 的是不带标签的source和target的数据。

GfG_fGf : 将数据映射到feature space,使 GyG_yGy 能分辨出source domain数据的label,GdG_dGd 分辨不出数据来自source domain还是target domain。

GyG_yGy : 对feature space的source domain数据进行分类,尽可能分出正确的label。

GdG_dGd : 对feature space的数据进行领域分类,尽量分辨出数据来自于哪一个domain。

最终,希望 GfG_fGfGdG_dGd 博弈的结果是source和target domain的数据在feature space上分布已经很一致, GdG_dGd 无法区分。于是,可以愉快的用 GyG_yGy 来分类target domain的数据啦。

  • 看懂的扣1,没看懂的扣 limx→0xsinxlim_{x\rightarrow 0} \frac{x}{sinx}limx0sinxx 的值。
  • 没关系,看不懂看下面的代码----------(省略一万字,文章效果嘛)更看不懂【狗头保命】。

3 coding

  • OK,今天的重头戏终于来到。接下来我会讲一下实验步骤:
  1. 加载mnist与m-mnist数据集(如下示意图)
  2. 搭建网络(难点)
  3. 训练,并测试模型对于这两个不同分布数据集的准确率
  4. 测试在没有使用domain adaptation(域自适应)的情况下,普通网络对这两种数据集的准确率
  5. 使用域自适应与没有使用,这两者在feature层输出的分布可视化区别

3.1 加载数据集

  • 由于m-mnist数据集并没有封装好,需要封装成类。这里不是重点,只是为了方便加载数据。
  • mnist(源域)数据集可以网络下载不用管,m-mnist(目标域)的数据集这里下载:download
  • 下载到本地后解压,将其放在 ./data 文件夹下。路径也能自己更改,看下面代码。
  • 这一步跑成功了,后面不是问题,bro。
import torch
import torchvision.transforms as tvtf
import torchvision as tv
import os
import time
import torch.optim as optim
import torch.nn as nn
import numpy as np
from torch.autograd import Function
import warnings
import torch.utils.data as data
from PIL import Image
import os# 当然,你可以将其以py文件的方式导入,这样会清爽很多
class MNISTMDataset(data.Dataset):def __init__(self, data_root, data_list, transform=None):self.root = data_rootself.transform = transformf = open(data_list, 'r')data_list = f.readlines()f.close()self.n_data = len(data_list)self.img_paths = []self.img_labels = []for data in data_list:self.img_paths.append(data[:-3])self.img_labels.append(data[-2])def __getitem__(self, item):img_paths, labels = self.img_paths[item], self.img_labels[item]imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')if self.transform is not None:imgs = self.transform(imgs)labels = int(labels)return imgs, labelsdef __len__(self):return self.n_data# 如上只是加载数据集,无需太过重视
warnings.filterwarnings('ignore')
image_size = 28
batch_size = 10
lr = 1e-3
n_epochs = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')tf_source = tvtf.Compose([tvtf.Resize(image_size),tvtf.ToTensor(),tvtf.Normalize(mean=(0.1307,), std=(0.3081,))
])
tf_target = tvtf.Compose([tvtf.Resize(image_size),tvtf.ToTensor(),tvtf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
start = time.time()
ds_source = tv.datasets.MNIST(root='./data', train=True, transform=tf_source, download=True)
dl_source = torch.utils.data.DataLoader(ds_source, batch_size)ds_target = MNISTMDataset(os.path.join('./data', 'mnist_m', 'mnist_m_train'),os.path.join('./data', 'mnist_m', 'mnist_m_train_labels.txt'),transform=tf_target)
# 使用dataloader加载器,边训练边加载图片,避免内存不够。较为灵活
dl_target = torch.utils.data.DataLoader(ds_target, batch_size)
end = time.time()
print ('loading about {} seconds'.format(end-start))

3.2 搭建网络

# Autograd Function objects are what record operation history on tensors,
# and define formulas for the forward and backprop.
class GradientReversalFn(Function):@staticmethoddef forward(ctx, x, alpha):# Store context for backpropctx.alpha = alpha# Forward pass is a no-opreturn x.view_as(x)@staticmethoddef backward(ctx, grad_output):# Backward pass is just to -alpha the gradientoutput = grad_output.neg() * ctx.alpha# Must return same number as inputs to forward()return output, Noneclass DACNN(nn.Module):def __init__(self):super().__init__()self.feature_extractor = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5),nn.BatchNorm2d(64), nn.MaxPool2d(2),nn.ReLU(True),nn.Conv2d(64, 50, kernel_size=5),nn.BatchNorm2d(50), nn.Dropout2d(), nn.MaxPool2d(2),nn.ReLU(True),)self.class_classifier = nn.Sequential(nn.Linear(50 * 4 * 4, 100), nn.BatchNorm1d(100), nn.Dropout2d(),nn.ReLU(True),nn.Linear(100, 100), nn.BatchNorm1d(100),nn.ReLU(True),nn.Linear(100, 10),nn.LogSoftmax(dim=1),)self.domain_classifier = nn.Sequential(nn.Linear(50 * 4 * 4, 100), nn.BatchNorm1d(100),nn.ReLU(True),nn.Linear(100, 2),nn.LogSoftmax(dim=1),)def forward(self, x, grl_lambda=1.0):# Handle single-channel input by expanding (repeating) the singleton dimentionx = x.expand(x.data.shape[0], 3, image_size, image_size)features = self.feature_extractor(x)features = features.view(-1, 50 * 4 * 4)reverse_features = GradientReversalFn.apply(features, grl_lambda)class_pred = self.class_classifier(features)domain_pred = self.domain_classifier(reverse_features)return class_pred, domain_predmodel = DACNN().to(device)
  • 注意到 GradientReversalFn 类,前向传播它不起任何作用,反向传播它会把来自domain classifier梯度值改成相反数。由于网络求导使用链式法则,GradientReversalFn 类之前的网络梯度,全部受影响,变成对应的相反数。
  • 滑天下之大稽。这样做不就与想要正确分类domain的目标相反了吗?非也,论文也就是这里体现了GAN的思想。作者希望网络 GfG_fGfGdG_dGd 对抗起来。最终 GdG_dGd 不能分类出数据来自哪个域,也就足以说明不管是源域数据还是目标域数据经过网络到达feature f ,它们的分布会尽可能一致。(如果分布不一致的话,GdG_dGd 就有能力分辨出来了阿)

3.3 训练及测试

训练

# Setup optimizer as usual
optimizer = optim.Adam(model.parameters(), lr)# Two losses functions this time
loss_fn_class = torch.nn.NLLLoss()
loss_fn_domain = torch.nn.NLLLoss()dl_source = torch.utils.data.DataLoader(ds_source, batch_size)
dl_target = torch.utils.data.DataLoader(ds_target, batch_size)# We'll train the same number of batches from both datasets
# max_batches = min(len(dl_source), len(dl_target))
max_batches = 5000for epoch_idx in range(n_epochs):print(f'Epoch{epoch_idx+1:04d}/{n_epochs:04d}', end='\n=================\n')dl_source_iter = iter(dl_source)dl_target_iter = iter(dl_target)for batch_idx in range(max_batches):optimizer.zero_grad()# Training progress and GRL lambdap = float(batch_idx + epoch_idx * max_batches) / (n_epochs * max_batches)grl_lambda = 2. / (1. + np.exp(-10 * p)) - 1# Train on source domainX_s, y_s = next(dl_source_iter)y_s_domain = torch.zeros(batch_size, dtype=torch.long) # generate source domain labelsclass_pred, domain_pred = model(X_s.to(device), grl_lambda)loss_s_label = loss_fn_class(class_pred, y_s.to(device))# a = 1loss_s_domain = loss_fn_domain(domain_pred, y_s_domain.to(device))# Train on target domainX_t, _ = next(dl_target_iter) # ignore target domain class labels!y_t_domain = torch.ones(batch_size, dtype=torch.long) # generate target domain labels_, domain_pred = model(X_t.to(device), grl_lambda)loss_t_domain = loss_fn_domain(domain_pred, y_t_domain.to(device))loss = loss_t_domain + loss_s_domain + loss_s_labelloss.backward()optimizer.step()if batch_idx % 1000 == 0:print(f'[{batch_idx+1}/{max_batches}] 'f'class_loss:{loss_s_label.item():.4f}' f's_domain_loss:{loss_s_domain.item():.4f}'f't_domain_loss:{loss_t_domain.item():.4f}' f'grl_lambda:{grl_lambda:.3f}')
  • mnist(源域)数据经过特征映射网络 GfG_fGf 分别进入域判别网络GdG_dGd 与标签分类网络 GyG_yGy ,m-mnist(目标域)数据经过特征映射网络 GfG_fGf 只进入域判别网络 GdG_dGd ,使用交叉熵依次计算损失值(这里有三个损失值)。对损失值之和反向传播,不断优化网络参数即可。

测试

ds_test_source = tv.datasets.MNIST(root='./data', train=False, transform=tf_source, download=True)
dl_test_source = torch.utils.data.DataLoader(ds_source, batch_size)ds_test_target = MNISTMDataset(os.path.join('./data', 'mnist_m', 'mnist_m_test'),os.path.join('./data', 'mnist_m', 'mnist_m_test_labels.txt'),transform=tf_target)
# 使用dataloader加载器,边训练边加载图片,避免内存不够。较为灵活
dl_test_target = torch.utils.data.DataLoader(ds_target, batch_size)def test(model,test_loader):total_cnt = 0length = 0.0for i,(data,target) in enumerate(test_loader):class_pred, domain_pred = model(data.to(device))pred = class_pred.max(1)[1]cnt = torch.sum(pred==target.to(device))total_cnt = total_cnt + cntlength = iif i == 3000/batch_size:breakprint ('total correct is {} and total data is 3000'.format(total_cnt))print ('corrent rate is {}'.format(total_cnt/(length*batch_size)))
test(model,dl_test_source)
test(model,dl_test_target)
  • 如下结果,源域数据准确率90%,目标域62.5%。还行吧,毕竟只跑了一个epoch。

Output:

3.4 对比

  • 如上测试的是在使用了无监督的domain adaptation策略的结果,下面我会使用普通的方式训练网络,以观察使用策略之后,是否真的提高了识别目标域的准确度。
class SCNN(nn.Module):def __init__(self):super().__init__()self.feature_extractor = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5),nn.BatchNorm2d(64), nn.MaxPool2d(2),nn.ReLU(True),nn.Conv2d(64, 50, kernel_size=5),nn.BatchNorm2d(50), nn.Dropout2d(), nn.MaxPool2d(2),nn.ReLU(True),)self.class_classifier = nn.Sequential(nn.Linear(50 * 4 * 4, 100), nn.BatchNorm1d(100), nn.Dropout2d(),nn.ReLU(True),nn.Linear(100, 100), nn.BatchNorm1d(100),nn.ReLU(True),nn.Linear(100, 10),nn.LogSoftmax(dim=1),)def forward(self, x, grl_lambda=1.0):# Handle single-channel input by expanding (repeating) the singleton dimentionx = x.expand(x.data.shape[0], 3, image_size, image_size)features = self.feature_extractor(x)features = features.view(-1, 50 * 4 * 4)class_pred = self.class_classifier(features)return class_predsimple_model = SCNN().to(device)
optimizer = optim.Adam(simple_model.parameters(), lr)
loss_fn_class = torch.nn.NLLLoss()def train(model,dl_source):for i,(data,target) in enumerate(dl_source):class_pred = model(data.to(device))loss_s_label = loss_fn_class(class_pred, target.to(device))optimizer.zero_grad()loss_s_label.backward()optimizer.step()if i % 1000 == 0:print ('arrive {}-th image data and the loss is {}'.format(i*batch_size,loss_s_label))
train(simple_model,dl_source)def test(model,test_loader):total_cnt = 0length = 0.0for i,(data,target) in enumerate(test_loader):class_pred = model(data.to(device))pred = class_pred.max(1)[1]cnt = torch.sum(pred==target.to(device))total_cnt = total_cnt + cntlength = iif i == 3000/batch_size:breakprint ('total correct is {} and total data is 3000'.format(total_cnt))print ('corrent rate is {}'.format(total_cnt/(length*batch_size)))test(simple_model,dl_test_source)
test(simple_model,dl_test_target)
  • 不使用策略,源域的识别是93%的准确率,目标域是43%的准确率。
  • 对比结果:使用了策略之后,目标域提升了将近20的百分点,源域没有明显变化。

Output:

3.5 可视化特征分布

  • 这一部分其实很有意思。
  • 特征映射网络 GfG_fGf 输出feature f 是个高维的向量,我们怎么才能可视化为图上的一个点呢?换个思路,使用 t-SNE 将其降维成2维或者3维的向量,这样不就能在图上可视化了吗。理论可行,开始时间。
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib as cmdl_source_train_1000 = torch.utils.data.DataLoader(ds_source, 1000)
dl_target_train_1000 = torch.utils.data.DataLoader(ds_target, 1000)
source_train_1000 , _ = next(iter(dl_source_train_1000))
target_train_1000 , _ =next(iter(dl_target_train_1000))
def plot_distribution(source_train_1000,target_train_1000):# Create a two dimensional t-SNE projection of the embeddingstsne = TSNE(2, verbose=1)if source_train_1000.is_cuda:source_train_1000 = source_train_1000.cpu().detach().numpy()else:source_train_1000 = source_train_1000.detach().numpy()if target_train_1000.is_cuda:target_train_1000 = target_train_1000.cpu().detach().numpy()else:target_train_1000 = target_train_1000.detach().numpy()tsne_source = tsne.fit_transform(source_train_1000.reshape(1000,-1))tsne_target = tsne.fit_transform(target_train_1000.reshape(1000,-1))# Plot those points as a scatter plot and label them based on the pred labelsfig, ax = plt.subplots(figsize=(8,8))num_categories = 2for lab in range(num_categories):if lab ==0:ax.scatter(tsne_source[:,0],tsne_source[:,1], c='red', label = 'source' ,alpha=0.5)else:ax.scatter(tsne_target[:,0],tsne_target[:,1], c='blue', label = 'target' ,alpha=0.5)ax.legend(fontsize='large', markerscale=2)plt.show()
plot_distribution(source_train_1000,target_train_1000)
  • 如下图是mnist与m-mnist数据集未经过处理的分布。明显,两者之间不具备规律性,属于不同分布

Output:

from mpl_toolkits import mplot3d
source_train_1000 = source_train_1000.reshape(-1,1,28,28)
target_train_1000 = target_train_1000.reshape(-1,3,28,28)
# batch_size 1000设置过大,gpu内存分配不够,故将其转为cpu状态运行
model = model.to(torch.device('cpu'))
feature_source_train_1000 = model.feature_extractor(source_train_1000.expand(1000, 3, 28, 28))
feature_target_train_1000 = model.feature_extractor(target_train_1000.expand(1000, 3, 28, 28))
# Create a two dimensional t-SNE projection of the embeddings
plot_distribution(feature_source_train_1000,feature_target_train_1000)
  • 爱了爱了。如下,你敢说源域与目标域的分布不一致?起码很相似了。这也就是我们的 特征映射网络 GfG_fGf 做的事情。

Output:

4 总结

  • 如果你和我一样是个小菜鸡,我相信你能从该篇文章中找到研究神经网络的新方法,即从分布上理解网络。
  • 呼应前文不务正业。我觉得多看看其他领域文章能开阔思路,也能提升自己对网络的理解。后续我还会开始强化学习,图神经网络以及NLP领域的学习。说起来,我产生这样的想法皆因,看论文的过程中,发现别人很多想法与思路大多借鉴其他领域。我也不知道这样好不好,但是现在才研一,未来再做回复。
  • 不过,需牢记一点,踏踏实实。一篇好论文还是需要认真研读+代码复现,体会别人解决问题的方法,如果这个问题交给你,你会怎么做。这就是主动学习与被动学习的区别。

附录

参考资料

https://towardsdatascience.com/visualizing-feature-vectors-embeddings-using-pca-and-t-sne-ef157cea3a42

https://zhuanlan.zhihu.com/p/50710267

https://nbviewer.jupyter.org/github/vistalab-technion/cs236605-tutorials/blob/master/tutorial6/tutorial6-TL_DA.ipynb

yeah,I‘m a real man

域自适应实战coding相关推荐

  1. 【深度域自适应】二、利用DANN实现MNIST和MNIST-M数据集迁移训练

    前言 在前一篇文章[深度域自适应]一.DANN与梯度反转层(GRL)详解中,我们主要讲解了DANN的网络架构与梯度反转层(GRL)的基本原理,接下来这篇文章中我们将主要复现DANN论文Unsuperv ...

  2. 目标检测的渐进域自适应,优于最新SOTA方法

    作者 | Han-Kai Hsu.Chun-Han Yao.Yi-Hsuan Tsai.Wei-Chih Hung.Hung-Yu Tseng.Maneesh Singh.Ming-Hsuan Yan ...

  3. 超越最新无监督域自适应方法,研究人员提轻量CNN新架构OSNet

    作者 | Kaiyang Zhou, Xiatian Zhu, Yongxin Yang, Andrea Cavallaro, and Tao Xiang 译者 | TroyChang 编辑 | Ja ...

  4. MS-DAYOLO来了!多尺度域自适应的YOLO,恶劣天气也看得见!

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 Multiscale Domain Adaptive YOLO for Cross-Domain Ob ...

  5. 近期必读的9篇CVPR 2019【域自适应(Domain Adaptation)】相关论文和代码

    [导读]最近小编推出CVPR2019图卷积网络.CVPR2019生成对抗网络.[可解释性],CVPR视觉目标跟踪,CVPR视觉问答,医学图像分割,图神经网络的推荐相关论文,反响热烈.最近,Domain ...

  6. CVPR 2022 | 利用域自适应思想,北大、字节跳动提出新型弱监督物体定位框架

    ©作者 | 朱磊 来源 | 机器之心 将弱监督物体定位看作图像与像素特征域间的域自适应任务,北大.字节跳动提出新框架显著增强基于图像级标签的弱监督图像定位性能. 物体定位作为计算机视觉的基本问题,可以 ...

  7. 迁移学习之域自适应理论简介(Domain Adaptation Theory)

    ©作者 | 江俊广 单位 | 清华大学 研究方向 | 迁移学习 本文主要介绍域自适应(Domain Adaptation)最基本的学习理论,全文不涉及理论的证明,主要是对部分理论的发展脉络的梳理,以及 ...

  8. 从近年CVPR看域自适应立体匹配

    ©PaperWeekly 原创 · 作者|张承灏 单位|中科院自动化所硕士生 研究方向|深度估计 深度立体匹配(deep stereo matching)算法能够取得较好的性能,一是来源于卷积神经网络 ...

  9. PHP-RSA加密跨域通讯实战

    PHP-RSA加密跨域通讯实战 AUTH:PHILO EMAIL:lijianying12 at gmail.com 基于POST GET 的http通讯虽然非常成熟,但是很容易被人监听. 并且如果使 ...

最新文章

  1. android+3e错误,Android 错误
  2. 大脑进化追不上社会文化:化石和脱氧核糖核酸证明人类大脑进化比社会慢
  3. 结构体 CString QString 成员赋值出错
  4. ccd相机好修吗_「CCD购买指南 」CCD废片大公开
  5. 客服会话 小程序 如何发起_小程序、公众号、App三者如何融合布局?这里有一份避坑指南...
  6. 从使用传统Web框架到切换到Spring Boot后的总结
  7. ESS控制台发布新功能:创建多实例规格的伸缩配置
  8. 关于Apt注解实践与总结【包含20篇博客】
  9. linux初级之总结复习
  10. 设计模式(10)——迭代器模式
  11. jenkins2 pipeline高级
  12. 使用transferTo方法转换MultipartFile(处理NoSuchFileException异常)
  13. Oracle、mysql产品性能优化总结
  14. PHP 把ofd格式文件转PDF,打开OFD格式文件及将OFD格式文件转换成PDF文件
  15. 如何实现微信二维码支付功能???
  16. java.lang.IllegalArgumentException: Index for header ‘XXX‘ is 1 but CSVRecord only has 1 value
  17. 计算机病毒手动查杀,电脑中毒了怎么办 如何手动彻底查杀病毒【解决方法】...
  18. Springboot 精髓
  19. 奥塔在线:Centos下使用RPM方式安装JDK1.8
  20. 微信H5纯签约 提示“发起签约页面非法”

热门文章

  1. String数组的使用
  2. /dev/ttyUSB0 permission denied 解决方案
  3. 5G 移动通信的硬件验证平台 高频段传输 新型多天线传输 同时同频全双工TDD
  4. vtk教程第一章介绍
  5. PHP+MySQL+LayUI分页查询显示
  6. 燃气蒸汽发生器助力酿酒企业更好把控啤酒加工温度
  7. nginx装逼语录待补充
  8. 我们和他们,究竟谁是傻X? | 华尔街黑历史(一)
  9. 第八章 注意力机制与外部记忆
  10. 495. 提莫攻击 有一个叫 “提莫” 的英雄,他的攻击可以让敌方英雄艾希(编者注:寒冰射手)进入中毒状态