keras 受限玻尔兹曼机_受限玻尔兹曼机及实现
# 实现受限玻尔兹曼机,暂仅考虑可视层、隐藏神经元取值均为二进制的情况
import numpyasnp
import os
def sigmoid(z):return 1 / (1 + np.exp(-z))classRBM:
def __init__(self, n_visible, n_hidden):
self.n_visible=n_visible #可见层节点数量
self.n_hidden=n_hidden #隐藏层节点数量
self.bias_a=np.zeros(self.n_visible) #可视层偏移量
self.bias_b=np.zeros(self.n_hidden) #隐藏层偏移量
self.weights= np.random.normal(0, 0.01, size=(self.n_visible, self.n_hidden))
self.n_sample=None
# 编码,即基于v计算h的条件概率:p(h=1|v)
def encode(self, v):return sigmoid(self.bias_b +v @ self.weights)
# 解码(重构):即基于h计算v的条件概率:p(v=1|h)
def decode(self, h):return sigmoid(self.bias_a +h @ self.weights.T)
# gibbs采样, 返回max_cd采样后的v以及h值
def gibbs_sample(self, v0, max_cd):
v=v0for _ inrange(max_cd):
# 首先根据输入样本对每个隐藏层神经元采样。二项分布采样,决定神经元是否激活
ph=self.encode(v)
h= np.random.binomial(1, ph, (self.n_sample, self.n_hidden))
# 根据采样后隐藏层神经元取值对每个可视层神经元采样
pv=self.decode(h)
# print(h)
# print(pv)
# print(max_cd)
# os.system("pause")
v= np.random.binomial(1, pv, (self.n_sample, self.n_visible))returnv
# 根据Gibbs采样得到的可视层取值(解码或重构),更新参数
def update(self, v0, v_cd, eta):
# print(v0)
# os.system("pause")
ph=self.encode(v0)
ph_cd=self.encode(v_cd)
# print(v0.T)
# print(ph)
# os.system("pause")
self.weights+= eta * (v0.T @ ph -v_cd.T @ ph) # 更新连接权重参数
self.bias_b+= eta * np.mean(ph - ph_cd, axis=0) # 更新隐藏层偏移量b
self.bias_a+= eta * np.mean(v0 - v_cd, axis=0) # 更新可视层偏移量areturn
"""训练主函数,采用对比散度算法(CD算法)更新参数
:param data: 训练数据集, (n_sample, n_input)
:param max_step: 最大迭代步数
:param max_cd: 采样步数
:param eta: 学习率
:return:""" def fit(self, data, max_step=100, max_cd=2, eta=0.1):
assert data.shape[1] == self.n_visible, "输入数据维度与可视层神经元数目不相等"self.n_sample= data.shape[0]for i inrange(max_step):
#v_cd 是反采样后的输入样本
v_cd=self.gibbs_sample(data, max_cd)
self.update(data, v_cd, eta)
error= np.sum((data - v_cd) ** 2) / self.n_sample / self.n_visible * 100
if not i % 100: # 将重构后的样本与原始样本对比计算误差
print("可视层状态误差比例:{0}%".format(round(error, 2)))returndef predict(self, v):
# 输入训练数据,预测隐藏层输出
ph= self.encode(v)[0]
states= ph >=np.random.rand(len(ph))return states.astype(int)if __name__ == '__main__':
rbm_model= RBM(n_visible=6, n_hidden=2)
train_data= np.array([[1, 1, 1, 0, 0, 0], [1, 0, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0],
[0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 0, 0], [0, 0, 1, 1, 1, 0]])
rbm_model.fit(train_data, max_step=1000, max_cd=1, eta=0.1)
print(rbm_model.weights, rbm_model.bias_a, rbm_model.bias_b)
user= np.array([[0, 0, 0, 1, 1, 0]])
print(rbm_model.predict(user))"""该数据的含义:
每个样本对应一个用户对6部电影的评分,简化为0(不好看)和1(好看),
6部电影分别属于奥斯卡获奖影片和奇幻影片,对应两个潜在因子,即2个隐藏层神经元,
据此可以判定用户的电影喜好类别。"""
keras 受限玻尔兹曼机_受限玻尔兹曼机及实现相关推荐
- tensorflow玻尔兹曼机_受限玻尔兹曼机(Restricted Boltzmann Machine)
受限玻尔兹曼机(Restricted Boltzmann Machine) 1. 生成模型 2. 参数学习 3. 对比散度学习算法 由于受限玻尔兹曼机的特殊结构,因此可以使用一种比吉布斯采样更有效 的 ...
- keras 受限玻尔兹曼机_受限玻尔兹曼机 代码
备注:这个python代码需要用到psyco包(安装困难),psyco包目前只有python2 32位版本. 在windows 64+python 3环境下,如果下载psyco的源代码安装,比较麻烦. ...
- tensorflow玻尔兹曼机_受限玻尔兹曼机(RBM)与其在Tensorflow的实现
Deep Learning with TensorFlow IBM Cognitive Class ML0120EN Module 4 - Restricted Boltzmann Machine 简 ...
- keras 受限玻尔兹曼机_目前深度学习的模型有哪几种,适用于哪些问题?
深度学习的模型有很多, 目前开发者最常用的深度学习模型与架构包括 CNN.DBN.RNN.RNTN.自动编码器.GAN 等.雷锋网搜集整理了涉及以上话题的精品文章,供初学者参考,加速深度学习新手入门. ...
- keras 受限玻尔兹曼机_深度学习之受限玻尔兹曼机
1.什么是受限玻尔兹曼机 玻尔兹曼机是一大类的神经网络模型,但是在实际应用中使用最多的则是受限玻尔兹曼机(RBM). 受限玻尔兹曼机(RBM)是一个随机神经网络(即当网络的神经元节点被激活时会有随机行 ...
- 玻尔兹曼机和受限玻尔兹曼机
文章目录 Boltzmann Machines起源 Boltzmann Machines的结构 Boltzmann Machines的搜索问题 Boltzmann Machines的学习问题 不含隐藏 ...
- 玻尔兹曼机BM 受限玻尔兹曼机 RBM
玻尔兹曼机是一种基于能量的模型 结构只有2层,浅层和隐层 RBM被认为是受限的,因为没有两个节点,同层共享一个连接 玻尔兹曼机是一种随机神经网络,借鉴了模拟退火思想. 玻尔兹曼机的网络模型与BP神经网 ...
- tensorflow玻尔兹曼机_资源 | 10种深度学习算法的TensorFlow实现
原标题:资源 | 10种深度学习算法的TensorFlow实现 选自 Github 作者:blackecho 机器之心编译 参与:吴攀 这个 repository 是使用 TensorFlow 库实现 ...
- 网络受限_受限人工神经网络对幸福的追求
网络受限 动机 (Motivation) Dust always finds its way to the forlorn places, the long forgotten shelves, th ...
最新文章
- 使用Uboot启动内核并挂载NFS根文件系统
- ElasticSearch之Centos7下安装
- 1720: 交换瓶子
- Ubuntu 16.04下使用Wine安装Xshell 4和Xftp 4
- 绑定控件中%#Eval()%和%=变量%的执行顺序
- gan 总结 数据增强_吴恩达Deeplearning.ai国庆上新:GAN专项课程
- 06-linux下Elasticsearch安装 设置Elasticsearch
- android 多个语音合成,android实现语音合成
- 有关ArrayList增加Map引发的一个BUG
- ARM汇编程序设计之--数据搬移
- maven package 打包报错 Failed to execute goal
- 英文论文发表必备干货!SCI投稿7个阶段经典邮件模板,请拿走
- 【自动化测试selenium】
- RGB与YCbCr颜色空间的转换
- python读取nii文件、nii.gz文件
- Nginx/PHP安装
- system占用cpu过高查找问题思路
- 信息系统项目管理师教程读书笔记(八)
- Java应用程序的运行机制(介绍)
- php生成手机桌面快捷方式,php生成网页桌面快捷方式