调用用 DataParallel 包装的 torch.nn.module 类的函数

Nag*_*S N 5 pytorch dataparallel

我有一个类A定义了我所有的网络。我用 包裹这个torch.nn.DataParallel。当我将转发函数称为 时a(),它工作正常。但是,我还想调用 的一些其他函数A,同时仍然保留DataParallel功能。这可能吗?或者我只需要执行转发功能?

最小非工作示例(只是为了更好地传达上下文):

class A(torch.nn.module)
    def __init__():
        blah blah blah
    
    def forward(some_arguments):
        blah blah blah

    def func1(some_arguments):
        blah blah blah

a = A()
a = torch.nn.DataParallel(a, device_ids=[0, 1])
# calling forward function
outputs = a(inputs)  # works fine.
# calling func1
outputs1 = a.func1(inputs)  # does not work.
outputs1 = a.module.func1(inputs)  # works without parallelizing data. I am not sure if this is the right thing to do
Run Code Online (Sandbox Code Playgroud)

Vja*_*miK 1

您是否尝试过从内部而不是外部调用 func1 ?所以本质上,你会调用forward,而forward又会调用func1。如果您想有条件地调用 func1,您可以将函数名称作为参数传递给转发。这些建议也出现在此线程中https://discuss.pytorch.org/t/dataparallel-model-with-custom-functions/75053/10