小编Kha*_*oui的帖子

AttributeError:'numpy.ndarray'对象没有属性'unsqueeze'

pyhtorch我正在使用和运行训练代码numpy

这是plot_example函数:

def plot_example(low_res_folder, gen):
    files=os.listdir(low_res_folder)
    
    gen.eval()
    for file in files:
        image=Image.open("test_images/" + file)
        with torch.no_grad():
            upscaled_img=gen(
                config1.both_transform(image=np.asarray(image))["image"]
                .unsqueeze(0)
                .to(config1.DEVICE)
            )
        save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
    gen.train()
Run Code Online (Sandbox Code Playgroud)

我遇到的问题是该unsqueeze属性引发错误:

File "E:\Downloads\esrgan-tf2-masteren\modules\train1.py", line 58, in train_fn
    plot_example("test_images/", gen)
    
File "E:\Downloads\esrgan-tf2-masteren\modules\utils1.py", line 46, in plot_example
    config1.both_transform(image=np.asarray(image))["image"]
    
AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'
Run Code Online (Sandbox Code Playgroud)

该网络是GAN网络,gen()代表生成器。

python-3.x pytorch numpy-ndarray

5
推荐指数
1
解决办法
1万
查看次数

标签 统计

numpy-ndarray ×1

python-3.x ×1

pytorch ×1