我有一个类A定义了我所有的网络。我用 包裹这个torch.nn.DataParallel。当我将转发函数称为 时a(),它工作正常。但是,我还想调用 的一些其他函数A,同时仍然保留DataParallel功能。这可能吗?或者我只需要执行转发功能?
最小非工作示例(只是为了更好地传达上下文):
class A(torch.nn.module)
def __init__():
blah blah blah
def forward(some_arguments):
blah blah blah
def func1(some_arguments):
blah blah blah
a = A()
a = torch.nn.DataParallel(a, device_ids=[0, 1])
# calling forward function
outputs = a(inputs) # works fine.
# calling func1
outputs1 = a.func1(inputs) # does not work.
outputs1 = a.module.func1(inputs) # works without parallelizing data. I am not sure if this is the right thing to do
Run Code Online (Sandbox Code Playgroud)