Keras vs PyTorch vs Caffe:CNN实现对比
作者|PRUDHVI VARMA 编译|VK 来源|Analytics Indiamag
在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署。这些深度学习框架提供了高级编程接口,帮助我们设计深度学习模型。使用深度学习框架,它通过提供内置的库函数来减少开发人员的工作,从而使我们能够更快更容易地构建模型。
在本文中,我们将构建相同的深度学习框架,即在Keras、PyTorch和Caffe中对同一数据集进行卷积神经网络图像分类,并对所有这些方法的实现进行比较。最后,我们将看到PyTorch构建的CNN模型如何优于内置Keras和Caffe的同行。
本文涉及的主题
如何选择深度学习框架。
Keras的优缺点
PyTorch的优缺点
Caffe的优缺点
在Keras、PyTorch和Caffe实现CNN模型。
选择深度学习框架
在选择深度学习框架时,有一些指标可以找到最好的框架,它应该提供并行计算、良好的运行模型的接口、大量内置的包,它应该优化性能,同时也要考虑我们的业务问题和灵活性,这些是我们在选择深度学习框架之前要考虑的基本问题。让我们比较三个最常用的深度学习框架Keras、Pytorch和Caffe。
Keras
Keras是一个开源框架,由Google工程师Francois Chollet开发,它是一个深度学习框架,我们只需编写几行代码,就可以轻松地使用和评估我们的模型。
如果你不熟悉深度学习,Keras是初学者最好的入门框架,Keras对初学者十分友好,并且易于与python一起工作,并且它有许多预训练模型(VGG、Inception等)。不仅易于学习,而且它支持Tensorflow作为后端。
使用Keras的局限性
Keras需要改进一些特性
我们需要牺牲速度来换取它的用户友好性
有时甚至使用gpu也需要很长时间。
使用Keras框架的实际实现
在下面的代码片段中,我们将导入所需的库。
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
超参数:
batch_size = 128
num_classes = 10
epochs = 12
img_rows, img_cols = 28, 28
(x_train, y_train), (x_test, y_test) = mnist.load_data()
在下面的代码片段中,我们将构建一个深度学习模型,其中包含几个层,并分配优化器、激活函数和损失函数。
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adam(),metrics=['accuracy'])
在下面的代码片段中,我们将训练和评估模型。
model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
PyTorch
PyTorch是一个由Facebook研究团队开发的开源框架,它是深度学习模型的一种实现,它提供了python环境提供的所有服务和功能,它允许自动微分,有助于加速反向传播过程,PyTorch提供了各种模块,如torchvision,torchaudio,torchtext,可以灵活地在NLP中工作,计算机视觉。PyTorch对于研究人员比开发人员更灵活。
PyTorch的局限性
PyTorch在研究人员中比在开发人员中更受欢迎。
它缺乏生产力。
使用PyTorch框架实现
安装所需的库
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data.dataloader as dataloader
import torch.optim as optim
from torch.utils.data import TensorDataset
from torchvision import transforms
from torchvision.datasets import MNIST
在下面的代码片段中,我们将加载数据集并将其拆分为训练集和测试集。
train = MNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),
]), )
test = MNIST('./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),
]), )
dataloader_args = dict(shuffle=True, batch_size=64,num_workers=1, pin_memory=True)
train_loader = dataloader.DataLoader(train, **dataloader_args)
test_loader = dataloader.DataLoader(test, **dataloader_args)
train_data = train.train_data
train_data = train.transform(train_data.numpy())
在下面的代码片段中,我们将构建我们的模型,并设置激活函数和优化器。
class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.fc1 = nn.Linear(784, 548)self.bc1 = nn.BatchNorm1d(548) self.fc2 = nn.Linear(548, 252)self.bc2 = nn.BatchNorm1d(252)self.fc3 = nn.Linear(252, 10) def forward(self, x):a = x.view((-1, 784))b = self.fc1(a)b = self.bc1(b)b = F.relu(b)b = F.dropout(b, p=0.5) b = self.fc2(b)b = self.bc2(b)b = F.relu(b)b = F.dropout(b, p=0.2)b = self.fc3(b)out = F.log_softmax(b)return out
model = Model()
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.001)
在下面的代码片段中,我们将训练我们的模型,在训练时,我们将指定损失函数,即交叉熵。
model.train()
losses = []
for epoch in range(12):for batch_idx, (data,data_1) in enumerate(train_loader):data,data_1 = Variable(data.cuda()), Variable(target.cuda())optimizer.zero_grad()y_pred = model(data) loss = F.cross_entropy(y_pred, target)losses.append(loss.data[0])loss.backward()optimizer.step()if batch_idx % 100 == 1:print('\r Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),loss.data[0]), end='') print()
#评估模型evaluate=Variable(test_loader.dataset.test_data.type_as(torch.FloatTensor())).cuda()
output = model(evaluate)
predict = output.data.max(1)[1]
pred = pred.eq(evaluate.data)
accuracy = pred.sum()/pred.size()[0]
print('Accuracy:', accuracy)
Caffe
Caffe(Convolutional Architecture for Fast Feature Embedding)是Yangqing Jia开发的开源深度学习框架。该框架支持人工智能领域的研究人员和工业应用。
大部分开发者使用Caffe是因为它的速度,它使用一个NVIDIA K40 GPU每天可以处理6000万张图像。Caffe有很多贡献者来更新和维护框架,而且与深度学习的其他领域相比,Caffe在计算机视觉模型方面工作得很好。
Caffe的局限性
Caffe没有更高级别的API,所以很难做实验。
在Caffe中,为了部署我们的模型,我们需要编译源代码。
安装Caffe
!apt install -y caffe-tools-cpu
导入所需的库
import os
import numpy as np
import math
import caffe
import lmdb
在下面的代码片段中,我们将指定硬件环境。
os.environ["GLOG_minloglevel"] = '2'
CAFFE_ROOT="/caffe"
os.chdir(CAFFE_ROOT)
USE_GPU = True
if USE_GPU:caffe.set_device(0)caffe.set_mode_gpu()
else:caffe.set_mode_cpu()
caffe.set_random_seed(1)
np.random.seed(24)
在下面的代码片段中,我们将定义有助于数据转换的image_generator和batch_generator 。
def image_generator(db_path):db_handle = lmdb.open(db_path, readonly=True) with db_handle.begin() as db:cur = db.cursor() for _, value in cur: datum = caffe.proto.caffe_pb2.Datum()datum.ParseFromString(value) int_x = caffe.io.datum_to_array(datum) x = np.asfarray(int_x, dtype=np.float32) tyield x - 128 def batch_generator(shape, db_path):gen = image_generator(db_path)res = np.zeros(shape) while True: for i in range(shape[0]):res[i] = next(gen) yield res
在下面的代码片段中,我们将给出MNIST数据集的路径。
num_epochs = 0
iter_num = 0
db_path = "content/mnist/mnist_train_lmdb"
db_path_test = "content/mnist/mnist_test_lmdb"
base_lr = 0.01
gamma = 1e-4
power = 0.75for epoch in range(num_epochs):print("Starting epoch {}".format(epoch))input_shape = net.blobs["data"].data.shapefor batch in batch_generator(input_shape, db_path):iter_num += 1net.blobs["data"].data[...] = batchnet.forward()for name, l in zip(net._layer_names, net.layers):for b in l.blobs:b.diff[...] = net.blob_loss_weights[name]net.backward()learning_rate = base_lr * math.pow(1 + gamma * iter_num, - power)for l in net.layers:for b in l.blobs:b.data[...] -= learning_rate * b.diffif iter_num % 50 == 0:print("Iter {}: loss={}".format(iter_num, net.blobs["loss"].data))if iter_num % 200 == 0:print("Testing network: accuracy={}, loss={}".format(*test_network(test_net, db_path_test)))
使用下面的代码片段,我们将获得最终的准确性。
print("Training finished after {} iterations".format(iter_num))
print("Final performance: accuracy={}, loss={}".format(*test_network(test_net, db_path_test)))
结论
在本文中,我们演示了使用三个著名框架:Keras、PyTorch和Caffe实现CNN图像分类模型的。我们可以看到,PyTorch开发的CNN模型在精确度和速度方面都优于在Keras和Caffe开发的CNN模型。
作为一个初学者,我一开始使用Keras,这对于初学者是一个非常简单的框架,但它的应用是有限的。但是PyTorch和Caffe在速度、优化和并行计算方面是非常强大的框架。
原文链接:https://analyticsindiamag.com/keras-vs-pytorch-vs-caffe-comparing-the-implementation-of-cnn/
欢迎关注磐创AI博客站: http://panchuang.net/
sklearn机器学习中文官方文档: http://sklearn123.com/
欢迎关注磐创博客资源汇总站: http://docs.panchuang.net/
Keras vs PyTorch vs Caffe:CNN实现对比相关推荐
- 【深度学习】Keras vs PyTorch vs Caffe:CNN实现对比
作者 | PRUDHVI VARMA 编译 | VK 来源 | Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度 ...
- 深度学习框架对决篇:Keras VS PyTorch
来源:机器之心 参与:杜伟.一鸣 Keras和PyTorch之争由来已久.一年前,机器之心就曾做过此方面的探讨:<Keras vs PyTorch:谁是「第一」深度学习框架?>.现在PyT ...
- TensorFlow和Caffe、MXNet、Keras等其他深度学习框架的对比
2019独角兽企业重金招聘Python工程师标准>>> TensorFlow和Caffe.MXNet.Keras等其他深度学习框架的对比 博客分类: 深度学习 Google 近日发布 ...
- PyTorch、TensorFlow最新版本对比,2021年了你选谁?
选自towards data science 作者:Mostafa Ibrahim 机器之心编译 编辑:陈萍 PyTorch(1.8)和Tensorflow(2.5)最新版本比较. 自深度学习重新获得 ...
- Keras vs PyTorch:谁是第一深度学习框架?
「第一个深度学习框架该怎么选」对于初学者而言一直是个头疼的问题.本文中,来自 deepsense.ai 的研究员给出了他们在高级框架上的答案.在 Keras 与 PyTorch 的对比中,作者还给出了 ...
- Keras vs PyTorch,哪一个更适合做深度学习?
选自Medium 作者:Karan Jakhar 机器之心编译 参与:小舟.魔王 如何选择工具对深度学习初学者是个难题.本文作者以 Keras 和 Pytorch 库为例,提供了解决该问题的思路. 当 ...
- 【前沿】何恺明大神ICCV2017最佳论文Mask R-CNN的Keras/TensorFlow/Pytorch 代码实现
我们提出了一个概念上简单.灵活和通用的用于目标实例分割(object instance segmentation)的框架.我们的方法能够有效地检测图像中的目标,同时还能为每个实例生成一个高质量的分割掩 ...
- 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)
初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...
- python cnn_使用python中pytorch库实现cnn对mnist的识别
使用python中pytorch库实现cnn对mnist的识别 1 环境:Anaconda3 64bit https://www.anaconda.com/download/ 2 环境:pycharm ...
最新文章
- 查看显卡显存_3d渲染需要多大显存比较合适?显存在渲染中的作用
- vue下拉框值改变_vue select下拉框绑定值不跟着变问题
- Java多线程和并发(三),Thread类和Runnable接口
- 第一章MCS-51单片机结构,单片机原理、接口及应用
- 【LINUX/UNIX网络编程】之使用消息队列,信号量和命名管道实现的多进程服务器(多人群聊系统)...
- leetcode147 对链表进行插入排序
- c语言 bool_程序的数据要放到哪里呢?|C语言第二篇
- 消除ie上的:为了有利于保护安全性,IE已限制此网页运行可以访问计算机的脚本或 ActiveX 控件...
- 图灵奖得主华人高徒发布首款AI芯片!64位RISC-V、高度可编程,低功耗
- CentOS 7下安装集群Zookeeper-3.4.9
- 标准库Allocator的简易实现(二)
- 三相滤波器怎么接线_三相电源滤波器作用 详解三相电源滤波器
- 对潇潇暮雨洒江天,一番洗清秋。渐霜风凄紧,关河冷落,残照当楼。是处红衰翠减,苒苒物华休。唯有长江水,无语东流。不忍登高临远,望故乡渺邈,归思难收。叹年来踪迹,何事苦淹留?想佳人,妆楼颙望,误几回、天际
- APP离线后,通过SystemClock.elapsedRealtime()校正时间
- 利用js快速完成大学生新生安全教育课程
- 微信小程录制视频上传服务器,微信小程序-从相册获取图片,视频使用相机拍照,录像上传+服务器nodejs版接收-微信小程序视频上传功能-微信小程序视频上传...
- 论文公式编辑比较麻烦,试试截图快速识别并编辑公式!
- (附源码)springboot中北创新创业官网 毕业设计 271443
- Spring之JDBC
- 恶意代码机理与防护笔记
热门文章
- 数据结构(二叉树相关、满、完全二叉树、霍夫曼树、排序方法及时间复杂度总结、)笔记-day11
- matplotlib绘制3D图形时使x轴、y轴、z轴的比例相等
- QT中级(6)基于QT的文件传输工具(2)
- 「小程序」——————————swiper引入图片不显示解决方案
- STC15系列读取DS18B20温度传感器串口显示代码
- Oracle 释放表高水位线(HWM)的五种方法
- 时钟系统:CPU为啥需要时钟;此时钟非彼时钟,时钟到底是啥玩意
- copyproperties(copyproperties用法)
- 吉他指弹入门——贝斯(walking bass)
- 偷服务器的空调维修工人,偷学!8个空调维修工不愿公开的技巧