AttributeError:'numpy.ndarray'对象没有属性'unsqueeze'

Kha*_*oui 5 python-3.x pytorch numpy-ndarray

pyhtorch我正在使用和运行训练代码numpy

这是plot_example函数:

def plot_example(low_res_folder, gen):
    files=os.listdir(low_res_folder)
    
    gen.eval()
    for file in files:
        image=Image.open("test_images/" + file)
        with torch.no_grad():
            upscaled_img=gen(
                config1.both_transform(image=np.asarray(image))["image"]
                .unsqueeze(0)
                .to(config1.DEVICE)
            )
        save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
    gen.train()
Run Code Online (Sandbox Code Playgroud)

我遇到的问题是该unsqueeze属性引发错误:

File "E:\Downloads\esrgan-tf2-masteren\modules\train1.py", line 58, in train_fn
    plot_example("test_images/", gen)
    
File "E:\Downloads\esrgan-tf2-masteren\modules\utils1.py", line 46, in plot_example
    config1.both_transform(image=np.asarray(image))["image"]
    
AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'
Run Code Online (Sandbox Code Playgroud)

该网络是GAN网络,gen()代表生成器。

小智 5

在进入任何 Pytorch 层之前,确保图像是 [批量大小、通道、高度、宽度] 形状的张量。

在这里你有 image=np.asarray(image)

我会删除这个 numpy 转换并将其保留为 torch.tensor。

或者,如果您确实希望它成为 numpy 数组,那么在它进入生成器之前,请确保torch.from_numpy()在 numpy 图像被解压缩之前按照本文档中所示的方式使用: https: //pytorch.org/docs/stable/ generated /torch.from_numpy.html

如果您不想摆脱原始转换,那么这个函数当然是一个替代方案。

沙克·贾因