从Numpy矩阵构造Python集

dzh*_*lil 17 python arrays numpy set

我正在尝试执行以下操作

>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])
>> y = set(x)
TypeError: unhashable type: 'numpy.ndarray'
Run Code Online (Sandbox Code Playgroud)

如何使用Numpy数组中的所有元素轻松高效地创建集合?

Eri*_*got 28

如果你想要一组元素,这里是另一种,可能更快的方法:

y = set(x.flatten())
Run Code Online (Sandbox Code Playgroud)

PS:之间执行比较后x.flat,x.flatten()x.ravel()一个10x100阵列上,我发现,它们都在大约相同的速度来执行.对于3x3阵列,最快的版本是迭代器版本:

y = set(x.flat)
Run Code Online (Sandbox Code Playgroud)

我建议这是因为它是内存较少的版本(它可以很好地扩展到数组的大小).

PS:还有一个类似的NumPy函数:

y = numpy.unique(x)
Run Code Online (Sandbox Code Playgroud)

这确实产生了具有相同元素set(x.flat)的NumPy数组,但是作为NumPy数组.这非常快(几乎快10倍),但如果你需要一个set,那么做的set(numpy.unique(x))比其他程序慢一点(构建一个集合带来了很大的开销).

  • 好建议!您也可以使用set(x.ravel()),它执行相同的操作但仅在需要时创建副本.或者,更好的是,使用set(x.flat).x.flat是平顶数组元素的迭代器,但不会浪费时间实际展平数组 (3认同)
  • 警告:这个答案*不会*给你一组向量,而是一组数字.如果你想要一组向量,那么请看下面的miku答案,它将向量转换为元组 (3认同)

mik*_*iku 14

数组的不可变对应元组是元组,因此,尝试将数组数组转换为元组数组:

>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])

>> x_hashable = map(tuple, x)

>> y = set(x_hashable)
set([(3, 2, 3), (4, 4, 4)])
Run Code Online (Sandbox Code Playgroud)


xpe*_*oni 7

如果你想创建一个包含在a中的元素的集合ndarray,但是如果你想创建一组ndarray对象 - 或者使用ndarray对象作为字典中的键 - 那么你将不得不提供一个可用的包装器他们.有关简单示例,请参阅下面的代码:

from hashlib import sha1

from numpy import all, array, uint8


class hashable(object):
    r'''Hashable wrapper for ndarray objects.

        Instances of ndarray are not hashable, meaning they cannot be added to
        sets, nor used as keys in dictionaries. This is by design - ndarray
        objects are mutable, and therefore cannot reliably implement the
        __hash__() method.

        The hashable class allows a way around this limitation. It implements
        the required methods for hashable objects in terms of an encapsulated
        ndarray object. This can be either a copied instance (which is safer)
        or the original object (which requires the user to be careful enough
        not to modify it).
    '''
    def __init__(self, wrapped, tight=False):
        r'''Creates a new hashable object encapsulating an ndarray.

            wrapped
                The wrapped ndarray.

            tight
                Optional. If True, a copy of the input ndaray is created.
                Defaults to False.
        '''
        self.__tight = tight
        self.__wrapped = array(wrapped) if tight else wrapped
        self.__hash = int(sha1(wrapped.view(uint8)).hexdigest(), 16)

    def __eq__(self, other):
        return all(self.__wrapped == other.__wrapped)

    def __hash__(self):
        return self.__hash

    def unwrap(self):
        r'''Returns the encapsulated ndarray.

            If the wrapper is "tight", a copy of the encapsulated ndarray is
            returned. Otherwise, the encapsulated ndarray itself is returned.
        '''
        if self.__tight:
            return array(self.__wrapped)

        return self.__wrapped
Run Code Online (Sandbox Code Playgroud)

使用包装类很简单:

>>> from numpy import arange

>>> a = arange(0, 1024)
>>> d = {}
>>> d[a] = 'foo'
Traceback (most recent call last):
  File "<input>", line 1, in <module>
TypeError: unhashable type: 'numpy.ndarray'
>>> b = hashable(a)
>>> d[b] = 'bar'
>>> d[b]
'bar'
Run Code Online (Sandbox Code Playgroud)