knn_label_count#
- scio.scores.utils.knn_label_count(index, labels, n_classes, k, query, *, self_query=False)[source]#
Count labels of neighbors for batched queries.
- Parameters:
index (
Index) – Search index.labels (
Tensor) – Labels of reference samples used to buildindex. Shape(n_reference,).n_classes (
int) – Number of classes.k (
int) – Number of neighbors to look up.query (
Tensor) – The query samples, not necessarily flattened. Shape(n_query, *sample_shape).self_query (
bool) – SeeIndex.search(). Requires one additional reference sample inindex. Defaults toFalse.
- Returns:
counts (
Tensor) – Class counts amongstknearest neighbors. Shape(n_query, n_classes). Full ofnanifk + self_query > index.ntotal.- Raises:
ValueError – If
labels.shape != (index.ntotal,).ValueError – If
(labels >= n_classes).any().