Diving inside Neural Networks#

This tutorial provides a short practical overview of the Recorder class, which is designed to ease interactions with the internal states of PyTorch Neural Network objects (torch.nn.Module) during or after inference.

from scio.recorder import Recorder

Wrapping and visualizing your Neural Network#

Let us first load an arbitrary Neural Network. We use a lightweight Tiniest architecture trained on CIFAR10 and hosted on our hub. We also prepare future input data for this tutorial.

import torch

inputs = torch.rand(5, 3, 32, 32)  # Random inputs with 5 samples
net = torch.hub.load("ThalesGroup/scio:hub", "tiniest", trust_repo=True, verbose=False)
net = net.to(inputs)

To wrap it into a Recorder Net, rnet, one simply needs to specify an input_size (including batch dimension) or provide input_data. This will directly analyze and store the control flow of the model, using the torchinfo library.

rnet = Recorder(net, input_data=inputs[[0]])
rnet  # Visualize layers
Recorder instance for the following network
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Param %
============================================================================================================================================
Tiniest (Tiniest)                        [1, 3, 32, 32]            [1, 10]                   --                             --
├─Conv2d (conv1): 1-1                    [1, 3, 32, 32]            [1, 48, 32, 32]           1,344                       1.38%
├─LayerNorm2d (ln1): 1-2                 [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    └─LayerNorm (ln): 2-1               [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
├─Block (l1): 1-3                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-2             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-3             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-4             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-5             [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-1          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-6                 [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-7                 [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l2): 1-4                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-8             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-9             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-10            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-11            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-2          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-12                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-13                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l3): 1-5                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-14            [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-15            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-16            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-17            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-3          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-18                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-19                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Conv2d (dsconv): 1-6                   [1, 48, 32, 32]           [1, 80, 16, 16]           3,920                       4.02%
├─LayerNorm2d (ln2): 1-7                 [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    └─LayerNorm (ln): 2-20              [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
├─Block (l4): 1-8                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-21            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-22            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-23            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-24            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-4          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-25                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-26                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l5): 1-9                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-27            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-28            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-29            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-30            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-5          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-31                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-32                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l6): 1-10                       [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-33            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-34            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-35            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-36            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-6          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-37                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-38                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─AdaptiveAvgPool2d (avgpool): 1-11      [1, 80, 16, 16]           [1, 80, 1, 1]             --                             --
├─Linear (fc): 1-12                      [1, 80]                   [1, 10]                   810                         0.83%
============================================================================================================================================
Total params: 97,530
Trainable params: 97,530
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 44.73
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 9.05
Params size (MB): 0.39
Estimated Total Size (MB): 9.45
============================================================================================================================================
Currently recording: None
============================================================================================================================================

Tip

For summary customization, refer to torchinfo.summary options. For example, it is possible to bound the depth of the representation tree with depth=2.

Note

In the case of dynamic control flow, refer to the force_static_flow argument of Recorder.

In many ways, this wrapper is transparent to the user. For example, one can naturally process data with rnet(inputs).

Selecting layers of interest#

The penultimate line of the above summary reports the layers that are currently set to be recorded (stored in rnet.recording). By default after instantiation, there are none.

print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: None

One can arbitrarily set this using rnet.record() with the depth-idx identifiers from the summary (e.g. 1-9). For example, the following specifies that the output of the first Block and the penultimate layer should be recorded.

rnet.record((1, 3), (1, 11))
print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: 1-3, 1-11

Note

Though not shown in the summary, it is possible to select 0-1 to refer to the entire model.

Warning

torchinfo.summary can only detect torch.nn.Module calls. As such, if a Neural Network uses activation functions, it should call their Module implementation (instead of their functional counterpart) and declare them as an attribute, for them to be visible as a layer in the summary. It is not necessary to declare a different attribute for every activation call (e.g. one self.relu = nn.ReLU() can be used multiple times in forward()).

Capturing internal states#

Once the recording layers are set, every forward pass will automatically store the corresponding internal states in the rnet.activations mapping. Its keys are the (depth, idx) 2-tuples.

out = rnet(inputs)  # Forward pass, records the activations
rnet.activations
mappingproxy({(1, 3): tensor([[[[-1.3814e+00,  1.8600e+00, -1.2796e+00,  ..., -6.8463e-01,
            2.6701e+00,  4.5864e+00],
          [ 2.8375e+00,  4.6212e+00,  1.1439e+00,  ..., -8.7338e+00,
           -1.1794e+01,  1.4767e+00],
          [ 9.5286e+00,  2.5366e+00, -2.1981e+00,  ...,  3.9953e+00,
            4.5271e-01, -1.0240e+01],
          ...,
          [ 9.5207e+00,  8.0983e+00,  5.6757e+00,  ...,  5.9793e+00,
           -5.4198e+00, -2.8429e+00],
          [ 5.1171e+00, -1.0076e+01, -9.3729e+00,  ...,  6.0934e+00,
            7.5260e-01,  3.1312e+00],
          [ 7.1745e+00,  4.1425e+00,  4.5234e+00,  ..., -3.3186e+00,
            6.1046e+00,  4.4920e+00]],

         [[ 5.0572e+00, -1.3649e+00, -1.4315e+00,  ..., -1.4473e+00,
           -2.9466e+00,  2.4056e+00],
          [ 6.3755e+00,  1.8466e+00, -1.7824e+00,  ...,  1.2950e+01,
            1.2809e+01,  4.0134e-01],
          [ 4.7586e+00, -1.9841e+00,  3.7532e-01,  ..., -1.4661e+00,
           -8.6045e-01,  6.0137e+00],
          ...,
          [-1.3454e+00, -3.7104e+00, -4.0511e+00,  ...,  6.0435e+00,
            1.3153e+01, -5.0027e-01],
          [-8.9703e-01,  1.9036e+01,  9.6374e+00,  ..., -1.5002e+00,
           -2.1892e+00, -2.8683e+00],
          [ 3.3922e-01,  6.0073e-01, -2.4227e+00,  ...,  4.7101e+00,
            4.0546e+00,  4.4392e+00]],

         [[ 1.5309e+00,  5.0459e+00,  1.1431e+00,  ...,  7.4003e+00,
            7.5176e+00,  2.2057e+00],
          [-9.7765e-01, -2.8043e-02,  1.6510e+00,  ..., -3.7375e+00,
           -4.6361e+00,  1.8146e+00],
          [ 2.5030e+00,  2.0673e+00, -2.9156e-01,  ...,  1.6247e+00,
            4.0193e+00, -2.2134e+00],
          ...,
          [-3.5594e-01,  1.6034e+00,  1.8066e+00,  ...,  2.1969e+00,
           -2.5826e+00, -6.8786e-01],
          [ 6.1411e-01, -5.2621e+00, -5.8353e+00,  ..., -1.2481e+00,
           -1.6014e+00,  2.6520e+00],
          [-4.2272e+00, -5.6984e+00, -4.5745e+00,  ..., -6.8844e+00,
           -5.1420e+00, -3.6021e+00]],

         ...,

         [[-5.7322e+00, -4.1594e-01, -2.0133e+00,  ..., -6.5402e-01,
           -3.0100e+00, -7.9647e+00],
          [-6.7694e+00, -2.2939e+00,  4.0918e+00,  ..., -7.6865e+00,
           -6.4666e+00, -2.4642e+00],
          [-1.2591e+01, -2.5190e+00, -2.0889e+00,  ..., -6.8775e+00,
            2.0600e+00, -1.8746e+00],
          ...,
          [-1.0628e+01, -5.7900e+00, -2.6520e+00,  ..., -8.3595e+00,
           -6.4805e+00,  2.5189e+00],
          [-1.3621e+00, -1.2000e+01, -3.7142e+00,  ..., -7.7702e+00,
           -3.8969e-02, -1.0165e+01],
          [-6.2325e+00, -3.2469e+00, -7.9726e-01,  ..., -5.7161e+00,
           -3.9607e+00, -3.4564e+00]],

         [[-2.7474e+00, -1.3268e+00, -3.6357e+00,  ...,  5.6249e-01,
           -3.0163e+00, -1.2353e+00],
          [-3.6210e+00,  1.3234e+00, -1.4321e+00,  ..., -1.7224e+00,
           -2.8337e+00, -1.8857e+00],
          [-1.3323e+01, -1.3397e+00, -3.5190e+00,  ..., -7.8960e+00,
           -2.0793e+00, -4.9954e+00],
          ...,
          [-4.1338e+00, -4.3460e+00, -3.6217e+00,  ..., -1.0785e+01,
           -2.1481e+00, -1.8808e+00],
          [-1.4733e+00, -8.7111e-01, -4.3669e+00,  ..., -9.9289e+00,
           -3.2654e-01, -1.0477e+01],
          [-6.0820e+00,  7.6525e-01, -1.7387e+00,  ..., -3.2933e+00,
           -4.3554e+00, -4.5810e+00]],

         [[-7.2913e+00,  1.6110e+00,  1.4007e+00,  ...,  2.5330e+00,
            4.5218e+00, -1.3940e-01],
          [-1.1641e+01,  1.2107e+00, -2.3018e-01,  ..., -6.2613e+00,
           -7.6314e+00,  6.2233e+00],
          [-6.7200e+00,  1.2862e+00, -1.1969e+00,  ...,  1.8134e-01,
            2.3498e+00, -3.5478e+00],
          ...,
          [ 8.0215e-01,  3.5425e+00,  3.1008e+00,  ..., -7.4718e+00,
           -7.9743e+00,  3.5217e-02],
          [ 7.9663e-03, -1.1282e+01, -6.0575e+00,  ..., -3.5946e+00,
            3.9941e-01, -1.0554e-01],
          [-3.7506e+00,  2.5578e-01,  4.4922e+00,  ..., -3.3814e+00,
           -6.9205e+00, -8.5543e+00]]],


        [[[-2.7997e+00,  4.9612e+00, -1.0930e+01,  ...,  1.4388e+00,
           -9.7566e-01, -4.1645e+00],
          [ 6.4128e+00,  4.3702e+00, -8.5426e+00,  ...,  4.6609e+00,
            2.0677e-01,  1.0563e+00],
          [ 1.7387e+00,  4.1835e-02,  1.1628e+00,  ...,  5.0198e+00,
            3.0541e+00, -1.0667e+01],
          ...,
          [ 4.5812e+00, -2.3860e+00, -1.0572e+01,  ..., -8.7403e+00,
            2.1202e+00, -1.0105e+01],
          [-5.3226e+00,  3.0809e+00,  4.4157e+00,  ...,  3.0978e+00,
            5.2042e-01, -8.0152e-01],
          [ 5.3460e+00,  6.0196e+00,  3.6563e+00,  ...,  4.3829e+00,
           -6.5481e+00,  3.6770e+00]],

         [[ 1.1323e+01, -1.4397e+00,  1.3587e+01,  ...,  2.2712e+00,
           -7.0213e-01,  6.8064e+00],
          [-2.6686e-01, -2.0717e+00,  1.4094e+01,  ..., -4.5449e-01,
            3.5650e+00, -4.1616e-01],
          [ 1.5873e+00,  1.1489e+00, -1.4001e-01,  ..., -1.0459e-01,
            1.0134e+00,  1.5372e+01],
          ...,
          [ 4.2435e-01, -8.4514e-02,  1.9573e+01,  ...,  1.3096e+01,
           -2.9339e-01,  1.5592e+01],
          [ 1.0565e+01, -1.4513e+00,  6.2372e+00,  ...,  3.7070e+00,
            5.9108e+00,  1.7746e+00],
          [-5.8835e-01,  3.4681e+00,  3.3161e+00,  ...,  1.1358e+00,
            1.0497e+01, -8.6927e-01]],

         [[ 3.5601e+00,  7.7500e+00,  1.0555e+00,  ...,  3.4526e+00,
            5.4212e+00,  4.1915e+00],
          [-9.1893e-02,  7.9895e-01, -7.3801e+00,  ...,  2.6522e+00,
           -8.1608e-02,  2.0742e+00],
          [-2.1261e+00,  1.6051e+00,  2.1891e+00,  ...,  1.4754e+00,
            1.6074e-02, -3.2832e+00],
          ...,
          [-1.3360e+00, -3.9876e+00, -6.7073e+00,  ...,  2.6689e+00,
            1.2149e+00, -4.9108e+00],
          [-3.5790e+00,  7.1246e-01,  6.1450e-03,  ..., -3.8866e+00,
            5.1352e-01, -5.4159e-01],
          [-3.1178e+00, -5.0416e+00, -4.7822e+00,  ..., -9.7147e+00,
           -9.2487e+00, -2.6500e+00]],

         ...,

         [[-9.5878e+00, -8.7002e+00, -5.8944e+00,  ..., -7.4368e-01,
           -1.9612e+00, -5.4044e+00],
          [-2.8311e+00, -8.3496e+00, -7.9494e+00,  ..., -4.4251e+00,
           -2.4272e+00, -7.0375e+00],
          [-2.1593e+00, -1.9787e+00,  1.3324e+00,  ..., -8.5349e+00,
            2.0919e+00, -1.0314e+01],
          ...,
          [-2.8095e+00, -1.3335e+00, -1.1761e+01,  ..., -7.9736e+00,
           -7.6109e+00, -7.3705e+00],
          [-7.0648e+00, -3.8597e+00, -7.6450e+00,  ..., -5.9929e-01,
           -7.8991e+00, -1.6621e+00],
          [-4.8733e-02, -6.1678e+00, -5.1505e+00,  ..., -2.7523e+00,
           -7.5580e+00, -1.0027e+01]],

         [[ 1.4237e+00, -4.2261e+00, -1.5834e+00,  ..., -1.6786e+00,
           -8.9872e-01, -3.4053e+00],
          [ 1.2113e+00, -9.1405e+00, -1.4661e+00,  ..., -2.0021e+00,
           -2.5327e+00, -5.6965e+00],
          [-8.5124e-01,  1.3246e+00, -2.2257e+00,  ..., -4.3420e+00,
           -1.4860e+00, -8.7517e-01],
          ...,
          [ 2.3198e-01, -2.3994e+00,  1.3740e-01,  ...,  8.8430e-01,
           -9.7738e+00, -1.3141e+00],
          [-2.6459e+00, -9.6048e-01, -8.4412e+00,  ...,  8.2699e-01,
           -1.2804e+00,  7.5256e-02],
          [ 8.7823e-01, -8.4639e-01, -5.5828e+00,  ..., -3.8812e+00,
           -3.5241e+00, -1.1436e+01]],

         [[-9.8434e+00,  2.2491e+00, -9.8094e+00,  ..., -1.0420e+00,
           -1.7324e-02, -1.1885e+00],
          [ 1.8917e+00, -1.1749e+00, -1.4480e+01,  ..., -3.5109e+00,
           -2.0118e+00, -6.1199e-01],
          [-1.1223e+00, -2.0223e+00,  7.4275e-01,  ...,  4.3683e-01,
           -1.1194e+00, -1.0861e+01],
          ...,
          [ 1.7008e+00, -1.9578e+00, -1.3817e+01,  ..., -4.8846e+00,
           -3.7437e+00, -9.4763e+00],
          [-6.5823e+00,  3.2142e+00, -6.5143e+00,  ..., -9.9540e-01,
           -7.9715e+00, -8.4695e-01],
          [ 1.0737e+00, -1.1735e+00, -6.3560e+00,  ..., -3.9431e+00,
           -8.7197e+00, -3.0655e+00]]],


        [[[ 6.0075e+00,  2.9829e+00,  5.4720e-01,  ..., -3.7747e+00,
           -1.6704e+00, -6.1667e+00],
          [ 8.7172e+00, -3.8604e+00,  9.4573e-01,  ...,  4.2484e+00,
           -8.8961e+00,  1.3971e+00],
          [ 3.1430e+00,  7.7646e+00,  3.7368e+00,  ...,  5.7664e+00,
           -1.1232e+01, -3.8206e+00],
          ...,
          [ 6.0924e+00,  2.6652e+00,  2.1534e+00,  ...,  3.0394e+00,
            2.6310e+00,  1.4668e+00],
          [ 8.1217e+00, -5.5440e+00,  1.2208e+00,  ..., -6.6198e+00,
            4.3178e+00, -9.9014e+00],
          [ 1.7434e-01,  1.3055e+00,  3.6563e+00,  ..., -3.8642e+00,
            5.1773e+00, -6.8622e+00]],

         [[ 1.0163e+00, -3.3881e+00, -2.8326e+00,  ...,  3.4356e+00,
           -1.2094e+00,  1.1610e+01],
          [-4.3041e+00,  1.5070e+01,  2.4692e-01,  ..., -1.6102e+00,
            1.3389e+01, -5.0282e-01],
          [-5.9519e-01,  3.9021e+00, -1.6933e+00,  ..., -1.7636e+00,
            1.6336e+01,  2.1620e+00],
          ...,
          [-2.9697e+00, -4.9208e-01,  5.4738e+00,  ..., -3.2097e+00,
            3.8336e-01,  7.8879e+00],
          [ 6.4547e+00,  5.7315e+00,  2.4578e+00,  ...,  1.5565e+01,
            6.2951e-01,  1.5040e+01],
          [ 8.5192e-01, -1.4575e+00,  5.9452e+00,  ...,  3.8698e+00,
            7.4964e-01,  1.1794e+01]],

         [[ 2.2816e+00,  6.7094e+00,  8.7219e+00,  ...,  3.1028e+00,
            6.7060e+00,  1.4464e+00],
          [ 1.3398e+00, -3.4836e+00,  1.4776e+00,  ...,  1.1059e+00,
           -5.8610e+00, -2.6909e-02],
          [ 2.1880e+00,  1.3065e+00,  2.0044e+00,  ...,  3.9363e-01,
           -7.4427e+00,  5.9955e-01],
          ...,
          [ 2.7026e+00,  1.0266e+00, -3.9750e+00,  ...,  1.1511e-01,
            2.7212e-02,  3.8934e+00],
          [ 2.3337e+00, -5.0934e+00, -1.0141e+00,  ..., -4.5169e+00,
           -4.3041e-02, -4.4352e+00],
          [-7.6196e+00, -7.3900e+00, -9.8710e+00,  ..., -7.4361e+00,
           -7.3628e+00, -7.6874e+00]],

         ...,

         [[-5.8783e+00, -4.9645e+00, -9.3438e+00,  ..., -4.4777e+00,
           -3.1152e+00, -7.3344e+00],
          [-1.7012e+00, -7.0169e+00, -1.3367e-01,  ..., -5.8335e+00,
           -2.9789e+00, -6.8422e+00],
          [-5.6938e+00, -5.0716e+00, -8.7754e+00,  ..., -1.0062e+01,
           -8.5531e+00,  4.8048e+00],
          ...,
          [-1.4653e-01,  2.5005e-01, -4.5251e+00,  ..., -7.2559e-01,
           -5.5258e+00, -7.6569e+00],
          [-9.8104e+00, -1.0944e+00, -5.5918e+00,  ..., -9.5838e+00,
           -6.3275e-01, -5.1200e+00],
          [-1.4280e+00,  9.8811e-01, -2.6171e+00,  ..., -4.3736e+00,
           -6.1173e+00, -6.9777e+00]],

         [[-4.2896e+00,  5.7543e-02, -6.8408e+00,  ...,  6.9426e-01,
            7.6362e-01, -4.1535e+00],
          [-9.1107e-01, -3.0082e+00,  9.9385e-01,  ..., -4.3142e+00,
           -2.3043e+00, -1.0355e+01],
          [-3.6213e+00, -4.8181e+00, -7.5293e+00,  ..., -7.2462e+00,
           -6.6916e-01, -3.8707e+00],
          ...,
          [-1.7730e-01, -5.9890e-01, -3.5566e+00,  ...,  8.8028e-02,
           -3.8770e+00, -4.8012e+00],
          [-9.4276e+00, -6.2567e+00, -1.9895e+00,  ...,  8.3226e-01,
           -2.4635e+00, -4.0336e+00],
          [-4.6126e+00, -8.1328e-01, -1.8470e+00,  ..., -3.6460e+00,
           -4.5896e+00, -3.0558e+00]],

         [[-1.2493e+00,  2.7069e+00, -1.2138e+00,  ..., -2.4928e+00,
            1.9143e+00, -9.7554e+00],
          [ 4.2818e+00, -8.6038e+00, -5.8478e-01,  ...,  3.3719e+00,
           -4.5103e+00, -2.0950e+00],
          [ 2.7163e+00, -4.9832e+00, -4.6880e-01,  ..., -6.0805e-02,
           -1.1873e+01, -3.1786e+00],
          ...,
          [ 3.3274e+00,  2.5207e+00, -3.6708e+00,  ...,  1.1031e+00,
           -3.9848e-01, -3.7807e+00],
          [-6.3624e+00, -1.5517e+00,  7.5914e-01,  ..., -1.2706e+01,
           -1.3660e+00, -8.4269e+00],
          [-2.3102e+00,  1.8461e+00, -4.0350e+00,  ..., -4.8226e+00,
           -2.2772e+00, -1.1901e+01]]],


        [[[-5.3089e+00,  6.4166e+00, -6.8450e+00,  ...,  3.0484e+00,
            8.8779e-01, -1.4340e+01],
          [ 7.6532e+00,  5.7146e+00, -4.7712e+00,  ..., -6.9933e+00,
            5.2417e+00,  1.9826e+00],
          [ 1.7021e+00,  7.0956e-01, -8.4613e-01,  ...,  7.3653e-01,
            3.6153e+00, -1.2525e+00],
          ...,
          [ 5.9555e+00, -3.5211e+00,  4.5845e+00,  ...,  6.2242e+00,
           -4.9657e+00,  4.3552e+00],
          [-1.5301e+00,  3.3871e+00,  6.9147e-01,  ...,  2.8682e+00,
           -2.4078e+00,  4.0558e+00],
          [-1.9326e+00,  6.1443e-01, -8.8946e+00,  ..., -6.5294e+00,
            4.1390e+00, -1.0230e+01]],

         [[ 1.1683e+01, -3.7293e+00,  8.2093e+00,  ...,  5.8495e-01,
           -1.8615e+00,  1.5324e+01],
          [-3.6798e+00, -1.5685e+00,  9.0611e+00,  ...,  1.5070e+01,
           -5.0984e-01,  5.7873e-01],
          [ 4.4220e+00,  2.7162e+00,  1.8239e+00,  ...,  2.9268e+00,
           -3.4802e+00, -3.2766e+00],
          ...,
          [-2.8489e+00, -1.3073e+00, -3.1676e+00,  ..., -2.0915e-01,
            1.1976e+01, -8.2914e-01],
          [ 6.5069e+00, -1.4891e+00,  1.5762e+00,  ..., -1.7948e+00,
            9.0457e+00, -2.4200e+00],
          [ 9.7356e+00,  2.3982e+00,  1.8995e+01,  ...,  1.5948e+01,
            1.3840e-01,  1.4330e+01]],

         [[-1.0790e+00,  8.0308e+00,  2.9523e+00,  ...,  4.0037e+00,
            5.5350e+00,  9.3414e-02],
          [ 1.3473e-01, -1.3791e+00, -6.1079e+00,  ..., -7.5394e+00,
           -5.9767e+00, -3.7274e+00],
          [-9.6001e-01,  1.9832e+00,  2.4758e+00,  ...,  2.3948e+00,
            1.1087e+00, -1.1710e+00],
          ...,
          [ 1.1001e+00, -3.8215e+00,  1.3172e+00,  ..., -2.3431e+00,
           -2.5002e+00,  2.0965e+00],
          [-2.6912e+00,  4.8765e+00,  6.9675e+00,  ...,  2.4994e+00,
           -5.8789e-01,  5.0264e+00],
          [-5.6265e+00, -6.2234e+00, -9.9354e+00,  ..., -8.6726e+00,
           -3.9908e+00, -6.7377e+00]],

         ...,

         [[-4.6358e+00, -6.2105e+00, -4.5394e+00,  ..., -4.0827e+00,
           -5.4951e+00, -8.9051e+00],
          [-9.3827e-01, -9.6505e+00, -7.5041e+00,  ..., -8.9395e+00,
           -7.5414e-01, -1.9387e-01],
          [-2.2748e+00, -5.2475e+00,  2.9579e-01,  ..., -1.0384e+00,
           -3.7133e+00,  2.1712e+00],
          ...,
          [-4.3560e+00, -3.6715e-02,  4.3798e-01,  ..., -9.4615e+00,
           -8.7997e+00, -7.7494e+00],
          [-3.2850e+00, -4.0302e+00, -1.7714e+00,  ...,  9.5723e-02,
           -7.0057e+00, -1.0475e+01],
          [-6.6503e+00, -4.9178e+00, -9.4870e+00,  ..., -1.1196e+01,
           -7.1939e+00, -8.3586e+00]],

         [[-5.0313e+00, -4.9987e+00, -9.6636e-01,  ..., -2.9338e+00,
           -5.3003e+00, -1.9836e+00],
          [-8.5575e-01, -5.8069e+00, -2.7377e+00,  ..., -5.4617e-01,
           -9.1020e-01,  1.6041e+00],
          [-5.2501e+00,  1.2012e+00, -2.2547e+00,  ..., -2.0173e+00,
           -5.2966e+00, -3.7321e-01],
          ...,
          [ 9.4993e-01, -2.3440e+00, -1.8691e+00,  ..., -5.7994e+00,
           -1.5264e-01, -4.2742e+00],
          [-3.9573e+00, -8.2388e+00, -5.1011e+00,  ..., -3.9187e+00,
           -2.1310e+00, -6.6231e+00],
          [-2.6912e+00, -4.8253e+00, -2.4590e+00,  ..., -2.4253e+00,
           -6.0310e+00, -1.9835e+00]],

         [[-5.6130e+00,  1.5045e+00, -4.3884e+00,  ..., -1.0281e+00,
           -2.1793e+00, -7.1258e+00],
          [ 4.7215e+00,  1.1978e-01, -1.3202e+01,  ..., -1.2931e+01,
           -1.9061e+00,  3.1473e-01],
          [-1.3411e+00,  1.0093e+00, -2.3578e-02,  ...,  9.8828e-02,
            2.6032e-01,  3.0817e-01],
          ...,
          [ 3.0105e+00,  8.7894e-02,  2.9974e+00,  ..., -1.1259e-01,
           -1.0205e+01,  1.4914e+00],
          [-1.7476e+00,  7.3002e-01,  9.1252e-01,  ..., -2.1691e+00,
           -1.2416e+01,  1.7065e+00],
          [-3.9426e+00, -1.7801e+00, -1.0031e+01,  ..., -1.4031e+01,
           -1.7012e+00, -8.2731e+00]]],


        [[[ 8.4952e+00,  6.6572e+00,  7.1634e+00,  ..., -4.4250e+00,
            3.9915e-01,  7.5612e-01],
          [-2.7846e+00,  2.8621e+00, -2.1968e+00,  ...,  1.1307e+00,
            6.2344e+00,  8.9173e-02],
          [ 5.2230e+00,  3.7516e-01,  2.7892e+00,  ...,  6.4741e-01,
            3.2088e+00, -1.1053e+00],
          ...,
          [-2.9640e+00, -1.9907e+00, -2.5158e+00,  ...,  6.0821e+00,
           -1.0965e+01, -7.7981e-01],
          [ 1.6023e+00, -3.0682e-01,  2.6658e+00,  ...,  1.5450e+00,
           -1.7731e+00, -2.8304e+00],
          [ 1.4510e+00,  2.3755e+00,  3.8491e+00,  ..., -2.7556e+00,
           -5.8633e+00,  4.4956e-01]],

         [[-3.0172e+00,  3.4978e+00,  7.4015e+00,  ...,  7.2696e+00,
            5.4348e+00,  1.8777e-02],
          [ 1.3804e+01,  2.9835e+00,  2.7670e+00,  ..., -4.2782e+00,
           -4.4769e+00,  2.5723e+00],
          [ 1.8387e+00, -2.7942e-01, -2.8371e+00,  ...,  4.6671e+00,
            1.2123e+00, -8.0046e-02],
          ...,
          [ 5.6310e+00, -5.7976e-01,  9.4391e-01,  ...,  3.5885e+00,
            1.3388e+01,  2.0126e-01],
          [-7.8383e-01, -6.1923e-01,  1.7017e-01,  ...,  4.7670e+00,
            1.7764e+00, -1.0202e+00],
          [-4.9477e-01,  3.9957e+00,  7.1832e+00,  ...,  4.9039e+00,
            1.0160e+01,  2.7419e+00]],

         [[ 6.4864e+00,  6.2394e+00,  4.4714e+00,  ..., -2.6463e-01,
           -5.6287e-02,  3.4257e+00],
          [-1.0039e+00,  1.6999e+00, -7.5959e-01,  ..., -4.1233e-01,
            2.9484e+00, -2.6631e-01],
          [-2.0895e+00, -8.4173e-01,  3.3876e+00,  ..., -1.4902e+00,
            5.0658e-01, -2.8299e+00],
          ...,
          [-2.1714e+00, -1.4510e+00,  3.2529e-01,  ...,  2.4452e+00,
           -5.2147e+00, -2.3137e+00],
          [ 9.7946e-01,  1.6600e+00,  1.9129e-01,  ..., -4.0613e+00,
           -8.4586e-01, -2.1381e+00],
          [-3.6712e+00, -8.2651e+00, -5.9812e+00,  ..., -9.3416e+00,
           -9.0518e+00, -6.1117e+00]],

         ...,

         [[-1.0984e+01, -9.1104e+00, -6.4418e+00,  ..., -2.0667e+00,
           -2.1424e+00, -5.6467e+00],
          [-9.8937e+00, -5.9174e+00, -5.0398e+00,  ...,  2.5154e+00,
            2.6306e+00, -4.0961e+00],
          [-5.0194e+00, -1.5151e+00, -3.9321e+00,  ..., -4.5688e+00,
           -5.0863e+00, -1.1303e+00],
          ...,
          [-6.0772e+00, -2.7435e+00, -3.9927e+00,  ..., -6.5105e+00,
           -9.8079e+00, -2.6355e+00],
          [-2.6958e+00,  1.7545e+00, -5.6109e+00,  ..., -2.6872e+00,
            1.8348e+00,  9.0143e-01],
          [-2.3880e+00, -2.7234e+00, -5.5100e+00,  ..., -2.6840e+00,
           -4.9957e+00, -1.4849e+00]],

         [[-5.0472e+00, -9.1172e+00, -6.3687e+00,  ..., -1.8357e+00,
           -4.4417e+00, -3.2424e+00],
          [-9.2601e-01, -2.2741e+00, -6.2047e-01,  ...,  2.1920e-01,
           -2.0296e+00, -4.2716e+00],
          [ 1.2096e+00, -1.3934e+00,  9.4632e-01,  ..., -2.9663e+00,
           -2.1573e-01,  4.8883e-01],
          ...,
          [-8.0154e-01, -6.6665e-01, -4.3691e-01,  ..., -4.7406e+00,
           -2.1090e+00, -3.4603e+00],
          [ 1.5799e+00, -5.5764e-01, -1.7364e-02,  ..., -2.0548e+00,
           -2.2499e+00,  3.5979e-01],
          [ 1.4363e-01, -2.2781e-01, -4.0250e+00,  ..., -2.6060e+00,
           -2.0633e+00, -6.6493e-01]],

         [[ 3.9476e+00, -5.5802e+00, -9.6231e+00,  ..., -4.3460e+00,
           -3.8544e+00,  4.8173e-01],
          [-7.8112e+00, -8.5727e+00, -6.4950e+00,  ...,  3.2030e+00,
            4.3899e+00, -4.4620e+00],
          [ 2.8494e-01,  1.1584e+00,  3.4524e+00,  ..., -9.2283e+00,
            1.9230e+00, -2.7937e-01],
          ...,
          [-4.1666e+00, -1.9450e+00, -3.1143e+00,  ..., -1.8752e+00,
           -1.0584e+01, -2.7799e+00],
          [ 1.1605e+00,  1.4293e+00, -2.0748e-01,  ..., -2.2928e+00,
           -1.6960e+00, -5.4957e-01],
          [-1.2512e-01, -2.8697e+00, -8.0678e+00,  ..., -2.0365e+00,
           -6.5314e+00, -2.5828e+00]]]], grad_fn=<AddBackward0>), (1, 11): tensor([[[[-3.2954e-01]],

         [[-4.1392e-02]],

         [[ 1.8950e-02]],

         [[-1.1438e-02]],

         [[ 5.0308e-02]],

         [[ 1.0506e-01]],

         [[-4.4698e-02]],

         [[ 1.0416e-01]],

         [[ 9.1888e-02]],

         [[-5.2962e-02]],

         [[-1.1717e-03]],

         [[ 1.0445e-01]],

         [[-1.2310e-01]],

         [[ 7.6537e-03]],

         [[-7.6144e-02]],

         [[-9.5200e-02]],

         [[-3.2903e-02]],

         [[-5.8771e-03]],

         [[ 4.2146e-02]],

         [[-2.2369e-02]],

         [[ 5.7957e-02]],

         [[-2.3260e-02]],

         [[-5.9498e-02]],

         [[ 1.9727e+00]],

         [[ 7.9506e-02]],

         [[-3.5970e-03]],

         [[-7.2696e-03]],

         [[ 1.3236e+00]],

         [[ 4.4285e-02]],

         [[ 8.9512e-02]],

         [[-1.1212e+00]],

         [[-2.7686e-01]],

         [[-8.4620e-02]],

         [[-3.3795e-02]],

         [[ 1.1169e+00]],

         [[ 1.1474e-01]],

         [[ 8.2334e-02]],

         [[ 2.9907e-02]],

         [[ 1.6733e-01]],

         [[-3.3486e-02]],

         [[-8.0503e-02]],

         [[-2.8749e-03]],

         [[ 1.6506e-02]],

         [[-7.8889e-03]],

         [[-6.7326e-02]],

         [[ 7.6423e-02]],

         [[ 2.4160e-02]],

         [[ 1.2525e-01]],

         [[-5.3605e-02]],

         [[ 1.6305e-02]],

         [[-5.0488e-02]],

         [[-9.4780e-02]],

         [[ 2.6307e-02]],

         [[-7.0703e-02]],

         [[-1.7795e-02]],

         [[-4.3913e-02]],

         [[-3.2790e-02]],

         [[ 1.0006e-02]],

         [[-2.0516e-02]],

         [[-1.6045e-01]],

         [[ 8.5491e-02]],

         [[-2.5708e-02]],

         [[-1.2039e-02]],

         [[-7.8266e-02]],

         [[-1.1724e-02]],

         [[ 4.9428e-02]],

         [[-4.5137e-02]],

         [[ 1.3128e-02]],

         [[-1.0530e-02]],

         [[-4.7346e-02]],

         [[-1.8275e-02]],

         [[ 4.3498e-02]],

         [[-1.2275e+00]],

         [[-6.1603e-01]],

         [[ 1.0947e-01]],

         [[ 3.3188e-02]],

         [[ 4.2756e-02]],

         [[-1.2146e-01]],

         [[-2.7672e-02]],

         [[-1.3488e-01]]],


        [[[-1.2282e+00]],

         [[-6.9207e-02]],

         [[ 1.4999e-03]],

         [[-2.3976e-02]],

         [[ 3.4658e-02]],

         [[ 6.5590e-02]],

         [[-1.1676e-02]],

         [[ 9.8684e-02]],

         [[ 5.3957e-02]],

         [[-5.9257e-02]],

         [[-7.7145e-03]],

         [[ 1.0464e-01]],

         [[-1.2547e-01]],

         [[-1.2971e-02]],

         [[-6.5476e-02]],

         [[-7.0972e-02]],

         [[ 2.1136e-04]],

         [[ 1.4957e-02]],

         [[ 8.2022e-02]],

         [[-4.5787e-02]],

         [[ 4.7979e-02]],

         [[-3.2828e-02]],

         [[-8.8509e-02]],

         [[ 1.8733e+00]],

         [[ 1.0333e-01]],

         [[-9.3139e-04]],

         [[-7.6340e-04]],

         [[ 9.0818e-01]],

         [[ 9.9145e-02]],

         [[ 1.2870e-01]],

         [[-1.4661e+00]],

         [[-3.3612e-01]],

         [[-1.1033e-01]],

         [[-2.3938e-02]],

         [[ 1.1764e+00]],

         [[ 1.1639e-01]],

         [[ 7.4871e-02]],

         [[ 4.7011e-02]],

         [[ 1.4714e-01]],

         [[ 5.3443e-02]],

         [[-5.7768e-02]],

         [[-3.9602e-03]],

         [[-1.3482e-02]],

         [[ 2.0430e-02]],

         [[-4.8997e-02]],

         [[ 5.6956e-02]],

         [[ 1.6053e-03]],

         [[ 1.2672e-01]],

         [[-2.3973e-02]],

         [[-1.5369e-02]],

         [[-2.1016e-02]],

         [[-7.1965e-02]],

         [[ 1.0056e-02]],

         [[-1.1172e-01]],

         [[-4.0957e-02]],

         [[-2.9695e-02]],

         [[-2.0906e-02]],

         [[ 3.5890e-02]],

         [[ 5.6055e-03]],

         [[-1.6344e-01]],

         [[ 1.1149e-01]],

         [[ 8.6193e-03]],

         [[-8.8553e-03]],

         [[-1.0676e-01]],

         [[-3.1798e-02]],

         [[ 5.2126e-02]],

         [[-5.7364e-02]],

         [[-3.4551e-02]],

         [[-5.3259e-02]],

         [[-4.3817e-02]],

         [[-8.5701e-03]],

         [[ 5.0781e-02]],

         [[-7.9956e-01]],

         [[ 2.9231e-01]],

         [[ 9.6562e-02]],

         [[ 4.8574e-02]],

         [[ 5.3199e-02]],

         [[-1.1042e-01]],

         [[-1.5728e-02]],

         [[-1.1778e-01]]],


        [[[-7.4214e-01]],

         [[-4.5831e-02]],

         [[-4.6955e-03]],

         [[ 1.0713e-02]],

         [[ 3.6458e-02]],

         [[ 1.4159e-02]],

         [[ 7.5680e-02]],

         [[ 5.7894e-02]],

         [[ 4.3541e-02]],

         [[-4.3660e-02]],

         [[-8.1107e-03]],

         [[-8.0403e-03]],

         [[-9.1961e-02]],

         [[ 2.0916e-02]],

         [[-5.2999e-02]],

         [[-2.3125e-02]],

         [[-5.3537e-02]],

         [[ 8.1381e-02]],

         [[ 1.3197e-01]],

         [[-6.5239e-02]],

         [[ 6.0029e-02]],

         [[-5.0667e-02]],

         [[ 1.3220e-03]],

         [[ 1.7878e+00]],

         [[ 8.8173e-02]],

         [[-4.2781e-02]],

         [[ 3.9467e-02]],

         [[ 1.3625e+00]],

         [[ 1.3916e-01]],

         [[ 3.6182e-01]],

         [[-1.5173e+00]],

         [[-3.5479e-01]],

         [[-1.3290e-01]],

         [[-2.9103e-02]],

         [[ 1.0781e+00]],

         [[ 1.3775e-01]],

         [[ 1.2014e-01]],

         [[-5.7623e-02]],

         [[ 4.0890e-02]],

         [[ 5.8836e-02]],

         [[-1.0300e-01]],

         [[ 5.6284e-02]],

         [[-5.3143e-02]],

         [[-3.3925e-02]],

         [[-2.9438e-02]],

         [[ 2.1809e-02]],

         [[-1.4221e-02]],

         [[ 1.5400e-01]],

         [[ 5.7543e-02]],

         [[-4.1208e-02]],

         [[ 2.2212e-02]],

         [[-6.3028e-02]],

         [[ 1.0909e-01]],

         [[-1.1350e-01]],

         [[-2.9245e-02]],

         [[ 3.8774e-02]],

         [[-8.1525e-03]],

         [[ 1.6708e-02]],

         [[ 8.9099e-02]],

         [[-1.4019e-01]],

         [[ 1.4330e-01]],

         [[ 6.3284e-02]],

         [[ 4.0107e-02]],

         [[-3.6125e-02]],

         [[ 1.3681e-02]],

         [[ 6.2493e-02]],

         [[-5.4155e-02]],

         [[ 4.7579e-02]],

         [[-8.3827e-02]],

         [[-6.2487e-02]],

         [[-4.2408e-02]],

         [[ 2.6499e-02]],

         [[-1.6416e+00]],

         [[ 6.1556e-01]],

         [[ 9.5130e-02]],

         [[-3.6979e-02]],

         [[ 7.7709e-02]],

         [[ 1.1940e-02]],

         [[ 2.5999e-02]],

         [[-9.1804e-02]]],


        [[[-7.6911e-01]],

         [[ 5.5719e-03]],

         [[ 7.5728e-03]],

         [[ 3.9353e-02]],

         [[-4.1333e-03]],

         [[ 3.9725e-02]],

         [[ 2.0427e-02]],

         [[ 6.9363e-02]],

         [[ 1.0436e-01]],

         [[-5.7364e-02]],

         [[ 1.2526e-02]],

         [[ 6.9431e-02]],

         [[-1.2298e-01]],

         [[-7.9281e-03]],

         [[-7.9848e-02]],

         [[-7.0291e-02]],

         [[-2.1328e-02]],

         [[ 1.6891e-02]],

         [[ 8.5750e-02]],

         [[-5.8206e-02]],

         [[-3.3873e-04]],

         [[-6.6840e-03]],

         [[-3.4840e-02]],

         [[ 2.2384e+00]],

         [[ 6.6063e-02]],

         [[ 1.6595e-02]],

         [[ 2.4139e-02]],

         [[ 5.3193e-01]],

         [[ 6.9625e-02]],

         [[ 3.5307e-01]],

         [[-1.4058e+00]],

         [[-4.2428e-01]],

         [[-1.2609e-01]],

         [[ 3.7517e-02]],

         [[ 1.1523e+00]],

         [[ 1.4467e-01]],

         [[ 1.1447e-01]],

         [[-3.8044e-03]],

         [[ 1.3782e-01]],

         [[ 3.3297e-02]],

         [[-8.9478e-02]],

         [[-2.0986e-02]],

         [[ 1.5957e-03]],

         [[-2.9632e-02]],

         [[-5.9831e-02]],

         [[ 4.9031e-02]],

         [[-1.2586e-02]],

         [[ 1.0060e-01]],

         [[ 9.5949e-03]],

         [[ 2.5897e-02]],

         [[-1.1623e-02]],

         [[-2.7303e-02]],

         [[-3.5173e-03]],

         [[-7.6410e-02]],

         [[-3.4299e-03]],

         [[-2.1692e-02]],

         [[ 2.0612e-02]],

         [[ 1.5891e-02]],

         [[ 4.3852e-02]],

         [[-1.2702e-01]],

         [[ 7.1252e-02]],

         [[ 1.9464e-03]],

         [[-2.4312e-02]],

         [[-7.3751e-02]],

         [[ 1.0265e-02]],

         [[ 2.1285e-02]],

         [[-1.0293e-01]],

         [[-1.0904e-03]],

         [[-9.6557e-02]],

         [[-5.4802e-02]],

         [[ 2.0492e-02]],

         [[ 2.3981e-02]],

         [[-7.5777e-01]],

         [[-4.2968e-01]],

         [[ 6.3788e-02]],

         [[ 1.1475e-02]],

         [[ 5.1990e-02]],

         [[-7.9708e-02]],

         [[-5.5385e-03]],

         [[-9.6558e-02]]],


        [[[-1.0508e+00]],

         [[-3.0886e-02]],

         [[ 3.0383e-05]],

         [[-8.3060e-03]],

         [[ 6.1881e-02]],

         [[ 6.4423e-02]],

         [[ 2.2924e-02]],

         [[ 1.1589e-01]],

         [[ 7.2073e-02]],

         [[-6.9679e-02]],

         [[-5.9223e-03]],

         [[ 2.9222e-02]],

         [[-1.3457e-01]],

         [[ 3.7856e-02]],

         [[-2.5369e-02]],

         [[-5.1544e-02]],

         [[-2.4386e-02]],

         [[ 1.5442e-02]],

         [[ 9.1866e-02]],

         [[-7.5163e-05]],

         [[ 7.4187e-02]],

         [[-5.7460e-02]],

         [[-1.0010e-01]],

         [[ 1.1105e+00]],

         [[ 1.0526e-01]],

         [[-2.8840e-02]],

         [[-5.4559e-03]],

         [[ 9.5324e-01]],

         [[ 8.5459e-02]],

         [[ 4.0763e-01]],

         [[-1.3676e+00]],

         [[-7.2923e-01]],

         [[-1.3794e-01]],

         [[-4.7686e-02]],

         [[ 1.6018e+00]],

         [[ 1.2602e-01]],

         [[ 6.8933e-02]],

         [[ 2.3652e-02]],

         [[ 5.8866e-02]],

         [[ 2.1771e-02]],

         [[-8.1918e-02]],

         [[-8.1321e-03]],

         [[-1.9095e-02]],

         [[-2.9603e-02]],

         [[-3.0547e-02]],

         [[ 3.8974e-02]],

         [[-5.2380e-03]],

         [[ 1.7229e-01]],

         [[ 1.3635e-02]],

         [[-1.4216e-02]],

         [[ 1.6009e-02]],

         [[-1.1780e-01]],

         [[ 7.0889e-02]],

         [[-1.0119e-01]],

         [[-2.9182e-02]],

         [[-2.0091e-02]],

         [[-2.1584e-02]],

         [[-1.5654e-02]],

         [[ 1.9887e-02]],

         [[-2.1245e-01]],

         [[ 1.5097e-01]],

         [[ 1.5055e-02]],

         [[ 3.0318e-02]],

         [[-8.0133e-02]],

         [[-2.7833e-02]],

         [[ 7.5295e-02]],

         [[-3.7641e-02]],

         [[-2.0009e-02]],

         [[-5.8148e-02]],

         [[-7.3551e-03]],

         [[-2.5408e-02]],

         [[ 7.3143e-02]],

         [[-1.2206e+00]],

         [[ 6.2249e-01]],

         [[ 1.3818e-01]],

         [[ 4.1081e-02]],

         [[ 7.8464e-02]],

         [[-7.9166e-02]],

         [[-5.5664e-02]],

         [[-1.2283e-01]]]], grad_fn=<AsStridedBackward1>)})

Other functionalities#

In addition to recording, the following functionalities are available and documented in Recorder:

  • One-the-fly postprocessing of activations during inference (e.g. clipping).

  • Local disabling of the recording during forward pass.

  • Access the recorded layers’ Module objects directly.

  • Get the named parameters of the recorded layers.

Implementation details#

This utility works by affecting hooks to every layer of net with Module.register_forward_hook. However, since layers are not aware of the context in which they are called, these hooks carry references to rnet and with it, the sufficient context to know when to trigger. This means that two different Recorder Nets can wrap the same net without any conflict. As an implementation detail, note that these references are made weak in order to be properly cleaned up upon deletion of rnet.

Total running time of the script: (0 minutes 0.193 seconds)

Gallery generated by Sphinx-Gallery