batched_grad#

scio.scores.utils.batched_grad(outputs, inputs, *, retain_graph=False)[source]#

Compute gradients for batched inputs/outputs.

Parameters:
  • outputs (Tensor) – Shape (n_samples,).

  • inputs (Tensor) – Shape (n_samples, *sample_shape).

  • retain_graph (bool) – Passed to torch.autograd.grad, defaults to False.

Returns:

out (Tensor) – Batched gradients, relative to batched inputs. Same shape as inputs.