在python中的进程之间共享连续的numpy数组

eri*_*ora 22 python numpy shared-memory multiprocessing caffe

虽然我找到了类似于我的问题的许多答案,但我不相信它已经直接在这里解决了 - 我还有几个问题.共享连续numpy数组的动机如下:

  • 我正在使用在Caffe上运行的卷积神经网络对一系列连续值标签执行图像回归.
  • 图像需要特定的预处理和数据增强.
  • 约束(1)标签的连续性(它们是浮点数)和(2)数据扩充意味着我在python中预处理数据,然后使用内存数据将其作为连续的numpy数组提供Caffe中的一层.
  • 将训练数据加载到内存中相对较慢.我想将它并行化,以便:

(1)我正在编写的python创建了一个"数据处理程序"类,它实例化了两个连续的numpy数组.(2)工作进程在这些numpy数组之间交替,从磁盘加载数据,执行预处理,以及将数据插入numpy数组.(3)同时,python Caffe包装器将数据从另一个阵列发送到GPU以通过网络运行.

我有几个问题:

  1. 是否有可能在连续的numpy数组中分配内存然后将它包装在共享内存对象中(我不确定'对象'是否是正确的术语)使用类似python多处理的Array类?

  2. Numpy数组有一个.ctypes属性,我认为这对于从Array()实例化共享内存数组很有用,但似乎无法确切地确定如何使用它们.

  3. 如果在没有 numpy数组的情况下实例化共享内存,它是否保持连续?如果没有,有没有办法确保它保持连续?

有可能做这样的事情:

import numpy as np
from multiprocessing import Array
contArr = np.ascontiguousarray(np.zeros((n_images, n_channels, img_height, img_width)), dtype=np.float32)
sm_contArr = Array(contArr.ctypes.?, contArr?)
Run Code Online (Sandbox Code Playgroud)

然后用实例化实例化

p.append(Process(target=some_worker_function, args=(data_to_load, sm_contArr)))
p.start()
Run Code Online (Sandbox Code Playgroud)

谢谢!

编辑:我知道有许多库在不同的维护状态下具有类似的功能.我宁愿将此限制为纯python和numpy,但如果这不可能,我当然愿意使用它.

blu*_*lub 6

将numpy包装ndarray在多处理的周围RawArray()

跨进程共享内存中的numpy数组有多种方法。让我们看一下如何使用多处理模块来实现它。

第一重要的观察是numpy的提供np.frombuffer()功能来包装的ndarray围绕一个预先存在的对象接口支持缓冲协议(如bytes()bytearray()array()等等)。这将根据只读对象创建只读数组,并根据可写对象创建可写数组。

我们可以结合起来,与该共享内存RawArray()提供。请注意,这Array()不适用于该目的,因为它是具有锁定的代理对象,并且不会直接公开缓冲区接口。当然,这意味着我们需要自己对数字化的RawArrays进行适当的同步。

关于ndarray包裹的RawArrays存在一个复杂的问题:当进程在进程之间发送这样的数组时-实际上,一旦创建数组,它就需要将我们的数组发送给两个工作人员-对其进行腌制,然后对其进行腌制。不幸的是,这导致它创建ndarray的副本,而不是在内存中共享它们。

解决方案虽然有点丑陋,但要保持RawArrays不变直到将它们转移到worker上,并且仅在每个worker进程启动后才将它们包装在ndarray中

此外,最好直接通过来通信数组,无论是普通的RawArray还是ndarray包装的数组multiprocessing.Queue,但这都不起作用。甲RawArray不能放在这样的内部队列ndarray -wrapped一个将已封装和拆封,所以在复制的效果。

解决方法是将所有预分配的数组列表的工作进程和沟通指数进入该名单上的队列。这非常像传递令牌(索引),并且持有令牌的人都可以在关联的数组上进行操作。

主程序的结构如下所示:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import queue

from multiprocessing import freeze_support, set_start_method
from multiprocessing import Event, Process, Queue
from multiprocessing.sharedctypes import RawArray


def create_shared_arrays(size, dtype=np.int32, num=2):
    dtype = np.dtype(dtype)
    if dtype.isbuiltin and dtype.char in 'bBhHiIlLfd':
        typecode = dtype.char
    else:
        typecode, size = 'B', size * dtype.itemsize

    return [RawArray(typecode, size) for _ in range(num)]


def main():
    my_dtype = np.float32

    # 125000000 (size) * 4 (dtype) * 2 (num) ~= 1 GB memory usage
    arrays = create_shared_arrays(125000000, dtype=my_dtype)
    q_free = Queue()
    q_used = Queue()
    bail = Event()

    for arr_id in range(len(arrays)):
        q_free.put(arr_id)  # pre-fill free queue with allocated array indices

    pr1 = MyDataLoader(arrays, q_free, q_used, bail,
                       dtype=my_dtype, step=1024)
    pr2 = MyDataProcessor(arrays, q_free, q_used, bail,
                          dtype=my_dtype, step=1024)

    pr1.start()
    pr2.start()

    pr2.join()
    print("\n{} joined.".format(pr2.name))

    pr1.join()
    print("{} joined.".format(pr1.name))


if __name__ == '__main__':
    freeze_support()

    # On Windows, only "spawn" is available.
    # Also, this tests proper sharing of the arrays without "cheating".
    set_start_method('spawn')
    main()
Run Code Online (Sandbox Code Playgroud)

这将准备两个数组的列表,两个队列 -一个“免费”队列,其中MyDataProcessor放置完成处理的数组索引,然后MyDataLoader从中获取它们;还有一个“二手”队列,其中MyDataLoader放置易于填充的数组的索引,然后由MyDataProcessor获取它们从- multiprocessing.Event开始从所有工人中进行一致的保释。我们现在可以取消后者,因为我们只有一个阵列的生产者和一个消费者,但是为更多的工人做准备并没有什么坏处。

然后,在列表中用RawArrays的所有索引预填充“空” 队列,并实例化每种类型的worker之一,并向它们传递必要的通信对象。我们启动它们两个,然后等待它们完成。join()

这是MyDataProcessor的样子,它使用“已用” 队列中的数组索引并将数据发送到某个外部黑匣子(debugio.output在示例中):

class MyDataProcessor(Process):
    def __init__(self, arrays, q_free, q_used, bail, dtype=np.int32, step=1):
        super().__init__()
        self.arrays = arrays
        self.q_free = q_free
        self.q_used = q_used
        self.bail = bail
        self.dtype = dtype
        self.step = step

    def run(self):
        # wrap RawArrays inside ndarrays
        arrays = [np.frombuffer(arr, dtype=self.dtype) for arr in self.arrays]

        from debugio import output as writer

        while True:
            arr_id = self.q_used.get()
            if arr_id is None:
                break

            arr = arrays[arr_id]

            print('(', end='', flush=True)          # just visualizing activity
            for j in range(0, len(arr), self.step):
                writer.write(str(arr[j]) + '\n')
            print(')', end='', flush=True)          # just visualizing activity

            self.q_free.put(arr_id)

            writer.flush()

        self.bail.set()                     # tell loaders to bail out ASAP
        self.q_free.put(None, timeout=1)    # wake up loader blocking on get()

        try:
            while True:
                self.q_used.get_nowait()    # wake up loader blocking on put()
        except queue.Empty:
            pass
Run Code Online (Sandbox Code Playgroud)

首先它是包裹接收RawArraysndarrays使用“np.frombuffer()”,并保持新的列表,所以他们可以作为numpy的进程的运行时数组和它没有来包装他们一遍又一遍。

还要注意,MyDataProcessor只会写入self.bail Event,而不会对其进行检查。相反,如果需要告诉它退出,它将None在队列上找到一个标记而不是数组索引。当MyDataLoader没有更多数据可用并启动拆卸过程时,将执行此操作,MyDataProcessor仍可以处理队列中的所有有效数组而不会过早退出。

这是MyDataLoader的样子:

class MyDataLoader(Process):
    def __init__(self, arrays, q_free, q_used, bail, dtype=np.int32, step=1):
        super().__init__()
        self.arrays = arrays
        self.q_free = q_free
        self.q_used = q_used
        self.bail = bail
        self.dtype = dtype
        self.step = step

    def run(self):
        # wrap RawArrays inside ndarrays
        arrays = [np.frombuffer(arr, dtype=self.dtype) for arr in self.arrays]

        from debugio import input as reader

        for _ in range(10):  # for testing we end after a set amount of passes
            if self.bail.is_set():
                # we were asked to bail out while waiting on put()
                return

            arr_id = self.q_free.get()
            if arr_id is None:
                # we were asked to bail out while waiting on get()
                self.q_free.put(None, timeout=1)  # put it back for next loader
                return

            if self.bail.is_set():
                # we were asked to bail out while we got a normal array
                return

            arr = arrays[arr_id]

            eof = False
            print('<', end='', flush=True)          # just visualizing activity
            for j in range(0, len(arr), self.step):
                line = reader.readline()
                if not line:
                    eof = True
                    break

                arr[j] = np.fromstring(line, dtype=self.dtype, sep='\n')

            if eof:
                print('EOF>', end='', flush=True)   # just visualizing activity
                break

            print('>', end='', flush=True)          # just visualizing activity

            if self.bail.is_set():
                # we were asked to bail out while we filled the array
                return

            self.q_used.put(arr_id)     # tell processor an array is filled

        if not self.bail.is_set():
            self.bail.set()             # tell other loaders to bail out ASAP
            # mark end of data for processor as we are the first to bail out
            self.q_used.put(None)
Run Code Online (Sandbox Code Playgroud)

它的结构与其他工人非常相似。它有点肿的原因是它在许多点上检查了self.bail 事件,以减少卡住的可能性。(这并不是完全万无一失的,因为在检查和访问队列之间设置事件的可能性很小。如果这是个问题,则需要使用一些同步原语来仲裁对事件队列的访问。)

它还从一开始就将接收到的RawArrays包装在ndarrays中,并从外部黑匣子中读取数据(debugio.input在示例中)。

请注意,通过在step=函数中同时使用两个工作人员的main()参数,我们可以更改完成读写的比率(严格地出于测试目的-在生产环境step=中将是1读写所有numpy数组成员)。

增大两个值会使工作进程仅访问numpy数组中的几个值,从而显着加快了所有工作,这表明性能不受工作进程之间的通信的限制。如果我们将numpy数组直接放在Queues上,然后在整个进程之间来回复制它们,则增加步长不会显着改善性能-它将保持缓慢。

作为参考,这是debugio我用于测试的模块:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from ast import literal_eval
from io import RawIOBase, BufferedReader, BufferedWriter, TextIOWrapper


class DebugInput(RawIOBase):
    def __init__(self, end=None):
        if end is not None and end < 0:
            raise ValueError("end must be non-negative")

        super().__init__()
        self.pos = 0
        self.end = end

    def readable(self):
        return True

    def read(self, size=-1):
        if self.end is None:
            if size < 0:
                raise NotImplementedError("size must be non-negative")
            end = self.pos + size
        elif size < 0:
            end = self.end
        else:
            end = min(self.pos + size, self.end)

        lines = []
        while self.pos < end:
            offset = self.pos % 400
            pos = self.pos - offset
            if offset < 18:
                i = (offset + 2) // 2
                pos += i * 2 - 2
            elif offset < 288:
                i = (offset + 12) // 3
                pos += i * 3 - 12
            else:
                i = (offset + 112) // 4
                pos += i * 4 - 112

            line = str(i).encode('ascii') + b'\n'
            line = line[self.pos - pos:end - pos]
            self.pos += len(line)
            size -= len(line)
            lines.append(line)

        return b''.join(lines)

    def readinto(self, b):
        data = self.read(len(b))
        b[:len(data)] = data
        return len(data)

    def seekable(self):
        return True

    def seek(self, offset, whence=0):
        if whence == 0:
            pos = offset
        elif whence == 1:
            pos = self.pos + offset
        elif whence == 2:
            if self.end is None:
                raise ValueError("cannot seek to end of infinite stream")
            pos = self.end + offset
        else:
            raise NotImplementedError("unknown whence value")

        self.pos = max((pos if self.end is None else min(pos, self.end)), 0)
        return self.pos


class DebugOutput(RawIOBase):
    def __init__(self):
        super().__init__()
        self.buf = b''
        self.num = 1

    def writable(self):
        return True

    def write(self, b):
        *lines, self.buf = (self.buf + b).split(b'\n')

        for line in lines:
            value = literal_eval(line.decode('ascii'))
            if value != int(value) or int(value) & 255 != self.num:
                raise ValueError("expected {}, got {}".format(self.num, value))

            self.num = self.num % 127 + 1

        return len(b)


input = TextIOWrapper(BufferedReader(DebugInput()), encoding='ascii')
output = TextIOWrapper(BufferedWriter(DebugOutput()), encoding='ascii')
Run Code Online (Sandbox Code Playgroud)