kldiv#

scio.scores.utils.kldiv(inputs, expected)[source]#

KL div for (potentially batched) \(1\)D inputs.

Computes \(D_{\text{KL}}(\text{inputs}\Vert\text{expected})\) where inputs may be batched. Input samples and expected should be vectors of same length, in the probability space (up to rescaling).

It is essentially a wrapper for intuitive use of torch.nn.functional.kl_div(), which should be preferred in case of numerical instability, as it can operate in \(\text{log}\) space.

Parameters:
  • inputs (Tensor) – Batched samples, in probability space (up to rescaling). Shape (*batch_shape, space_size).

  • expected (Tensor) – Expectation in probability space (up to rescaling). Shape (space_size,).

Returns:

div (Tensor) – The sample-wise divergence. Shape batch_shape. If expected is invalid (e.g. contains nan), returns all nan. The same is true individually for each input sample. The returned div.dtype is torch.result_type(inputs, expected) if at least one of inputs of expected is of floating type, torch.float otherwise.

Raises:

ValueError – If expected is a scalar.

Example

>>> inputs = torch.tensor([
...     [0, 10, 20],
...     [0, 2, 2],
...     [0, 0.5, 0.5],
...     [0, 1, 0],
...     [1, 1, 2],
...     [0, -1, 2],
...     [0, 1, torch.nan],
... ])
>>> expected = torch.tensor([0, 1, 2])
>>> kldiv(inputs, expected)
tensor([0.0000, 0.0589, 0.0589, 1.0986,    inf,    nan,    nan])