enu*_*ris 5 python multiprocessing pytorch
我正在尝试让一些东西与 keras 的“fit_generator”方法类似。基本上,我有一个(非常)大的小批量数据文件,我想让我的 CPU 抓取小批量数据并填充与我的 GPU 并行的队列,从队列中获取小批量数据并对其进行训练。通过让 CPU 与 GPU 并行工作(而不是让 CPU 抓取一个批次并让 GPU 在该批次上进行训练之前等待 CPU),我应该能够将训练时间减少大约一半。我对 CPU 获取一个小批量所需的时间进行了基准测试,并且它所花费的时间与我的 GPU 训练一个小批量所需的时间相当,因此 CPU 和 GPU 的并行化应该可以正常工作。我还没有在 pytorch 中找到内置方法来执行此操作,如果有,请告诉我。
所以我尝试使用 torch.multiprocessing 模块来做我想做的事情,但我无法完成训练,因为我总是在训练完成之前遇到某种错误。torch.multiprocessing 模块应该是一个包装器,其功能与常规多处理模块基本相同,只是它允许在进程之间共享 pytorch 张量。基本上,我已经将代码设置为具有 2 个函数,一个加载器函数和一个训练器函数,如下所示:
def data_gen(que,PATH,epochs,steps_per_epoch,batch_size=32):
for epoch in range(epochs):
for j in range(steps_per_epoch):
with h5py.File(PATH,'r') as f:
X = f['X'][j*batch_size:(j+1)*batch_size]
Y = f['Y'][j*batch_size:(j+1)*batch_size]
X = autograd.Variable(torch.Tensor(X).resize_(batch_size,256,25).cpu())
Y = autograd.Variable(torch.Tensor(Y).cpu())
que.put((X,Y))
que.put('stop')
que.close()
return
def train_network(que,net,optimizer,epochs):
print('Training for %s epochs...' %epochs)
for epoch in range(epochs):
while(True):
data = que.get()
if(data == 'stop'):
break
net.zero_grad()
net.hid = net.init_hid()
inp,labels = data
inp = inp.cuda()
labels = labels.cuda()
out,hid = net(inp)
loss = F.binary_cross_entropy(out,labels)
loss.backward()
optimizer.step()
print('Epoch end reached')
return
Run Code Online (Sandbox Code Playgroud)
然后我并行运行两个进程,如下所示:
if __name__ == '__main__':
tmp.set_start_method('spawn')
que = tmp.Queue(maxsize=10)
loader = tmp.Process(target=data_gen, args=(que,PATH,epochs,steps), kwargs={'batch_size':batch_size})
loader.start()
trainer = tmp.Process(target=train_network, args=(que,net,optimizer,epochs,steps))
trainer.start()
loader.join()
trainer.join()
Run Code Online (Sandbox Code Playgroud)
我在每个 epoch 结束时将 que 放入“停止”值,这样我就可以跳出训练器中的循环并进入下一个 epoch。这种“毒丸”方法似乎有效,因为代码运行了多个时期,并且训练器实际上打印了时期验证消息的结束。代码运行,并且它似乎确实加快了训练过程(我一直在尝试在一小部分数据上构建此代码的原型,因此有时很难判断我获得了多少速度),但是在训练结束时(并且总是在最后,无论我指定多少个时期),我总是会收到错误:
Process Process-2:
Traceback (most recent call last):
File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/media/digitalstorm/Storage/RNN_Prototype/Lazuli_rnnprototype.py", line 307, in train_network
data = que.get()
File "/usr/lib/python3.6/multiprocessing/queues.py", line 113, in get
return _ForkingPickler.loads(res)
File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
fd = df.detach()
File "/usr/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
with _resource_sharer.get_connection(self._id) as conn:
File "/usr/lib/python3.6/multiprocessing/resource_sharer.py", line 87, in get_connection
c = Client(address, authkey=process.current_process().authkey)
File "/usr/lib/python3.6/multiprocessing/connection.py", line 487, in Client
c = SocketClient(address)
File "/usr/lib/python3.6/multiprocessing/connection.py", line 614, in SocketClient
s.connect(address)
FileNotFoundError: [Errno 2] No such file or directory
Run Code Online (Sandbox Code Playgroud)
或者,如果我对各种选项进行了一些修改,有时会收到如下错误:
Process Process-2:
Traceback (most recent call last):
File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/media/digitalstorm/Storage/RNN_Prototype/Lazuli_rnnprototype.py", line 306, in train_network
data = que.get()
File "/usr/lib/python3.6/multiprocessing/queues.py", line 113, in get
return _ForkingPickler.loads(res)
File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
fd = df.detach()
File "/usr/lib/python3.6/multiprocessing/resource_sharer.py", line 58, in detach
return reduction.recv_handle(conn)
File "/usr/lib/python3.6/multiprocessing/reduction.py", line 182, in recv_handle
return recvfds(s, 1)[0]
File "/usr/lib/python3.6/multiprocessing/reduction.py", line 153, in recvfds
msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_LEN(bytes_size))
ConnectionResetError: [Errno 104] Connection reset by peer
Run Code Online (Sandbox Code Playgroud)
我不知道我哪里错了。诚然,我是多处理方面的新手,所以我很难准确地调试出了什么问题。任何帮助将不胜感激,谢谢!
由于这个问题还没有任何进展,我将发布我自己的解决方法来解决这个问题。基本上,加载进程在完成处理并将示例排队后关闭队列。它没有等待训练器进程完成,因此当训练器进程要获取下一个小批量时,它找不到它。我不太清楚为什么加载程序进程过早关闭队列,文档que.close()说这应该只告诉队列没有更多对象被发送到队列,但它实际上不应该关闭队列。另外,删除并que.close()没有解决问题,所以我认为该错误与该命令无关。对我来说解决这个问题的方法是在命令后面添加一个time.sleep(2)命令que.close()。这会强制队列在将所有内容放入队列后休眠几秒钟,并允许程序完成并退出而不会出现错误。
| 归档时间: |
|
| 查看次数: |
7934 次 |
| 最近记录: |