Pytorch里的CrossEntropyLoss详解


在使用Pytorch时经常碰见这些函数cross_entropyCrossEntropyLoss, log_softmax, LogSoftmax。看得我头大,所以整理本文以备日后查阅。

首先要知道上面提到的这些函数一部分是来自于torch.nn,而另一部分则来自于torch.nn.functional(常缩写为F)。二者函数的区别可参见 知乎:torch.nn和funtional函数区别是什么?

下面是对与cross entropy有关的函数做的总结:

torch.nn torch.nn.functional (F)
CrossEntropyLoss cross_entropy
LogSoftmax log_softmax
NLLLoss nll_loss

下面将主要介绍torch.nn.functional中的函数为主,torch.nn中对应的函数其实就是对F里的函数进行包装以便管理变量等操作。

在介绍cross_entropy之前先介绍两个基本函数:

log_softmax

这个很好理解,其实就是logsoftmax合并在一起,同时执行。

nll_loss

该函数的全称是negative log likelihood loss,函数表达式为
$$
f(x,class)=−x[class]f(x,class)=−x[class]
$$
例如假设$x=[1,2,3],class=2$,那么$f(x,class)=−x[2]=−3$。

cross_entropy

交叉熵的计算公式为:
$$
cross_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right)
$$
其中$p$表示真实值,在这个公式中是one-hot形式;$q$是预测值,在这里假设已经是经过softmax后的结果了。

仔细观察可以知道,因为$p$的元素不是0就是1,而且又是乘法,所以很自然地我们如果知道1所对应的index,那么就不用做其他无意义的运算了。所以在PyTorch代码中target不是以one-hot形式表示的,而是直接用scalar表示。所以交叉熵的公式(m表示真实类别)可变形为:
$$
cross_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right)=-log , q_m
$$
仔细看看,是不是就是等同于log_softmaxnll_loss两个步骤。

所以Pytorch中的F.cross_entropy会自动调用上面介绍的log_softmaxnll_loss来计算交叉熵,其计算方式如下:
$$
\operatorname{loss}(x, \text {class})=-\log \left(\frac{\exp (x[\operatorname{class}])}{\sum_{j} \exp (x[j])}\right)
$$
代码示例:

input = torch.randn(3, 5, requires_grad=True)
target = torch.randint(5, (3,), dtype=torch.int64)
loss = F.cross_entropy(input, target)
loss.backward()

文章作者: CarlYoung
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 CarlYoung !
 上一篇
深度学习中的logits是什么? 深度学习中的logits是什么?
在深度学习编码的过程中,常常会遇见一些变量名叫做logits,这个logits到底指代了一个什么东西呢?查阅资料之后,我在Google的machine learning文档中找到了定义: LogitsThe vector of raw (
2021-04-02
下一篇 
Focal Loss——从直觉到实现 Focal Loss——从直觉到实现
Focal Loss——从直觉到实现问题(Why?)做机器学习分类问题,难免遇到Biased-Data-Problem,例如 CV的目标检测问题: 绝大多数检测框里都是 backgroud NLP的异常文本检测: 绝大多数文本都是 nor
2021-03-29
  目录