Skip to content

modules.conv.AE2d

Class · nn.Module · Source

net = mdnc.modules.conv.AE2d(
    channel, layers,
    kernel_size=3, in_planes=1, out_planes=1
)

This moule is a built-in model for 2D convolutional auto-encoder. The network structure is almost the same as mdnc.modules.conv.UNet2d but all block-level skip connections are removed. Generally, using mdnc.modules.conv.UNet2d should be a better choice.

The network would down-sample and up-sample the input data according to the network depth. The depth is given by the length of the argument layers. The network structure is shown in the following chart:

flowchart TB
    b1["Block 1<br>Stack of layers[0] layers"]
    b2["Block 2<br>Stack of layers[1] layers"]
    bi["Block ...<br>Stack of ... layers"]
    bn["Block n<br>Stack of layers[n-1] layers"]
    u1["Block 2n-1<br>Stack of layers[0] layers"]
    u2["Block 2n-2<br>Stack of layers[1] layers"]
    ui["Block ...<br>Stack of ... layers"]
    b1 -->|down<br>sampling| b2 -->|down<br>sampling| bi -->|down<br>sampling| bn
    bn -->|up<br>sampling| ui -->|up<br>sampling| u2 -->|up<br>sampling| u1
    linkStyle 0,1,2 stroke-width:4px, stroke:#800 ;
    linkStyle 3,4,5 stroke-width:4px, stroke:#080 ;

The argument layers is a sequence of int. For each block \(i\), it contains layers[i-1] repeated modern convolutional layers (see mdnc.modules.conv.ConvModern2d). Each down-sampling or up-sampling is configured by stride=2. The channel number would be doubled in the down-sampling route and reduced to ½ in the up-sampling route.

Arguments

Requries

Argument Type Description
channel int The channel number of the first hidden block (layer). After each down-sampling, the channel number would be doubled. After each up-sampling, the channel number would be reduced to ½.
layers (int,) A sequence of layer numbers for each block. Each number represents the number of convolutional layers of a stage (block). The stage numer, i.e. the depth of the network is the length of this list.
kernel_size int or
(int, int)
The kernel size of each convolutional layer.
in_planes int The channel number of the input data.
out_planes int The channel number of the output data.

Operators

__call__

y = net(x)

The forward operator implemented by the forward() method. The input is a 2D tensor, and the output is the final output of this network.

Requries

Argument Type Description
x torch.Tensor A 2D tensor, the size should be (B, C, L1, L2), where B is the batch size, C is the input channel number, and (L1, L2) is the input data size.

Returns

Argument Description
y A 2D tensor, the size should be (B, C, L1, L2), where B is the batch size, C is the output channel number, and (L1, L2) is the input data size.

Properties

nlayers

net.nlayers

The total number of convolutional layers along the depth of the network.

Examples

Example
1
2
3
4
5
import mdnc

net = mdnc.modules.conv.AE2d(64, [2, 2, 3, 3, 3], in_planes=3, out_planes=1)
print('The number of convolutional layers along the depth is {0}.'.format(net.nlayers))
mdnc.contribs.torchsummary.summary(net, (3, 64, 63), device='cpu')
The number of convolutional layers along the depth is 25.
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 64, 63]           4,800
    InstanceNorm2d-2           [-1, 64, 64, 63]             128
             PReLU-3           [-1, 64, 64, 63]              64
            Conv2d-4           [-1, 64, 64, 63]          36,864
     _ConvModernNd-5           [-1, 64, 64, 63]               0
    InstanceNorm2d-6           [-1, 64, 64, 63]             128
             PReLU-7           [-1, 64, 64, 63]              64
            Conv2d-8           [-1, 64, 32, 32]          36,864
     _ConvModernNd-9           [-1, 64, 32, 32]               0
  _BlockConvStkNd-10           [-1, 64, 32, 32]               0
   InstanceNorm2d-11           [-1, 64, 32, 32]             128
            PReLU-12           [-1, 64, 32, 32]              64
           Conv2d-13          [-1, 128, 32, 32]          73,728
    _ConvModernNd-14          [-1, 128, 32, 32]               0
   InstanceNorm2d-15          [-1, 128, 32, 32]             256
            PReLU-16          [-1, 128, 32, 32]             128
           Conv2d-17          [-1, 128, 16, 16]         147,456
    _ConvModernNd-18          [-1, 128, 16, 16]               0
  _BlockConvStkNd-19          [-1, 128, 16, 16]               0
   InstanceNorm2d-20          [-1, 128, 16, 16]             256
            PReLU-21          [-1, 128, 16, 16]             128
           Conv2d-22          [-1, 256, 16, 16]         294,912
    _ConvModernNd-23          [-1, 256, 16, 16]               0
   InstanceNorm2d-24          [-1, 256, 16, 16]             512
            PReLU-25          [-1, 256, 16, 16]             256
           Conv2d-26          [-1, 256, 16, 16]         589,824
    _ConvModernNd-27          [-1, 256, 16, 16]               0
   InstanceNorm2d-28          [-1, 256, 16, 16]             512
            PReLU-29          [-1, 256, 16, 16]             256
           Conv2d-30            [-1, 256, 8, 8]         589,824
    _ConvModernNd-31            [-1, 256, 8, 8]               0
  _BlockConvStkNd-32            [-1, 256, 8, 8]               0
   InstanceNorm2d-33            [-1, 256, 8, 8]             512
            PReLU-34            [-1, 256, 8, 8]             256
           Conv2d-35            [-1, 512, 8, 8]       1,179,648
    _ConvModernNd-36            [-1, 512, 8, 8]               0
   InstanceNorm2d-37            [-1, 512, 8, 8]           1,024
            PReLU-38            [-1, 512, 8, 8]             512
           Conv2d-39            [-1, 512, 8, 8]       2,359,296
    _ConvModernNd-40            [-1, 512, 8, 8]               0
   InstanceNorm2d-41            [-1, 512, 8, 8]           1,024
            PReLU-42            [-1, 512, 8, 8]             512
           Conv2d-43            [-1, 512, 4, 4]       2,359,296
    _ConvModernNd-44            [-1, 512, 4, 4]               0
  _BlockConvStkNd-45            [-1, 512, 4, 4]               0
   InstanceNorm2d-46            [-1, 512, 4, 4]           1,024
            PReLU-47            [-1, 512, 4, 4]             512
           Conv2d-48           [-1, 1024, 4, 4]       4,718,592
    _ConvModernNd-49           [-1, 1024, 4, 4]               0
   InstanceNorm2d-50           [-1, 1024, 4, 4]           2,048
            PReLU-51           [-1, 1024, 4, 4]           1,024
           Conv2d-52           [-1, 1024, 4, 4]       9,437,184
    _ConvModernNd-53           [-1, 1024, 4, 4]               0
   InstanceNorm2d-54           [-1, 1024, 4, 4]           2,048
            PReLU-55           [-1, 1024, 4, 4]           1,024
         Upsample-56           [-1, 1024, 8, 8]               0
           Conv2d-57            [-1, 512, 8, 8]       4,718,592
    _ConvModernNd-58            [-1, 512, 8, 8]               0
  _BlockConvStkNd-59            [-1, 512, 8, 8]               0
   InstanceNorm2d-60            [-1, 512, 8, 8]           1,024
            PReLU-61            [-1, 512, 8, 8]             512
           Conv2d-62            [-1, 512, 8, 8]       2,359,296
    _ConvModernNd-63            [-1, 512, 8, 8]               0
   InstanceNorm2d-64            [-1, 512, 8, 8]           1,024
            PReLU-65            [-1, 512, 8, 8]             512
           Conv2d-66            [-1, 512, 8, 8]       2,359,296
    _ConvModernNd-67            [-1, 512, 8, 8]               0
   InstanceNorm2d-68            [-1, 512, 8, 8]           1,024
            PReLU-69            [-1, 512, 8, 8]             512
         Upsample-70          [-1, 512, 16, 16]               0
           Conv2d-71          [-1, 256, 16, 16]       1,179,648
    _ConvModernNd-72          [-1, 256, 16, 16]               0
  _BlockConvStkNd-73          [-1, 256, 16, 16]               0
   InstanceNorm2d-74          [-1, 256, 16, 16]             512
            PReLU-75          [-1, 256, 16, 16]             256
           Conv2d-76          [-1, 256, 16, 16]         589,824
    _ConvModernNd-77          [-1, 256, 16, 16]               0
   InstanceNorm2d-78          [-1, 256, 16, 16]             512
            PReLU-79          [-1, 256, 16, 16]             256
           Conv2d-80          [-1, 256, 16, 16]         589,824
    _ConvModernNd-81          [-1, 256, 16, 16]               0
   InstanceNorm2d-82          [-1, 256, 16, 16]             512
            PReLU-83          [-1, 256, 16, 16]             256
         Upsample-84          [-1, 256, 32, 32]               0
           Conv2d-85          [-1, 128, 32, 32]         294,912
    _ConvModernNd-86          [-1, 128, 32, 32]               0
  _BlockConvStkNd-87          [-1, 128, 32, 32]               0
   InstanceNorm2d-88          [-1, 128, 32, 32]             256
            PReLU-89          [-1, 128, 32, 32]             128
           Conv2d-90          [-1, 128, 32, 32]         147,456
    _ConvModernNd-91          [-1, 128, 32, 32]               0
   InstanceNorm2d-92          [-1, 128, 32, 32]             256
            PReLU-93          [-1, 128, 32, 32]             128
         Upsample-94          [-1, 128, 64, 64]               0
           Conv2d-95           [-1, 64, 64, 64]          73,728
    _ConvModernNd-96           [-1, 64, 64, 64]               0
  _BlockConvStkNd-97           [-1, 64, 64, 64]               0
   InstanceNorm2d-98           [-1, 64, 64, 63]             128
            PReLU-99           [-1, 64, 64, 63]              64
          Conv2d-100           [-1, 64, 64, 63]          36,864
   _ConvModernNd-101           [-1, 64, 64, 63]               0
  InstanceNorm2d-102           [-1, 64, 64, 63]             128
           PReLU-103           [-1, 64, 64, 63]              64
          Conv2d-104           [-1, 64, 64, 63]          36,864
   _ConvModernNd-105           [-1, 64, 64, 63]               0
 _BlockConvStkNd-106           [-1, 64, 64, 63]               0
          Conv2d-107            [-1, 3, 64, 63]           4,803
            AE2d-108            [-1, 3, 64, 63]               0
================================================================
Total params: 34,241,859
Trainable params: 34,241,859
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 79.62
Params size (MB): 130.62
Estimated Total Size (MB): 210.29
----------------------------------------------------------------

Last update: March 14, 2021

Comments