如何阻止 numpy meshgrid 将默认数据类型设置为 int64

b10*_*ard 6 python numpy scipy

我必须使用 numpy meshgrid 创建一个非常大的网格。为了节省内存,我使用 int8 作为我尝试网格化的数组的数据类型。然而,meshgrid 不断将类型更改为 int64,这会占用大量内存。这是问题的一个简单示例......

import numpy

grids = [numpy.arange(1, 4, dtype=numpy.int8), numpy.arange(1, 5, dtype=numpy.int8)]

print grids
print grids[0].dtype, grids[0].nbytes

x1, y1 = numpy.meshgrid(*grids)

print x1.dtype, x1.nbytes
Run Code Online (Sandbox Code Playgroud)

该脚本打印

[array([1, 2, 3], dtype=int8), array([1, 2, 3, 4], dtype=int8)]

int8 3

int64 96
Run Code Online (Sandbox Code Playgroud)

为什么网格要这样做?有什么办法可以阻止它吗?我需要创建一个巨大的数组,因此我无法使用网格网格,除非我可以控制输出的数据类型。这是预期的行为还是一个麻木的错误?我在 numpy 中使用的所有其他函数都保留数据类型或允许您使用 dtype 参数更改它。meshgrid 函数似乎不允许这样做。

Leo*_*eon 5

您可以将可选copy参数设置numpy.meshgrid()False(但请注意,它有一些限制):

meshgrid(*xi, **kwargs)

...

copybool, 选修的

如果False,则返回原始数组的视图以节省内存。默认为True. 请注意,sparse=False, copy=False可能会返回不连续的数组。此外,广播数组的多个元素可以引用单个存储位置。如果需要写入数组,请先制作副本。

证明它有效:

>>> import numpy
>>> 
>>> grids = [numpy.arange(1, 4, dtype=numpy.int8), numpy.arange(1, 5, dtype=numpy.int8)]
>>> 
>>> print grids
[array([1, 2, 3], dtype=int8), array([1, 2, 3, 4], dtype=int8)]
>>> print grids[0].dtype, grids[0].nbytes
int8 3
>>>
>>> x1, y1 = numpy.meshgrid(*grids, copy=False)
>>>                        #        ^^^^^^^^^^
>>> print x1.dtype, x1.nbytes
int8 12
Run Code Online (Sandbox Code Playgroud)