小编Wan*_*han的帖子

如何在 PyTorch 中使用多处理?

我正在尝试使用具有复杂损失函数的 PyTorch。为了加速代码,我希望我可以使用 PyTorch 多处理包。

第一次试验,我将 10x1 的特征放入神经网络并得到 10x4 的输出。

之后,我想将 10x4 参数传递给一个函数来做一些计算。(以后的计算会很复杂。)

计算后,该函数将总共返回一个 10x1 的数组。该数组将设置为 NN_energy 并计算损失函数。

此外,我也想知道是否有另一种方法来创建一个向后数组来存储 NN_energy 数组,而不是使用

NN_energy = net(Data_in)[0:10,0]
Run Code Online (Sandbox Code Playgroud)

非常感谢。

完整代码:

import torch
import numpy as np
from torch.autograd import Variable 
from torch import multiprocessing

def func(msg,BOP):
    ans = (BOP[msg][0]+BOP[msg][1]/BOP[msg][2])*BOP[msg][3]
    return ans

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden_1, n_hidden_2, n_output):
        super(Net, self).__init__()
        self.hidden_1 = torch.nn.Linear(n_feature , n_hidden_1)  # hidden layer
        self.hidden_2 = torch.nn.Linear(n_hidden_1, n_hidden_2)  # hidden layer
        self.predict  = torch.nn.Linear(n_hidden_2, n_output  )  # output layer

    def …
Run Code Online (Sandbox Code Playgroud)

pytorch

11
推荐指数
1
解决办法
2493
查看次数

标签 统计

pytorch ×1