介绍

本文提出了一种注意力层+强化学习的训练模型,以解决TSP、VRP、OP、PCTSP等路径问题。文章致力于使用相同的超参数,解决多种路径问题。文中采用了贪心算法作为基线,相较于值函数效果更好。

注意力模型

文中定义了Attention Model以解决TSP问题,针对其它问题,不需要改变模型,只需要修改输入、掩码、解码上下文等参量。模型采用编码-解码结构,编码器生成所有输入节点的嵌入,解码器依次生成输入节点的序列π。以下都以TSP问题举例:

编码器

本文中的编码器部分与Transformer架构中的编码器类似,但不使用位置编码。编码器结点输入维度是2,经过一个线性网络将特征维度扩展到128维;之后经过N个子层得到输出。其中,每个子层都是由一个8头注意力层和一个全连接层组成,每层都采用了残差连接,经过了批归一化得到输出。
输入:
x (batch_size, graph_size, embed_dim)
输出:
h (batch_size, graph_size, embed_dim)结点嵌入
h.mean (batch_size, embed_dim)图嵌入

*
结点编码类class GraphAttentionEncoder:*

class GraphAttentionEncoder(nn.Module):def __init__(self,n_heads,embed_dim,n_layers,node_dim=None,normalization='batch',feed_forward_hidden=512):super(GraphAttentionEncoder, self).__init__()# To map input to embedding spaceself.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else Noneself.layers = nn.Sequential(*(MultiHeadAttentionLayer(n_heads, embed_dim, feed_forward_hidden, normalization)for _ in range(n_layers)))def forward(self, x, mask=None):assert mask is None, "TODO mask not yet supported!"# Batch multiply to get initial embeddings of nodesh = self.init_embed(x.view(-1, x.size(-1))).view(*x.size()[:2], -1) if self.init_embed is not None else xh = self.layers(h)return (h,  # (batch_size, graph_size, embed_dim)h.mean(dim=1),  # average to get embedding of graph, (batch_size, embed_dim))

子层MHA+FF class MultiHeadAttentionLayer:

class MultiHeadAttentionLayer(nn.Sequential):def __init__(self,n_heads,embed_dim,feed_forward_hidden=512,normalization='batch',):super(MultiHeadAttentionLayer, self).__init__(SkipConnection(MultiHeadAttention(n_heads,input_dim=embed_dim,embed_dim=embed_dim)),Normalization(embed_dim, normalization),SkipConnection(nn.Sequential(nn.Linear(embed_dim, feed_forward_hidden),nn.ReLU(),nn.Linear(feed_forward_hidden, embed_dim)) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim)),Normalization(embed_dim, normalization))

解码器

输出:
cost (batch_size) 总路径
_log_p (batch_size, graph_size) 结点输出概率和
pi (batch_size, graph_size) 游走序列
解码器是由两层attention层组成,先经过一层多头注意力层,再经过一层单头注意力层得到相关性分数logits。
结点经过编码器后进行解码,解码过程中使用了一个上下文结点c来表示解码上下文。上下文结点是由编码过程中得到的图嵌入、序列中第一个结点嵌入、序列中上一步添加的结点嵌入三者经过线性变换作为query,编码后的结点通过线性变换作为key、value:
特别的,选取第一个结点时,采用两个可学习的参数代替第一个结点和上一步结点。上下文结点嵌入定义如下:
这里的[ · · ]组合文中说是三个向量的连接,但是代码中不是这样,代码中是将图嵌入经过一个线性网络后与两个结点嵌入连接后经过线性网络相加得到:

query = fixed.context_node_projected + self.project_step_context(self._get_parallel_step_context(fixed.node_embeddings, state))

特别说明:两次注意力机制中的glimpse key、glimpse value、logit key是通过一个线性网络self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=False) 实现的。
经过一次8头注意力网络后得到新的上下文结点 h c ( N + 1 ) h^{(N+1)}_{c} hc(N+1),经过线性变换后得到单头注意力层的query,经过单头注意力机制后得到相关性分数logits:
这里的注意力层没有采用跳跃连接、批归一化、全连接层操作,直接得到相关性分数。
def _one_to_many_logits得到每一次经过两层注意力层后的相关性分数

def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):batch_size, num_steps, embed_dim = query.size()key_size = val_size = embed_dim // self.n_heads# Compute the glimpse, rearrange dimensions so the dimensions are (n_heads, batch_size, num_steps, 1, key_size)glimpse_Q = query.view(batch_size, num_steps, self.n_heads, 1, key_size).permute(2, 0, 1, 3, 4)# Batch matrix multiplication to compute compatibilities (n_heads, batch_size, num_steps, graph_size)compatibility = torch.matmul(glimpse_Q, glimpse_K.transpose(-2, -1)) / math.sqrt(glimpse_Q.size(-1))if self.mask_inner:assert self.mask_logits, "Cannot mask inner without masking logits"compatibility[mask[None, :, :, None, :].expand_as(compatibility)] = -math.inf# Batch matrix multiplication to compute heads (n_heads, batch_size, num_steps, val_size)heads = torch.matmul(torch.softmax(compatibility, dim=-1), glimpse_V)# Project to get glimpse/updated context node embedding (batch_size, num_steps, embedding_dim)glimpse = self.project_out(heads.permute(1, 2, 3, 0, 4).contiguous().view(-1, num_steps, 1, self.n_heads * val_size))# Now projecting the glimpse is not needed since this can be absorbed into project_out# final_Q = self.project_glimpse(glimpse)final_Q = glimpse# Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)# logits = 'compatibility'logits = torch.matmul(final_Q, logit_K.transpose(-2, -1)).squeeze(-2) / math.sqrt(final_Q.size(-1))# From the logits compute the probabilities by clipping, masking and softmaxif self.tanh_clipping > 0:logits = torch.tanh(logits) * self.tanh_clippingif self.mask_logits:logits[mask] = -math.infreturn logits, glimpse.squeeze(-2)

通过采样/贪心策略选择结点,得到序列
def _inner输出每一步概率与最终结点序列

def _inner(self, input, embeddings):outputs = []sequences = []state = self.problem.make_state(input)# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every stepfixed = self._precompute(embeddings)batch_size = state.ids.size(0)# Perform decoding stepsi = 0while not (self.shrink_size is None and state.all_finished()):if self.shrink_size is not None:unfinished = torch.nonzero(state.get_finished() == 0)if len(unfinished) == 0:breakunfinished = unfinished[:, 0]# Check if we can shrink by at least shrink_size and if this leaves at least 16# (otherwise batch norm will not work well and it is inefficient anyway)if 16 <= len(unfinished) <= state.ids.size(0) - self.shrink_size:# Filter statesstate = state[unfinished]fixed = fixed[unfinished]log_p, mask = self._get_log_p(fixed, state)# Select the indices of the next nodes in the sequences, result (batch_size) longselected = self._select_node(log_p.exp()[:, 0, :], mask[:, 0, :])  # Squeeze out steps dimensionstate = state.update(selected)# Now make log_p, selected desired output size by 'unshrinking'if self.shrink_size is not None and state.ids.size(0) < batch_size:log_p_, selected_ = log_p, selectedlog_p = log_p_.new_zeros(batch_size, *log_p_.size()[1:])selected = selected_.new_zeros(batch_size)log_p[state.ids[:, 0]] = log_p_selected[state.ids[:, 0]] = selected_# Collect output of stepoutputs.append(log_p[:, 0, :])sequences.append(selected)i += 1# Collected lists, return Tensorreturn torch.stack(outputs, 1), torch.stack(sequences, 1)  # (512,20,20)  (512,20)

训练

  1. 定义model
  2. 初始化baseline:baseline = RolloutBaseline(model, problem, opts) 构建了数据集(10000,20,2),batch_size=1024,设置搜索方式为贪心,初始化baseline参数
  3. 训练:
    for epoch in range(n_epochs–100):
    train_epoch:
    ----构建数据集(1280000,20,2),batch_size=512;设置搜索方式为采样
    ----for batch in training_dataloader:
    ----train_batch:
    --------分别计算baseline和model通过attention后的cost(512,1)
    --------计算平均损失
    ​--------optimizer优化
    ​----对比二者的cost,通过baseline.epoch_callback()判断是否更新参数
  • 代码中初始化baseline和训练过程中数据集的大小是不一样的
  • 每经过一个epoch,才判断baseline是否更新(不是batch_size)

实验

AM模型相比于其它深度学习模型,准确率更佳。
对比注意力网络和指针网络,注意力网络表现更好;相同网络中Rollout基线表现更好。

特别感谢:https://zhuanlan.zhihu.com/p/375218972
完整代码:https://github.com/wouterkool/attention-learn-to-route

【论文笔记+代码解读】《ATTENTION, LEARN TO SOLVE ROUTING PROBLEMS!》相关推荐

  1. 【ML4CO基础】Attention, learn to solve routing problems(Wouter Kool, 2018)

    Attention, learn to solve routing problems! Paper: Kool W, Van Hoof H, Welling M. Attention, learn t ...

  2. 文献阅读10期:ATTENTION, LEARN TO SOLVE ROUTING PROBLEMS!

    [ 文献阅读·路径规划 ] ATTENTION, LEARN TO SOLVE ROUTING PROBLEMS! [1] 推荐理由:这篇应该不用多说了,ATTENTION模型做路径规划,算是一篇Mi ...

  3. Exploiting Shared Representations for Personalized Federated Learning 论文笔记+代码解读

    论文地址点这里 一. 介绍 联邦学习中由于各个客户端上数据异构问题,导致全局训练模型无法适应每一个客户端的要求.作者通过利用客户端之间的共同代表来解决这个问题.具体来说,将数据异构的联邦学习问题视为并 ...

  4. Memory-Associated Differential Learning论文及代码解读

    Memory-Associated Differential Learning论文及代码解读 论文来源: 论文PDF: Memory-Associated Differential Learning论 ...

  5. EGNet: Edge Guidance Network for Salient Object Detection 论文及代码解读

    EGNet: Edge Guidance Network for Salient Object Detection 论文及代码解读 注:本文原创作者为Jia-Xing Zhao, Jiang-Jian ...

  6. VGAE(Variational graph auto-encoders)论文及代码解读

    一,论文来源 论文pdf Variational graph auto-encoders 论文代码 github代码 二,论文解读 理论部分参考: Variational Graph Auto-Enc ...

  7. GAN for NLP (论文笔记及解读

    GAN 自从被提出以来,就广受大家的关注,尤其是在计算机视觉领域引起了很大的反响."深度解读:GAN模型及其在2016年度的进展"[1]一文对过去一年GAN的进展做了详细介绍,十分 ...

  8. 一文详解单目VINS论文与代码解读目录

    本文旨在对前一阶段学习vins-mono开源框架的总结.结合暑假秋招之前报名的深蓝学院的<从零开始手写VIO>课程,本文从VIO原理以及开源代码分析两部分进行详细介绍.PS:提升代码能力最 ...

  9. TCN论文及代码解读总结

    前言:传统的时序处理,普遍采用RNN做为基础网络模型,如其变体LSTM.GRU.BPTT等.但是在处理使用LSTM时时序的卷积神经网络 目录 论文及代码链接 一.论文解读 1. 摘要 2.引言(摘) ...

最新文章

  1. 每天五分钟linux(8)-cp
  2. Unicode/not set/multi-byte/部分常用函数
  3. 科大星云诗社动态20201203
  4. JZOJ 1240. Fibonacci sequence
  5. mysql课程设计案例_JAVA中MySQL建立连接
  6. Java运算符优先级和表达式及数据类型转换
  7. 双侧检验的p值和单侧检验_假设检验03----假设检验的步骤
  8. kotlin内联函数_Kotlin内联函数,参数化
  9. thread和threadLocal之间的关系
  10. python+django+vue某小区物业管理系统
  11. c语言二级考试程序设计题怎么运行,2017计算机二级C语言上机考试技巧
  12. 冯诺依曼元胞计算机,冯诺依曼元胞自动机
  13. 文言文代码算什么?跟着九章算术学Python编程才厉害
  14. 2022年固定资产管理系统的概况
  15. 交易系统开发(十二)——QuickFIX官方文档
  16. 免费ICP域名备案查接口
  17. 绑定挂载mount --bind介绍
  18. 电脑发出很大的嗡嗡声_跟踪嗡嗡声的十大方法
  19. 揭秘郭台铭兄弟开店计划 苹果中国渠道裂变
  20. 7. print的应用(3):格式化输出之format模式

热门文章

  1. php输出今天明天后天的代码,js获取日期:前天、昨天、今天、明天、后天
  2. Microsoft JScript 运行时错误: 缺少对象
  3. 二三线城市企业如何构建培养“数字化人才梯队”
  4. 几何画板 国际正版 英文 国际版序列号
  5. 十分钟搞懂基-2 FFT原理及编程思想
  6. 【漫画】TCP断开连接为什么是四次挥手,不是二次挥手/三次挥手?
  7. 小程序云开发上传及使用图片
  8. Spring常见错误 - Bean构造注入报空指针异常
  9. 如何把删除的文件恢复
  10. 3D游戏建模学多久能工作?