make_indexes#

scio.scores.utils.make_indexes(all_samples, all_groups=None, *, n_groups=None, metric, squeeze=True)[source]#

Prepare multiple search indexes.

For now only faiss.IndexFlatL2 or faiss.IndexFlatIP. Note that distances returned when querying these faiss indexes are respectively the squared euclidian distance and the inner product – which is a similarity measure, not a distance. In the latter case, the returned neighbors are still ordered from the “closest” (most similar) to the furthest.

Parameters:
  • all_samples (tuple[Tensor, ...] | Tensor) – If a Tensor, treated as (all_samples,). Tuple of tensors of shape (n_samples, *sample_shape) of common length n_samples. The samples from which to build search indexes, treated as vector samples.

  • all_groups (tuple[Tensor, ...] | Tensor, optional) – If not provided, treated as torch.zeros(n_samples). If a Tensor, treated as (all_groups,). Tuple of tensors of shape (n_samples,). The group every sample belongs to. Values must be nonnegative integers.

  • n_groups (int, optional) – If all_groups is not provided, treated as 1. Else it is required to be an int. Group values in range(n_groups) are considered. Empty indexes are possible (e.g. when there are no samples for a given group value).

  • metric (IndexMetricLike) – See IndexMetric.

  • squeeze (bool) – Whether to apply the postprocessing squeezing steps described in indexes. Defaults to True.

Returns:

indexes (Nested tuples of Index, optionally squeezed) – Every search index. When all_groups and all_samples are tuples, there are len(all_groups) * len(all_samples) * n_groups indexes. From outermost to innermost, nested tuples are respectively along all_groups, all_samples and group values. Finally, the following is applied if squeeze. If all_groups or all_samples are Tensor, their respective tuple is squeezed. Furthermore, if all_groups was not provided the tuple corresponding to group value is squeezed.

Hint

Think of all_samples as layers activations stacked in a tuple along layers for example.

Think of all_groups as (true_labels, pred_labels) for example.

Raises:

ValueError – If n_groups is not an int, despite all_groups being None.

Note

For accurate output type specification, please refer to the associated source stub.