Note
Go to the end to download the full example code.
Inferring with Confidence#
This tutorial shows how to quickly setup confidence scores for
inference in a classification setup, using algorithms already
implemented in
scio.scores.
Let’s start with preparing a trained model and some (fake) calibration data. Both should be naturally defined by your own use-case. For this tutorial, we use a lightweight Tiniest architecture trained on CIFAR10 and hosted on our hub, and use random calibration data.
# These should be defined by your use-case
import torch
sample_shape = (3, 32, 32)
calib_data = torch.rand(10, *sample_shape)
calib_labels = torch.randint(0, 10, calib_data.shape)
net = torch.hub.load("ThalesGroup/scio:hub", "tiniest", trust_repo=True, verbose=False)
net = net.to(calib_data)
1. Choose an algorithm & define internal representations#
Let us choose a confidence score algorithm implemented in scio.scores, say
KNN. The only required parameter for this method
is k, defining which neighbors to look for in latent spaces. As a
rule of thumb, the scientific community uses
\((n_{\text{calib}})^{0.4}\).
The authors of the KNN paper recommend using the
feature maps (penultimate layer) as internal representations, so we
will set this up when creating our Recorder
net (responsible for capturing embeddings, see
Diving inside Neural Networks).
from scio.recorder import Recorder
from scio.scores import KNN
score = KNN(k=int(len(calib_data) ** 0.4))
rnet = Recorder(net, input_data=calib_data[[0]])
rnet # Visualize layers
Recorder instance for the following network
============================================================================================================================================
Layer (type (var_name):depth-idx) Input Shape Output Shape Param # Param %
============================================================================================================================================
Tiniest (Tiniest) [1, 3, 32, 32] [1, 10] -- --
├─Conv2d (conv1): 1-1 [1, 3, 32, 32] [1, 48, 32, 32] 1,344 1.38%
├─LayerNorm2d (ln1): 1-2 [1, 48, 32, 32] [1, 48, 32, 32] -- --
│ └─LayerNorm (ln): 2-1 [1, 32, 32, 48] [1, 32, 32, 48] 96 0.10%
├─Block (l1): 1-3 [1, 48, 32, 32] [1, 48, 32, 32] 48 0.05%
│ └─Conv2d (dwconv1): 2-2 [1, 12, 32, 32] [1, 12, 32, 32] 120 0.12%
│ └─Conv2d (dwconv2): 2-3 [1, 12, 32, 32] [1, 12, 32, 32] 600 0.62%
│ └─Conv2d (dwconv3): 2-4 [1, 12, 32, 32] [1, 12, 32, 32] 600 0.62%
│ └─LayerNorm2d (ln): 2-5 [1, 48, 32, 32] [1, 48, 32, 32] -- --
│ │ └─LayerNorm (ln): 3-1 [1, 32, 32, 48] [1, 32, 32, 48] 96 0.10%
│ └─Conv2d (fc1): 2-6 [1, 48, 32, 32] [1, 96, 32, 32] 4,704 4.82%
│ └─Conv2d (fc2): 2-7 [1, 48, 32, 32] [1, 48, 32, 32] 2,352 2.41%
├─Block (l2): 1-4 [1, 48, 32, 32] [1, 48, 32, 32] 48 0.05%
│ └─Conv2d (dwconv1): 2-8 [1, 12, 32, 32] [1, 12, 32, 32] 120 0.12%
│ └─Conv2d (dwconv2): 2-9 [1, 12, 32, 32] [1, 12, 32, 32] 600 0.62%
│ └─Conv2d (dwconv3): 2-10 [1, 12, 32, 32] [1, 12, 32, 32] 600 0.62%
│ └─LayerNorm2d (ln): 2-11 [1, 48, 32, 32] [1, 48, 32, 32] -- --
│ │ └─LayerNorm (ln): 3-2 [1, 32, 32, 48] [1, 32, 32, 48] 96 0.10%
│ └─Conv2d (fc1): 2-12 [1, 48, 32, 32] [1, 96, 32, 32] 4,704 4.82%
│ └─Conv2d (fc2): 2-13 [1, 48, 32, 32] [1, 48, 32, 32] 2,352 2.41%
├─Block (l3): 1-5 [1, 48, 32, 32] [1, 48, 32, 32] 48 0.05%
│ └─Conv2d (dwconv1): 2-14 [1, 12, 32, 32] [1, 12, 32, 32] 120 0.12%
│ └─Conv2d (dwconv2): 2-15 [1, 12, 32, 32] [1, 12, 32, 32] 600 0.62%
│ └─Conv2d (dwconv3): 2-16 [1, 12, 32, 32] [1, 12, 32, 32] 600 0.62%
│ └─LayerNorm2d (ln): 2-17 [1, 48, 32, 32] [1, 48, 32, 32] -- --
│ │ └─LayerNorm (ln): 3-3 [1, 32, 32, 48] [1, 32, 32, 48] 96 0.10%
│ └─Conv2d (fc1): 2-18 [1, 48, 32, 32] [1, 96, 32, 32] 4,704 4.82%
│ └─Conv2d (fc2): 2-19 [1, 48, 32, 32] [1, 48, 32, 32] 2,352 2.41%
├─Conv2d (dsconv): 1-6 [1, 48, 32, 32] [1, 80, 16, 16] 3,920 4.02%
├─LayerNorm2d (ln2): 1-7 [1, 80, 16, 16] [1, 80, 16, 16] -- --
│ └─LayerNorm (ln): 2-20 [1, 16, 16, 80] [1, 16, 16, 80] 160 0.16%
├─Block (l4): 1-8 [1, 80, 16, 16] [1, 80, 16, 16] 80 0.08%
│ └─Conv2d (dwconv1): 2-21 [1, 20, 16, 16] [1, 20, 16, 16] 200 0.21%
│ └─Conv2d (dwconv2): 2-22 [1, 20, 16, 16] [1, 20, 16, 16] 1,000 1.03%
│ └─Conv2d (dwconv3): 2-23 [1, 20, 16, 16] [1, 20, 16, 16] 1,000 1.03%
│ └─LayerNorm2d (ln): 2-24 [1, 80, 16, 16] [1, 80, 16, 16] -- --
│ │ └─LayerNorm (ln): 3-4 [1, 16, 16, 80] [1, 16, 16, 80] 160 0.16%
│ └─Conv2d (fc1): 2-25 [1, 80, 16, 16] [1, 160, 16, 16] 12,960 13.29%
│ └─Conv2d (fc2): 2-26 [1, 80, 16, 16] [1, 80, 16, 16] 6,480 6.64%
├─Block (l5): 1-9 [1, 80, 16, 16] [1, 80, 16, 16] 80 0.08%
│ └─Conv2d (dwconv1): 2-27 [1, 20, 16, 16] [1, 20, 16, 16] 200 0.21%
│ └─Conv2d (dwconv2): 2-28 [1, 20, 16, 16] [1, 20, 16, 16] 1,000 1.03%
│ └─Conv2d (dwconv3): 2-29 [1, 20, 16, 16] [1, 20, 16, 16] 1,000 1.03%
│ └─LayerNorm2d (ln): 2-30 [1, 80, 16, 16] [1, 80, 16, 16] -- --
│ │ └─LayerNorm (ln): 3-5 [1, 16, 16, 80] [1, 16, 16, 80] 160 0.16%
│ └─Conv2d (fc1): 2-31 [1, 80, 16, 16] [1, 160, 16, 16] 12,960 13.29%
│ └─Conv2d (fc2): 2-32 [1, 80, 16, 16] [1, 80, 16, 16] 6,480 6.64%
├─Block (l6): 1-10 [1, 80, 16, 16] [1, 80, 16, 16] 80 0.08%
│ └─Conv2d (dwconv1): 2-33 [1, 20, 16, 16] [1, 20, 16, 16] 200 0.21%
│ └─Conv2d (dwconv2): 2-34 [1, 20, 16, 16] [1, 20, 16, 16] 1,000 1.03%
│ └─Conv2d (dwconv3): 2-35 [1, 20, 16, 16] [1, 20, 16, 16] 1,000 1.03%
│ └─LayerNorm2d (ln): 2-36 [1, 80, 16, 16] [1, 80, 16, 16] -- --
│ │ └─LayerNorm (ln): 3-6 [1, 16, 16, 80] [1, 16, 16, 80] 160 0.16%
│ └─Conv2d (fc1): 2-37 [1, 80, 16, 16] [1, 160, 16, 16] 12,960 13.29%
│ └─Conv2d (fc2): 2-38 [1, 80, 16, 16] [1, 80, 16, 16] 6,480 6.64%
├─AdaptiveAvgPool2d (avgpool): 1-11 [1, 80, 16, 16] [1, 80, 1, 1] -- --
├─Linear (fc): 1-12 [1, 80] [1, 10] 810 0.83%
============================================================================================================================================
Total params: 97,530
Trainable params: 97,530
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 44.73
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 9.05
Params size (MB): 0.39
Estimated Total Size (MB): 9.45
============================================================================================================================================
Currently recording: None
============================================================================================================================================
rnet.record((1, 11)) # Record the feature maps
2. Calibrate the score function#
To obtain a usable confidence score function, we simply need to
calibrate the score instance using the
fit() method:
score.fit(rnet, calib_data, calib_labels)
KNN[fit](act_norm=2.0, mode='raw', k=2, index_metric='l2')
For example in our KNN case, this step
internally creates a search index populated with the embeddings of
calibration samples, for efficient query of \(k^{\text{th}}\)
nearest neighbor during inference.
3. Infer with confidence!#
We can now use the score instance to infer with confidence scores!
test_data = torch.rand(5, *sample_shape)
out, conformity = score(test_data)
The first tensor out is (almost) exactly what net(test_data)
would have yielded without confidence scores. It corresponds to the
logits for every sample and every class.
The conformity tensor has the same shape as out: it
represents, for every sample, the conformity assigned to every
class (some methods provide class-specific results, unlike
KNN).
preds = out.argmax(1)
confs = conformity[:, 0] # KNN is constant across classes
for i, (pred, conf) in enumerate(zip(preds, confs, strict=True)):
print(f"Sample {i}: predicted {pred} with confidence {conf}")
Sample 0: predicted 6 with confidence -0.05027196183800697
Sample 1: predicted 6 with confidence -0.08342856913805008
Sample 2: predicted 6 with confidence -0.05249234288930893
Sample 3: predicted 6 with confidence -0.10877114534378052
Sample 4: predicted 5 with confidence -0.07776926457881927
Note, as mentioned in Confidence Scores, that scores are not necessarily between \(0\) and \(1\), and that only order matters.
If you wish to use these scores to perform OoD Detection and compare different algorithms, read Visualizing & Evaluating OoD Detection algorithms.
[Bonus] Profiling#
The timer attribute contains
information about execution times. Querying its
report attribute shows a complete
report over the object’s lifetime.
score.timer.report # Report execution times
ScoreTimer report for
KNN(act_norm=2.0, index_metric='l2',
k=2, mode='raw')
at 0x7e223b376270
┏━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━┓
┃ Operation ┃ # samples ┃ Duration ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━┩
│ inference │ 5 │ 37.49 ms │
│ calibration │ 10 │ 85.50 ms │
╰─────────────┴───────────┴──────────╯
Entries are listed from newest to
oldest
Total running time of the script: (0 minutes 0.498 seconds)