直接参考官方文档,x1和x2是两个输入,A是参数矩阵,如下表达式

但仔细看实现发现这个表达式并不是简单连乘的关系。假设x1(shape是b,n)和x2(shape是b,m)是二维,那么A是个三维tensor(shape是a,n,m)。具体实现时,A先拆成a个(n,m)形状的tensor,x1分别与之矩阵乘后再点乘(公式里的两次乘法),得到了a个(b,m)形状的tensor,然后在axis=1的维度上求和,最终输出成(b,a)的shape,用爱因斯坦表示法就是bn,anm,bm->ba,下面上一下code:

import torch
import torch.nn as nn
import numpy as npl = torch.ones(2,5)
A = torch.ones(3,5,4)
r = torch.ones(2,4)
print(torch.einsum('bn,anm,bm->ba', l, A, r))
print(torch.nn.functional.bilinear(l,r,A))x = torch.ones(2,5)
w = torch.ones(3,5)
print(torch.einsum('ij,kj->ik', x,w))
print(torch.nn.functional.linear(x,w))print('learn nn.Bilinear')
m = nn.Bilinear(5, 4, 3)output = m(l, r)
print(output.size())
arr_output = output.data.cpu().numpy()weight = m.weight.data.cpu().numpy()
bias = m.bias.data.cpu().numpy()
x1 = l.data.cpu().numpy()
x2 = r.data.cpu().numpy()
print(x1.shape, weight.shape, x2.shape, bias.shape)
y = np.zeros((x1.shape[0], weight.shape[0]))
for k in range(weight.shape[0]):buff = np.dot(x1, weight[k])buff = buff * x2buff = np.sum(buff, axis=1)y[:, k] = buff
y += bias
dif = y - arr_output
print(np.mean(np.abs(dif.flatten())))

pytorch中bilinear的理解相关推荐

  1. pytorch中repeat()函数理解

    pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...

  2. pytorch 中 contiguous() 函数理解

    pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...

  3. Pytorch中的contiguous理解

    最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解. 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新 ...

  4. Pytorch中contiguous()函数理解

    引言 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的.换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据. 会改变元数据的操作是: n ...

  5. Pytorch中dim的理解

    dim的定义 dim 表示维度 x = torch.randn(2, 3, 3)print(x) print(x.size()) print(x.dim()) 输出: tensor([[[-1.694 ...

  6. pytorch中unsqueeze()函数理解

    unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度. 在第一个维度(中括号)的每个元素加中括号 0表示在张量最外层加一个中括号变成第一维. 直接看例子: import torch i ...

  7. pytorch中数组维度的理解

    pytorch中数组维度理解与numpy中类似,pytorch中维度用dim表示,numpy中用axis表示 这里主要想说下维度的变化. dim = x ,表示在第x为上进行操作,那个维度会发生变化. ...

  8. pytorch中的nn.Bilinear

    参考:pytorch中的nn.Bilinear的计算原理详解 代码实现 使用numpy实现Bilinear(来自参考资料): print('learn nn.Bilinear') m = nn.Bil ...

  9. pytorch中网络loss传播和参数更新理解

    相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56 ...

最新文章

  1. Web Components 入门实例教程
  2. app中传递java数据_Java实现app接口和Socket消息传递(6)servlet映射并返回Json数据
  3. java连接Redis数据库
  4. 端口复用和重映射--STM32F103
  5. python爬虫框架scrapy实例详解_python爬虫框架scrapy实例详解
  6. c++2010修复不了_汽车凹痕太小修复不了?汽车无痕修复是骗局还是技术不行?...
  7. iOS开发CAAnimation详解
  8. redis类型 tp5_tp5配置使用redis笔记!
  9. Git在dev分支获取master分支最新代码
  10. 外媒点赞,浪潮存储为何能入围全球最佳主存储供应商
  11. mybatis mysql连接时区_MySQL时区的查看和设置
  12. 基于Arduino的智能泡茶机(1)——机械系机械创新比赛总结技术点与不足处
  13. 杭州电子科技大学全国计算机排名,杭电排名为什么比211还高,杭州电子科技大学是211吗...
  14. 年终奖变期权,曝字节跳动将开启员工期权兑换
  15. 企业实现统一身份认证的作用和好处有哪些?(图文并茂)
  16. Mac os下时间戳转换
  17. 小学美术计算机教案模板,小学美术教案模板
  18. 线性代数笔记18——投影矩阵和最小二乘
  19. 本地连接云服务器mysql数据库出现Access denied的解决方法
  20. 课时05 Octave教程(Octave Tutorial)

热门文章

  1. 艺术与工程技术的交叉碰撞
  2. 艺术对于学计算机来说有用吗,人工智能都能画画了,学艺术还有什么用?
  3. 一强悍老婆给老公的100条幸福条约
  4. Notepad++去掉回车
  5. linux用户行为审计
  6. 在python中数据的输出用哪个函数名_在Python中,数据的输出用哪个函数名
  7. 帮我DIY一台10000左右的游戏电脑
  8. 修复登录接口仿抽奖助手小程序源码-支持商家认证多种开奖方式
  9. module ‘glm‘ has no attribute ‘vec3‘
  10. BZOJ5217: [Lydsy2017省队十连测]航海舰队 FFT