Diving inside Neural Networks#

This tutorial provides a short practical overview of the Recorder class, which is designed to ease interactions with the internal states of PyTorch Neural Network objects (torch.nn.Module) during or after inference.

from scio.recorder import Recorder

Wrapping and visualizing your Neural Network#

Let us first load an arbitrary Neural Network. We use a lightweight Tiniest architecture trained on CIFAR10 and hosted on our hub. We also prepare future input data for this tutorial.

import torch

inputs = torch.rand(5, 3, 32, 32)  # Random inputs with 5 samples
net = torch.hub.load("ThalesGroup/scio:hub", "tiniest", trust_repo=True, verbose=False)
net = net.to(inputs)

To wrap it into a Recorder Net, rnet, one simply needs to specify an input_size (including batch dimension) or provide input_data. This will directly analyze and store the control flow of the model, using the torchinfo library.

rnet = Recorder(net, input_data=inputs[[0]])
rnet  # Visualize layers
Recorder instance for the following network
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Param %
============================================================================================================================================
Tiniest (Tiniest)                        [1, 3, 32, 32]            [1, 10]                   --                             --
├─Conv2d (conv1): 1-1                    [1, 3, 32, 32]            [1, 48, 32, 32]           1,344                       1.38%
├─LayerNorm2d (ln1): 1-2                 [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    └─LayerNorm (ln): 2-1               [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
├─Block (l1): 1-3                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-2             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-3             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-4             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-5             [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-1          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-6                 [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-7                 [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l2): 1-4                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-8             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-9             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-10            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-11            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-2          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-12                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-13                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l3): 1-5                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-14            [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-15            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-16            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-17            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-3          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-18                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-19                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Conv2d (dsconv): 1-6                   [1, 48, 32, 32]           [1, 80, 16, 16]           3,920                       4.02%
├─LayerNorm2d (ln2): 1-7                 [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    └─LayerNorm (ln): 2-20              [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
├─Block (l4): 1-8                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-21            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-22            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-23            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-24            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-4          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-25                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-26                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l5): 1-9                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-27            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-28            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-29            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-30            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-5          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-31                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-32                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l6): 1-10                       [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-33            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-34            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-35            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-36            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-6          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-37                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-38                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─AdaptiveAvgPool2d (avgpool): 1-11      [1, 80, 16, 16]           [1, 80, 1, 1]             --                             --
├─Linear (fc): 1-12                      [1, 80]                   [1, 10]                   810                         0.83%
============================================================================================================================================
Total params: 97,530
Trainable params: 97,530
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 44.73
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 9.05
Params size (MB): 0.39
Estimated Total Size (MB): 9.45
============================================================================================================================================
Currently recording: None
============================================================================================================================================

Tip

For summary customization, refer to torchinfo.summary options. For example, it is possible to bound the depth of the representation tree with depth=2.

Note

In the case of dynamic control flow, refer to the force_static_flow argument of Recorder.

In many ways, this wrapper is transparent to the user. For example, one can naturally process data with rnet(inputs).

Selecting layers of interest#

The penultimate line of the above summary reports the layers that are currently set to be recorded (stored in rnet.recording). By default after instantiation, there are none.

print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: None

One can arbitrarily set this using rnet.record() with the depth-idx identifiers from the summary (e.g. 1-9). For example, the following specifies that the output of the first Block and the penultimate layer should be recorded.

rnet.record((1, 3), (1, 11))
print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: 1-3, 1-11

Note

Though not shown in the summary, it is possible to select 0-1 to refer to the entire model.

Warning

torchinfo.summary can only detect torch.nn.Module calls. As such, if a Neural Network uses activation functions, it should call their Module implementation (instead of their functional counterpart) and declare them as an attribute, for them to be visible as a layer in the summary. It is not necessary to declare a different attribute for every activation call (e.g. one self.relu = nn.ReLU() can be used multiple times in forward()).

Capturing internal states#

Once the recording layers are set, every forward pass will automatically store the corresponding internal states in the rnet.activations mapping. Its keys are the (depth, idx) 2-tuples.

out = rnet(inputs)  # Forward pass, records the activations
rnet.activations
mappingproxy({(1, 3): tensor([[[[-1.7022e+00, -8.1401e+00,  1.3161e+00,  ...,  4.2003e+00,
            3.1560e+00, -5.7395e+00],
          [ 8.2100e+00,  3.4759e+00,  3.0745e+00,  ...,  3.8526e-01,
            4.9458e+00,  7.8325e-01],
          [ 7.4736e+00,  7.1036e+00,  5.7739e+00,  ..., -3.3250e-01,
            5.7480e-01, -2.9331e+00],
          ...,
          [ 7.0612e+00,  2.4737e+00, -9.3712e+00,  ...,  7.5873e-01,
           -3.0715e+00,  1.3497e+00],
          [-6.2940e+00,  4.0016e-02,  4.7528e+00,  ...,  7.5221e-01,
            5.2391e+00,  2.6424e+00],
          [ 2.3914e+00, -2.5762e+00,  5.1090e+00,  ...,  6.8672e+00,
            2.0650e+00, -5.4759e+00]],

         [[ 1.6582e+00,  8.6526e+00, -1.0749e+00,  ...,  2.8766e+00,
           -5.5018e-01,  5.1222e+00],
          [-1.4082e+00, -1.4155e+00,  3.1035e+00,  ...,  3.2815e+00,
           -1.9658e-01, -8.0176e-01],
          [-4.5562e+00,  1.3145e+00,  6.4768e+00,  ...,  4.0439e+00,
           -3.5790e-02,  6.9798e+00],
          ...,
          [-3.3441e+00,  1.0621e+00,  1.3430e+01,  ...,  4.6467e+00,
           -8.9313e-02,  2.4199e+00],
          [ 1.1811e+01,  1.9927e+00, -3.3585e+00,  ...,  1.3413e-02,
           -2.6428e+00,  2.1664e+00],
          [ 7.7002e-01,  7.3239e+00,  1.8310e+00,  ...,  8.6775e-01,
           -4.2041e-01,  1.0785e+01]],

         [[-2.5241e-01,  1.5138e+00,  5.0879e+00,  ...,  6.6563e+00,
            4.2072e+00,  3.8277e+00],
          [ 1.7377e+00,  3.1053e+00,  1.7281e+00,  ..., -9.0835e-01,
           -5.1616e+00, -3.9332e+00],
          [ 4.0045e+00, -1.3305e+00, -4.0025e+00,  ...,  5.4383e-01,
            4.8545e+00,  2.0072e+00],
          ...,
          [ 3.1167e+00, -9.8798e-02, -7.7456e+00,  ..., -3.2560e+00,
           -4.4654e+00, -7.2924e-01],
          [-4.1589e+00, -1.8304e+00, -2.9414e-02,  ..., -2.2777e+00,
            4.6770e-01, -1.9580e+00],
          [-6.0022e+00, -8.6861e+00, -9.0680e+00,  ..., -4.7578e+00,
           -6.5775e+00, -6.8521e+00]],

         ...,

         [[-2.3116e+00, -5.4705e+00, -2.1156e+00,  ..., -3.0074e+00,
           -6.0479e+00, -3.9592e+00],
          [-8.3428e+00, -2.0116e+00, -4.2231e+00,  ..., -4.6542e+00,
           -6.2577e-01, -2.7182e+00],
          [-2.2370e+00, -7.2099e+00, -4.8415e+00,  ..., -9.0038e-01,
            4.2641e+00, -6.0504e+00],
          ...,
          [-7.6772e+00,  5.7089e-01, -4.6675e+00,  ..., -6.8624e+00,
           -2.1480e+00, -3.0441e+00],
          [-1.1076e+01, -7.1043e-01, -1.0415e+00,  ..., -1.0344e-01,
            8.1842e-01, -1.0041e+00],
          [-1.1128e+00, -4.1235e+00, -5.3697e+00,  ..., -2.0745e+00,
           -1.9154e+00, -5.4500e+00]],

         [[-3.6294e+00, -1.2266e+00, -1.4711e+00,  ..., -4.7060e+00,
           -1.9022e-01, -1.3691e+00],
          [-1.2468e+00, -3.0143e-01, -9.9140e-01,  ...,  4.5021e-02,
            6.1783e-01, -1.7144e+00],
          [-1.9860e+00, -4.8642e-01, -4.3711e+00,  ..., -1.7553e+00,
           -1.6518e+00, -3.4629e+00],
          ...,
          [-7.6630e+00, -9.3625e-01, -3.5440e+00,  ..., -2.4302e+00,
           -4.6675e+00, -2.9561e+00],
          [-2.3291e+00,  5.9916e-02, -5.8214e-01,  ..., -3.6602e+00,
           -1.1828e+00,  6.9605e-01],
          [-3.1220e+00, -6.2251e-01, -3.7096e+00,  ..., -2.9578e+00,
           -2.4020e+00, -1.4033e+00]],

         [[-4.3595e-01, -8.9710e+00, -2.4228e-01,  ..., -4.7376e+00,
            1.0235e+00,  5.8505e-01],
          [ 3.3992e+00,  3.0944e+00, -5.5038e-01,  ..., -4.6174e+00,
           -3.4045e-01, -7.0661e-01],
          [ 2.3746e+00,  9.6190e-01, -1.1166e+01,  ..., -1.8727e+00,
           -2.0663e+00, -4.8873e+00],
          ...,
          [-1.8387e+00, -6.9693e-01, -9.4641e+00,  ..., -6.7464e+00,
           -1.2005e+00, -3.1699e+00],
          [-8.2991e+00, -1.2437e+00,  6.0767e+00,  ..., -6.4197e-01,
            6.6454e-01, -1.3952e-01],
          [-1.7444e+00, -4.1268e+00, -2.1140e+00,  ..., -3.7640e+00,
           -1.1238e+00, -7.8603e+00]]],


        [[[ 5.0805e+00,  8.7671e-02,  3.4758e+00,  ...,  3.1559e+00,
            3.5970e+00, -7.2715e-01],
          [-3.0254e+00,  2.2791e+00, -6.0707e+00,  ..., -4.3978e-02,
           -3.6069e+00, -2.4760e+00],
          [ 5.1218e+00,  1.0057e+00,  1.5490e+00,  ..., -3.0058e+00,
           -5.9112e+00, -1.2956e+00],
          ...,
          [ 2.1297e-01,  5.1988e+00,  3.5929e-01,  ...,  3.2178e+00,
           -9.8183e+00, -1.1763e+00],
          [ 5.3548e+00, -1.0457e+01,  4.2351e-01,  ...,  4.9426e+00,
           -5.5472e+00, -5.5033e+00],
          [ 3.0098e+00, -6.6520e+00,  2.4066e+00,  ..., -6.9228e+00,
            5.4740e+00,  2.2548e+00]],

         [[ 1.9567e+00,  3.2346e+00, -2.1096e+00,  ...,  4.4348e+00,
            8.5156e-01, -2.2350e+00],
          [ 1.0953e+01, -2.1557e+00,  1.0032e+01,  ...,  1.4487e+00,
            9.7532e+00,  3.1700e+00],
          [-2.7019e+00,  3.3467e+00, -1.4208e+00,  ...,  2.7078e+00,
            1.3399e+01,  6.6587e-01],
          ...,
          [ 2.9796e+00, -1.6890e+00, -3.2654e+00,  ..., -2.0565e+00,
            1.4040e+01,  2.8258e+00],
          [-5.8443e-02,  1.2664e+01,  1.4248e-02,  ...,  6.1573e-01,
            1.6862e+01,  2.0706e+00],
          [ 3.0496e+00,  1.1344e+01,  1.5120e+00,  ...,  1.4912e+01,
            6.4290e-01,  4.0607e-01]],

         [[ 1.3388e+00,  1.6310e+00,  6.3092e+00,  ...,  5.9953e+00,
            7.9956e+00,  7.2546e+00],
          [-2.7757e-01,  1.3050e+00, -3.9240e+00,  ..., -4.0540e-01,
           -3.9568e+00, -3.2103e+00],
          [ 2.7136e+00,  2.2255e+00,  2.7000e+00,  ..., -1.7754e+00,
           -6.4103e+00, -1.9679e+00],
          ...,
          [-2.9834e+00,  1.7282e+00,  6.7993e-01,  ...,  1.9443e+00,
            7.1873e-02,  3.7831e+00],
          [-1.4561e-01, -8.4292e+00, -7.0686e+00,  ..., -8.1948e-01,
           -4.6235e+00, -2.3658e+00],
          [-5.3289e+00, -8.2413e+00, -7.1389e+00,  ..., -1.0153e+01,
           -6.2089e+00, -4.4526e+00]],

         ...,

         [[-5.7656e+00, -2.2378e+00, -6.3310e+00,  ..., -5.1985e+00,
           -9.4067e+00, -7.1324e+00],
          [-1.0993e+01,  4.1962e-01, -7.7093e+00,  ..., -2.6705e+00,
           -8.7259e+00, -8.8013e-01],
          [-2.1185e+00, -4.3493e+00,  8.9117e-01,  ...,  1.5759e+00,
           -6.3580e+00,  9.0258e-01],
          ...,
          [-9.2017e-01, -7.5772e+00, -4.9637e-01,  ..., -5.8232e+00,
           -8.7223e+00, -3.3486e+00],
          [-2.9168e+00, -4.0060e+00, -3.4872e+00,  ..., -5.7944e+00,
           -8.8498e+00,  3.0686e-01],
          [-3.4334e+00, -7.0884e+00, -2.7922e-01,  ..., -1.0622e+01,
           -6.7335e+00, -4.0388e+00]],

         [[ 1.7408e-01, -1.4576e+00, -1.2563e+00,  ..., -4.4072e+00,
           -6.4268e+00, -7.1240e+00],
          [-2.6519e+00, -3.4047e+00, -1.7169e+00,  ..., -5.6319e+00,
            8.7324e-01, -2.0666e+00],
          [ 2.4200e-01, -1.7536e+00, -2.0500e+00,  ..., -3.9415e+00,
           -2.8769e+00, -3.6237e+00],
          ...,
          [-2.0481e+00, -8.9259e-01, -8.5338e-01,  ..., -9.2954e+00,
           -1.9070e+00, -6.3516e+00],
          [-1.4993e+00, -4.0809e+00, -5.8545e+00,  ..., -6.6398e+00,
           -1.1979e+00, -1.4924e+00],
          [-4.5012e+00, -2.6363e+00, -2.7986e+00,  ..., -9.6162e-01,
           -7.9291e+00, -4.7604e+00]],

         [[ 7.8305e-01, -2.0886e+00,  3.4647e+00,  ..., -2.8440e+00,
           -1.1117e+00,  8.7459e-01],
          [-1.0883e+01,  1.8911e+00, -9.3761e+00,  ..., -2.1934e+00,
           -4.7214e+00, -2.8249e+00],
          [ 2.2531e+00, -4.9591e-01,  2.7125e+00,  ..., -3.9317e+00,
           -9.3989e+00, -1.9396e+00],
          ...,
          [ 8.0485e-01,  3.3610e+00, -2.2048e+00,  ..., -2.9378e+00,
           -5.2954e+00, -9.4283e-01],
          [ 4.9259e+00, -5.7080e+00, -7.3379e+00,  ..., -3.1136e+00,
           -9.6568e+00,  1.5648e+00],
          [-7.1825e-01, -8.1123e+00, -3.0114e+00,  ..., -1.4047e+01,
           -2.4005e+00, -3.6580e+00]]],


        [[[ 4.9511e+00, -1.5027e+00, -1.9901e+00,  ...,  5.0412e+00,
            1.0706e+00,  1.0697e+00],
          [ 6.2415e+00,  5.0564e+00,  7.1823e+00,  ...,  2.6606e+00,
           -4.2820e+00, -2.6922e+00],
          [ 2.6145e+00, -7.9739e+00, -3.2277e+00,  ..., -5.7815e-01,
            2.0001e+00, -6.0351e-01],
          ...,
          [ 2.4288e+00, -6.2262e+00, -1.1152e+00,  ...,  2.1734e-01,
           -2.6232e+00, -1.4114e+00],
          [ 7.7376e+00,  1.9148e+00,  5.6097e+00,  ..., -8.7782e+00,
           -2.0215e+00, -5.4768e+00],
          [-3.0471e+00,  3.3768e+00, -9.9078e+00,  ...,  3.4497e+00,
            5.2545e+00, -1.2969e+00]],

         [[ 4.1481e+00, -9.4472e-01,  4.9418e+00,  ...,  2.6200e+00,
           -1.1429e+00,  1.1250e+00],
          [ 4.8664e+00, -8.3640e-01, -2.8931e+00,  ...,  3.2391e+00,
            7.3649e+00,  7.5707e-01],
          [ 2.4969e+00,  1.1368e+01,  3.4130e+00,  ...,  2.6504e+00,
            1.0609e+00,  5.4673e-01],
          ...,
          [-1.8373e+00,  6.1022e+00, -6.7085e-01,  ...,  2.4950e+00,
            1.8695e+00,  3.7362e-01],
          [ 5.4703e-01,  1.7698e+00,  8.2906e-01,  ...,  1.3463e+01,
            7.3126e+00,  4.2411e+00],
          [ 5.5332e+00,  4.1572e-01,  1.9763e+01,  ..., -5.6383e-01,
            2.3872e-01, -1.6456e+00]],

         [[ 1.6635e+00,  4.7899e+00,  4.6182e+00,  ...,  3.7803e+00,
            5.6378e+00,  3.1067e+00],
          [ 1.4200e-01,  4.7036e+00,  2.7844e+00,  ..., -3.7973e+00,
           -5.9692e+00, -5.1718e+00],
          [-9.5488e-01, -5.1446e+00, -5.2252e+00,  ..., -2.6949e+00,
           -1.1517e+00,  1.7335e+00],
          ...,
          [-1.0350e+00, -1.8431e-01,  1.8109e+00,  ...,  2.8584e+00,
           -9.8807e-01,  2.1699e+00],
          [-1.4714e+00, -2.8016e+00, -3.2801e+00,  ..., -2.4862e+00,
            1.2551e+00, -4.3983e-01],
          [-6.9402e+00, -7.4319e+00, -8.3717e+00,  ..., -6.5469e+00,
           -6.9269e+00, -5.0238e+00]],

         ...,

         [[-5.0227e+00, -5.6273e-01, -3.4659e+00,  ..., -7.5313e+00,
            1.0427e+00, -2.2290e+00],
          [-6.1889e+00, -9.1081e+00, -1.1549e+01,  ..., -5.5020e+00,
           -6.7357e+00, -2.7971e-01],
          [-2.6193e+00, -7.4915e+00, -4.3165e+00,  ..., -8.1270e-01,
           -7.7799e-01,  1.2551e+00],
          ...,
          [-4.2430e+00, -5.7854e+00, -2.2560e+00,  ..., -1.5200e+00,
           -7.4140e-01,  1.1314e+00],
          [-5.5488e+00, -2.2813e+00, -8.8323e+00,  ..., -1.0026e+01,
           -4.5795e+00, -3.4211e+00],
          [-4.6952e+00, -2.8062e+00, -1.2410e+01,  ..., -5.3215e+00,
           -3.2938e+00,  1.9967e+00]],

         [[-1.2178e+00,  1.6422e-01, -1.1672e-01,  ..., -4.9102e-01,
            1.5544e+00,  2.0597e+00],
          [-3.8109e+00, -1.3714e+01, -1.1449e+01,  ..., -3.6783e+00,
           -2.6467e+00, -1.1190e+00],
          [-2.0145e+00, -8.8556e-01,  2.8787e-01,  ..., -3.5085e+00,
           -3.4998e+00,  1.6843e+00],
          ...,
          [-1.5821e+00, -7.1529e-01, -2.7843e+00,  ..., -1.1820e+00,
           -3.6872e+00, -2.0156e+00],
          [-5.4658e+00, -1.8714e+00, -5.0533e+00,  ..., -1.3434e+00,
           -4.6807e+00, -4.6568e+00],
          [-4.4704e+00, -4.3855e+00, -2.0163e-01,  ..., -6.6995e+00,
           -1.7948e+00,  4.0083e-01]],

         [[-1.6396e+00,  6.7258e-01, -3.5196e+00,  ..., -1.6236e-01,
            3.0336e+00,  4.3678e-01],
          [-5.2991e+00, -1.4800e+00,  7.0375e-01,  ..., -2.9283e+00,
           -5.6261e+00, -2.0446e+00],
          [ 2.3331e-01, -8.6102e+00, -3.7920e+00,  ..., -3.9211e+00,
           -8.3980e-01,  3.2438e+00],
          ...,
          [ 3.1507e+00, -2.6826e+00, -1.6874e+00,  ..., -4.1145e-01,
           -5.2226e-01,  1.0141e+00],
          [-3.9451e+00, -2.5996e+00, -2.6936e+00,  ..., -9.9214e+00,
           -2.6286e+00, -3.2309e+00],
          [-3.8168e+00, -3.4380e+00, -1.1530e+01,  ..., -3.0331e+00,
            1.6014e-01, -3.6031e-01]]],


        [[[ 2.0893e+00, -2.3039e+00,  2.0787e+00,  ...,  5.1635e+00,
           -9.0531e+00,  4.4809e+00],
          [-1.0887e-01,  3.1873e+00,  4.4716e+00,  ...,  4.6777e+00,
            6.2017e+00, -2.5134e+00],
          [-2.8318e+00, -2.5922e+00, -2.4505e-02,  ..., -4.2166e+00,
            3.7548e-02, -1.2096e+01],
          ...,
          [ 3.9919e+00,  4.7979e-01,  3.6397e+00,  ..., -6.4143e+00,
            2.8919e+00, -4.7591e+00],
          [ 4.5705e+00,  2.5882e+00,  2.0689e+00,  ...,  2.8891e+00,
            2.9826e+00, -8.5048e+00],
          [-3.1135e-01, -7.5162e+00,  5.6306e+00,  ...,  5.1670e+00,
           -5.2382e+00, -2.8540e+00]],

         [[ 1.3881e+00,  5.5435e+00,  2.9294e+00,  ...,  7.7402e-01,
            1.7676e+01,  1.3867e-01],
          [ 6.8434e+00, -2.3691e-01,  3.4590e+00,  ..., -2.7104e+00,
            7.0901e+00,  1.3174e+01],
          [ 7.0824e+00,  4.8466e-01,  6.4088e+00,  ...,  1.0828e+01,
           -2.0479e-01,  1.1387e+01],
          ...,
          [-8.3610e-01, -4.0287e-01, -4.9684e-01,  ...,  8.8358e+00,
           -1.8629e+00,  1.0501e+01],
          [ 2.4810e+00,  2.7490e+00,  3.8610e+00,  ..., -2.6410e+00,
           -5.4583e-02,  8.4745e+00],
          [ 1.6425e+00,  1.3265e+01,  1.9957e+00,  ...,  2.1364e+00,
            8.3680e+00, -1.2503e-01]],

         [[-1.0420e+00,  2.1431e+00,  6.4447e+00,  ...,  5.0374e+00,
            1.2574e+00,  3.6897e+00],
          [-1.4454e+00,  7.9292e-01, -2.6673e+00,  ...,  7.4208e+00,
            4.6575e+00,  2.3305e+00],
          [-1.9125e+00, -1.1830e+00, -1.1700e+00,  ...,  1.2454e-01,
            3.8036e+00, -2.5609e+00],
          ...,
          [-2.1241e+00, -2.9599e+00,  2.5515e+00,  ..., -1.7982e+00,
            1.1366e+00, -2.0058e+00],
          [-3.0118e+00,  1.6905e+00,  1.1247e-01,  ...,  1.5423e+00,
            3.8171e+00, -3.5545e+00],
          [-5.4530e+00, -9.4980e+00, -8.3122e+00,  ..., -8.5646e+00,
           -8.4361e+00, -4.9283e+00]],

         ...,

         [[-2.6127e+00, -3.0723e+00, -6.5077e+00,  ..., -6.2837e+00,
           -1.1476e+01, -7.2984e+00],
          [-8.3010e+00,  6.1631e-01, -4.8115e+00,  ..., -5.2064e+00,
           -9.8241e+00, -1.1031e+01],
          [-3.8627e+00, -2.2239e+00, -5.6432e+00,  ..., -4.3798e+00,
           -2.3981e+00, -4.8689e+00],
          ...,
          [ 1.8968e-01, -4.3457e+00,  1.6314e+00,  ..., -8.3611e+00,
           -5.7451e+00, -7.8080e+00],
          [-1.1148e+00, -5.9908e+00, -1.2987e+00,  ..., -2.0315e+00,
           -8.8658e+00, -1.6768e+00],
          [-2.0954e+00, -9.9398e+00, -3.4920e+00,  ...,  6.3892e-01,
           -4.9292e+00,  2.4073e+00]],

         [[-3.9088e+00, -4.0540e+00, -1.6110e+00,  ..., -7.3893e-01,
            3.3084e-01, -1.6310e+00],
          [-3.6002e+00, -7.9951e+00, -2.0541e+00,  ..., -9.3239e+00,
           -1.0841e+01, -4.5850e+00],
          [-5.7898e+00,  6.6386e-01, -1.8813e+00,  ..., -2.6447e+00,
           -8.1268e+00, -3.8014e+00],
          ...,
          [ 8.2782e-01, -2.3320e+00, -1.6123e+00,  ..., -5.6575e-01,
           -1.0072e+01, -2.3292e+00],
          [ 2.2774e+00, -5.0485e-01, -1.0462e+00,  ..., -4.1099e-01,
           -6.4708e+00, -3.4683e+00],
          [-2.1650e+00, -2.7324e+00, -7.5223e-01,  ..., -1.8702e+00,
           -2.9764e+00, -1.5140e+00]],

         [[ 1.3596e+00, -3.5608e+00,  9.0926e-01,  ...,  1.8114e+00,
           -1.4149e+01,  1.8307e+00],
          [-7.4231e+00, -1.5974e+00, -1.2709e+00,  ...,  6.7976e-01,
           -7.5322e+00, -9.9841e+00],
          [-2.2412e+00, -4.9828e-01, -2.6589e+00,  ..., -5.4160e+00,
           -5.3216e+00, -6.5535e+00],
          ...,
          [ 5.4452e-02, -4.7757e+00,  1.4365e+00,  ..., -6.7303e+00,
           -3.2851e+00, -1.1895e+01],
          [ 8.0024e-02, -3.5524e+00, -9.8891e-01,  ..., -9.9743e-02,
           -7.4416e-01, -5.7622e+00],
          [-1.3401e+00, -1.2223e+01, -9.6632e-01,  ..., -2.7656e+00,
           -4.8970e+00, -1.3839e+00]]],


        [[[ 3.6690e+00,  2.5520e+00,  2.9692e+00,  ...,  3.4680e+00,
            2.3649e+00, -1.5641e+00],
          [ 3.4036e+00,  2.7713e+00,  6.4362e+00,  ..., -3.3172e+00,
           -6.8631e+00, -1.0134e+01],
          [ 8.8785e+00, -6.3975e+00,  3.3293e+00,  ...,  8.7862e-01,
            5.7202e+00,  4.8354e+00],
          ...,
          [ 6.8225e+00, -2.5770e+00, -2.3185e+00,  ...,  2.1159e+00,
           -5.9010e-01, -5.3468e+00],
          [ 6.3929e+00, -2.1199e+00,  4.6196e+00,  ...,  4.3277e+00,
            2.9618e+00,  3.2522e+00],
          [-1.8842e+00,  7.4935e+00, -1.0892e+01,  ...,  4.0326e+00,
            2.9407e+00, -7.4702e-01]],

         [[ 5.3414e+00, -1.8673e+00, -8.6500e-01,  ..., -1.6556e+00,
            2.0963e+00,  1.7088e-01],
          [ 4.5298e-01, -2.0708e+00,  4.4050e+00,  ...,  8.0618e+00,
            1.4643e+01,  8.3957e+00],
          [ 1.7456e-01,  1.3286e+01,  2.1636e-02,  ..., -2.4302e+00,
            4.0165e-01, -3.6621e+00],
          ...,
          [-2.3843e+00,  9.6206e+00, -8.0462e-01,  ...,  1.1077e+00,
            3.4552e+00,  8.0494e+00],
          [-1.8785e+00,  1.1684e+01, -1.9896e+00,  ...,  1.9554e+00,
           -2.1284e-01, -1.8010e+00],
          [ 1.2430e+01, -7.8955e-01,  1.5918e+01,  ..., -5.2934e-01,
            5.4331e+00, -2.4154e+00]],

         [[ 2.3463e-01,  5.7875e+00,  5.8817e+00,  ...,  4.2302e+00,
            3.8826e+00,  3.3450e+00],
          [ 5.3408e-01,  2.1319e+00,  2.1296e+00,  ..., -3.4848e+00,
           -5.7868e+00, -3.7917e+00],
          [-1.3288e-02, -3.5988e+00,  5.1664e-01,  ..., -1.9988e-01,
            1.2574e+00,  4.0582e+00],
          ...,
          [ 2.5484e-01, -6.6267e+00, -2.5156e+00,  ...,  1.5981e+00,
           -2.6924e-01, -1.9111e+00],
          [ 2.7637e+00, -4.3802e+00, -2.9469e+00,  ..., -1.9701e+00,
           -1.3910e+00, -5.0275e-01],
          [-6.4593e+00, -4.5150e+00, -8.7036e+00,  ..., -3.8838e+00,
           -8.2901e+00, -2.6530e+00]],

         ...,

         [[-2.9976e+00, -6.1895e+00, -2.9982e+00,  ..., -1.0275e+00,
           -6.3693e+00, -5.2613e-01],
          [ 1.1100e+00, -2.6757e+00, -5.5097e+00,  ..., -3.8175e+00,
           -1.1384e+01, -6.2965e+00],
          [-6.1898e+00, -1.0496e+01, -7.8370e+00,  ..., -2.4375e+00,
           -5.9865e+00, -1.0095e+01],
          ...,
          [-2.7099e+00, -8.5365e+00, -4.3424e+00,  ..., -1.4820e+00,
           -4.2424e+00, -4.9803e+00],
          [-8.2978e+00, -4.8246e+00, -3.3537e-01,  ..., -6.1985e+00,
           -2.5679e+00, -1.0037e+01],
          [-8.6422e+00, -9.3835e+00, -9.4268e+00,  ..., -2.2890e+00,
           -2.9433e+00,  1.7462e+00]],

         [[-2.0019e+00, -5.2207e+00, -1.7289e+00,  ..., -2.2713e+00,
           -2.0510e+00, -7.7137e-01],
          [ 9.0098e-01, -1.8118e+00, -3.5080e+00,  ..., -2.9670e+00,
           -1.0092e+00, -2.7505e+00],
          [-1.5150e+00, -2.1111e+00, -1.0562e+01,  ...,  1.6245e+00,
           -8.8989e+00, -9.1484e+00],
          ...,
          [-7.7905e+00, -6.0290e-01, -2.5958e+00,  ..., -1.0867e+00,
           -3.6504e+00, -4.2370e+00],
          [-4.3087e+00, -2.1581e+00, -3.9353e-01,  ..., -9.1619e-01,
           -5.6986e-01, -1.0848e+01],
          [-3.5276e+00, -9.9181e+00, -1.9165e+00,  ...,  9.8782e-02,
           -4.8288e+00,  9.0963e-01]],

         [[-2.4254e+00, -4.5376e-01, -1.2398e+00,  ...,  1.9969e+00,
            7.6587e-01, -9.2107e-01],
          [ 1.3693e+00,  2.4650e-01, -9.1861e+00,  ..., -4.4209e+00,
           -1.3174e+01, -4.6477e+00],
          [ 1.8506e+00, -1.2098e+01, -1.8639e+00,  ...,  2.4029e+00,
           -2.3152e+00,  1.8254e+00],
          ...,
          [-2.4255e-01, -1.0004e+01, -3.1749e-01,  ...,  7.4667e-01,
           -1.6225e+00, -4.1853e+00],
          [ 2.2631e+00, -6.9062e+00,  1.7108e+00,  ...,  3.5608e-01,
            2.0018e+00, -2.1197e+00],
          [-1.2604e+01, -9.0950e-01, -8.7572e+00,  ...,  4.8706e+00,
           -4.6455e+00, -3.0807e-01]]]], grad_fn=<AddBackward0>), (1, 11): tensor([[[[-1.0670e+00]],

         [[-4.3418e-02]],

         [[ 4.2365e-02]],

         [[-6.9876e-03]],

         [[ 1.7812e-02]],

         [[ 2.1149e-02]],

         [[-2.4057e-02]],

         [[ 7.5805e-02]],

         [[ 8.4818e-02]],

         [[-6.0109e-02]],

         [[-5.8056e-02]],

         [[ 1.0286e-01]],

         [[-1.2813e-01]],

         [[-3.2312e-02]],

         [[-6.6257e-02]],

         [[-8.3487e-02]],

         [[ 2.5726e-02]],

         [[ 2.0722e-02]],

         [[ 8.6520e-02]],

         [[-9.7668e-02]],

         [[ 1.5154e-02]],

         [[ 1.2321e-02]],

         [[-4.5299e-02]],

         [[ 2.4467e+00]],

         [[ 9.0947e-02]],

         [[ 2.6310e-05]],

         [[ 3.2019e-02]],

         [[ 8.9005e-01]],

         [[ 9.6995e-02]],

         [[ 4.2771e-01]],

         [[-1.3770e+00]],

         [[-7.2000e-02]],

         [[-9.7183e-02]],

         [[-1.2281e-03]],

         [[ 1.6013e+00]],

         [[ 1.2338e-01]],

         [[ 1.1452e-01]],

         [[ 3.6941e-03]],

         [[ 1.9058e-01]],

         [[ 6.3625e-02]],

         [[-8.9511e-02]],

         [[-1.0939e-04]],

         [[-1.6465e-02]],

         [[-9.1979e-03]],

         [[-6.3224e-02]],

         [[ 4.5120e-02]],

         [[-8.6800e-03]],

         [[ 8.8166e-02]],

         [[-3.6262e-02]],

         [[ 9.5349e-06]],

         [[-3.3885e-02]],

         [[-3.4650e-02]],

         [[-2.6191e-02]],

         [[-1.0277e-01]],

         [[ 1.0132e-02]],

         [[-3.2154e-02]],

         [[-4.0415e-02]],

         [[ 7.2653e-02]],

         [[ 3.2351e-02]],

         [[-1.1845e-01]],

         [[ 8.7250e-02]],

         [[ 1.5899e-02]],

         [[-2.1419e-02]],

         [[-1.2630e-01]],

         [[-2.0656e-02]],

         [[ 3.8991e-02]],

         [[-8.1988e-02]],

         [[-3.4685e-02]],

         [[-9.5655e-02]],

         [[-5.1056e-02]],

         [[ 3.7123e-02]],

         [[ 5.1359e-02]],

         [[-3.2432e-01]],

         [[-9.4480e-02]],

         [[ 5.1250e-02]],

         [[ 3.8884e-02]],

         [[ 4.1830e-02]],

         [[-9.5562e-02]],

         [[ 1.3547e-02]],

         [[-1.4625e-01]]],


        [[[-8.3459e-01]],

         [[-7.0291e-02]],

         [[-1.2763e-02]],

         [[-1.5234e-02]],

         [[ 5.3395e-02]],

         [[ 7.7670e-02]],

         [[-2.6480e-02]],

         [[ 1.2930e-01]],

         [[ 7.3504e-02]],

         [[-4.2087e-02]],

         [[-7.7048e-03]],

         [[ 8.1294e-02]],

         [[-1.3147e-01]],

         [[ 2.2291e-02]],

         [[-6.0163e-02]],

         [[-7.2472e-02]],

         [[ 1.0571e-02]],

         [[ 2.1816e-02]],

         [[ 6.7408e-02]],

         [[-2.7233e-02]],

         [[ 3.4666e-02]],

         [[-4.3219e-02]],

         [[-5.7314e-02]],

         [[ 2.1528e+00]],

         [[ 1.1179e-01]],

         [[-2.2600e-02]],

         [[ 2.1942e-02]],

         [[ 9.9519e-01]],

         [[ 7.0537e-02]],

         [[-1.0749e-01]],

         [[-1.6291e+00]],

         [[-4.4280e-01]],

         [[-8.4148e-02]],

         [[ 8.0464e-03]],

         [[ 1.0146e+00]],

         [[ 1.0817e-01]],

         [[ 6.2207e-02]],

         [[ 3.9145e-02]],

         [[ 1.2623e-01]],

         [[ 1.4152e-02]],

         [[-4.4623e-02]],

         [[-4.4330e-03]],

         [[-1.4209e-02]],

         [[-2.9335e-03]],

         [[-4.0488e-02]],

         [[ 5.0055e-02]],

         [[-2.6504e-02]],

         [[ 1.3933e-01]],

         [[-1.2863e-02]],

         [[-1.8533e-02]],

         [[-3.7252e-02]],

         [[-9.3649e-02]],

         [[ 2.6300e-02]],

         [[-1.0248e-01]],

         [[-3.8544e-02]],

         [[-2.4356e-02]],

         [[-3.1093e-02]],

         [[ 1.7427e-02]],

         [[-1.2240e-02]],

         [[-1.4430e-01]],

         [[ 7.5422e-02]],

         [[-8.2282e-03]],

         [[ 2.1688e-02]],

         [[-8.6355e-02]],

         [[-2.6483e-02]],

         [[ 4.8176e-02]],

         [[-7.8586e-02]],

         [[ 7.3879e-04]],

         [[-2.2237e-02]],

         [[-4.7223e-02]],

         [[-3.6401e-02]],

         [[ 6.5606e-02]],

         [[-7.2511e-01]],

         [[-4.5284e-01]],

         [[ 1.1181e-01]],

         [[ 1.3715e-02]],

         [[ 5.5456e-02]],

         [[-9.0925e-02]],

         [[-2.8159e-02]],

         [[-1.2489e-01]]],


        [[[-5.6662e-01]],

         [[-1.7372e-02]],

         [[ 3.6378e-02]],

         [[ 8.1078e-03]],

         [[ 4.7308e-02]],

         [[ 4.4889e-02]],

         [[ 1.6944e-02]],

         [[ 6.6558e-02]],

         [[ 9.0538e-02]],

         [[-6.3685e-02]],

         [[-1.9302e-02]],

         [[ 8.1401e-02]],

         [[-1.6091e-01]],

         [[-1.9780e-02]],

         [[-3.0893e-02]],

         [[-6.8446e-02]],

         [[-1.3715e-02]],

         [[ 1.1084e-02]],

         [[ 8.5815e-02]],

         [[-6.6132e-02]],

         [[ 3.6838e-02]],

         [[-6.0206e-03]],

         [[-6.9936e-02]],

         [[ 1.6245e+00]],

         [[ 7.3798e-02]],

         [[-2.3805e-02]],

         [[-4.6397e-03]],

         [[ 6.0671e-01]],

         [[ 1.0871e-01]],

         [[ 9.7804e-02]],

         [[-1.5399e+00]],

         [[-1.1413e-01]],

         [[-1.4293e-01]],

         [[-2.3569e-02]],

         [[ 2.0211e+00]],

         [[ 1.5142e-01]],

         [[ 1.0056e-01]],

         [[ 9.7862e-03]],

         [[ 1.3259e-01]],

         [[ 4.7402e-02]],

         [[-8.1663e-02]],

         [[ 6.7836e-03]],

         [[-1.7629e-02]],

         [[-3.9694e-02]],

         [[-3.2865e-02]],

         [[ 4.0548e-02]],

         [[ 1.4034e-02]],

         [[ 1.3091e-01]],

         [[-4.7296e-02]],

         [[ 1.9194e-02]],

         [[ 6.8931e-03]],

         [[-5.2722e-02]],

         [[ 4.3650e-04]],

         [[-1.1848e-01]],

         [[-1.5542e-02]],

         [[-2.0323e-02]],

         [[-2.2623e-03]],

         [[ 4.1010e-02]],

         [[ 3.3822e-02]],

         [[-1.7070e-01]],

         [[ 1.1287e-01]],

         [[ 6.3826e-03]],

         [[-1.2915e-02]],

         [[-1.1304e-01]],

         [[-1.6583e-02]],

         [[ 4.8252e-02]],

         [[-5.7537e-02]],

         [[-2.6370e-02]],

         [[-1.0637e-01]],

         [[-7.5141e-02]],

         [[ 2.2686e-02]],

         [[ 5.3278e-02]],

         [[-1.1578e+00]],

         [[ 1.7661e-01]],

         [[ 1.0572e-01]],

         [[ 2.8829e-02]],

         [[ 8.7615e-02]],

         [[-8.5253e-02]],

         [[-1.0406e-02]],

         [[-1.3934e-01]]],


        [[[-9.3577e-01]],

         [[-4.3846e-02]],

         [[ 3.8813e-02]],

         [[ 2.1416e-02]],

         [[ 2.9245e-02]],

         [[ 2.7624e-03]],

         [[ 4.0589e-02]],

         [[ 9.1565e-02]],

         [[ 8.4967e-02]],

         [[-8.0109e-02]],

         [[-3.3492e-02]],

         [[ 6.6276e-03]],

         [[-1.0586e-01]],

         [[-1.4072e-03]],

         [[-7.0655e-02]],

         [[-4.2444e-02]],

         [[-5.7034e-02]],

         [[ 2.8436e-02]],

         [[ 1.0867e-01]],

         [[-6.0945e-02]],

         [[ 6.3024e-02]],

         [[-8.4975e-03]],

         [[-3.5578e-02]],

         [[ 1.3830e+00]],

         [[ 6.5107e-02]],

         [[ 1.2732e-02]],

         [[ 6.0123e-02]],

         [[ 5.6068e-01]],

         [[ 1.3857e-01]],

         [[ 5.7051e-01]],

         [[-1.3417e+00]],

         [[-2.6874e-01]],

         [[-1.4000e-01]],

         [[-3.4107e-02]],

         [[ 1.5092e+00]],

         [[ 1.5911e-01]],

         [[ 1.0113e-01]],

         [[-5.0543e-02]],

         [[ 6.8653e-02]],

         [[ 1.6767e-02]],

         [[-9.4661e-02]],

         [[ 3.1447e-03]],

         [[-4.7750e-02]],

         [[-4.2482e-02]],

         [[-4.4464e-02]],

         [[ 4.5177e-02]],

         [[-7.6296e-03]],

         [[ 1.3239e-01]],

         [[ 9.2028e-03]],

         [[ 1.5603e-02]],

         [[-1.8425e-03]],

         [[-5.7741e-02]],

         [[ 5.6675e-02]],

         [[-9.9763e-02]],

         [[-2.4355e-02]],

         [[ 1.5806e-02]],

         [[ 1.7609e-02]],

         [[ 1.9650e-02]],

         [[ 8.5162e-02]],

         [[-1.3551e-01]],

         [[ 1.1470e-01]],

         [[ 5.0192e-02]],

         [[ 1.9493e-02]],

         [[-5.4585e-02]],

         [[ 1.9097e-02]],

         [[ 4.5244e-02]],

         [[-6.2959e-02]],

         [[-1.0459e-02]],

         [[-1.1219e-01]],

         [[-6.3300e-02]],

         [[-4.7604e-03]],

         [[ 3.0598e-02]],

         [[-1.4422e+00]],

         [[ 5.3024e-01]],

         [[ 8.9773e-02]],

         [[-1.8431e-02]],

         [[ 6.1410e-02]],

         [[-3.8839e-02]],

         [[ 5.5084e-03]],

         [[-1.0831e-01]]],


        [[[-3.3189e-01]],

         [[-1.9446e-02]],

         [[ 5.5837e-02]],

         [[ 6.8228e-03]],

         [[ 1.0625e-02]],

         [[ 8.1591e-02]],

         [[-2.1832e-02]],

         [[ 9.5458e-02]],

         [[ 8.2751e-02]],

         [[-5.1256e-02]],

         [[-3.3994e-02]],

         [[ 6.3541e-02]],

         [[-1.4032e-01]],

         [[-2.3860e-02]],

         [[-9.0546e-02]],

         [[-9.8386e-02]],

         [[-4.7722e-02]],

         [[ 4.0285e-02]],

         [[ 5.6187e-02]],

         [[-4.2866e-02]],

         [[ 6.9068e-02]],

         [[ 3.1833e-02]],

         [[-3.3349e-02]],

         [[ 1.9428e+00]],

         [[ 6.5216e-02]],

         [[-3.6645e-03]],

         [[ 4.2890e-03]],

         [[ 8.9104e-01]],

         [[ 9.1819e-02]],

         [[ 9.3136e-02]],

         [[-1.4033e+00]],

         [[-2.6981e-01]],

         [[-1.1012e-01]],

         [[-2.1555e-02]],

         [[ 1.1218e+00]],

         [[ 1.4088e-01]],

         [[ 1.1866e-01]],

         [[ 5.5112e-03]],

         [[ 1.3498e-01]],

         [[ 1.9925e-02]],

         [[-1.2176e-01]],

         [[ 4.7967e-03]],

         [[-1.4316e-03]],

         [[-4.7118e-04]],

         [[-5.6228e-02]],

         [[ 6.4555e-02]],

         [[ 7.8806e-03]],

         [[ 1.1947e-01]],

         [[-6.1449e-02]],

         [[ 2.2647e-03]],

         [[-3.3782e-02]],

         [[-7.3577e-02]],

         [[ 6.5389e-02]],

         [[-7.7591e-02]],

         [[-1.1609e-02]],

         [[-4.7356e-02]],

         [[ 7.8632e-03]],

         [[ 1.0340e-02]],

         [[ 3.5183e-02]],

         [[-1.4570e-01]],

         [[ 1.0716e-01]],

         [[ 2.1980e-02]],

         [[ 1.2434e-02]],

         [[-6.8202e-02]],

         [[ 2.0452e-02]],

         [[ 7.7962e-02]],

         [[-4.5192e-02]],

         [[ 7.6807e-03]],

         [[-8.4734e-02]],

         [[-2.8607e-02]],

         [[-4.5014e-03]],

         [[ 8.4037e-02]],

         [[-9.3664e-01]],

         [[-6.6527e-02]],

         [[ 9.7697e-02]],

         [[ 5.0429e-02]],

         [[ 4.1691e-02]],

         [[-8.6985e-02]],

         [[-3.5593e-03]],

         [[-1.3485e-01]]]], grad_fn=<AsStridedBackward1>)})

Other functionalities#

In addition to recording, the following functionalities are available and documented in Recorder:

  • One-the-fly postprocessing of activations during inference (e.g. clipping).

  • Local disabling of the recording during forward pass.

  • Access the recorded layers’ Module objects directly.

  • Get the named parameters of the recorded layers.

Implementation details#

This utility works by affecting hooks to every layer of net with Module.register_forward_hook. However, since layers are not aware of the context in which they are called, these hooks carry references to rnet and with it, the sufficient context to know when to trigger. This means that two different Recorder Nets can wrap the same net without any conflict. As an implementation detail, note that these references are made weak in order to be properly cleaned up upon deletion of rnet.

Total running time of the script: (0 minutes 0.153 seconds)

Gallery generated by Sphinx-Gallery