Stay hungry, Stay foolish
大多数的分类模型,得到的预测结果仅有定序意义,而不能够定量。很多情况下,仅仅得到一个好的AUC值是远远不够的,我们需要得到一个准确的概率值。这就要求,模型的输出结果从定序上升为定距。 有两种方法可以实现由定序到定距:普拉托变换(Platt Scaling)和保序回归(Isotonic Regression).Platt Scaling的适用条件较为严格,他仅适用于被扭曲的预测结果是sigmoid的模型;Isotonic Regression的适用条件较为宽松,它只要预测结果是单调的。不幸的是,力量需要代价:相比Platt Scaling, Isotonic Regression更容易过拟合,尤其是当训练数据集稀少的时候。
普拉托在1999年提出可以通过sigmoid函数来讲SVM的预测结果转化为一个后验的概率值。主要分为三步:
假设SVM的输出结果为, 为了得到校准之后的概率值,我们对进行变换
其中参数A和B为参数。
假设 为模型的预测值,为真实结果,则对于训练集,极大似然函数为
其中。 为了计算方便,我们对极大似然函数取对数,并将变为,则原问题变为
为了防止过拟合,在极大似然函数中,y值并不是简单的0或者1. 假设训练集中有个正样本,个负样本,则普拉托变换采用 和 代替1和0.
对于给定的训练集,其中为模型的预测值(为正样本的概率),为真实分类。Isotonic Regression寻找变换,使得
Isotonic Regression的一个最为广泛的实现是Pool Adjacent Violators算法,简称PAV算法,算法流程如下图。
时间复杂度,空间复杂度
下图为PAV算法在15个样本(6个负样本,9个正样本)上的运行示例。
我们预估得到的概率值必须具有单调性,即Score值越大,预估概率值也应当越大,PAV算法的主要思想就是通过不断合并、调整违反单调性的局部区间,使得最终得到的区间满足单调性。具体过程如下: 1. PAV算法首先将所有样本按照Score的值由大到小降序排列;然后,将每个样本划分为15个独立的区间
此时,每个区间内只有一个样本,对包含负样本(0)的区间赋予概率值0,对包含正样本(1)的区间赋予概率值1;
PAV由底向上()寻找相邻的两个违反单调性的区间,第一对违反单调性的区间出现在和两个区间,PAV将这两个区间合并,并采用两个区间中较小的序号为新建区间重新命名,记为,将新的内所有元素的概率值求平均,该区间内每个样本的预估概率值为1/2;
PAV继续向后查找,发现下两个违反单调性的区间是和,将这两个区间合并,命名为,内的每个元素的概率值现在变为1/3。
重复上述动作,直至剩下的所有区间都满足单调性要求。
附上一个我自己用Python实现的PAV。
1 import sys
2 class Node:
3 def __init__(self, start, end, prob, total_cases, next):
4 self.start = start
5 self.end = end
6 self.prob = prob
7 self.total_cases = total_cases
8 self.next = next
9
10 def main(train_file):
11 '''输入文件格式uid,待校验的概率值(比如正样本概率),真实结果(0或者1)
12 !!!ATTENTION!!! 需要输入文件已经按照待校验的概率值排序'''
13 head = None
14 for line in open(train_file,'r'):
15 vec = line.split('\t')
16 if not head:
17 head = Node(float(vec[0]), float(vec[0]), float(vec[1]), 1, None)
18 last_node = head
19 else:
20 last_node.next = Node(float(vec[0]), float(vec[0]), float(vec[1]), 1, None)
21 last_node = last_node.next
22
23 completed = False
24 while not completed:
25 completed = True
26 iter = head
27 while iter.next:
28 if iter.prob >= iter.next.prob:
29 iter.end = iter.next.end
30 iter.prob = (iter.prob * iter.total_cases + iter.next.prob * iter.next.total_cases) /(iter.total_cases + iter.next.total_cases)
31 iter.total_cases = iter.total_cases + iter.next.total_cases
32 iter.next = iter.next.next
33 completed = False
34 break
35 iter = iter.next
36
37 iter = head
38 while iter:
39 print iter.start, iter.end, iter.prob
40 iter = iter.next