Focal Loss——从直觉到实现
问题(Why?)
做机器学习分类问题,难免遇到Biased-Data-Problem,例如
- CV的目标检测问题: 绝大多数检测框里都是 backgroud
- NLP的异常文本检测: 绝大多数文本都是 normal
对此,以下套路可以缓解:
- 升/降采样, 或者调整样本权重
- 换个更鲁棒的loss函数 ,或者加正则
- 集成模型: Bagging,RandomForest …
- 利于外部先验知识:预训练 + 微调
- 多任务联合学习(multi-task,joint learning)
- … (以上概念纯属经验总结,既不完备也不互斥)
今天要聊的就是一种针对该问题精心设计的loss函数——Focal Loss
。
现状
先来回顾一下常用的 BinaryCrossEntropyLoss
公式如下:
不难看出,CE是个“笨学生”。
考前复习的时候,他不会划重点,对所有知识点 “一视同仁”。
如果教科书上有100道例题,包括: 90道加减乘除 + 10道 三角函数。CE同学就会吭哧吭哧的“平均用力”反复练习这100道例题,结果可想而知——他会精通那90道个位数加减乘除题目,然后其他题目基本靠蒙。那10道他不会的题,往往还是分值高的压轴题。
嗯,大概就是这么个症状。
解决办法
机智如你,想必已经有办法了 —— 给他指个方向,别再“平均用力”就好了
方法一、分科复习
每个【科目】的难度是不同的; 你要花 30%的精力在四则运算,70%的精力在三角函数。 — 老师告诉CE同学 第一个技巧
对应到公式中,就是针对每个类别赋予不同的权重,即下述$\alpha_{t}$:
这是个简单粗暴有效的办法。
方法二、刷题战术
每道【题目】的难度是不同的; 你要根据以往刷类似题时候的正确率来合理分配精力。
— 老师告诉CE同学 第二个技巧
观察CE中的$p_{t}$,它反映了模型对这个样本的识别能力(即 “这个知识点掌握得有多好”);显然,对于$p_t$越大的样本,我们越要打压它对loss的贡献。
FL是这么定义的:
这里有个超参数$\gamma$; 直观来看,$\gamma$越大,打压越重。如下图所示:
- 横轴是$p_t$, 纵轴是$FL(p_t)$
- 总体来说,所有曲线都是单调下降的,即 “掌握越好的知识点越省力”
- 当 $\gamma = 0$ 时,FL退化成CE,即蓝色线条
- 当 $\gamma$ 很大时,线条逐步压低到绿色位置,即各样本对于总loss的贡献受到打压;中间靠右区段承压尤其明显
方法三、综合上述两者
代码
Keras实现
from keras import backend as K
def focal_loss(alpha=0.75, gamma=2.0):
""" 参考 https://blog.csdn.net/u011583927/article/details/90716942 """
def focal_loss_fixed(y_true, y_pred):
# y_true 是个一阶向量, 下式按照加号分为左右两部分
# 注意到 y_true的取值只能是 0或者1 (假设二分类问题),可以视为“掩码”
# 加号左边的 y_true*alpha 表示将 y_true中等于1的槽位置为标量 alpha
# 加号右边的 (ones-y_true)*(1-alpha) 则是将等于0的槽位置为 1-alpha
ones = K.ones_like(y_true)
alpha_t = y_true*alpha + (ones-y_true)*(1-alpha)
# 类似上面,y_true仍然视为 0/1 掩码
# 第1部分 `y_true*y_pred` 表示 将 y_true中为1的槽位置为 y_pred对应槽位的值
# 第2部分 `(ones-y_true)*(ones-y_pred)` 表示 将 y_true中为0的槽位置为 (1-y_pred)对应槽位的值
# 第3部分 K.epsilon() 避免后面 log(0) 溢出
p_t = y_true*y_pred + (ones-y_true)*(ones-y_pred) + K.epsilon()
# 就是公式的字面意思
focal_loss = -alpha_t * K.pow((ones-p_t),gamma) * K.log(p_t)
return focal_loss_fixed
model = ...
model.compile(..., loss=focal_loss(gamma=3, alpha=0.5))
PyTorch二分类
class BCEFocalLoss(torch.nn.Module):
def __init__(self, gamma=2, alpha=0.25, reduction='mean'):
super(BCEFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, predict, target):
pt = torch.sigmoid(predict) # sigmoide获取概率
#在原始ce上增加动态权重因子,注意alpha的写法,下面多类时不能这样使用
loss = - self.alpha * (1 - pt) ** self.gamma * target * torch.log(pt)
- (1 - self.alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
if self.reduction == 'mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss
PyTorch多分类
class MultiCEFocalLoss(torch.nn.Module):
def __init__(self, class_num, gamma=2, alpha=None, reduction='mean'):
super(MultiCEFocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, predict, target):
pt = F.softmax(predict, dim=1) # softmmax获取预测概率
class_mask = F.one_hot(target, 5) #获取target的one hot编码
ids = target.view(-1, 1)
alpha = self.alpha[ids.data.view(-1)] # 注意,这里的alpha是给定的一个list(tensor
#),里面的元素分别是每一个类的权重因子
probs = (pt * class_mask).sum(1).view(-1, 1) # 利用onehot作为mask,提取对应的pt
log_p = probs.log()
# 同样,原始ce上增加一个动态权重衰减因子
loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
onehot
也可以用下面三行代码自己实现onehot
ids = target.view(-1, 1)
onehot =torch.zeros_like(P)
onehot.scatter_(1, ids.data, 1.)
调参经验
- $\alpha \in (0,1)$反映了“方法一、分科复习”时,各科目的难度比率;
- 二分类场景下,类似于正例的
sample_weight
概念,可以按照样本占比,适度加权 - e.g. 设有5条正例、95条负例,则建议 $\alpha = 0.95$
- 取 $\alpha = 0.5$ 相当于关掉该功能
- 二分类场景下,类似于正例的
- $\gamma \in [0,+\infty)$ 反映了 “方法二、刷题战术”时,对于难度的区分程度
- 取 $\gamma = 0$ 相当于关掉该功能; 即不考虑难度区别,一视同仁
- $\gamma$ 越大,则越重视难度,即专注于比较困难的样本。建议在 $(0.5,10.0)$ 范围尝试
总结
- 机器学习分类问题中,各类别样本数差距悬殊是很常见的情况;这会干扰模型效果
- 通过将CrossEntropyLoss替换为综合版的FocalLoss,可以有效缓解上述问题
- 具体思路就是引入两个额外的变量来区分对待每个样本
- $\alpha_t$根据类别加权
- $(1-p_t)^{\gamma}$根据难度加权
- 代码实现很简单、调参也不复杂,详见上文