nDCG笔记及在spark中的实现
目录
- 0. 前言
- 1. 原理
- 2. 步骤
- 2.1 计算CG
- 2.2 计算DCG
- 2.3 计算nDCG
- 3. 本地代码实现
- 3.1 自己编写代码
- 3.2 使用sklearn.metrics.ndcg_score
- 3.3 两种代码的速度比较
- 3.4 两种代码之间的误差
- 4. spark代码实现
0. 前言
之前在排序项目中用到了nDCG,其离线指标对模型训练及上线具有一定的参考意义,在此做一个总结。因为计算是在集群的hive表进行的,所以加上了spark上的计算代码。
1. 原理
- 高度相关的文档在搜索引擎结果列表的前面显示时更有用。
- 高度相关的文档比不相关的文档更有用,勉强相关的文件比不相关的文件更有用
2. 步骤
有以下例子:
位置(iii) | 预测值排序 | 真实的相关性分(relirel_ireli) | 折损值(log2(i+1)log_2(i+1)log2(i+1)) | 折损后的相关性(reli/log2(i+1)rel_i/log_2(i+1)reli/log2(i+1)) |
---|---|---|---|---|
1 | D1 | 3 | 1 | 3 |
2 | D2 | 2 | 1.585 | 1.262 |
3 | D3 | 3 | 2 | 1.5 |
4 | D4 | 0 | 2.322 | 0 |
5 | D5 | 1 | 2.585 | 0.387 |
6 | D6 | 2 | 2.807 | 0.712 |
2.1 计算CG
CG(Cumulative Gain)是结果列表中所有 items 的相关性得分的总和。CGk\text{CG}_kCGk是前 kkk 个 items 的相关性得分的总和。
CGk=∑i=1kreli\text{CG}_k=\sum_{i=1}^{k}rel_i CGk=i=1∑kreli
比如上面表格中的 CG1=3\text{CG}_1=3CG1=3, CG2=5\text{CG}_2=5CG2=5, CG3=8\text{CG}_3=8CG3=8…
缺点:CG 忽略了位次的重要性,比如CG3=8\text{CG}_3=8CG3=8的序列有{(3,3,2), (3,2,3), (2,3,3)},但最优的序列是将最大值都排在前面的序列,如(3,3,2)。
2.2 计算DCG
DCG(Discounted CG),折损累计收益。因为CG的缺点是不能区分位次,所以将位次作为折损,位次越靠后,折损越大,所以DCG的计算为:
DCGk=∑i=1krelilog2(i+1)\text{DCG}_k=\sum_{i=1}^{k}{\frac {rel_i} {log_2(i+1)}} DCGk=i=1∑klog2(i+1)reli
2.3 计算nDCG
通常在排序中,nDCG(normalized DCG)是使用最多的。即对DCG进行归一化,归一化就是理想的排序结果,即相关性最大的排在前面,其DCG成为IDCG(Ideal DCG)。
nDCGk=DCGIDCG\text{nDCG}_k={\frac {DCG} {IDCG}} nDCGk=IDCGDCG
3. 本地代码实现
3.1 自己编写代码
def get_dcg(y_true, y_scores, k):'''@pred_arr: 预测顺序的相关度,ndarray,shape=(n,1)@gt_arr: 实际顺序相关度,ndarray,shape=(n,1)@k: 要计算的最大位置,int'''arr = [x[1] for x in sorted(zip(y_scores, y_true), reverse=True)[:k]]weights = np.power(np.log2(range(2,len(arr)+2)), -1)dcg = np.sum(arr*weights)return dcgdef get_ndcg(y_true, y_scores, k):dcg = get_dcg(y_true, y_scores, k)idcg = get_dcg(sorted(y_true), np.arange(len(y_true)), k)return dcg/idcg
3.2 使用sklearn.metrics.ndcg_score
使用方式可参考我的博文sklearn.metrics模块重要API的原理与应用总结 的“DGC和nDCG”一节。
3.3 两种代码的速度比较
随机生成不同数量的样本,查看两个函数的运行时长。
m = 100 # 样本量
y_true = np.random.randint(0,1000000,m).reshape(1,-1)
y_scores = np.random.randint(0,1000000,m).reshape(1,-1)k = 10 # 前k个最高排名的ndcg
ndcg1 = get_ndcg(y_true[0], y_scores[0], k=k)
ndcg2 = ndcg_score(y_true, y_scores, k=k)# 在jupyter中执行
"""
%%timeit
get_ndcg(y_true[0], y_scores[0], k=k)%%timeit
ndcg_score(y_true, y_scores, k=k)
"""
样本量(m) | 自己的代码 | sklearn |
---|---|---|
10000 | 22.7 ms | 5.61 ms |
5000 | 10.5 ms | 3.06 ms |
3000 | 5.99 ms | 1.93 ms |
2000 | 3.79 ms | 1.44 ms |
1000 | 1.81 ms | 0.95 ms |
500 | 896 µs | 711 µs |
400 | 717 µs | 639 µs |
350 | 631 µs | 618 µs |
300 | 547 µs | 588 µs |
250 | 462 µs | 558 µs |
100 | 212 µs | 458 µs |
可以看到,当样本数量较大时,使用sklearn提供的函数速度较。值得注意的是,自己写的代码还没有优化,与复杂度与数据量大致成线性增长,可以使用前k大个数算法进行优化,进一步缩短时长。
3.4 两种代码之间的误差
两种代码之间的误差主要来自于对相同y_scores的值不同的排序位次导致的,比如:
y_scores = [0.9, 0.5, 0.6, 0.9, 0.9]
y_true = [7, 4, 1, 0, 0]
根据y_scores,y_true的排序方式及相应的ndcg为:
相关性排序 | Value |
---|---|
[7,0,0,1,4] | 0.896 |
[0,7,0,1,4] | 0.638 |
[0,0,7,1,4] | 0.547 |
如果y_score保留小数点后6位的话,两个方法的误差将在10−610^{-6}10−6数量级。
4. spark代码实现
(1) import 所需要的库
import os
os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.6.5"
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
spark=SparkSession.builder.appName("appname_202211301352").enableHiveSupport().getOrCreate()
from pyspark.sql.window import Window
import numpy as np
import pandas as pd
(2) 创建spark DataFrame
y_true = np.random.random(10000).reshape(1,-1)
y_score = np.random.random(10000).reshape(1,-1)d1 = {"query":["三亚", "成都", "西安", "舟山"]*2500, 'target':y_true[0], 'score':y_score[0]}
dfA = spark.createDataFrame(pd.DataFrame(d1))
dfA.show(8)"""
输出:
+-----+-------------------+-------------------+
|query| target| score|
+-----+-------------------+-------------------+
| 三亚| 0.5639595701241265| 0.5871782502891383|
| 成都| 0.3831924220224111|0.21391983705716255|
| 西安| 0.8023402614068417| 0.2101881628533363|
| 舟山| 0.8276250390471849| 0.1300697976081132|
| 三亚| 0.8631646566534542|0.09625928182309162|
| 成都| 0.6041016696702783|0.19275281212842987|
| 西安| 0.8804418925939791| 0.6438387067088728|
| 舟山|0.43794448416652887| 0.6674715678386118|
+-----+-------------------+-------------------+
"""
(3) 计算nDCG
def get_ndcg(dfA, k=10):"""dfA: spark DataFrame, 需要包含target和score两个字段"""# 分别求出每个query按target和score进行排序的位次window1 = Window.partitionBy("query").orderBy([F.col('target').desc()])window2 = Window.partitionBy("query").orderBy([F.col('score').desc()])dfB = dfA.withColumn("target_rank", F.row_number().over(window1)) \.withColumn("score_rank", F.row_number().over(window2))# 每个query下的dcg和ndcgdf_dcg = dfB.filter(F.col("score_rank")<=k).groupby("query").agg(F.sum(F.col("target")/F.log2(F.col("score_rank")+1)).alias("dcg"))df_idcg= dfB.filter(F.col("target_rank")<=k).groupby("query").agg(F.sum(F.col("target")/F.log2(F.col("target_rank")+1)).alias("idcg"))df_dcg = df_dcg.withColumnRenamed("query", "query1")cond=[df_dcg.query1==df_idcg.query]df_ndcg = df_dcg.join(df_idcg, on=cond, how='inner') \.select("query", (F.col("dcg")/F.col("idcg")).alias("ndcg"))return df_ndcgdf_ndcg = get_ndcg(dfA, k=10)
df_ndcg.show()"""
输出:
+-----+------------------+
|query| ndcg|
+-----+------------------+
| 三亚|0.6060123512248438|
| 成都|0.4464340922149389|
| 舟山|0.5963748392641535|
| 西安|0.6608559450094559|
+-----+------------------+"""
完。
nDCG笔记及在spark中的实现相关推荐
- <极客时间:零基础入门Spark> 学习笔记(持续更新中...)
看的是极客时间的课,讲得很不错 零基础入门 Spark (geekbang.org) 基础知识 01 Spark:从"大数据的Hello World"开始 准备工作 IDEA安装S ...
- spark中stage的划分与宽依赖/窄依赖(转载+自己理解/整理)
[1]宽依赖和窄依赖,这是Spark计算引擎划分Stage的根源所在,遇到宽依赖,则划分为多个stage,针对每个Stage,提交一个TaskSet: 上图:一张网上的图: (个人笔记,rdd中有多个 ...
- Spark中如何使用矩阵运算间接实现i2i
本文主要包含以下几部分: 1.背景 2.Spark支持的数据类型 2.1 Local Vector(本地向量) 2.2 Labeled point(带标签的点) 2.3 Local Matrix(本地 ...
- 尚硅谷大数据技术Spark教程-笔记01【Spark(概述、快速上手、运行环境、运行架构)】
视频地址:尚硅谷大数据Spark教程从入门到精通_哔哩哔哩_bilibili 尚硅谷大数据技术Spark教程-笔记01[Spark(概述.快速上手.运行环境.运行架构)] 尚硅谷大数据技术Spark教 ...
- Spark中mapToPair和flatMapToPair的区别【附示例源码及运行结果】
本文重点介绍 Spark 中 [mapToPair]和[flatMapToPair]的区别,请继续看到尾部,后续有示例说明,会理解更加清晰. 函数原型 1.JavaPairRDD<K2,V2&g ...
- Spark中的内存计算是什么?
由于计算的融合只发生在 Stages 内部,而 Shuffle 是切割 Stages 的边界,因此一旦发生 Shuffle,内存计算的代码融合就会中断. 在 Spark 中,内存计算有两层含义: 第一 ...
- Java查询spark中生成的文件_java+spark-sql查询excel
Spark官网下载Spark 下载Windows下Hadoop所需文件winutils.exe 同学们自己网上找找吧,这里就不上传了,其实该文件可有可无,报错也不影响Spark运行,强迫症可以下载,本 ...
- Spark中Task,Partition,RDD、节点数、Executor数、core数目(线程池)、mem数
Spark中Task,Partition,RDD.节点数.Executor数.core数目的关系和Application,Driver,Job,Task,Stage理解 from:https://bl ...
- Spark中常用的算法
Spark中常用的算法: 3.2.1 分类算法 分类算法属于监督式学习,使用类标签已知的样本建立一个分类函数或分类模型,应用分类模型,能把数据库中的类标签未知的数据进行归类.分类在数据挖掘中是一项重要 ...
最新文章
- JavaScript一步一步:JavaScript 对象和HTML DOM 对象
- 和rna用什么鉴定_RNA-seq:测序原理之文库构建
- boost::coroutine2模块实现相同的边缘的测试程序
- django-oscar接入paypal的时候提示Error 10001 - Internal Error
- Android左右声道控制软件,Android左右声道的控制
- [收藏]REST -维基百科
- Annotation版本的HelloWorld
- 【BZOJ - 3993】星际战争(网络流最大流+二分)
- MySQL 创建用户
- 我的内容管理系统(CMS)寻找历程 -- Mambo出鞘,谁与争锋?
- 区块链教程Fabric1.0源代码分析流言算法Gossip服务端一兄弟连区块链教程
- 计算机小写换大写函数,Excel函数公式应用:小写数字转换成人民币大写9种方法-excel技巧-电脑技巧收藏家...
- 科学计算机中溢出是指,算术溢出
- matlab 扫雷小游戏
- fabs linux头文件,fabs(c语言fabs函数用法求精度)
- python3手动配置环境变量
- Rebbitmq-3-SpringBoot整合
- 1 海康视觉平台VisionMaster 上手系列: 开篇
- JAVA使用OUTLOOK发送邮件[451 5.7.3 STARTTLS is required to send mail]
- Ubuntu 16.04安装Zimbra邮件服务器
热门文章
- Android插件化的思考——仿QQ一键换肤,思考比实现更重要!
- 汤姆大叔的深入理解JavaScript读后感三(设计模式篇)
- 动态数组_栈的应用之十进制与十六进制的转换
- 图片预加载的几种方式
- mysql 连表查询_mysql数据库之联表查询
- JSP环境美容服务公司网站
- WordPress是什么?我也想用 WordPress~
- linux-文件授权命令chmod
- vn的可变数据类型_casting - 是否有任何编程语言可以禁止对返回类型进行类转换? - SO中文参考 - www.soinside.com...
- 如何通过快照进行数据备份?