1.1 鸢尾花数据集介绍

iris数据集是用来给莺尾花做分类的数据集,每个样本包含了花萼长度、花萼宽度、花瓣长度、花瓣宽度四个特征,我们需要建立一个分类器,该分类器可通过样本的四个特征来来判断样本属于山鸢尾(Setosa)、变色鸢尾(Versicolour)还是维吉尼亚鸢尾(Virginica)中的哪一个,选择神经网络进行分类。

1.2 思路流程

  • 导入鸢尾花数据集
  • 对数据集进行切分,分为训练集和测试集
  • 搭建网络模型
  • 训练网络
  • 将所训练出的模型进行保存(准确率大于90%)

1.3 网络模型

采用sigmoid等函数,算激活函数时(指数运算),计算量大,反向传播求误差梯度时,求导涉及除法,计算量相对大,而采用Relu激活函数,整个过程的计算量节省很多,故采用Relu作为激活函数

1.4 实现代码

导入所需要的的模块

import torch
import torch.nn as nn
from sklearn import datasets
from sklearn.model_selection import train_test_split

神经网络类

class Net(nn.Module):def __init__(self,in_num,out_num,hid_num):super(Net,self).__init__()self.network = nn.Sequential(nn.Linear(in_num,hid_num),nn.ReLU(),nn.Linear(hid_num,out_num))self.optimizer = torch.optim.SGD(self.parameters(), lr=0.05)self.loss_func = torch.nn.CrossEntropyLoss()def forward(self,x):return self.network(x)def train(self,x,y):out = self.forward(x)loss = self.loss_func(out,y)self.optimizer.zero_grad()loss.backward()self.optimizer.step()print('loss = %.4f' % loss.item())def test(self,x):return self.forward()

引入数据集,并按照8:2切分训练集和测试集

dataset = datasets.load_iris()
input = torch.FloatTensor(dataset['data'])
label = torch.LongTensor(dataset['target'])
x_train, x_test, y_train, y_test = train_test_split(input, label, test_size=0.2)

如果存在已有训练好的网络则导入,并在总体数据集上测试其准确性

try:print("iris_model exist and have been loaded")mynet = torch.load('iris_model.pkl')output = mynet(input)pred_y = torch.max(output, 1)[1].numpy()sum = 0for i in range(len(label)):if pred_y[i] == label[i]:sum = sum + 1accuracy = float(sum / len(label))print('model accuracy = %d%% (testing on the whole dataset)' % (accuracy * 100))

若不存在训练好的网络则进行训练,直到准确性大于90%后将其保存

except:mynet = Net(4,10,3)accuracy = 0.0while accuracy < 0.9:for i in range (10000):mynet.train(x_train,y_train)output = mynet(x_test)pred_y = torch.max(output, 1)[1].numpy()sum=0for i in range(len(y_test)):if pred_y[i] == y_test[i]:sum=sum+1accuracy = float(sum / len(y_test))torch.save(mynet, 'iris_model.pkl')print(mynet)print("The net have been saved")print('accuracy = %d%%' % (accuracy*100))

鸢尾花识别完整代码

import torch
import torch.nn as nn
from sklearn import datasets
from sklearn.model_selection import train_test_split
class Net(nn.Module):def __init__(self,in_num,out_num,hid_num):super(Net,self).__init__()self.network = nn.Sequential(nn.Linear(in_num,hid_num),nn.ReLU(),nn.Linear(hid_num,out_num))self.optimizer = torch.optim.SGD(self.parameters(), lr=0.05)self.loss_func = torch.nn.CrossEntropyLoss()def forward(self,x):return self.network(x)def train(self,x,y):out = self.forward(x)loss = self.loss_func(out,y)self.optimizer.zero_grad()loss.backward()self.optimizer.step()print('loss = %.4f' % loss.item())def test(self,x):return self.forward()if __name__ == '__main__':dataset = datasets.load_iris()input = torch.FloatTensor(dataset['data'])label = torch.LongTensor(dataset['target'])x_train, x_test, y_train, y_test = train_test_split(input, label, test_size=0.2)try:print("iris_model exist and have been loaded")mynet = torch.load('iris_model.pkl')output = mynet(input)pred_y = torch.max(output, 1)[1].numpy()sum = 0for i in range(len(label)):if pred_y[i] == label[i]:sum = sum + 1accuracy = float(sum / len(label))print('model accuracy = %d%% (testing on the whole dataset)' % (accuracy * 100))except:mynet = Net(4,10,3)accuracy = 0.0while accuracy < 0.9:for i in range (10000):mynet.train(x_train,y_train)output = mynet(x_test)pred_y = torch.max(output, 1)[1].numpy()sum=0for i in range(len(y_test)):if pred_y[i] == y_test[i]:sum=sum+1accuracy = float(sum / len(y_test))torch.save(mynet, 'iris_model.pkl')print(mynet)print("The net have been saved")print('accuracy = %d%%' % (accuracy*100))

github文件链接

鸢尾花数据集分类--神经网络相关推荐

  1. 一层神经网络实现鸢尾花数据集分类

    一层神经网络实现鸢尾花数据集分类 1.数据集介绍 2.程序实现 2.1 数据集导入 2.2 数据集乱序 2.3 数据集划分成永不相见的训练集和测试集 3.4 配成[输入特征,标签]对,每次喂入一小撮( ...

  2. 利用神经网络对鸢尾花数据集分类

    利用神经网络对鸢尾花数据集分类 详细实现代码请见:https://download.csdn.net/download/weixin_43521269/12578696 一.简介 一个人工神经元网络是 ...

  3. (决策树,朴素贝叶斯,人工神经网络)实现鸢尾花数据集分类

    from sklearn.datasets import load_iris # 导入方法类iris = load_iris() #导入数据集iris iris_feature = iris.data ...

  4. Python实现鸢尾花数据集分类问题——基于skearn的SVM(有详细注释的)

    Python实现鸢尾花数据集分类问题--基于skearn的SVM 代码如下: 1 # !/usr/bin/env python2 # encoding: utf-83 __author__ = 'Xi ...

  5. 基于Adaboost实现鸢尾花数据集分类

    写在之前 提交内容分为两大部分: 一为Adaboost算法实现,代码在文件夹<算法实现>中,<提升方法笔记>为个人学习笔记. 二为基于Adaboost模型实现鸢尾花数据集分类, ...

  6. Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression

    Python实现鸢尾花数据集分类问题--基于skearn的LogisticRegression 一. 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题, ...

  7. 用逻辑回归实现鸢尾花数据集分类(1)

    鸢尾花数据集的分类问题指导 -- 对数几率回归(逻辑回归)问题研究 (1) 这一篇Notebook是应用对数几率回归(Logit Regression)对鸢尾花数据集进行品种分类的.首先会带大家探索一 ...

  8. 实验一:鸢尾花数据集分类

    实验一:鸢尾花数据集分类 一.问题描述 利用机器学习算法构建模型,根据鸢尾花的花萼和花瓣大小,区分鸢尾花的品种.实现一个基础的三分类问题. 二.数据集分析 Iris 鸢尾花数据集内包含 3 种类别,分 ...

  9. orange实现逻辑回归_分别用逻辑回归和决策树实现鸢尾花数据集分类

    学习了决策树和逻辑回归的理论知识,决定亲自上手尝试一下.最终导出决策树的决策过程的图片和pdf.逻辑回归部分参考的是用逻辑回归实现鸢尾花数据集分类,感谢原作者xiaoyangerr 注意:要导出为pd ...

  10. 【机器学习】决策树案例二:利用决策树进行鸢尾花数据集分类预测

    利用决策树进行鸢尾花数据集分类预测 2 利用决策树进行鸢尾花数据集分类预测 2.1 导入模块与加载数据 2.2 划分数据 2.3 模型创建与应用 2.4 模型可视化 手动反爬虫,禁止转载: 原博地址 ...

最新文章

  1. python模块学习---cmd
  2. ubuntu 命令整合1
  3. 【转载】PHP 常用的header头部定义汇总
  4. pku acm 2248 addtion chians 解题报告
  5. boost::function模块实现分配器的测试程序
  6. RangeAssignor(范围分区)
  7. android get width单位是什么意思,浅析Android中getWidth()和getMeasuredWidth()的区别
  8. [绝对原创]从VS2003(.net1.1)升级到vs2005(.net2.0)全程跟踪记录
  9. Hadoop实战经验之HDFS故障排除-尚硅谷大数据培训
  10. OpenCV 安装配置 Jupyter Notebook
  11. 邮递员问题java实现_中国邮递员问题算法.PPT
  12. 基于SSM毕业生就业管理系统
  13. web视频(点播/直播)播放器选型
  14. vue-element-admin基础学习
  15. mac下面如何修改只读文件(Read-only file system)
  16. Javascript显示隐藏DIV
  17. matlab 524288,Cannot display summaries of variables with more than 524288 elements. 怎么...
  18. 史上最强ASR非特定人声语音识别模块,完爆LD3320
  19. TCP选项之SACK选项概述
  20. (PHP)程序中如何判断当前用户终端是手机等移动终端

热门文章

  1. 音视频开发系列(10):基于qt的音频推流
  2. 9种常用的数据分析方法
  3. 机器学习实战教程(13篇)
  4. android studio 融云,融云 SDK 是否支持 AndroidX
  5. MongoVUE的基本使用
  6. AD9如何设置原点位置
  7. 使用cmd命令删除服务
  8. YYKit-YYCache分析
  9. 【Altium Designer】如何导出gerber文件
  10. java基于ssm的农产品网上销售系统