cs231n 课程学习 二

cs231n 课程资源:Stanford University CS231n: Convolutional Neural Networks for Visual Recognition

我的 github 作业:FinCreWorld/cs231n: The assigments of cs231n (github.com)

第二章 线性分类器

本章介绍线性分类器,其包括两个主要的组成部分评价函数(score function)以及损失函数(loss function)

  • 评价函数:针对样例输入,输出其在不同分类上的分数,该分数为分类的依据
  • 损失函数:量化预测分数与真实标签(ground truth labels)的差距

最后我们将通过优化评价函数的参数,将损失函数最小化,从而将分类问题转化为优化问题。

一 构造评价函数

定义评价函数 f:RD→RKf:R^D\to R^Kf:RD→RK,构造从 DDD 维数据点到 KKK 维分类分数的映射。构造线性映射
f(xi,W,b)=Wxi+bf(x_i,W,b)=Wx_i+b f(xi​,W,b)=Wxi​+b
其中 xix_ixi​ 为 D×1D\times1D×1 维数据点,权重 WWW 为 K×DK\times DK×D 维矩阵,bbb 为 K×1K\times1K×1 维偏差向量。

  • WxiWx_iWxi​ 实质上是并行的 KKK 个分类器对 xix_ixi​ 进行预测,WWW 每一行都可以看做一个独立的分类器
  • 训练过程中,输入数据 (xi,yi)(x_i, y_i)(xi​,yi​) 是不可变的,我们的目的获取到合适的参数 W,bW,bW,b,使得针对输入数据,我们能够在正确的分类上获取到较高的分数
  • 该模型训练时间较长,预测时只需要将数据代入公式即可,因此预测时间较短

我们可以通过为每个数据最后一位添 1 的方式,将 bbb 合并入 WWW 中,即
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ f(x_i,W,b)&=(W…

二 线性分类器理解

模式匹配的角度

xix_ixi​ 实际上对应了一幅图片中所有像素点以及色彩通道的数值。对于不同的物体,其图像像素分布不同,颜色也不同。对于轮船,可能蓝色居多,对于鸟类,可能较多出现在图像的边缘。

过更改权重 WWW,我们可以更加侧重于采样某些位置的像素点以及某些颜色的像素点,如果某一幅图片在这些像素点中具有较高的值,那么该图像在该分类器中就会取得较高的分数。

如上图所示,权重 WWW 第一行用于分类出小猫,第二行用于分类出小狗,第三行用于分类出轮船,我们可以看出,小猫分类器更侧重于 xi1,xi4x_i^1,x_i^4xi1​,xi4​(这些位置权重为 0.2),更不喜欢 xi2x_i^2xi2​,当数据点 xi1,xi4x_i^1,x_i^4xi1​,xi4​ 有较大值,xi2x_i^2xi2​ 有较小值时,其在小猫类别上的得分就更高。

因此,权重 WWW 的每一行均为一个模板(template),我们将这些模板与输入图像进行匹配,学习到每个分类最佳的匹配模板。

上图就是我们学习到模板,可以看到 planeship 中蓝色居多,同时不同的图像称对称分布,这是因为我们的输入图像拥有不同的朝向。

进一步的,我们可以从最近邻分类器的角度来理解,我们学习到了 WWW,就相当于学习到了 KKK 种每个类别对应的模板,对于给定的图像,我们找出距离测试样例最近的模板,该模板对应的分类即为测试样例的分类。

高维空间的角度

我们可以把每个 DDD 维数据样例 xix_ixi​看做高维空间的一个点,每一个权重 WiW_iWi​ 决定了一个超平面,一侧是其所属类,一侧是其他类。

三 损失函数

我们通过损失函数来衡量预测结果与实际标签之间的不匹配程度,很自然的,如果预测结果与实际标签不匹配,那么损失函数的值就应该高,如果预测正确,损失函数的值就应该较低。

多分类支持向量机(Multiclass Support Vector Machine)

SVM 要求对于正确类别的预测分数至少要比错误类别的分类分数高 Δ\DeltaΔ。给定样例 xix_ixi​,我们计算出分类结果 f(xi,W)f(x_i,W)f(xi​,W),分类结果为 K×1K\times1K×1 维向量,f(xi,W)jf(x_i,W)^jf(xi​,W)j 为对于第 jjj 类的分类分数,我们令 sj=f(xi,W)js_j=f(x_i,W)^jsj​=f(xi​,W)j 代替,正确分类标签为 yiy_iyi​(注,如果有 KKK 个类别,那么 yi∈{0−9}y_i\in\{0-9\}yi​∈{0−9}),我们定义样例 xix_ixi​ 的预测损失为
Li=∑j≠yimax⁡(0,sj−syi+Δ)L_i=\sum_{j\neq y_i}\max{(0,s_j-s_{y_i}+\Delta)} Li​=j​=yi​∑​max(0,sj​−syi​​+Δ)
如果我们使用线性分类器 f(xi,W)=Wxif(x_i,W)=Wx_if(xi​,W)=Wxi​,那么我们可以写出如下式子
Li=∑j≠yimax⁡(0,wjxi−wyixi+Δ)L_i=\sum_{j\neq y_i}\max{(0,w_jx_i-w_{y_i}x_i+\Delta)} Li​=j​=yi​∑​max(0,wj​xi​−wyi​​xi​+Δ)
WWW 并不唯一,假设我们学习到了一个 WWW 使得任意 xix_ixi​ 都有 Li=0L_i=0Li​=0,那么对于 W′=λW(λ>1)W'=\lambda W(\lambda>1)W′=λW(λ>1) 来说, 也满足所有损失值为 0 的条件。因此我们需要对 WWW 进行限制,我们可以为损失函数增加正则惩罚项 R(W)。

最后我们得到了一个完整的损失函数,包括数据损失(data loss)和正则损失(regularization loss)
L=1N∑iLi⏟data loss+λR(W)⏟regularization lossL=\underbrace{\frac{1}{N}\sum_{i}L_i}_{\text{data loss}}+\underbrace{\lambda R(W)}_{\text{regularization loss}} L=data lossN1​i∑​Li​​​+regularization lossλR(W)​​
我们通常定义 R(W)R(W)R(W) 为 WWW 的 L2 范数,即 R(W)=∑k∑lWk,l2R(W)=\sum_k\sum_l W_{k,l}^2R(W)=∑k​∑l​Wk,l2​ 限制权值的增大,表示我们喜欢更小的权值。同时正则项的引入能够增加泛化精度,因为我们避免了某几个权重拥有较大的值。损失函数的展开形式为
L=1N∑i∑j≠yimax⁡(0,f(xi;W)j−f(xi;W)yi+Δ)+λ∑k∑lWk,lL=\frac{1}{N}\sum_i\sum_{j\neq y_i}\max{(0,f(x_i;W)_j-f(x_i;W)_{y_i}+\Delta)}+\lambda\sum_k\sum_l W_{k,l} L=N1​i∑​j​=yi​∑​max(0,f(xi​;W)j​−f(xi​;W)yi​​+Δ)+λk∑​l∑​Wk,l​

Softmax 分类器

对于样例 xix_ixi​,其分数 f(xi,W)f(x_i,W)f(xi​,W) 保持不变,但是我们将分数 fff 看做不规整的 log⁡\loglog 概率(unnormalized log probabilities),并且采用了交叉熵的形式,有
Li=−log⁡(efyi∑jefj)L_i=-\log{(\frac{e^{f_{y_i}}}{\sum{_j}e^{f_j}})} Li​=−log(∑j​efj​efyi​​​)

这里的 fjf_jfj​ 与上面 svm 的 sjs_jsj​ 表示含义一致

注意到,我们称 fj(x)=exj∑kexkf_j(x)=\frac{e^{x_j}}{\sum_ke^{x_k}}fj​(x)=∑k​exk​exj​​ 函数为 softmax 函数,该函数用于将向量 xxx 的所有分量上的值压缩到 0−10-10−1,并且总和为 111。

Softmax 分类器的解释
  • 信息论视角

    对于真实的事件分布 ppp 和估计的事件分布 qqq 之间的交叉熵定义为
    H(p,q)=−∑xp(x)log⁡q(x)H(p,q)=-\sum_{x}p(x)\log{q(x)} H(p,q)=−x∑​p(x)logq(x)
    Softmax 分类器用于最小化估计事件的概率与真实事件发生概率之间的交叉熵。估计事件概率为q=efyi∑jefjq=\frac{e^{f_{y_i}}}{\sum_je^{f_j}}q=∑j​efj​efyi​​​,而实际事件发生概率为 p=[0,...,1,...,0]p=[0,...,1,...,0]p=[0,...,1,...,0],其中 pyi=0p_{y_i}=0pyi​​=0。KL 散度用于衡量两种时间分布之间的距离,而 H(p,q)=H(p)+DKL(p∣∣q)H(p,q)=H(p)+D_{KL}(p||q)H(p,q)=H(p)+DKL​(p∣∣q),事件 ppp 的熵H(p)=0H(p)=0H(p)=0,因此有H(p,q)=DKL(p∣∣q)H(p,q)=D_{KL}(p||q)H(p,q)=DKL​(p∣∣q),将事件发生概率交叉熵之后可得 H(p,q)=DKL(p∣∣q)=−log⁡(efyi∑jefj)H(p,q)=D_{KL}(p||q)=-\log{(\frac{e^{f_{y_i}}}{\sum{_j}e^{f_j}})}H(p,q)=DKL​(p∣∣q)=−log(∑j​efj​efyi​​​)

  • 概率论视角

    详情参考聊一聊机器学习的MLE和MAP:最大似然估计和最大后验估计 - 知乎 (zhihu.com)

保持数据不溢出

由于 softmax 分类器中存在大量的指数操作,因此需要一些小技巧防止溢出,有
efyi∑jefj=CefyiC∑jefj=efyi+log⁡C∑jefj+log⁡C\frac{e^{f_{y_i}}}{\sum_je^{f_j}}=\frac{Ce^{f_{y_i}}}{C\sum_je^{f_j}}=\frac{e^{f_{y_i}+\log{C}}}{\sum_je^{f_j+\log{C}}} ∑j​efj​efyi​​​=C∑j​efj​Cefyi​​​=∑j​efj​+logCefyi​​+logC​
我们使 log⁡C=−max⁡fj\log{C}=-\max{f_j}logC=−maxfj​ 可以避免溢出

cs231n 课程学习 二相关推荐

  1. 转载CS231n课程学习笔记

    CS231n课程学习笔记 CS231n网易云课堂链接 CS231n官方笔记授权翻译总集篇发布 - 智能单元 - 知乎专栏 https://zhuanlan.zhihu.com/p/21930884 C ...

  2. 深度学习总结——CS231n课程深度学习(机器视觉相关)笔记整理

    深度学习笔记整理 说明 基本知识点一:模型的设置(基本) 1. 激活函数的设置 2. 损失函数的设置 (1) 分类问题 (2) 属性问题 (3) 回归问题 3. 正则化方式的设置 (1) 损失函数添加 ...

  3. 深度学习初学者推荐怎么在本地完成CS231n课程作业-配置环境

    近期学习cs231n课程,并准备做作业,整理一下整个过程以防忘记.也许会出一个系列. 课程推荐: 喜欢看视频的可看下面两个链接之一: 1.https://cloud.tencent.com/edu/l ...

  4. 武汉大学-黄如花-信息检索课程学习笔记二

    武汉大学-黄如花-信息检索课程学习笔记二 一.信息检索基本方法 1.布尔逻辑检索 2.临近检索 3.短语检索(精确检索) 4.截词检索 5.字段限制检索 6.区分大小写的检索 二.多种检索方法的综合运 ...

  5. CS231n课程笔记翻译:图像分类笔记(下)

    译者注:本文翻译自斯坦福CS231n课程笔记image classification notes,课程教师Andrej Karpathy授权翻译.本篇教程由杜客进行翻译,ShiqingFan和巩子嘉进 ...

  6. CS231n课程笔记翻译3:线性分类笔记

    译者注 :本文 智能单元 首发,译自斯坦福CS231n课程笔记 Linear Classification Note ,课程教师 Andrej Karpathy 授权翻译.本篇教程由 杜客 翻译完成, ...

  7. CS231n课程笔记翻译9:卷积神经网络笔记

    译者注:本文翻译自斯坦福CS231n课程笔记ConvNet notes,由课程教师Andrej Karpathy授权进行翻译.本篇教程由杜客和猴子翻译完成,堃堃和李艺颖进行校对修改. 原文如下 内容列 ...

  8. 斯坦福CS231N深度学习与计算机视觉

    又一重磅再获翻译授权,斯坦福CS231N深度学习与计算机视觉 https://zhuanlan.zhihu.com/intelligentunit http://toutiao.com/i631103 ...

  9. 如何在本地完成CS231n课程作业

    最近开始学习斯坦福大学的CS231n课程,课程地址:网易云课堂,只有中文字幕,现在学完了1-7课时,准备着手做一下第一次作业,但是第一次接触不免有些手忙脚乱,自己探索了半天,准备写一个教程给和我一样的 ...

最新文章

  1. php引用数据库实例,PHP单例模式实例,连接数据库对类的引用
  2. STM32F103外部晶振由8M变为12M
  3. xml simpleXML_load_file(), simpleXML_load_string()
  4. 阿里研究员:警惕软件复杂度困局
  5. 浅析Java的“克隆”方法[zt]
  6. python开源考试_Github 上 10 个值得学习的 Springboot 开源项目
  7. 复合索引字段的排序对搜素的影响
  8. 一文掌握GaussDB(DWS) SQL进阶技能:全文检索
  9. 学python必须得英语精通吗_“学习python必须精通的几个模块“
  10. java cropper_cropper 使用总结
  11. 第三方支付的发展趋势及优势
  12. 基于微信校园跑腿小程序毕业设计设计与实现毕设参考
  13. 普罗米修斯prometheus
  14. 怼天怼地的马斯克道歉了?
  15. 【干货】百度自动化运维是怎么做的?
  16. cisco路由器各接口模块代表的含义是什么
  17. access制作卡片_(2020年编辑)Access入门教程大全
  18. 蓝桥杯官网python组基础练习-基础1-5
  19. requests库(正则提取)爬取千图网
  20. 有关细粒度图像分析(Fine-Grained Image Analysis)

热门文章

  1. JS基础语法,if分支
  2. 韩顺平SQL,雇员系统表.txt
  3. 安装facenet环境及N卡GPU驱动
  4. 指针 字符串的复制(函数)
  5. iOS界面设计,12个优秀案例激发你的灵感
  6. Linux网络编程-很全面
  7. 为何说Android ViewDragHelper是神器 (二)
  8. 这才是互联网赚钱的正确姿势!你学会了吗?
  9. 记录:百度前端技术学院任务笔记(一)
  10. Web自动化测试06