是否可以在pytorch中的transform.compose中添加自己的函数

Anu*_*eep 2 machine-learning neural-network python-3.x deep-learning pytorch

我正在使用预先训练的 Alex 模型。我正在一些随机图像数据集上运行这个模型。我想在训练之前将 RGB 图像转换为 YCbCr 图像。

我想知道是否可以自己添加一个功能transform.compose,例如:

transform = transforms.Compose([
  ycbcr(), #something like this
  transforms.Resize((224, 224)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
Run Code Online (Sandbox Code Playgroud)

在哪里,

def ycbcr(img):
   img = cv2.imread(img)  
   img = cv2.cvtColor(img, cv2.COLOR_BGR2ycbcr)
   t = torch.from_numpy(img)
 return t

training_dataset = datasets.ImageFolder(link_train ,transform = transform_train)

training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=96, shuffle=True)
Run Code Online (Sandbox Code Playgroud)

这个过程正确吗?请帮助我如何继续?

Dav*_*d S 7

torchvision.transform您可以通过定义类来传递自定义转换。

为了更好地理解,我建议您阅读文档

在你的情况下,它将类似于以下内容:

class ycbcr(object):
    def __call__(self, img):
        """
        :param img: (PIL): Image 

        :return: ycbr color space image (PIL)
        """
        img = cv2.imread(img)  
        img = cv2.cvtColor(img, cv2.COLOR_BGR2ycbcr)
        # t = torch.from_numpy(img)

        return Image.fromarray(t)

    def __repr__(self):
        return self.__class__.__name__+'()'
Run Code Online (Sandbox Code Playgroud)

请注意,它获取一个 PIL 图像并返回一个 PIL 图像。所以你可能想正确调整你的代码。但这是定义自定义转换的一般方法。