我试图在 CNN 上实现一个 pytorch 框架。
我确信代码是正确的,因为它来自教程,并且当我在 GoogleDrive 上的 Jupyter Notebook 上运行它时它可以工作。
但是当我尝试将其本地化为.py文件时,它提示一个错误:
AttributeError: Can't pickle local object 'pre_datasets.<locals>.<lambda>'
我知道它是关于函数外部的推断对象,但是这个错误的确切原因是什么?
我应该如何解决它?
这是代码的主要部分。
def pre_datasets():
TRAIN_TFM = transforms.Compose(
[
transforms.Resize(size=(128, 128)),
# TODO
transforms.ToTensor(),
]
)
train_set = DatasetFolder(
root=CONFIG["train_set_path"],
loader=lambda x: Image.open(x),
extensions="jpg",
transform=TRAIN_TFM,
)
train_loader = DataLoader(
dataset=train_set,
batch_size=CONFIG["batch_size"],
shuffle=True,
num_workers=CONFIG["num_workers"],
pin_memory=True,
)
return train_loader
def train(train_loader):
...
for epoch in range(CONFIG["num_epochs"]):
...
for batch in train_loader: # error happened here
...
if __name__ == "__main__":
train_loader = …Run Code Online (Sandbox Code Playgroud)