mindspore.mint.nn.KLDivLoss
- class mindspore.mint.nn.KLDivLoss(reduction='mean', log_target=False)[source]
Computes the Kullback-Leibler divergence between the input and the target.
For tensors of the same shape \(x\) and \(y\), the updating formulas of KLDivLoss algorithm are as follows,
\[L(x, y) = y \cdot (\log y - x)\]Then,
\[\begin{split}\ell(x, y) = \begin{cases} L(x, y), & \text{if reduction} = \text{'none';}\\ \operatorname{mean}(L(x, y)), & \text{if reduction} = \text{'mean';}\\ \operatorname{sum}(L(x, y)) / x.\operatorname{shape}[0], & \text{if reduction} = \text{'batchmean';}\\ \operatorname{sum}(L(x, y)), & \text{if reduction} = \text{'sum'.} \end{cases}\end{split}\]where \(x\) represents input, \(y\) represents target, and \(\ell(x, y)\) represents the output.
Note
The output aligns with the mathematical definition of Kullback-Leibler divergence only when reduction is set to
'batchmean'
.- Parameters
- Inputs:
input (Tensor) - The input Tensor. The data type must be float16, float32 or bfloat16(only supported by Atlas A2 training series products).
target (Tensor) - The target Tensor which has the same type as input. The shapes of target and input should be broadcastable.
- Outputs:
Tensor, has the same dtype as input. If reduction is
'none'
, then output has the shape as broadcast result of the input and target. Otherwise, it is a scalar Tensor.
- Raises
TypeError – If neither input nor target is a Tensor.
TypeError – If dtype of input or target is not float16, float32 or bfloat16.
TypeError – If dtype of target is not the same as input.
ValueError – If reduction is not one of
'none'
,'mean'
,'sum'
,'batchmean'
.ValueError – If shapes of target and input can not be broadcastable.
- Supported Platforms:
Ascend
Examples
>>> import mindspore as ms >>> from mindspore import mint >>> import numpy as np >>> input = ms.Tensor(np.array([[0.5, 0.5], [0.4, 0.6]]), ms.float32) >>> target = ms.Tensor(np.array([[0., 1.], [1., 0.]]), ms.float32) >>> loss = mint.nn.KLDivLoss(reduction='mean', log_target=False) >>> output = loss(input, target) >>> print(output) -0.225