本篇论文来自2019ICML的一篇动态架构的持续学习论文,论文地址点这里

一. 介绍

在学习一系列学习任务时,DNN会经历所谓的“灾难性遗忘”问题,在接受新任务训练后,它们通常会“忘记”以前学习过的任务。在本文中,提出了一个从学习到成长的框架,它明确地将模型结构的学习与模型参数的估计分离开来。搜索结构后,然后估计模型参数。我们发现:1)明确的持续结构学习可以更有效地利用任务之间的参数,从而为不同的任务带来更好的性能和合理的结构;2) 与具有相似模型复杂度的其他基线方法相比,将结构和参数学习分离可以显著减少灾难性遗忘。

二. Learn-to-grow 的框架

2.1 持续学习中的问题定义

考虑一个序列的 N N N个任务,表示为 T = \mathbf{T}= T= ( T 1 , T 2 , … , T N ) \left(T_1, T_2, \ldots, T_N\right) (T1​,T2​,…,TN​). 每个任务 T t T_t Tt​ 都有一个训练集 D train  ( t ) = { ( x i ( t ) , y i ( t ) ) ; i = 1 , ⋯ , n t } \mathcal{D}_{\text {train }}^{(t)}=\left\{\left(x_i^{(t)}, y_i^{(t)}\right) ; i=1, \cdots, n_t\right\} Dtrain (t)​={(xi(t)​,yi(t)​);i=1,⋯,nt​}其中 n t n_t nt​表示为样本的数量。 让 D train  = ∪ t = 1 N D train  ( t ) \mathcal{D}_{\text {train }}=\cup_{t=1}^N \mathcal{D}_{\text {train }}^{(t)} Dtrain ​=∪t=1N​Dtrain (t)​ 表示为所有任务的训练集。同样,我们定义 D test  ( t ) \mathcal{D}_{\text {test }}^{(t)} Dtest (t)​为任务 T t T_t Tt​的测试集. 使用 f ( ⋅ ; Θ t ) f\left(\cdot ; \Theta_t\right) f(⋅;Θt​) 表示为需要学习的模型,其中 Θ t \Theta_t Θt​表示为正在当前学习任务 T t T_t Tt​的所有参数集合。 模型将不断地学习到任务(1到 N N N),并且学习过后的任务数据不会再出现。持续学习的主要工作是最大化模型 f ( ⋅ ; Θ t ) f\left(\cdot ; \Theta_t\right) f(⋅;Θt​) 在任务 T t T_t Tt​ 上的表现同时最小化模型在之前 学习中的遗忘。理想情况下,我们希望在持续学习环境中最小化以下目标函数:
L ( Θ N ; D train  ) = ∑ t = 1 N L t ( Θ t ; D train  ( t ) ) L t ( Θ t ; D train  ( t ) ) = 1 n t ∑ i = 1 n t ℓ t ( f ( x i ( t ) ; Θ t ) , y i ( t ) ) \begin{aligned} &\mathcal{L}\left(\Theta_N ; \mathcal{D}_{\text {train }}\right)=\sum_{t=1}^N \mathcal{L}_t\left(\Theta_t ; \mathcal{D}_{\text {train }}^{(t)}\right) \\ &\mathcal{L}_t\left(\Theta_t ; \mathcal{D}_{\text {train }}^{(t)}\right)=\frac{1}{n_t} \sum_{i=1}^{n_t} \ell_t\left(f\left(x_i^{(t)} ; \Theta_t\right), y_i^{(t)}\right) \end{aligned} ​L(ΘN​;Dtrain ​)=t=1∑N​Lt​(Θt​;Dtrain (t)​)Lt​(Θt​;Dtrain (t)​)=nt​1​i=1∑nt​​ℓt​(f(xi(t)​;Θt​),yi(t)​)​
其中 ℓ t \ell_t ℓt​表示为任务 T t T_t Tt​上的损失。由于无法看到所有的数据集,因此上述的目标函数无法直接计算和优化。主要的挑战是保持 ∑ t ′ = 1 t − 1 L t ′ ( Θ t ′ ; D train  ( t ′ ) ) \sum_{t^{\prime}=1}^{t-1} \mathcal{L}_{t^{\prime}}\left(\Theta_{t^{\prime}} ; \mathcal{D}_{\text {train }}^{\left(t^{\prime}\right)}\right) ∑t′=1t−1​Lt′​(Θt′​;Dtrain (t′)​)不要改变太多。
下图展示的是目前的一些方法,左边的是基于正则化的方法,中间表示的是会给每个任务增加一些新的额外参数,右边的是本文的方法。

2.2 方法架构

在本文的方法中,参数是不断地扩展的: Θ t = Θ t − 1 ∪ θ t \Theta_t=\Theta_{t-1} \cup \theta_t Θt​=Θt−1​∪θt​,但是扩展的参数不是固定的,而是根据新任务来决定是否重新利用旧的参数。使用 s t ( Θ t ) s_t(\Theta_t) st​(Θt​)表示任务 T t T_t Tt​任务特定模型,那么损失函数可以修改为:
L t ( s t ( Θ t ) ) = 1 n t ∑ i = 1 n t ℓ t ( f ( x i ( t ) ; s t ( Θ t ) ) , y i ( t ) ) \mathcal{L}_t\left(s_t\left(\Theta_t\right)\right)=\frac{1}{n_t} \sum_{i=1}^{n_t} \ell_t\left(f\left(x_i^{(t)} ; s_t\left(\Theta_t\right)\right), y_i^{(t)}\right) Lt​(st​(Θt​))=nt​1​i=1∑nt​​ℓt​(f(xi(t)​;st​(Θt​)),yi(t)​)
现在,在学习所有任务时,明确考虑了结构。优化上述式子中更新的损失函数时,需要根据结构 s t s_t st​确定最佳参数。这种损失可以从两个方面看。可以将其解释为从“超级网络”中选择特定于任务的网络,该网络具有使用 s t s_t st​的参数 Θ \Theta Θ,或者对于每个任务,我们使用参数 s t ( Θ t ) s_t(\Theta_t) st​(Θt​)训练一个新模型。这两种观点之间有细微的差别。前者对总模型尺寸有限制,而后者则没有。因此,在后者的最坏情况下,随着任务数量的增加,模型大小将线性增长。这将导致一个问题——为不同的任务训练完全不同的模型,不再是持续学习!为了解决这个问题,我们提出以下惩罚损失函数:
L t ( s t ( Θ t ) ) = 1 n t ∑ i = 1 n t ℓ t ( f ( x i ( t ) ; s t ( Θ t ) ) , y i ( t ) ) + β t R t s ( s t ) + λ t R t p ( Θ t ) \begin{aligned} \mathcal{L}_t\left(s_t\left(\Theta_t\right)\right)=& \frac{1}{n_t} \sum_{i=1}^{n_t} \ell_t\left(f\left(x_i^{(t)} ; s_t\left(\Theta_t\right)\right), y_i^{(t)}\right)+\\ & \beta_t \mathcal{R}_t^s\left(s_t\right)+\lambda_t \mathcal{R}_t^p\left(\Theta_t\right) \end{aligned} Lt​(st​(Θt​))=​nt​1​i=1∑nt​​ℓt​(f(xi(t)​;st​(Θt​)),yi(t)​)+βt​Rts​(st​)+λt​Rtp​(Θt​)​
其中 β t > 0 , λ t ≥ 0 \beta_t>0, \lambda_t \geq 0 βt​>0,λt​≥0为超参数, R t s \mathcal{R}_t^s Rts​和 R t p \mathcal{R}_t^p Rtp​分别表示网络结构和模型参数的正则化器。例如,在优化模型参数时,可以对 R t p \mathcal{R}_t^p Rtp​使用 ℓ 2 \ell_2 ℓ2​正则化。和 R t s \mathcal{R}_t^s Rts​可以像 ( log ⁡ ) (\log ) (log)参数个数一样简单。这样,参数总数从上面有界,从而避免了退化情况。

三. 方法实现

(本文的方法用到了神经网络的结构搜索DARTs,并且在文中说明的时候很模糊,如果你对DARTs不是很了解的话,先去看一看相关视频,这里推荐一个视频,点这里,讲的很好并且在简介有详细的论文地址和代码地址!!!!!!)

3.1 结构优化

考虑一个网络有 L L L层共享的层以及一层任务特定层(最后一层),一个超网络 S \mathcal{S} S包括所有的任务特定层和不断增加的共享层。
我们的目标是在每一层中找到最优的解,每一层的候选项有三种:“reuse”, “adaptation"以及"new”。“reuse"表示直接使用之前的任务相关参数,“adaptation"表示为添加一小部分的额外参数,“new"则是使用一个完全新的参数(和这层一样)。我们定义在第 l l l层的超网络的大小为 ∣ S l ∣ |\mathcal{S}^l| ∣Sl∣,那么当前层的搜索空间 C l C_l Cl​为 2 ∣ S l + 1 ∣ 2|\mathcal{S}^l+1| 2∣Sl+1∣,这是因为共有 ∣ S l ∣ |\mathcal{S}^l| ∣Sl∣个"reuse”, ∣ S l ∣ |\mathcal{S}^l| ∣Sl∣个"adapataion"以及1个“new”。因此,整个的搜索空间为 ∏ l L C l \prod_l^L C_l ∏lL​Cl​。这里的一个潜在问题是,在最坏的情况下,搜索空间可能会随着任务的数量呈指数级增长。解决这个问题的一种方法是限制可能选择的总数,并维护一个优先级队列来学习选项。在我们所有的实验中,我们都不认为这是必要的。
和DARTS相似,需要将这些选择变为连续的,我们为每一个选择添加一个参数 α \alpha α,接着使用Softmax来进行计算: x l + 1 = ∑ c = 1 C l exp ⁡ ( α c l ) ∑ c ′ = 1 C l exp ⁡ ( α c ′ l ) g c l ( x l ) x_{l+1}=\sum_{c=1}^{C_l} \frac{\exp \left(\alpha_c^l\right)}{\sum_{c^{\prime}=1}^{C_l} \exp \left(\alpha_{c^{\prime}}^l\right)} g_c^l\left(x_l\right) xl+1​=∑c=1Cl​​∑c′=1Cl​​exp(αc′l​)exp(αcl​)​gcl​(xl​),这里 g c l g^l_c gcl​表示为在第 l l l层的第 c c c个选项:
g c l ( x l ) = { S c l ( x l ) if  c ≤ ∣ S l ∣ S c l ( x l ) + γ c − ∣ S l ∣ l ( x l ) if  ∣ S l ∣ < c ≤ 2 ∣ S l ∣ , o l ( x l ) if  c = 2 ∣ S l ∣ + 1 g_c^l\left(x_l\right)= \begin{cases}S_c^l\left(x_l\right) & \text { if } c \leq\left|\mathcal{S}^l\right| \\ S_c^l\left(x_l\right)+\gamma_{c-\left|\mathcal{S}^l\right|}^l\left(x_l\right) & \text { if }\left|\mathcal{S}^l\right|<c \leq 2\left|\mathcal{S}^l\right|, \\ o^l\left(x_l\right) & \text { if } c=2\left|\mathcal{S}^l\right|+1\end{cases} gcl​(xl​)=⎩ ⎨ ⎧​Scl​(xl​)Scl​(xl​)+γc−∣Sl∣l​(xl​)ol(xl​)​ if c≤∣ ∣​Sl∣ ∣​ if ∣ ∣​Sl∣ ∣​<c≤2∣ ∣​Sl∣ ∣​, if c=2∣ ∣​Sl∣ ∣​+1​
这里 γ \gamma γ表示为“adaption”操作, o o o表示为"new"操作,经过不断计算后就可以提取一个权重 α = { α l } \alpha=\{\alpha^l\} α={αl}。经过搜寻后,我们选择最大 α c l \alpha^l_c αcl​的对应操作, c l = arg ⁡ max ⁡ α l c_l=\arg \max \alpha^l cl​=argmaxαl。
这里采用DARTS的训练策略,将训练数据集 D train ( t ) \mathcal{D}_{\text {train}}^{(t)} Dtrain(t)​拆分为两个子集:一个用于NAS的验证子集和一个用于参数估计的训练子集。我们使用验证损失 L v a l L_{v al} Lval​来更新体系结构权值 α \alpha α,而参数则由训练损失 L train L_{\text {train}} Ltrain​来估计。在搜索过程中,体系结构权重和参数会交替更新。因为它是一个嵌套的双层优化问题,原始的DARTS提供了一个二阶逼近,以实现更精确的优化。
最后,稍微介绍一下“reuse”, “adaptation”,“new”操作的具体过程。假设在CNN模型中,我们选择3x3的卷积核,对于"reuse”,直接使用之前的参数进行训练,对于"adaptation”,我们在层上方添加一个并行的1x1的卷积,而"new"则直接初始化新的参数。
(本文没有找到代码,之后找到我会进行分析的)

Learn to Grow: A Continual Structure Learning Framework for Overcoming Catastrophic Forgetting论文阅读相关推荐

  1. A Novel Two-stage Separable Deep Learning Framework for Practical Blind Watermarking论文阅读

    A Novel Two-stage Separable Deep Learning Framework for Practical Blind Watermarking Abstract 数字水印是一 ...

  2. [论文分享] Overcoming Catastrophic Forgetting in Incremental Few-Shot Learning by Finding Flat Minima

    我又来给大家分享PAPER了!!! 今天给大家分享的这篇论文是NIPS' 2021的一篇Few-Shot增量学习(FSCIL)文章,这篇文章通过固定backbone和prototype得到一个简单的b ...

  3. Capture, Learning, and Synthesisof 3D Speaking styles论文阅读笔记 VOCA

    Capture, Learning, and Synthesisof 3D Speaking Styles论文阅读笔记 摘要 制作了一个4D面部(3D mesh 序列 + 同步语音)数据集:29分钟, ...

  4. 《Continual lifelong learning with neural networks : A review》阅读笔记

    增量学习更多的起源于认知科学神经科学对记忆和遗忘机制的研究,Continual lifelong learning with neural networks : A review综述探讨了增量学习在生 ...

  5. Learning Multiview 3D point Cloud Registration论文阅读笔记

    Learning multiview 3D point cloud registration Abstract 提出了一种全新的,端到端的,可学习的多视角三维点云配准算法. 多视角配准往往需要两个阶段 ...

  6. 【DiMP】Learning Discriminative Model Prediction for Tracking论文阅读

    Learning Discriminative Model Prediction for Tracking 论文地址 写在前面 又是MD大神的一个作品,发现MD大神也把Siamese的框架用起来了,而 ...

  7. 一种用于加密流分类的多模态深度学习框架A Novel Multimodal Deep Learning Framework for Encrypted Traffic Classification论文

    一.背景 l 网络应用程序流量被加密 l 基于传统有效载荷交通分类方法和基于端口的流量分类方法不在有效 l 已有的模型不能用于更细粒度的操作 二.pean介绍 概括 PEAN模型是一种软件架构模式,它 ...

  8. XDL: An Industrial Deep Learning Framework for High-dimensional Sparse Data 论文笔记

    本文的github地址: https://github.com/alibaba/x-deeplearning X-Deep Learning(简称XDL)于2018年12月由阿里巴巴开源,是面向高维稀 ...

  9. Deep Meta Learning for Real-Time Target-Aware Visual Tracking 论文阅读

    这篇文章是韩国的一个组做的,一直没中, 直到19年中了ICCV,据说是第一篇将元学习引入目标跟踪的文章,用的架构是siamese网络的架构,但是在模型在线更新的时候使用了meta-learning的思 ...

最新文章

  1. Sharding-JDBC教程:Spring Boot整合Sharding-JDBC实现数据分表+读写分离
  2. 你想要的生物信息知识全在这——生信宝典文章目录
  3. HTML图片热区map area的用法(转)
  4. angular http demo
  5. 填充磁盘空间的工具和方法
  6. 这不是商业互吹,是学习的宝藏
  7. VBA 用 Environ 获取系统环境变量
  8. HashMap死锁原因及替代方案
  9. oracle group by sql,Oracle SQL GROUP BY“不是GROUP BY表达式”的帮助
  10. HTMLParser使用
  11. 《02333软件工程课后习题答案-2011版 王立福》
  12. matlab建立遗传算法,Matlab遗传算法(一)
  13. 投影仪与计算机连接方式,【投影网教程】投影仪连接电脑的方法
  14. 教育网Linux下赶mule
  15. 中级微观经济学:Chap 9 购买和销售
  16. JPBC库应用之BLS签名
  17. HTTP,TCP,UDP常见端口对照表大全
  18. 提高生活、学习、工作效率的方法——时间管理Vs个人管理
  19. Winform从入门到精通(17)——PictureBox(史上最全)
  20. ccf论文分级_论文等级如何划分

热门文章

  1. 华为od统一考试B卷【敏感字段加密】JavaScript 实现
  2. 【深度学习】资源:最全的 Pytorch 资源大全
  3. 商用密码产品(密码模块)-密码模块接口
  4. 回溯法 之 马周游(马跳日)问题
  5. java建立一个小小留言板
  6. 微信砍价源码 php_微信砍价助力系统 v1.0
  7. Java中一种计算Hash值的算法——SHA-256
  8. mysql 中文_让MySQL支持中文
  9. 阿里云平台与MQTTX软件通信
  10. Google Earth 常用操作