目录

01 前情摘要

1.1 导包

1.2 特征提取以及数据集的建立

02 建立模型

2.1 深度学习框架

2.1.1 网络结构搭建

2.1.2 搭建CNN网络

2.1.3 CNN基础知识

03 CNN模型训练与测试

3.1 模型训练

3.2 预测测试集


01 前情摘要

前面讲解了音频数据的分析以及特征提取等内容,本次任务主要是讲解CNN模型的搭建与训练,由于模型训练需要用到之前的特侦提取等得让,于是在此再贴一下相关代码。

1.1 导包

#基本库
import pandas as pd
import numpy as np
pd.plotting.register_matplotlib_converters()
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler#深度学习框架
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout
from tensorflow.keras.utils import to_categorical
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import tensorflow as tf
import tensorflow.keras#音频处理库
import os
import librosa
import librosa.display
import glob 

1.2 特征提取以及数据集的建立

feature = []
label = []
# 建立类别标签,不同类别对应不同的数字。
label_dict = {'aloe': 0, 'burger': 1, 'cabbage': 2,'candied_fruits':3, 'carrots': 4, 'chips':5,'chocolate': 6, 'drinks': 7, 'fries': 8, 'grapes': 9, 'gummies': 10, 'ice-cream':11,'jelly': 12, 'noodles': 13, 'pickles': 14, 'pizza': 15, 'ribs': 16, 'salmon':17,'soup': 18, 'wings': 19}
label_dict_inv = {v:k for k,v in label_dict.items()}

建立提取音频特征的函数

from tqdm import tqdm
def extract_features(parent_dir, sub_dirs, max_file=10, file_ext="*.wav"):c = 0label, feature = [], []for sub_dir in sub_dirs:for fn in tqdm(glob.glob(os.path.join(parent_dir, sub_dir, file_ext))[:max_file]): # 遍历数据集的所有文件# segment_log_specgrams, segment_labels = [], []#sound_clip,sr = librosa.load(fn)#print(fn)label_name = fn.split('/')[-2]label.extend([label_dict[label_name]])X, sample_rate = librosa.load(fn,res_type='kaiser_fast')mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征feature.extend([mels])return [feature, label]
# 自己更改目录
parent_dir = './train_sample/'
save_dir = "./"
folds = sub_dirs = np.array(['aloe','burger','cabbage','candied_fruits','carrots','chips','chocolate','drinks','fries','grapes','gummies','ice-cream','jelly','noodles','pickles','pizza','ribs','salmon','soup','wings'])# 获取特征feature以及类别的label
temp = extract_features(parent_dir,sub_dirs,max_file=100)
temp = np.array(temp)
data = temp.transpose()
# 获取特征
X = np.vstack(data[:, 0])# 获取标签
Y = np.array(data[:, 1])
print('X的特征尺寸是:',X.shape)
print('Y的特征尺寸是:',Y.shape)
X的特征尺寸是: (1000, 128)
Y的特征尺寸是: (1000,)
# 在Keras库中:to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示
Y = to_categorical(Y)'''最终数据'''
print(X.shape)
print(Y.shape)
(1000, 128)
(1000, 20)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1, stratify=Y)
print('训练集的大小',len(X_train))
print('测试集的大小',len(X_test))
训练集的大小 750
测试集的大小 250
X_train = X_train.reshape(-1, 16, 8, 1)
X_test = X_test.reshape(-1, 16, 8, 1)

02 建立模型

2.1 深度学习框架

Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。现在Keras已经和TensorFlow合并,可以通过TensorFlow来调用。

2.1.1 网络结构搭建

Keras 的核心数据结构是 model,一种组织网络层的方式。最简单的模型是 Sequential 顺序模型,它由多个网络层线性堆叠。对于更复杂的结构,你应该使用 Keras 函数式 API,它允许构建任意的神经网络图。

Sequential模型可以直接通过如下方式搭建:

from keras.models import Sequential

model = Sequential()

model = Sequential()

2.1.2 搭建CNN网络

# 输入的大小
input_dim = (16, 8, 1)

2.1.3 CNN基础知识

推荐的资料中,我们推荐大家去看看李宏毅老师的讲的CNN网络这里也附上老师的PPT。

CNN网络的基本架构

卷积神经网络CNN的结构一般包含这几个层:

1)输入层:用于数据的输入

2)卷积层:使用卷积核进行特征提取和特征映射------>可以多次重复使用

3)激励层:由于卷积也是一种线性运算,因此需要增加非线性映射(也就是激活函数)

4)池化层:进行下采样,对特征图稀疏处理,减少数据运算量----->可以多次重复使用

5)Flatten操作:将二维的向量,拉直为一维的向量,从而可以放入下一层的神经网络中

6)全连接层:通常在CNN的尾部进行重新拟合,减少特征信息的损失----->DNN网络

对于Keras操作中,可以简单地使用 .add() ,将需要搭建的神经网络的layer堆砌起来,像搭积木一样:

model.add(Conv2D(64, (3, 3), padding = "same", activation = "tanh", input_shape = input_dim))# 卷积层
model.add(MaxPool2D(pool_size=(2, 2)))# 最大池化
model.add(Conv2D(128, (3, 3), padding = "same", activation = "tanh")) #卷积层
model.add(MaxPool2D(pool_size=(2, 2))) # 最大池化层
model.add(Dropout(0.1))
model.add(Flatten()) # 展开
model.add(Dense(1024, activation = "tanh"))
model.add(Dense(20, activation = "softmax")) # 输出层:20个units输出20个类的概率

如果需要,你还可以进一步地配置你的优化器.complies())。Keras 的核心原则是使事情变得相当简单,同时又允许用户在需要的时候能够进行完全的控制(终极的控制是源代码的易扩展性)。

# 编译模型,设置损失函数,优化方法以及评价标准
model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

03 CNN模型训练与测试

3.1 模型训练

批量的在之前搭建的模型上训练:

# 训练模型
model.fit(X_train, Y_train, epochs = 90, batch_size = 50, validation_data = (X_test, Y_test))

查看网络的统计信息

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 16, 8, 64)         640
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 8, 4, 64)          0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 8, 4, 128)         73856
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 4, 2, 128)         0
_________________________________________________________________
dropout (Dropout)            (None, 4, 2, 128)         0
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0
_________________________________________________________________
dense (Dense)                (None, 1024)              1049600
_________________________________________________________________
dense_1 (Dense)              (None, 20)                20500
=================================================================
Total params: 1,144,596
Trainable params: 1,144,596
Non-trainable params: 0
_________________________________________________________________

3.2 预测测试集

新的数据生成预测

def extract_features(test_dir, file_ext="*.wav"):feature = []for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件X, sample_rate = librosa.load(fn,res_type='kaiser_fast')mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征feature.extend([mels])return feature

保存预测的结果

X_test = extract_features('./test_a/')X_test = np.vstack(X_test)
predictions = model.predict(X_test.reshape(-1, 16, 8, 1))preds = np.argmax(predictions, axis = 1)
preds = [label_dict_inv[x] for x in preds]path = glob.glob('./test_a/*.wav')
result = pd.DataFrame({'name':path, 'label': preds})result['name'] = result['name'].apply(lambda x: x.split('/')[-1])
result.to_csv('submit.csv',index=None)
!ls ./test_a/*.wav | wc -l
2000
!wc -l submit.csv
2001 submit.csv

【语音识别】食物声音识别(四)音频数据特征提取相关推荐

  1. 零基础入门语音识别-食物声音识别[Task 3]

    Task3 食物声音识别之音频数据特征提取 Task1 食物声音识别之Baseline学习 Task2 食物声音识别之赛题数据介绍与分析 1 特征提取背景 在Task2中,我们已经了解了我们需要识别的 ...

  2. 零基础入门语音识别-食物声音识别[Task 1]

    Task1 食物声音识别之Baseline学习 作为零基础入门语音识别的新人赛,本次任务不涉及复杂的声音模型.语言模型,希望大家通过两种baseline的学习能体验到语音识别的乐趣. 任务说明:我们提 ...

  3. 天池学习赛 -【零基础入门语音识别-食物声音识别】Task1 食物声音识别-Baseline【代码详细手写解释】

    文章目录 一.Task1 食物声音识别-Baseline 二.对应解析 三.参考链接 一.Task1 食物声音识别-Baseline 天池对应代码链接 二.对应解析 三.参考链接 tqdm的解释 深度 ...

  4. 使用Sinc卷积从原始音频数据进行轻量级的端到端语音识别

    论文: Lightweight End-to-End Speech Recognition from Raw Audio Data Using Sinc-Convolutions 摘要: 许多端到端自 ...

  5. 【组队学习】【24期】零基础入门语音识别(食物声音识别)

    零基础入门语音识别(食物声音识别) 开源内容: https://github.com/datawhalechina/team-learning-nlp/tree/master/FoodVoiceRec ...

  6. Android 音频开发(四) 如何播放一帧音频数据下

    再看这一篇文章前,如果你是小白,我建议你先看一下Android 音频开发(一) 基础入门篇这一篇.今天继续讲解如何通过Android SDK自带API实现播放一帧音频数据. 我们都知道,Android ...

  7. 音频数据建模全流程代码示例:通过讲话人的声音进行年龄预测

    来源:DeepHub IMBA 本文约6100字,建议阅读10+分钟 本文展示了从EDA.音频预处理到特征工程和数据建模的完整源代码演示. 大多数人都熟悉如何在图像.文本或表格数据上运行数据科学项目. ...

  8. python音频特征提取_使用Python对音频进行特征提取

    写在前面 因为喜欢玩儿音乐游戏,所以打算研究一下如何用深度学习的模型生成音游的谱面.这篇文章主要目的是介绍或者总结一些音频的知识和代码. 恩.如果没玩儿过的话,音乐游戏大概是下面这个样子. 下面进入正 ...

  9. 阿里天池比赛——食物声音识别

    阿里天池比赛--食物声音识别 最近写毕业论文无聊之余,再次参加阿里天池比赛,之前一直做CV,第一次尝试做语音识别,记录一下过程. 策略: 1.梅尔频谱和梅尔倒谱以及混合 2.多模型测试 想玩这个项目的 ...

最新文章

  1. 融合应用11.1.8安装,一步一步的引导
  2. python_day2基本数据类型
  3. C++设计模式——简单工厂模式
  4. 【JFreeChart】JFreeChart—输出柱形图
  5. 使用Jersey跨服务器上传图片 报405 Method Not Allowed错误
  6. 入库成本与目标成本对比报表中我学到的东西
  7. C语言中的文件是什么?
  8. CSS两栏布局之右栏布局
  9. OpenShift 4 - 如何用Machine Config Operator修改集群节点CoreOS的配置
  10. mysql 函数定义常量_php如何定义一个自定义常量
  11. 【操作系统/OS笔记20】打开文件、文件数据块分配、空闲空间管理、多磁盘管理(RAID)、磁盘调度算法概述
  12. Java使用BufferedImage修改图片内容
  13. ubuntu18使用wine安装TIM和微信
  14. Informatic学习总结_day01
  15. fedora 29 使用百度网盘客户端
  16. 利用lcx作端口映射
  17. 如何使用Reviewboard进行代码Review?
  18. FZU - 1759 Problem 1759 Super A^B mod C 欧拉降幂公式
  19. 如何在Google表格中直接使用Google翻译
  20. 纸本书变电子书是很小的事——詹宏志谈数字出版时代

热门文章

  1. 计算机控制系统在地铁应用,浅谈计算机技术在地铁通信系统中的应用
  2. 计算机 云 开发,云计算ppt-【ppt】介绍一种计算机新技术的基本原理、应用和发展情况。(如云计算、物联网、嵌入式软件设计开发等)...
  3. VR交互动画短片《拾梦老人》的开发经历
  4. Abaqus 2016 安装总结
  5. 笔记本计算机死机,一起可笑的笔记本电脑死机故障
  6. 网络和共享中心 服务器运行失败,网络和共享中心显示依赖服务或组无法启动导致无法上网(C15)...
  7. 逻辑回归原理以及推导
  8. 软件测试Mysql题库_软件测试面试常见数据库考题及答案
  9. 华南x79主板u盘装系统教程_英特尔X79主板怎么设置u盘启动
  10. wpsppt加载项在哪里_《wps表格加载项在哪里》 WPS版的EXCEL中 加载宏和数据分析在哪?...