knn_label_count#

scio.scores.utils.knn_label_count(index, labels, n_classes, k, query, *, self_query=False)[source]#

Count labels of neighbors for batched queries.

Parameters:
  • index (Index) – Search index.

  • labels (Tensor) – Labels of reference samples used to build index. Shape (n_reference,).

  • n_classes (int) – Number of classes.

  • k (int) – Number of neighbors to look up.

  • query (Tensor) – The query samples, not necessarily flattened. Shape (n_query, *sample_shape).

  • self_query (bool) – See Index.search(). Requires one additional reference sample in index. Defaults to False.

Returns:

counts (Tensor) – Class counts amongst k nearest neighbors. Shape (n_query, n_classes). Full of nan if k + self_query > index.ntotal.

Raises:
  • ValueError – If labels.shape != (index.ntotal,).

  • ValueError – If (labels >= n_classes).any().