有没有办法检索随机 torchvision 变换中使用的特定参数?

Nic*_*ill 5 python affinetransform pytorch torchvision data-augmentation

我可以在训练期间通过应用随机变换(旋转/平移/重新缩放)来增强数据,但我不知道选择的值。

我需要知道应用了哪些值。我可以手动设置这些值,但这样我就失去了火炬视觉变换提供的很多好处。

有没有一种简单的方法可以让这些价值观以合理的方式实现并在培训期间应用?

这是一个例子。我希望能够打印出每个图像上应用的旋转角度、平移/重新缩放:

import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms


RandAffine = transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2))

rotate = transforms.RandomRotation(degrees=45)
shift = RandAffine
composed = transforms.Compose([rotate,
                               shift])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = np.zeros((28,28))
sample[5:15,7:20] = 255
sample = transforms.ToPILImage()(sample.astype(np.uint8))
title = ['None', 'Rot','Aff','Comp']
for i, tsfrm in enumerate([None,rotate, shift, composed]):
    if tsfrm:
        t_sample = tsfrm(sample)
    else:
        t_sample = sample
    ax = plt.subplot(1, 5, i + 2)
    plt.tight_layout()
    ax.set_title(title[i])
    ax.imshow(np.reshape(np.array(list(t_sample.getdata())), (-1,28)), cmap='gray')    

plt.show()

Run Code Online (Sandbox Code Playgroud)

Iva*_*van 5

恐怕没有简单的方法可以解决这个问题:Torchvision 的随机变换实用程序的构建方式是在调用时对变换参数进行采样。它们是独特的随机变换,因为(1)用户无法访问所使用的参数,并且(2)相同的随机变换不可重复。

从 Torchvision 0.8.0开始,随机变换通常由两个主要函数构建:

  • get_params:它将根据变换的超参数进行采样(初始化变换运算符时提供的内容,即参数的值范围)

  • forward:应用变换时执行的函数。重要的是它获取参数,get_params然后使用关联的确定性函数将其应用于输入。对于RandomRotationF.rotate将被调用。同样,RandomAffine将使用F.affine.

解决您的问题的一种方法是从您自己那里采样参数get_params并调用函数式确定性API。因此,您不会使用RandomRotationRandomAffine,也不会使用任何其他Random*转换。


例如,让我们看一下T.RandomRotation(为了简洁我删除了注释)。

class RandomRotation(torch.nn.Module):
    def __init__(
        self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, 
        center=None, fill=None, resample=None):
        # ...

    @staticmethod
    def get_params(degrees: List[float]) -> float:
        angle = float(torch.empty(1).uniform_(float(degrees[0]), \
            float(degrees[1])).item())
        return angle

    def forward(self, img):
        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * F._get_image_num_channels(img)
            else:
                fill = [float(f) for f in fill]
        angle = self.get_params(self.degrees)

        return F.rotate(img, angle, self.resample, self.expand, self.center, fill)

    def __repr__(self):
        # ...
Run Code Online (Sandbox Code Playgroud)

考虑到这一点,这里是一个可能需要修改的覆盖T.RandomRotation

class RandomRotation(T.RandomRotation):
    def __init__(*args, **kwargs):
        super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work

        self.angle = self.get_params(self.degrees) # initialize your random parameters

    def forward(self): # override T.RandomRotation's forward
        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * F._get_image_num_channels(img)
            else:
                fill = [float(f) for f in fill]

        return F.rotate(img, self.angle, self.resample, self.expand, self.center, fill)
Run Code Online (Sandbox Code Playgroud)

我基本上复制了T.RandomRotationforward函数,唯一的区别是参数是在__init__一次)而不是内部forward每次调用时)采样的。Torchvision 的实现涵盖了所有情况,您通常不需要复制完整的forward. 在某些情况下,您几乎可以立即调用功能版本。例如,如果您不需要设置参数fill,则可以丢弃该部分并仅使用:

class RandomRotation(T.RandomRotation):
    def __init__(*args, **kwargs):
        super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work

        self.angle = self.get_params(self.degrees) # initialize your random parameters

    def forward(self): # override T.RandomRotation's forward
        return F.rotate(img, self.angle, self.resample, self.expand, self.center)
Run Code Online (Sandbox Code Playgroud)

如果您想覆盖其他随机变换,您可以查看源代码。该 API 是相当不言自明的,为每个转换实现覆盖时不应该遇到太多问题。