第一句子网 - 唯美句子、句子迷、好句子大全
第一句子网 > softmax分类器_Softmax 理解

softmax分类器_Softmax 理解

时间:2023-11-19 12:32:10

相关推荐

softmax分类器_Softmax 理解

Softmax深入理解[译] - AIUAI​

Pytorch的交叉熵nn.CrossEntropyLoss在训练阶段,里面是内置了softmax操作的,因此只需要喂入原始的数据结果即可,不需要在之前再添加softmax层。

Softmax 是非线性函数,主要用于 multi-class classification 任务中分类器的输出端.

import numpy as npdef softmax(x):exp_x = np.exp(x) # ps 下面结果中的e+01 是科学计数法 e+01 = 10#print(exp_x) # [2.20264658e+04 7.38905610e+00 2.35385267e+17 5.45981500e+01]sum_exp_x = np.sum(exp_x)sm_x = exp_x/sum_exp_xreturn sm_xx = np.array([10, 2, 40, 4])print(np.exp(1)) # 2.718281828459045print(softmax(x))# [9.35762297e-14 3.13913279e-17 1.00000000e+00 2.31952283e-16]

2.Softmax 的数值稳定性

#softmax 的概率值可以看出,当元素值范围非常大时,容易出现数值不稳定性.# 比如,修改上面向量的第三个元素值为 10000,并重新计算 softmaxx = np.array([10, 2, 10000, 4])print(softmax(x))#[0.0, 0.0, nan, 0.0]#nan 表示 not-a-number,往往出现在过拟合(overflow) 和 欠拟和(underflow) 中. #但是,Softmax 为什么会输出这样的结果呢?是不能得到向量的概率分布吗?# 答案: 一个非常大的数值的指数会是非常、非常大的值,如 , 导致过拟合.

def softmax(x):max_x = np.max(x) # 最大值exp_x = np.exp(x - max_x)sum_exp_x = np.sum(exp_x)sm_x = exp_x/sum_exp_xreturn sm_xx = np.array([10, 2, 10000, 4])print(softmax(x)) #[0., 0., 1., 0.]# 可以看出,nan 问题解决了

3. Log Softmax

Softmax 计算的一个关键评估显示了指数计算和除法计算的模式. 是否可以简化这些计算呢?可以通过优化 log softmax 来代替. 其具有如下更优的特点:[1] - 数值稳定性[2] - log softmax 的梯度计算为加法计算,因为log(a/b)=log(a)-log(b)[3] - 除法和乘法计算被转换成加法,更少的计算量和计算成本[4] - log 函数是单调递增函数,可以更好的利用该特点.softmax 和 log softmax 的计算:x = np.array([10, 2, 10000, 4])print(softmax(x)) #[0., 0., 1., 0.]print(np.log(softmax(x))) #[-inf, -inf, 0., -inf]#回到数值稳定性问题,实际上,log softmax 数值欠拟合问题: 为什么会这样? ---> 在对每个元素计算 log 计算时, log(0)是未定义的.

4. Log-Softmax 变形

import numpy as npdef logsoftmax(x, recover_probs=True):# LogSoftMax Implementationmax_x = np.max(x)exp_x = np.exp(x - max_x)sum_exp_x = np.sum(exp_x)log_sum_exp_x = np.log(sum_exp_x)max_plus_log_sum_exp_x = max_x + log_sum_exp_xlog_probs = x - max_plus_log_sum_exp_x # 避免了数值不稳定,减少了计算量# Recover probsif recover_probs:exp_log_probs = np.exp(log_probs)sum_log_probs = np.sum(exp_log_probs)probs = exp_log_probs / sum_log_probsreturn probsreturn log_probsx = np.array([10, 2, 10000, 4])print(logsoftmax(x,recover_probs=False)) # [-9990. -9998. 0. -9996.]print(logsoftmax(x, recover_probs=True)) # [0. 0. 1. 0.]

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。