TFC*_*TFC 12 python pytorch torchvision
我正在为包含许多图像对的数据集编写一个简单的转换。作为数据增强,我想对每一对应用一些随机变换,但该对中的图像应该以相同的方式进行变换。\n例如,给定一对两个图像A和B,如果A水平翻转,则B必须水平翻转作为A。那么下一对C和D应该与A和进行不同的变换B,但是C和 也D以相同的方式进行变换。我正在尝试用下面的方式
import random\nimport numpy as np\nimport torchvision.transforms as transforms\nfrom PIL import Image\n\nimg_a = Image.open("sample_ajpg") # note that two images have the same size\nimg_b = Image.open("sample_b.png")\nimg_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")\n\ntransform = transforms.RandomChoice(\n [transforms.RandomHorizontalFlip(), \n transforms.RandomVerticalFlip()]\n)\nrandom.seed(0)\ndisplay(transform(img_a))\ndisplay(transform(img_b))\n\nrandom.seed(1)\ndisplay(transform(img_c))\ndisplay(transform(img_d))\nRun Code Online (Sandbox Code Playgroud)\n然而\xe3\x80\x81上面的代码没有选择相同的转换,并且根据我的测试,它取决于调用的次数transform。
有没有办法transforms.RandomChoice在指定时强制使用相同的转换?
Iva*_*van 11
通常的解决方法是在第一张图像上应用变换,检索该变换的参数,然后在其余图像上应用这些参数的确定性变换。但是,这里RandomChoice不提供 API 来获取所应用变换的参数,因为它涉及可变数量的变换。在这些情况下,我通常会覆盖原始函数。
看看torchvision 的实现,它很简单:
class RandomChoice(RandomTransforms):
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
Run Code Online (Sandbox Code Playgroud)
这里有两种可能的解决方案。
您可以从转换列表中进行采样__init__不是 on __call__:
import random
import torchvision.transforms as T
class RandomChoice(torch.nn.Module):
def __init__(self):
super().__init__()
self.t = random.choice(self.transforms)
def __call__(self, img):
return self.t(img)
Run Code Online (Sandbox Code Playgroud)
所以你可以这样做:
transform = RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
display(transform(img_a)) # both img_a and img_b will
display(transform(img_b)) # have the same transform
transform = RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
display(transform(img_c)) # both img_c and img_d will
display(transform(img_d)) # have the same transform
Run Code Online (Sandbox Code Playgroud)
或者更好的是,批量转换图像:
import random
import torchvision.transforms as T
class RandomChoice(torch.nn.Module):
def __init__(self, transforms):
super().__init__()
self.transforms = transforms
def __call__(self, imgs):
t = random.choice(self.transforms)
return [t(img) for img in imgs]
Run Code Online (Sandbox Code Playgroud)
这允许执行以下操作:
transform = RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
img_at, img_bt = transform([img_a, img_b])
display(img_at) # both img_a and img_b will
display(img_bt) # have the same transform
img_ct, img_dt = transform([img_c, img_d])
display(img_ct) # both img_c and img_d will
display(img_dt) # have the same transform
Run Code Online (Sandbox Code Playgroud)