Recorder#

class scio.recorder.Recorder(net, /, *, force_static_flow=True, **summary_kwargs)[source]#

Bases: Module

Wrapper class to operate inside a torch neural network.

A “recorder net” (referred to as rnet) is a wrapped instance of torch.nn.Module, augmented with inspection utilities.

Parameters:
  • net (nn.Module) – Net to be recorded.

  • force_static_flow (bool) –

    If True, the control flow is supposed static, meaning that it should be the same for any processed input. Two counter examples would be:

    • if the network operates a data-shape-specific resizing flows;

    • if the network uses if statements or loops based on the layer output values, leading to specific handling.

    More information on dynamic control flow. Defaults to True.

  • **summary_kwargs – Passed to torchinfo.summary(). Requires input_data or input_size. Keyword force_static_flow is restricted by design.

Example

Consider the following toy rnet, built with Recorder and layers from torch.nn.Module:

toynet = Sequential(Linear(5, 4), ReLU(), Linear(4, 3), ReLU())
rnet = Recorder(toynet, input_size=(1, 5))

First of all, we easily observe the architecture of the network in a human-redeable form with labelled layers:

>>> rnet
Recorder instance for the following network
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Param %
============================================================================================================================================
Sequential (Sequential)                  [1, 5]                    [1, 3]                    --                             --
├─Linear (0): 1-1                        [1, 5]                    [1, 4]                    24                         61.54%
├─ReLU (1): 1-2                          [1, 4]                    [1, 4]                    --                             --
├─Linear (2): 1-3                        [1, 4]                    [1, 3]                    15                         38.46%
├─ReLU (3): 1-4                          [1, 3]                    [1, 3]                    --                             --
============================================================================================================================================
Total params: 39
Trainable params: 39
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.00
============================================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
============================================================================================================================================
Currently recording: None
============================================================================================================================================

Note how layers are labelled (e.g. 1-3). The (hidden) label 0-1 refers to the entire network. In the above example, the last line reports that no layer is being recorded. One can change that with record():

>>> rnet.record((0, 1), (1, 3))
>>> rnet
[...]
Currently recording: 0-1, 1-3
[...]

Upon computing a forward pass, the associated activations (i.e. the output of the targeted modules) are accessible through the activations attribute:

>>> rnet(torch.rand((1, 5)))  # Forward pass for 1 random sample
tensor([[0.0000, 0.1592, 0.0471]], grad_fn=<ReluBackward0>)
>>> rnet.activations
mappingproxy({(1, 3): tensor([[-0.1398,  0.7325, -0.2175]], grad_fn=<AddmmBackward0>), (0, 1): tensor([[0.0000, 0.7325, 0.0000]], grad_fn=<ReluBackward0>)})

Other utilities are documented below.

Warning

In case of dynamic control flow, the result is not guaranteed! It might work, fail, or fail silently (i.e. the recorded activations could in fact correspond to the wrong layers).

Useful methods defined here

forward(*args[, activation_postproc, ...])

Forward pass with optional recording & on-the-fly processing.

record(*depth_idxs)

Set the recorded layers.

Useful attributes defined here

activations

Recorded activations from the latest forward pass.

net

The wrapped torch.nn.Module instance.

recorded_modules

Map to recorded modules.

recorded_params

Map to named parameters from recorded modules only.

recording

Tuple of layers which are being recorded.

forward(*args, activation_postproc=None, dont_record=False, **kwargs)[source]#

Forward pass with optional recording & on-the-fly processing.

During instantiation, the structure of the network is computed with torchinfo. By using record(), the user defines layers of interest. During a forward pass, their activations are recorded and later accessible through the activations attribute. If required, these are also postprocessed in-place before they are propagated further in the network (see activation_postproc).

Parameters:
  • *args – Passed to self.net.

  • activation_postproc (list[Postprocessor] | Postprocessor, optional) –

    If provided, the postprocessing to apply to activations before propagating them further in the network. We define:

    type Postprocessor = Callable[[Tensor], Tensor | None]
    

    where the signature can be interpreted as output -> modified output or None. Postprocessings are applied after recording and not on a copy. As such, any in-place operation will impact records. If a list, elements correspond to per-layer postprocessing, in the same order recorded layers are called. The list must be long enough for every recorded layer to get a postprocessor, and remaining elements are ignored. If a Postprocessor, it will be applied to all recorded layers.

  • dont_record (bool) – If True, activations are not recorded. Postprocessing, if any, will still occur. Defaults to False.

  • **kwargs – Passed to self.net. The keywords activation_postproc and dont_record are restricted by design.

Note

As with any torch.nn.Module instance, one should directly use call statements instead of using this method.

Raises:

RuntimeError – If not all the requested activations were found during the forward pass.

record(*depth_idxs)[source]#

Set the recorded layers.

Parameters:

*depth_idxs (DepthIdx) – The 2-tuples (depth, idx) of layers to record. Duplicates and order are ignored.

Example

rnet.record((0, 1), (1, 3))
Raises:

ValueError – If the requested layers were not found.

activations#

Recorded activations from the latest forward pass.

Example

>>> rnet.activations
mappingproxy({(1, 3): tensor([[0.0729, 0.0405, 0.2566]], grad_fn=<AddmmBackward0>), (0, 1): tensor([[0.0729, 0.0405, 0.2566]], grad_fn=<ReluBackward0>)})
>>> rnet.activations[(1, 3)]
tensor([[0.0729, 0.0405, 0.2566]], grad_fn=<AddmmBackward0>)
net#

The wrapped torch.nn.Module instance.

recorded_modules#

Map to recorded modules.

Example

>>> rnet.recorded_modules
{(0, 1): Sequential(
  (0): Linear(in_features=5, out_features=4, bias=True)
  (1): ReLU()
  (2): Linear(in_features=4, out_features=3, bias=True)
  (3): ReLU()
), (1, 3): Linear(in_features=4, out_features=3, bias=True)}
recorded_params#

Map to named parameters from recorded modules only.

Example

>>> rnet.recorded_params
{'_net.0.weight': Parameter containing:
tensor([[-0.2762, -0.1564,  0.1478, -0.1285, -0.1707],
        [ 0.1708, -0.3490, -0.3634,  0.4290, -0.0088],
        [ 0.2128,  0.2551,  0.4042,  0.4373,  0.2639],
        [ 0.1333,  0.0394,  0.2923,  0.2545, -0.4048]], requires_grad=True), '_net.0.bias': Parameter containing:
tensor([ 0.2976,  0.3207, -0.0690, -0.1086], requires_grad=True), '_net.2.weight': Parameter containing:
tensor([[ 0.3134,  0.0609, -0.2515, -0.1318],
        [ 0.0082,  0.3514, -0.3874,  0.0739],
        [ 0.4850, -0.3249,  0.4155, -0.1247]], requires_grad=True), '_net.2.bias': Parameter containing:
tensor([-0.0571, -0.2613,  0.0395], requires_grad=True)}
recording#

Tuple of layers which are being recorded.

Example

>>> rnet.recording
((0, 1), (1, 3))