pytorch中bilinear的理解
直接参考官方文档,x1和x2是两个输入,A是参数矩阵,如下表达式
但仔细看实现发现这个表达式并不是简单连乘的关系。假设x1(shape是b,n)和x2(shape是b,m)是二维,那么A是个三维tensor(shape是a,n,m)。具体实现时,A先拆成a个(n,m)形状的tensor,x1分别与之矩阵乘后再点乘(公式里的两次乘法),得到了a个(b,m)形状的tensor,然后在axis=1的维度上求和,最终输出成(b,a)的shape,用爱因斯坦表示法就是bn,anm,bm->ba,下面上一下code:
import torch
import torch.nn as nn
import numpy as npl = torch.ones(2,5)
A = torch.ones(3,5,4)
r = torch.ones(2,4)
print(torch.einsum('bn,anm,bm->ba', l, A, r))
print(torch.nn.functional.bilinear(l,r,A))x = torch.ones(2,5)
w = torch.ones(3,5)
print(torch.einsum('ij,kj->ik', x,w))
print(torch.nn.functional.linear(x,w))print('learn nn.Bilinear')
m = nn.Bilinear(5, 4, 3)output = m(l, r)
print(output.size())
arr_output = output.data.cpu().numpy()weight = m.weight.data.cpu().numpy()
bias = m.bias.data.cpu().numpy()
x1 = l.data.cpu().numpy()
x2 = r.data.cpu().numpy()
print(x1.shape, weight.shape, x2.shape, bias.shape)
y = np.zeros((x1.shape[0], weight.shape[0]))
for k in range(weight.shape[0]):buff = np.dot(x1, weight[k])buff = buff * x2buff = np.sum(buff, axis=1)y[:, k] = buff
y += bias
dif = y - arr_output
print(np.mean(np.abs(dif.flatten())))
pytorch中bilinear的理解相关推荐
- pytorch中repeat()函数理解
pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...
- pytorch 中 contiguous() 函数理解
pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...
- Pytorch中的contiguous理解
最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解. 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新 ...
- Pytorch中contiguous()函数理解
引言 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的.换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据. 会改变元数据的操作是: n ...
- Pytorch中dim的理解
dim的定义 dim 表示维度 x = torch.randn(2, 3, 3)print(x) print(x.size()) print(x.dim()) 输出: tensor([[[-1.694 ...
- pytorch中unsqueeze()函数理解
unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度. 在第一个维度(中括号)的每个元素加中括号 0表示在张量最外层加一个中括号变成第一维. 直接看例子: import torch i ...
- pytorch中数组维度的理解
pytorch中数组维度理解与numpy中类似,pytorch中维度用dim表示,numpy中用axis表示 这里主要想说下维度的变化. dim = x ,表示在第x为上进行操作,那个维度会发生变化. ...
- pytorch中的nn.Bilinear
参考:pytorch中的nn.Bilinear的计算原理详解 代码实现 使用numpy实现Bilinear(来自参考资料): print('learn nn.Bilinear') m = nn.Bil ...
- pytorch中网络loss传播和参数更新理解
相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56 ...
最新文章
- Web Components 入门实例教程
- app中传递java数据_Java实现app接口和Socket消息传递(6)servlet映射并返回Json数据
- java连接Redis数据库
- 端口复用和重映射--STM32F103
- python爬虫框架scrapy实例详解_python爬虫框架scrapy实例详解
- c++2010修复不了_汽车凹痕太小修复不了?汽车无痕修复是骗局还是技术不行?...
- iOS开发CAAnimation详解
- redis类型 tp5_tp5配置使用redis笔记!
- Git在dev分支获取master分支最新代码
- 外媒点赞,浪潮存储为何能入围全球最佳主存储供应商
- mybatis mysql连接时区_MySQL时区的查看和设置
- 基于Arduino的智能泡茶机(1)——机械系机械创新比赛总结技术点与不足处
- 杭州电子科技大学全国计算机排名,杭电排名为什么比211还高,杭州电子科技大学是211吗...
- 年终奖变期权,曝字节跳动将开启员工期权兑换
- 企业实现统一身份认证的作用和好处有哪些?(图文并茂)
- Mac os下时间戳转换
- 小学美术计算机教案模板,小学美术教案模板
- 线性代数笔记18——投影矩阵和最小二乘
- 本地连接云服务器mysql数据库出现Access denied的解决方法
- 课时05 Octave教程(Octave Tutorial)
热门文章
- 艺术与工程技术的交叉碰撞
- 艺术对于学计算机来说有用吗,人工智能都能画画了,学艺术还有什么用?
- 一强悍老婆给老公的100条幸福条约
- Notepad++去掉回车
- linux用户行为审计
- 在python中数据的输出用哪个函数名_在Python中,数据的输出用哪个函数名
- 帮我DIY一台10000左右的游戏电脑
- 修复登录接口仿抽奖助手小程序源码-支持商家认证多种开奖方式
- module ‘glm‘ has no attribute ‘vec3‘
- BZOJ5217: [Lydsy2017省队十连测]航海舰队 FFT