使不同大小的火炬张量相等

Jak*_*sig 5 python machine-learning padding pytorch tensor

我正在寻找一种方法来获取图像/目标批次进行分割并返回图像尺寸已更改为与整个批次相同的批次。我已经使用下面的代码尝试过:

def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    # separate the image and masks
    image_batch,mask_batch = zip(*batch)

    # pad the images and masks
    image_batch = torch.nn.utils.rnn.pad_sequence(image_batch, batch_first=True)
    mask_batch = torch.nn.utils.rnn.pad_sequence(mask_batch, batch_first=True)

    # rezip the batch
    batch = list(zip(image_batch, mask_batch))

    return batch
Run Code Online (Sandbox Code Playgroud)

但是,我收到此错误:

RuntimeError: The expanded size of the tensor (650) must match the existing size (439) at non-singleton dimension 2.  Target sizes: [3, 650, 650].  Tensor sizes: [3, 406, 439]
Run Code Online (Sandbox Code Playgroud)

如何有效地将张量填充为相等尺寸并避免此问题?

Mic*_*ngo 11

rnn.pad_sequence仅填充序列维度,它要求所有其他维度都相等。您不能使用它来跨二维(高度和宽度)填充图像。

可以使用填充图像torch.nn.functional.pad,但您需要手动确定需要填充的高度和宽度。

import torch.nn.functional as F

# Determine maximum height and width
# The mask's have the same height and width
# since they mask the image.
max_height = max([img.size(1) for img in image_batch])
max_width = max([img.size(2) for img in image_batch])

image_batch = [
    # The needed padding is the difference between the
    # max width/height and the image's actual width/height.
    F.pad(img, [0, max_width - img.size(2), 0, max_height - img.size(1)])
    for img in image_batch
]
mask_batch = [
    # Same as for the images, but there is no channel dimension
    # Therefore the mask's width is dimension 1 instead of 2
    F.pad(mask, [0, max_width - mask.size(1), 0, max_height - mask.size(0)])
    for mask in mask_batch
]
Run Code Online (Sandbox Code Playgroud)

填充长度以维度的相反顺序指定,其中每个维度都有两个值,一个用于开头的填充,一个用于结尾的填充。对于具有尺寸的图像,[channels, height, width]填充为:[width_beginning, width_end, height_beginning, height_top],可以重写为[left, right, top, bottom]。因此,上面的代码将图像填充到右侧和底部。通道被省略,因为它们没有被填充,这也意味着相同的填充可以直接应用于蒙版。