FM

参数数量和时间复杂度优化

当我们使用一阶原始特征和二阶组合特征来刻画样本的时候,会得到如下式子:

y ^ = w 0 + ∑ i = 1 n w i x i + ∑ i = 1 n − 1 ∑ j = i + 1 n w i j x i x j \hat{y}=w_{0}+\sum_{i=1}^{n} w_{i} x_{i}+\sum_{i=1}^{n-1} \sum_{j=i+1}^{n} w_{i j} x_{i} x_{j} y^=w0+i=1nwixi+i=1n1j=i+1nwijxixj

x i x_i xix j x_j xj 分别表示两个不同的特征取值,对于 n n n 维的特征来说,这样的二阶组合特征一共有 n ( n − 1 ) 2 \frac{n(n-1)}{2} 2n(n1) 种,也就意味着我们需要同样数量的权重参数。但是由于现实场景中的特征是高维稀疏的,导致 n n n 非常大,比如上百万,这里两两特征组合的特征量级 C n 2 C_n^2 Cn2 ,所带来的参数量就是一个天文数字。对于一个上百亿甚至更多参数空间的模型来说,我们需要海量训练样本才可以保证完全收敛。这是非常困难的。

FM解决这个问题的方法非常简单,它不再是简单地为交叉之后的特征对设置参数,而是设置了一种计算特征参数的方法。

FM模型引入了新的矩阵 V V V ,它是一个 n × k n \times k n×k 的二维矩阵。这里的 k k k 是超参,一般不会很大,比如16、32之类。对于特征每一个维度 x i x_i xi ,我们都可以找到一个表示向量 v i ∈ R k v_i \in R^k viRk 。从NLP的角度来说,就是为每个特征学习一个embedding。原先的参数量从 O ( n 2 ) O(n^2) O(n2) 降低到了 O ( k × n ) O(k \times n) O(k×n) 。ALBERT论文的因式分解思想跟这个非常相似: O ( V × H ) ⋙ O ( V × E + E × H ) O(V \times H) \ggg O(V \times E + E \times H) O(V×H)O(V×E+E×H)

有了 V V V 矩阵,上式就可以改写成如下形式:
y ^ = w 0 + ∑ i = 1 n w i x i + ∑ i = 1 n − 1 ∑ j = 1 n v i T v j x i x j \hat{y}=w_{0}+\sum_{i=1}^{n} w_{i} x_{i}+\sum_{i=1}^{n-1} \sum_{j=1}^{n} v_{i}^{T} v_{j} x_{i} x_{j} y^=w0+i=1nwixi+i=1n1j=1nviTvjxixj
k k k 足够大的时候,即 k = n k = n k=n ,那么就有 W = V W = V W=V 。在实际的应用场景当中,我们并不需要设置非常大的K,因为特征矩阵往往非常稀疏,我们可能没有足够多的样本来训练这么大量的参数,并且限制K也可以一定程度上提升FM模型的泛化能力

此外这样做还有一个好处就是有利于模型训练,因为对于有些稀疏的特征组合来说,我们所有的样本当中可能都是空的。比如在刚才的例子当中用户A和电影B的组合,可能用户A在电影B上就没有过任何行为,那么这个数据就是空的,我们也不可能训练出任何参数来。但是引入了 V V V 之后,虽然这两项缺失,但是我们针对用户A和电影B分别训练出了向量参数,我们用这两个向量参数点乘,就得到了这个交叉特征的系数。

虽然我们将模型的参数降低到了 O ( k × n ) O(k \times n) O(k×n) ,但预测一条样本所需要的时间复杂度仍为 O ( k × n 2 ) O(k \times n^2) O(k×n2) ,这仍然是不可接受的。所以对它进行优化也是必须的,并且这里的优化非常简单,可以直接通过数学公式的变形推导得到:
∑ i = 1 n ∑ j = i + 1 n v i T v j x i x j = 1 2 ∑ i = 1 n ∑ j = 1 n v i T v j x i x j − 1 2 ∑ i = 1 n v i T v j x i x j = 1 2 ( ∑ i = 1 n ∑ j = 1 n ∑ f = 1 k v i , f v j , f x i x j − ∑ i = 1 n ∑ f = 1 k v i , f v i , f x i x i ) = 1 2 ∑ f = 1 k ( ( ∑ i = 1 n v i , f x i ) ( ∑ j = 1 n v j , f x j ) − ∑ i = 1 n v i , f 2 x i 2 ) = 1 2 ∑ f = 1 k ( ( ∑ i = 1 n v i , f x i ) 2 − ∑ i = 1 n v i , f 2 x i 2 ) \begin{aligned} \sum_{i=1}^{n} \sum_{j=i+1}^{n} v_{i}^{T} v_{j} x_{i} x_{j} &=\frac{1}{2} \sum_{i=1}^{n} \sum_{j=1}^{n} v_{i}^{T} v_{j} x_{i} x_{j}-\frac{1}{2} \sum_{i=1}^{n} v_{i}^{T} v_{j} x_{i} x_{j} \\ &=\frac{1}{2}\left(\sum_{i=1}^{n} \sum_{j=1}^{n} \sum_{f=1}^{k} v_{i, f} v_{j, f} x_{i} x_{j}-\sum_{i=1}^{n} \sum_{f=1}^{k} v_{i, f} v_{i, f} x_{i} x_{i}\right) \\ &=\frac{1}{2} \sum_{f=1}^{k}\left(\left(\sum_{i=1}^{n} v_{i, f} x_{i}\right)\left(\sum_{j=1}^{n} v_{j, f} x_{j}\right)-\sum_{i=1}^{n} v_{i, f}^{2} x_{i}^{2}\right) \\ &=\frac{1}{2} \sum_{f=1}^{k}\left(\left(\sum_{i=1}^{n} v_{i, f} x_{i}\right)^{2}-\sum_{i=1}^{n} v_{i, f}^{2} x_{i}^{2}\right) \end{aligned} i=1nj=i+1nviTvjxixj=21i=1nj=1nviTvjxixj21i=1nviTvjxixj=21i=1nj=1nf=1kvi,fvj,fxixji=1nf=1kvi,fvi,fxixi=21f=1k((i=1nvi,fxi)(j=1nvj,fxj)i=1nvi,f2xi2)=21f=1k(i=1nvi,fxi)2i=1nvi,f2xi2

FM模型预测的时间复杂度优化到了 O ( k × n ) O(k \times n) O(k×n) .

模型训练

优化过后的式子如下:
y ^ = w 0 + ∑ i = 1 n w i x i + 1 2 ∑ f = 1 k ( ( ∑ i = 1 n v i , f x i ) 2 − ∑ i = 1 n v i , f 2 x i 2 ) \hat{y}=w_{0}+\sum_{i=1}^{n} w_{i} x_{i}+\frac{1}{2} \sum_{f=1}^{k}\left(\left(\sum_{i=1}^{n} v_{i, f} x_{i}\right)^{2}-\sum_{i=1}^{n} v_{i, f}^{2} x_{i}^{2}\right) y^=w0+i=1nwixi+21f=1k(i=1nvi,fxi)2i=1nvi,f2xi2
针对FM模型我们一样可以使用梯度下降算法来进行优化。既然要使用梯度下降,那么我们就需要写出模型当中所有参数的偏导,主要分为三个部分:

  • w 0 w_0 w0 : ∂ θ ∂ w 0 = 1 \frac{\partial \theta}{\partial w_{0}}=1 w0θ=1
  • ∑ i = 1 n w i x i \sum_{i=1}^{n} w_{i} x_{i} i=1nwixi : ∂ 0 ∂ w i = x i \frac{\partial 0}{\partial w_{i}}=x_{i} wi0=xi
  • 1 2 ∑ f = 1 k ( ( ∑ i = 1 n v i , f x i ) 2 − ∑ i = 1 n v i , f 2 x i 2 ) \frac{1}{2} \sum_{f=1}^{k}\left(\left(\sum_{i=1}^{n} v_{i, f} x_{i}\right)^{2}-\sum_{i=1}^{n} v_{i, f}^{2} x_{i}^{2}\right) 21f=1k((i=1nvi,fxi)2i=1nvi,f2xi2) : ∂ y ^ ∂ v i , f = 1 2 ( 2 x i ( ∑ j = 1 n v j , f x j ) − 2 v i , f x i 2 ) = x i ∑ j = 1 n v j , f x j − v i , f x i 2 \frac{\partial \hat{y}}{\partial v_{i, f}} = \frac{1}{2} (2x_i (\sum_{j=1}^{n} v_{j, f} x_{j}) - 2v_{i,f} x_i^2) = x_{i} \sum_{j=1}^{n} v_{j, f} x_{j}-v_{i, f} x_{i}^{2} vi,fy^=21(2xi(j=1nvj,fxj)2vi,fxi2)=xij=1nvj,fxjvi,fxi2

综合如下:
∂ y ^ ∂ θ = { 1 , if  θ is  w 0 x i , if  θ is  w i x i ∑ j = 1 n v j , f x j − v i , f x i 2 if  θ is  v i , f \frac{\partial \hat{y}}{\partial \theta}= \begin{cases}1, & \text { if } \theta \text { is } w_{0} \\ x_{i}, & \text { if } \theta \text { is } w_{i} \\ x_{i} \sum_{j=1}^{n} v_{j, f} x_{j}-v_{i, f} x_{i}^{2} & \text { if } \theta \text { is } v_{i, f}\end{cases} θy^=1,xi,xij=1nvj,fxjvi,fxi2ifθisw0ifθiswiifθisvi,f
由于 ∑ j = 1 n v j , f x j \sum_{j=1}^n v_{j,f} x_j j=1nvj,fxj 是可以提前计算好存储起来的,因此我们对所有参数的梯度计算也都能在 O ( 1 ) O(1) O(1) 时间复杂度内完成。

拓展到 d d d

参照刚才的公式,可以写出FM模型推广到d维的方程:
y ^ = w 0 + ∑ i = 1 n w i x i + ∑ l = 2 d ∑ i 1 = 1 n − l + 1 ⋯ ∑ i l = i l − 1 + 1 n ( Π j − 1 l x i j ) ( ∑ f = 1 k Π j = 1 l v i j , f l ) \hat{y}=w_{0}+\sum_{i=1}^{n} w_{i} x_{i}+\sum_{l=2}^{d} \sum_{i_1=1}^{n-l+1} \cdots \sum_{i_{l}=i_{l-1}+1}^{n}\left(\Pi_{j-1}^{l} x_{i_{j}}\right)\left(\sum_{f=1}^{k} \Pi_{j=1}^{l} v_{i_{j}, f}^{l}\right) y^=w0+i=1nwixi+l=2di1=1nl+1il=il1+1n(Πj1lxij)f=1kΠj=1lvij,fl
d = 3 d=3 d=3 为例,上式为:
y ^ = w 0 + ∑ i = 1 n w i x i + ∑ i = 1 n − 1 ∑ j = i + 1 n x i x j ( ∑ t = 1 k v i , t v j , t ) + ∑ i = 1 n − 2 ∑ j = i + 1 n − 1 ∑ l = j + 1 n x i x j x l ( ∑ t = 1 k v i , t v j , t v l , t ) \hat{y}=w_{0}+\sum_{i=1}^{n} w_{i} x_{i} + \sum_{i=1}^{n-1} \sum_{j=i+1}^{n} x_{i} x_{j}\left(\sum_{t=1}^{k} v_{i, t} v_{j, t}\right)+\sum_{i=1}^{n-2} \sum_{j=i+1}^{n-1} \sum_{l=j+1}^{n} x_{i} x_{j} x_{l}\left(\sum_{t=1}^{k} v_{i, t} v_{j, t} v_{l, t}\right) y^=w0+i=1nwixi+i=1n1j=i+1nxixj(t=1kvi,tvj,t)+i=1n2j=i+1n1l=j+1nxixjxl(t=1kvi,tvj,tvl,t)
它的复杂度是 O ( k × n d ) O(k \times n^d) O(k×nd) 。当 d = 2 d=2 d=2 的时候,我们通过一系列变形将它的复杂度优化到了 O ( k × n ) O(k \times n) O(k×n) 。而当 d > 2 d > 2 d>2 的时候,没有很好的优化方法,而且三重特征的交叉往往没有意义,并且会过于稀疏,所以我们一般情况下只会使用 d = 2 d=2 d=2 的情况。

最佳实践

import torch
from torch import nnndim = len(feature_names)  # 原始特征数量
k = 4class FM(nn.Module):def __init__(self, dim, k):super(FM, self).__init__()self.dim = dimself.k = kself.w = nn.Linear(self.dim, 1, bias=True)# 初始化V矩阵self.v = nn.Parameter(torch.rand(self.dim, self.k) / 100)def forward(self, x):linear = self.w(x)# 二次项quadradic = 0.5 * torch.sum(torch.pow(torch.mm(x, self.v), 2) - torch.mm(torch.pow(x, 2), torch.pow(self.v, 2)))# 套一层sigmoid转成分类模型,也可以不加,就是回归模型return torch.sigmoid(linear + quadradic)fm = FM(ndim, k)
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(fm.parameters(), lr=0.005, weight_decay=0.001)
iteration = 0
epochs = 10for epoch in range(epochs):fm.train()for X, y in data_iter:output = fm(X)l = loss_fn(output.squeeze(dim=1), y)optimizer.zero_grad()l.backward()optimizer.step()iteration += 1        if iteration % 200 == 199:with torch.no_grad():fm.eval()output = fm(X_eva_tensor)l = loss_fn(output.squeeze(dim=1), y_eva_tensor)acc = ((torch.round(output).long() == y_eva_tensor.view(-1, 1).long()).sum().float().item()) / 1024print('Epoch: {}, iteration: {}, loss: {}, acc: {}'.format(epoch, iteration, l.item(), acc))fm.train()

DeepFM

y ^ = sigmoid ⁡ ( y F M + y D N N ) \hat{y}=\operatorname{sigmoid}\left(y_{F M}+y_{D N N}\right) y^=sigmoid(yFM+yDNN)

FM

该组件就是在计算FM:
y F M = ⟨ w , x ⟩ + ∑ j 1 = 1 d ∑ j 2 = j 1 + 1 d ⟨ V i , V j ⟩ x j 1 ⋅ x j 2 y_{F M}=\langle w, x\rangle+\sum_{j_{1}=1}^{d} \sum_{j_{2}=j_{1}+1}^{d}\left\langle V_{i}, V_{j}\right\rangle x_{j_{1}} \cdot x_{j_{2}} yFM=w,x+j1=1dj2=j1+1dVi,Vjxj1xj2
注意不是: w 0 + ∑ i = 1 n w i x i + 1 2 ∑ f = 1 k ( ( ∑ i = 1 n v i , f x i ) 2 − ∑ i = 1 n v i , f 2 x i 2 ) w_{0}+\sum_{i=1}^{n} w_{i} x_{i}+\frac{1}{2} \sum_{f=1}^{k}\left(\left(\sum_{i=1}^{n} v_{i, f} x_{i}\right)^{2}-\sum_{i=1}^{n} v_{i, f}^{2} x_{i}^{2}\right) w0+i=1nwixi+21f=1k((i=1nvi,fxi)2i=1nvi,f2xi2)

  • 每个 F i e l d Field Field 是one-hot形式,黄色的圆表示 1 1 1 ,蓝色的代表 0 0 0
  • 连接黄色圆的黑线就是在做: ⟨ w , x ⟩ \langle w, x\rangle w,x
  • 连接embedding的红色线就是在做: ∑ j 1 = 1 d ∑ j 2 = j 1 + 1 d ⟨ V i , V j ⟩ x j 1 ⋅ x j 2 \sum_{j_{1}=1}^{d} \sum_{j_{2}=j_{1}+1}^{d}\left\langle V_{i}, V_{j}\right\rangle x_{j_{1}} \cdot x_{j_{2}} j1=1dj2=j1+1dVi,Vjxj1xj2

DNN

DNN部分比较简单,但它是与FM部分共享Embedding的。


参考

  • 原创 | 想做推荐算法?先把FM模型搞懂再说
  • DeepFM模型CTR预估理论与实战
  • 深度推荐模型之DeepFM
  • 吃透论文——推荐算法不可不看的DeepFM模型

FM DeepFM相关推荐

  1. Wide Deep、DeepFM系列算法原理与优缺点对比

    一.Wide & Deep模型 Wide & Deep Learning 模型的核心思想是结合广义线性模型的记忆能力(memorization)和深度前馈神经网络模型的泛化能力(gen ...

  2. pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)

    写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...

  3. 搜索推荐系统实战:进化篇

    搜索推荐系统实战篇-中篇 一切源于炼丹笔记,我只是敲了敲代码. 搜索推荐系统实战:起始篇 搜索推荐系统实战:进化篇 搜索推荐系统实战:终极奥秘 此处我们假设模型训练的大的框架已经固定,同时数据采样的方 ...

  4. 推荐系统学习(一)推荐系统分类与基本流程

    文章目录 推荐系统框架 1. 推荐系统分类 2. 推荐系统基本流程 推荐系统框架 1. 推荐系统分类 基于统计学 个性化推荐:一切即标签 基于推荐原则的分类 基于相似度 基于知识 基于模型 基于数据源 ...

  5. NLP模型集锦----pynlp

    github 地址 目录 1.Introduction 2.Our Model 2.1 CTR 2.1.1 Models List 2.1.2 Convolutional Click Predicti ...

  6. 特征交互新路线|阿里 Co-action Network论文解读

    最近看到阿里的新工作在公众号上突然流行起来,自己也没忍住去认真拜读了一下,确实是好文.按照自己的理解对论文做了粗浅的解读. 这篇文章主要介绍周国睿大佬的新工作:CAN: Revisiting Feat ...

  7. 推荐系统中的召回和排序

    在推荐系统中一般会分为召回和排序两个阶段: 召回 召回的目标是从千万级甚至亿级的候选中召回几千个item,召回一般由多路组成,每一路会有不同的侧重点(优化目标),如在广告中成熟期广告和冷启动广告分为两 ...

  8. 盘点智能风控中的机器学习技术

    前言 生命里面碰到了很多愿意无偿帮助我,教导我的同事和领导.他们有的给我技术上的帮助,有的给我工作上的宽容,有的给我自由发挥的机会.也许未来所完成的每一件事情,都是他们力量汇聚的结果.而我所做的,不过 ...

  9. 60分钟吃掉三杀模型FiBiNET

    神经网络的结构设计有3个主流的高级技巧: 1,高低融合 (将高层次特征与低层次特征融合,提升特征维度的丰富性和多样性,像人一样同时考虑整体和细节) 2,权值共享 (一个权值矩阵参与多个不同的计算,降低 ...

最新文章

  1. 十年程序员的告诫:千万不要重写代码!
  2. SuperSocket+unity 网络笔记
  3. jenkins 插件目录_三十二张图告诉你如何用Jenkins构建SpringBoot
  4. maya崩溃自动保存路径_maya 使用swig将插件编译成pyd,无缝使用内置数据实现加速计算模块...
  5. IDA动态调试Android的DEX文件
  6. Cuckoo Hashing
  7. maven不引入parent_Maven从入门到放弃
  8. Android报错:FAILED:_nl_intern_locale_data: ?? ‘cnt < (sizeof (_nl_value_type_LC_TIME)
  9. java web 中Integer.valueof()与integer.parseint()
  10. 实现二叉树各种遍历算法
  11. softmgr主程序_SoftMgrBase.dll
  12. 华工计算机工图答案,华南理工 网络画法几何及工程制图-课程习题集答案
  13. ora 01033 解决
  14. html中怎样写渐变色代码,html颜色渐变代码 怎样用css实现网页背景颜色渐变
  15. 海天老师 资深TTT/思维训练专家
  16. 爬虫学习:爬取京东图书
  17. 2022 Gartner RPA魔力象限发布,两家国产厂商入选,超自自动化成重点
  18. Window自带的定时自动执行程序
  19. Java后端技术框架
  20. 计算机通信原理知识点,《计算机通信原理与技术》.pdf

热门文章

  1. 幼儿抽象逻辑思维举例_2岁多的孩子,需要锻炼逻辑思维吗?
  2. CAD中ColorIndex索引对应的颜色及RGB值
  3. 【学习方式】开源项目
  4. 【Linux】C语言缓冲区、缓冲区的实现
  5. 华为云桌面随时随地,开启云上办公
  6. 工具 好用的一些windows工具,包括git、录屏、ps、navicat等等,后续会持续更新
  7. 20_clickhouse,硬件管理与优化(cpu,内存,网络,存储,操作系统配置),profile管理,Quotas设置,约束管理,查询权限,用户管理配置等
  8. hadoop实现求共同好友
  9. java 增长的极限_下列关于《增长的极限》报告中的论述说法正确的是()。
  10. 凭什么同窗好友Java开发都是三年,他能进大厂,工资还是我的双倍?