PyTorch:如何对多个图像应用相同的随机变换?

TFC*_*TFC 12 python pytorch torchvision

我正在为包含许多图像对的数据集编写一个简单的转换。作为数据增强,我想对每一对应用一些随机变换,但该对中的图像应该以相同的方式进行变换。\n例如,给定一对两个图像AB,如果A水平翻转,则B必须水平翻转作为A。那么下一对CD应该与A和进行不同的变换B,但是C和 也D以相同的方式进行变换。我正在尝试用下面的方式

\n
import random\nimport numpy as np\nimport torchvision.transforms as transforms\nfrom PIL import Image\n\nimg_a = Image.open("sample_ajpg") # note that two images have the same size\nimg_b = Image.open("sample_b.png")\nimg_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")\n\ntransform = transforms.RandomChoice(\n    [transforms.RandomHorizontalFlip(), \n     transforms.RandomVerticalFlip()]\n)\nrandom.seed(0)\ndisplay(transform(img_a))\ndisplay(transform(img_b))\n\nrandom.seed(1)\ndisplay(transform(img_c))\ndisplay(transform(img_d))\n
Run Code Online (Sandbox Code Playgroud)\n

然而\xe3\x80\x81上面的代码没有选择相同的转换,并且根据我的测试,它取决于调用的次数transform

\n

有没有办法transforms.RandomChoice在指定时强制使用相同的转换?

\n

Iva*_*van 11

通常的解决方法是在第一张图像上应用变换,检索该变换的参数,然后在其余图像上应用这些参数的确定性变换。但是,这里RandomChoice不提供 API 来获取所应用变换的参数,因为它涉及可变数量的变换。在这些情况下,我通常会覆盖原始函数。

看看torchvision 的实现,它很简单:

class RandomChoice(RandomTransforms):
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)
Run Code Online (Sandbox Code Playgroud)

这里有两种可能的解决方案。

  1. 您可以从转换列表中进行采样__init__不是 on __call__

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.t = random.choice(self.transforms)
    
        def __call__(self, img):
            return self.t(img)
    
    Run Code Online (Sandbox Code Playgroud)

    所以你可以这样做:

    transform = RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    display(transform(img_a)) # both img_a and img_b will
    display(transform(img_b)) # have the same transform
    
    transform = RandomChoice([
        T.RandomHorizontalFlip(), 
        T.RandomVerticalFlip()
    ])
    display(transform(img_c)) # both img_c and img_d will
    display(transform(img_d)) # have the same transform
    
    Run Code Online (Sandbox Code Playgroud)

  1. 或者更好的是,批量转换图像:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self, transforms):
           super().__init__()
           self.transforms = transforms
    
        def __call__(self, imgs):
            t = random.choice(self.transforms)
            return [t(img) for img in imgs]
    
    Run Code Online (Sandbox Code Playgroud)

    这允许执行以下操作:

    transform = RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    
    img_at, img_bt = transform([img_a, img_b])
    display(img_at) # both img_a and img_b will
    display(img_bt) # have the same transform
    
    img_ct, img_dt = transform([img_c, img_d])
    display(img_ct) # both img_c and img_d will
    display(img_dt) # have the same transform
    
    Run Code Online (Sandbox Code Playgroud)

  • 我认为它应该是 RandomChoice() 而不是 T.RandomChoice() 否则它会调用 torchvision.transforms 的 RandomChoice 类。另外,当我用 RandomRotate 尝试此方法时,它不起作用。因为它只是从您列出的转换列表中随机选择一个转换,而不是在这些转换中。例如,如果您有一组图像需要以相同的方式进行增强,则不幸的是,此方法不起作用,因为它们仍然可能会随机程度地进行变换。 (4认同)