Nag*_*S N 5 pytorch dataparallel
我有一个类A定义了我所有的网络。我用 包裹这个torch.nn.DataParallel。当我将转发函数称为 时a(),它工作正常。但是,我还想调用 的一些其他函数A,同时仍然保留DataParallel功能。这可能吗?或者我只需要执行转发功能?
最小非工作示例(只是为了更好地传达上下文):
class A(torch.nn.module)
def __init__():
blah blah blah
def forward(some_arguments):
blah blah blah
def func1(some_arguments):
blah blah blah
a = A()
a = torch.nn.DataParallel(a, device_ids=[0, 1])
# calling forward function
outputs = a(inputs) # works fine.
# calling func1
outputs1 = a.func1(inputs) # does not work.
outputs1 = a.module.func1(inputs) # works without parallelizing data. I am not sure if this is the right thing to do
Run Code Online (Sandbox Code Playgroud)
您是否尝试过从内部而不是外部调用 func1 ?所以本质上,你会调用forward,而forward又会调用func1。如果您想有条件地调用 func1,您可以将函数名称作为参数传递给转发。这些建议也出现在此线程中https://discuss.pytorch.org/t/dataparallel-model-with-custom-functions/75053/10
| 归档时间: |
|
| 查看次数: |
223 次 |
| 最近记录: |