直接上代码:

import torch, random
import torch.nn as nn
import torch.optim as optimtorch.manual_seed(42)class RBFN(nn.Module):"""以高斯核作为径向基函数"""def __init__(self, centers, n_out=3):""":param centers: shape=[center_num,data_dim]:param n_out:"""super(RBFN, self).__init__()self.n_out = n_outself.num_centers = centers.size(0) # 隐层节点的个数self.dim_centure = centers.size(1) # self.centers = nn.Parameter(centers)# self.beta = nn.Parameter(torch.ones(1, self.num_centers), requires_grad=True)self.beta = torch.ones(1, self.num_centers)*10# 对线性层的输入节点数目进行了修改self.linear = nn.Linear(self.num_centers+self.dim_centure, self.n_out, bias=True)self.initialize_weights()# 创建对象时自动执行def kernel_fun(self, batches):n_input = batches.size(0)  # number of inputsA = self.centers.view(self.num_centers, -1).repeat(n_input, 1, 1)B = batches.view(n_input, -1).unsqueeze(1).repeat(1, self.num_centers, 1)C = torch.exp(-self.beta.mul((A - B).pow(2).sum(2, keepdim=False)))return Cdef forward(self, batches):radial_val = self.kernel_fun(batches)class_score = self.linear(torch.cat([batches, radial_val], dim=1))return class_scoredef initialize_weights(self, ):"""网络权重初始化:return:"""for m in self.modules():if isinstance(m, nn.Conv2d):m.weight.data.normal_(0, 0.02)m.bias.data.zero_()elif isinstance(m, nn.ConvTranspose2d):m.weight.data.normal_(0, 0.02)m.bias.data.zero_()elif isinstance(m, nn.Linear):m.weight.data.normal_(0, 0.02)m.bias.data.zero_()def print_network(self):num_params = 0for param in self.parameters():num_params += param.numel()print(self)print('Total number of parameters: %d' % num_params)# centers = torch.rand((5,8))
# rbf_net = RBFN(centers)
# rbf_net.print_network()
# rbf_net.initialize_weights()if __name__ =="__main__":data = torch.tensor([[0.25, 0.75], [0.75,0.75], [0.25,0.5], [0.5,0.5],[0.75,0.5],[0.25,0.25],[0.75,0.25],[0.5,0.125],[0.75,0.125]], dtype=torch.float32)label = torch.tensor([[-1,1,-1],[1,-1,-1],[-1,-1,1],[-1,-1,1],[-1,-1,1],[1,-1,-1],[-1,1,-1],[-1,1,-1],[1,-1,-1]], dtype=torch.float32)print(data.size())centers = data[0:8,:]rbf = RBFN(centers,3)params = rbf.parameters()loss_fn = torch.nn.MSELoss()optimizer = torch.optim.SGD(params,lr=0.1,momentum=0.9)for i in range(10000):optimizer.zero_grad()y = rbf.forward(data)loss = loss_fn(y,label)loss.backward()optimizer.step()print(i,"\t",loss.data)# 加载使用y = rbf.forward(data)print(y.data)print(label.data)

说明:代码在https://goodgoodstudy.blog.csdn.net/article/details/105756137上进行了小修改(原代码应该是错的),并加了一个自己的实验。

pytorch 实现RBF网络相关推荐

  1. pytorch空间变换网络

    pytorch空间变换网络 本文将学习如何使用称为空间变换器网络的视觉注意机制来扩充网络.可以在DeepMind paper 有关空间变换器网络的内容. 空间变换器网络是对任何空间变换的差异化关注的概 ...

  2. 多层感知机MLP、RBF网络、Hopfield网络、自组织映射神经网络、神经网络算法地图

    多层感知机MLP.RBF网络.Hopfield网络.自组织映射神经网络.神经网络算法地图 目录

  3. RBF网络——核心思想:把向量从低维m映射到高维P,低维线性不可分的情况到高维就线性可分了...

    RBF网络能够逼近任意的非线性函数,可以处理系统内的难以解析的规律性,具有良好的泛化能力,并有很快的学习收敛速度,已成功应用于非线性函数逼近.时间序列分析.数据分类.模式识别.信息处理.图像处理.系统 ...

  4. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  5. pytorch贝叶斯网络_贝叶斯神经网络:2个在TensorFlow和Pytorch中完全连接

    pytorch贝叶斯网络 贝叶斯神经网络 (Bayesian Neural Net) This chapter continues the series on Bayesian deep learni ...

  6. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  7. 实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)

    实战:使用Pytorch搭建分类网络(肺结节假阳性剔除) 阅前可看: 实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割 其中的脚本资源getMat.py文件是对肺结节进行切割. ...

  8. (《机器学习》完整版系列)第5章 神经网络——5.2 RBF网络(单层RBF就可解决异或问题)与ART网络(实现“自适应谐振”)

    单层RBF神经网络就可解决异或问题. ART网络已发展出一个了一个算法族,需要理解它是如何实现"自适应谐振"的. RBF网络 径向基函数网络RBF如图5.3所示,此图为缩略图,即一 ...

  9. matlab建立rbf网络,大家看一下,这组数据Matlab如何构建RBF神经网络!!!!

    大家看一下,这组数据如何构建RBF神经网络!!!! 大家看一下如何编写RBF网络,前4例是输入,后2列是输出.共有19组数据,谁会呀!!谢谢. m_data=[0.000000 -96.688193 ...

最新文章

  1. linux命令grep如何使用,Linux下如何使用grep搜索文本
  2. .NET简谈网络系统大局观
  3. sqlserver清除缓存(转载)
  4. JMetro 5.5版发布
  5. linux gpio设备驱动程序,嵌入式Linux设备驱动开发之:GPIO驱动程序实例-嵌入式系统-与非网...
  6. 图像处理与图像识别笔记(三)图像增强1
  7. 修改itunes备份路径的方法(奇奇怪怪的文件堆积C盘,别让文件成为最后的稻草哦)
  8. Application应用框架思考(三) 之[插件机制]
  9. idea非活动变更列表中的文件被修改,IntellJ IDEA中的更改列表是什么?一个变化列表比较什么?寻求准确的解释...
  10. python pandas中文手册-Pandas速查手册中文版(转)
  11. Microsoft Word 教程「3」,如何在 Word 中创建项目符号列表、显示字数统计?
  12. win10使用VMware Workstations安装CentOS
  13. 基于解决sci和ei等外文思维顺序问题的辅助软件分析
  14. Flash常见问题与解答
  15. android遥控器适配
  16. 扫福活动开始,你的公众号图文排版也要“福”气满满
  17. seurat使用笔记(数据处理、PCA、聚类)
  18. 直播系统定制开发中安卓直播间websocket协议破解还原
  19. sqlserver中window身份验证跟sql server身份验证的区别
  20. 大小端交换的程序_数据库在小程序云开发中的应用

热门文章

  1. oracle一个表拆成多个表,oracle拆分函数,将字符串拆分成多行多字段表数据
  2. android sim卡命令,Android常用命令
  3. android开发 解析 b5,张绍文android开发高手课读书笔记4-启动优化篇
  4. java删除sql表中记录,您将如何维护SQL表中的历史记录?
  5. win10系统崩溃怎么修复_新手怎么重装系统win10
  6. python 交集_Python设置交集
  7. python pow_Python pow()
  8. numpy矩阵乘法_NumPy矩阵乘法
  9. mockito 静态方法_Mockito模拟静态方法– PowerMock
  10. UIView的setNeedsLayout, layoutIfNeeded 和 layoutSubviews