iVoid's Blog

Stay hungry, Stay foolish

概率值校正

大多数的分类模型,得到的预测结果仅有定序意义,而不能够定量。很多情况下,仅仅得到一个好的AUC值是远远不够的,我们需要得到一个准确的概率值。这就要求,模型的输出结果从定序上升为定距。 有两种方法可以实现由定序到定距:普拉托变换(Platt Scaling)和保序回归(Isotonic Regression).Platt Scaling的适用条件较为严格,他仅适用于被扭曲的预测结果是sigmoid的模型;Isotonic Regression的适用条件较为宽松,它只要预测结果是单调的。不幸的是,力量需要代价:相比Platt Scaling, Isotonic Regression更容易过拟合,尤其是当训练数据集稀少的时候。

Platt Scaling

普拉托在1999年提出可以通过sigmoid函数来讲SVM的预测结果转化为一个后验的概率值。主要分为三步:

1)sigmoid变换

假设SVM的输出结果为, 为了得到校准之后的概率值,我们对进行变换

其中参数A和B为参数。

2)采用极大似然法求解A和B

假设 为模型的预测值,为真实结果,则对于训练集,极大似然函数为

其中。 为了计算方便,我们对极大似然函数取对数,并将变为,则原问题变为

3) 为了防止过拟合,对y值进行校正

为了防止过拟合,在极大似然函数中,y值并不是简单的0或者1. 假设训练集中有个正样本,个负样本,则普拉托变换采用 代替1和0.

Isotonic Regression

对于给定的训练集,其中为模型的预测值(为正样本的概率),为真实分类。Isotonic Regression寻找变换,使得

Isotonic Regression的一个最为广泛的实现是Pool Adjacent Violators算法,简称PAV算法,算法流程如下图。

PAV算法伪代码

时间复杂度,空间复杂度

下图为PAV算法在15个样本(6个负样本,9个正样本)上的运行示例。

PAV算法运行实例

我们预估得到的概率值必须具有单调性,即Score值越大,预估概率值也应当越大,PAV算法的主要思想就是通过不断合并、调整违反单调性的局部区间,使得最终得到的区间满足单调性。具体过程如下: 1. PAV算法首先将所有样本按照Score的值由大到小降序排列;然后,将每个样本划分为15个独立的区间

此时,每个区间内只有一个样本,对包含负样本(0)的区间赋予概率值0,对包含正样本(1)的区间赋予概率值1;

  1. PAV由底向上()寻找相邻的两个违反单调性的区间,第一对违反单调性的区间出现在两个区间,PAV将这两个区间合并,并采用两个区间中较小的序号为新建区间重新命名,记为,将新的内所有元素的概率值求平均,该区间内每个样本的预估概率值为1/2;

  2. PAV继续向后查找,发现下两个违反单调性的区间是,将这两个区间合并,命名为,内的每个元素的概率值现在变为1/3。

  3. 重复上述动作,直至剩下的所有区间都满足单调性要求。

附上一个我自己用Python实现的PAV。

 1 import sys
 2     class Node:
 3 	def __init__(self, start, end, prob, total_cases, next):
 4         	self.start = start
 5 		self.end = end
 6 		self.prob = prob
 7 		self.total_cases = total_cases
 8 		self.next = next
 9 
10     def main(train_file):
11     '''输入文件格式uid,待校验的概率值(比如正样本概率),真实结果(0或者1)
12     !!!ATTENTION!!! 需要输入文件已经按照待校验的概率值排序'''
13 	head = None
14 	for line in open(train_file,'r'):
15 		vec = line.split('\t')
16 		if not head:
17 			head = Node(float(vec[0]), float(vec[0]), float(vec[1]), 1, None)
18 			last_node = head
19 		else:
20 			last_node.next = Node(float(vec[0]), float(vec[0]), float(vec[1]), 1, None)
21 			last_node = last_node.next
22 
23 	completed = False
24 	while not completed:
25 		completed = True
26 		iter = head
27 		while iter.next:
28 			if iter.prob >= iter.next.prob:
29 				iter.end = iter.next.end
30 				iter.prob = (iter.prob * iter.total_cases + iter.next.prob * iter.next.total_cases) /(iter.total_cases + iter.next.total_cases)
31 				iter.total_cases = iter.total_cases + iter.next.total_cases
32 				iter.next = iter.next.next
33 				completed = False
34 				break
35 			iter = iter.next
36 
37 	iter = head
38 	while iter:
39 		print iter.start, iter.end, iter.prob
40 		iter = iter.next
blog comments powered by Disqus