Pytorch 的 share_memory_() 与内置 Python 的共享内存:为什么在 Pytorch 中我们不需要访问共享内存块?

gtg*_*gtg 5 python shared-memory multiprocessing pytorch

在尝试了解内置包multiprocessingPytorch 的multiprocessing包时,我观察到两者之间存在不同的行为。我觉得这很奇怪,因为Pytorch 的包与内置包完全兼容。

具体来说,我指的是进程之间共享变量的方式。在 Pytorch 中,张量通过 inplace 操作移动到共享内存share_memory_()。另一方面,通过使用shared_memory模块,我们可以得到与内置包相同的结果。

我很难理解两者之间的区别在于,使用内置版本,我们必须显式访问启动进程内的共享内存块然而,我们不需要使用 Pytorch 版本这样做。

这是一个Pytorch的玩具示例,显示了这一点:

import time

import torch
# the same behavior happens when importing:
# import multiprocessing as mp
import torch.multiprocessing as mp


def get_time(s):
    return round(time.time() - s, 1)


def foo(a):
    # wait ~1sec to print the value of the tensor.
    time.sleep(1.0)
    with lock:
        #-------------------------------------------------------------------
        # WITHOUT explicitely accessing the shared memory block, we can observe
        # that the tensor has changed:
        #-------------------------------------------------------------------
        print(f"{__name__}\t{get_time(s)}\t\t{a}")


# global variables.
lock = mp.Lock()
s = time.time()


if __name__ == '__main__':
    print("Module\t\tTime\t\tValue")
    print("-"*50)

    # create tensor and assign it to shared memory.
    a = torch.zeros(2).share_memory_()
    print(f"{__name__}\t{get_time(s)}\t\t{a}")

    # start child process.
    p0 = mp.Process(target=foo, args=(a,))
    p0.start()

    # modify the value of the tensor after ~0.5sec.
    time.sleep(0.5)
    with lock:
        a[0] = 1.0

    print(f"{__name__}\t{get_time(s)}\t\t{a}")
    time.sleep(1.5)

    p0.join()
Run Code Online (Sandbox Code Playgroud)

其输出(如预期):

Module          Time            Value
--------------------------------------------------
__main__        0.0             tensor([0., 0.])
__main__        0.5             tensor([1., 0.])
__mp_main__     1.0             tensor([1., 0.])
Run Code Online (Sandbox Code Playgroud)

这是一个带有内置包的玩具示例:

import time
import multiprocessing as mp
from multiprocessing import shared_memory

import numpy as np


def get_time(s):
    return round(time.time() - s, 1)


def foo(shm_name, shape, type_):
    #-------------------------------------------------------------------
    # WE NEED TO explicitely access the shared memory block to observe
    # that the array has changed:
    #-------------------------------------------------------------------
    existing_shm = shared_memory.SharedMemory(name=shm_name)
    a = np.ndarray(shape, type_, buffer=existing_shm.buf)

    # wait ~1sec to print the value.
    time.sleep(1.0)
    with lock:
        print(f"{__name__}\t{get_time(s)}\t\t{a}")


# global variables.
lock = mp.Lock()
s = time.time()


if __name__ == '__main__':
    print("Module\t\tTime\t\tValue")
    print("-"*35)

    # create numpy array and shared memory block.
    a = np.zeros(2,)
    shm = shared_memory.SharedMemory(create=True, size=a.nbytes)
    a_shared = np.ndarray(a.shape, a.dtype, buffer=shm.buf)
    a_shared[:] = a[:]
    print(f"{__name__}\t{get_time(s)}\t\t{a_shared}")

    # start child process.
    p0 = mp.Process(target=foo, args=(shm.name, a.shape, a.dtype))
    p0.start()

    # modify the value of the vaue after ~0.5sec.
    time.sleep(0.5)
    with lock:
        a_shared[0] = 1.0

    print(f"{__name__}\t{get_time(s)}\t\t{a_shared}")
    time.sleep(1.5)

    p0.join()
Run Code Online (Sandbox Code Playgroud)

正如预期的那样,它等效地输出:

Module          Time            Value
-----------------------------------
__main__        0.0             [0. 0.]
__main__        0.5             [1. 0.]
__mp_main__     1.0             [1. 0.]
Run Code Online (Sandbox Code Playgroud)

所以我很难理解为什么我们不需要在内置版本和 Pytorch 版本中遵循相同的步骤,即 Pytorch 如何能够避免显式访问共享内存块的需要?

PS 我使用的是 Windows 操作系统和 Python 3.9

J_H*_*J_H 7

您正在给 pytorch 作者写一封情书。也就是说,你在拍拍他们的背,祝贺他们的包装工作“干得好!” 这是一个可爱的图书馆。

让我们退后一步,使用一个非常简单的数据结构,即字典d。如果父进程d使用一些值进行初始化,然后启动一对工作子进程,则每个子进程都有一个d.

那是怎么发生的?该multiprocessing模块分叉了工作人员,查看了一组定义的变量,其中包括d,并将 这些(键,值)对从父级向下序列化到子级。

所以此时我们有 3 个独立 的副本d。如果父级或任一子级修改d,则其他 2 个副本完全不受影响。

现在切换到 pytorch 包装器。您提供了一些简洁的代码,演示了如果我们想要 3 个对相同共享结构的引用而不是 3 个独立副本,应用程序需要执行的小 .SharedMemory() 舞蹈。pytorch 包装器序列化对公共数据结构的引用 ,而不是生成副本。在幕后它正在做和你一样的舞蹈。但是在应用程序级别没有重复的措辞,因为细节已经很好地抽象出来了,FTW!

为什么在 Pytorch 中我们不需要访问共享内存块?

tl;dr:我们确实需要访问它。但图书馆承担了担心细节的重担,所以我们不必担心。


Ahm*_*AEK 3

pytorch 对共享内存有一个简单的包装器,python 的共享内存模块只是对底层操作系统相关函数的包装器。

可以完成的方法是,您不序列化数组或共享内存本身,而仅使用文档中的__getstate____setstate__方法序列化创建它们所需的内容,以便您的对象既充当代理又充当容器同时。

下面的bar类可以通过这种方式为代理和容器加倍,如果用户不必担心共享内存部分,这非常有用。

import time
import multiprocessing as mp
from multiprocessing import shared_memory
import numpy as np

class bar:
    def __init__(self):
        self._size = 10
        self._type = np.uint8
        self.shm = shared_memory.SharedMemory(create=True, size=self._size)
        self._mem_name = self.shm.name
        self.arr = np.ndarray([self._size], self._type, buffer=self.shm.buf)

    def __getstate__(self):
        """Return state values to be pickled."""
        return (self._mem_name, self._size, self._type)

    def __setstate__(self, state):
        """Restore state from the unpickled state values."""
        self._mem_name, self._size, self._type = state
        self.shm = shared_memory.SharedMemory(self._mem_name)
        self.arr = np.ndarray([self._size], self._type, buffer=self.shm.buf)

def get_time(s):
    return round(time.time() - s, 1)

def foo(shm, lock):
    # -------------------------------------------------------------------
    # without explicitely access the shared memory block we observe
    # that the array has changed:
    # -------------------------------------------------------------------
    a = shm

    # wait ~1sec to print the value.
    time.sleep(1.0)
    with lock:
        print(f"{__name__}\t{get_time(s)}\t\t{a.arr}")

# global variables.
s = time.time()

if __name__ == '__main__':
    lock = mp.Lock()  # to work on windows/mac.

    print("Module\t\tTime\t\tValue")
    print("-" * 35)

    # create numpy array and shared memory block.
    a = bar()
    print(f"{__name__}\t{get_time(s)}\t\t{a.arr}")

    # start child process.
    p0 = mp.Process(target=foo, args=(a, lock))
    p0.start()

    # modify the value of the vaue after ~0.5sec.
    time.sleep(0.5)
    with lock:
        a.arr[0] = 1.0

    print(f"{__name__}\t{get_time(s)}\t\t{a.arr}")
    time.sleep(1.5)

    p0.join()
Run Code Online (Sandbox Code Playgroud)

python只是让在类中隐藏这些细节变得更加容易,而不会用这些细节打扰用户。

编辑:我希望他们使锁不可继承,这样你的代码就可以在锁上引发错误,相反,有一天你会发现它实际上并没有锁定......在它使你的应用程序在生产中崩溃之后。