Home > Net >  Changing the output of a convolutional layer to a tuple of tensors
Changing the output of a convolutional layer to a tuple of tensors

Time:04-29

For processing video frames, I use the squeeze and excitation block for weighting the channels of a convolutional layer. I want to combine (using torch.stack) the channels(feature maps) of a convolutional layer with the weighted channels (by using the mentioned squeeze and excitation block). But I faced with an error that when using the torch.stack(x, weighted_channels) the argument that is related with the convolutional layer's channelsx, the error says that the TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor.

class conv(nn.Module):
def __init__(self, in_channel, out_channel, out_sigmoid=False):
    super(conv, self).__init__()
            
    self.deconv = self._deconv(in_channel=512, out_channel=256, num_conv=3)
    self.upsample = Upsample(scale_factor=2, mode='bilinear')
    self.SEBlock = SE_Block(c=256)


def _deconv(self, in_channel, out_channel, num_conv=2, kernel_size=3, stride=1, padding=1):
    layers=[]
    layers.append(BasicConv2d(in_channel, out_channel,kernel_size=kernel_size, stride=stride, padding=padding))
    for i in range(1, num_conv):
        layers.append(_SepConv2d(out_channel, out_channel,kernel_size=kernel_size, stride=stride, padding=padding))
    return nn.Sequential(*layers)

   def forward(self, x):
   
    x=self.deconv(x)
    x = self.upsample(x)
    stack = torch.stack(x, self.SEBlock(x,c=256))
    return x



    class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
    self.bn = nn.BatchNorm2d(out_planes, eps=1e-3, momentum=0.001, affine=True)
    self.relu = nn.ReLU()

def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)
    return x

 class _SepConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
    super(_SepConv2d, self).__init__()
    self.conv_s = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,       padding=padding, bias=False, groups=in_planes)
    self.bn_s = nn.BatchNorm2d(out_planes)
    self.relu_s = nn.ReLU()

    self.conv_t = nn.Conv2d(out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
    self.bn_t = nn.BatchNorm2d(out_planes)
    self.relu_t = nn.ReLU()

def forward(self, x):
    x = self.conv_s(x)
    x = self.bn_s(x)
    x = self.relu_s(x)

    x = self.conv_t(x)
    x = self.bn_t(x)
    x = self.relu_t(x)
    return x 

class SE_Block(nn.Module):
"credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4"
def __init__(self, c, r=16):
    super().__init__()
    self.squeeze = nn.AdaptiveAvgPool2d(1)
    self.excitation = nn.Sequential(
        nn.Linear(c, c // r, bias=False),
        nn.ReLU(inplace=True),
        nn.Linear(c // r, c, bias=False),
        nn.Sigmoid()
    )

def forward(self, x):
    bs, c, _, _ = x.shape
    y = self.squeeze(x).view(bs, c)
    y = self.excitation(y).view(bs, c, 1, 1)
    return x * y.expand_as(x)

I checked two arguments of torch.stack but the both are of the same size.

CodePudding user response:

See https://pytorch.org/docs/stable/generated/torch.stack.html.

torch.stack(tensors, dim=0, *, out=None) → Tensor

  • tensors (sequence of Tensors) – sequence of tensors to concatenate

A sequences of tensors can be a tuple like (tensor1, tensor2, tensor3) or a list [tensor1, tensor2, tensor3]. What you did is input x which is a tensor instead of a sequence of tensors and weighted_channels as the dim parameter into the function.

So as noted in the comments either
torch.stack((x, weighted_channels)) or torch.stack([x, weighted_channels]) should work.

Keep in mind that this is the same for all functions which take an arbitrary number of tensors and does something with them, e.g. torch.cat and all other stack functions like vstack, hstack-

  • Related