从npy文件加载稀疏数组

ifo*_*orm 10 python sparse-array scipy

我正在尝试加载我之前保存的稀疏数组.保存稀疏数组很容易.试图阅读它是一种痛苦.scipy.load在我的稀疏数组周围返回一个0d数组.

import scipy as sp
A = sp.load("my_array"); A
array(<325729x325729 sparse matrix of type '<type 'numpy.int8'>'
with 1497134 stored elements in Compressed Sparse Row format>, dtype=object)
Run Code Online (Sandbox Code Playgroud)

为了获得稀疏矩阵,我必须展平0d数组,或者使用sp.asarray(A).这似乎是一种非常难以做到的事情.Scipy是否足够聪明才能理解它已经加载了一个稀疏数组?有没有更好的方法来加载稀疏数组?

unu*_*tbu 15

所述mmwrite/mmread在scipy.io功能可以保存/加载稀疏矩阵在矩阵市场格式.

scipy.io.mmwrite('/tmp/my_array',x)
scipy.io.mmread('/tmp/my_array').tolil()    
Run Code Online (Sandbox Code Playgroud)

mmwrite并且mmread可能就是你所需要的一切.它经过了充分测试,采用了众所周知的格式.

但是,以下可能会更快一些:

我们可以将行和列坐标和数据保存为npz格式的一维数组.

import random
import scipy.sparse as sparse
import scipy.io
import numpy as np

def save_sparse_matrix(filename,x):
    x_coo=x.tocoo()
    row=x_coo.row
    col=x_coo.col
    data=x_coo.data
    shape=x_coo.shape
    np.savez(filename,row=row,col=col,data=data,shape=shape)

def load_sparse_matrix(filename):
    y=np.load(filename)
    z=sparse.coo_matrix((y['data'],(y['row'],y['col'])),shape=y['shape'])
    return z

N=20000
x = sparse.lil_matrix( (N,N) )
for i in xrange(N):
    x[random.randint(0,N-1),random.randint(0,N-1)]=random.randint(1,100)

save_sparse_matrix('/tmp/my_array',x)
load_sparse_matrix('/tmp/my_array.npz').tolil()
Run Code Online (Sandbox Code Playgroud)

这里有一些代码建议在npz文件中保存稀疏矩阵可能比使用mmwrite/mmread更快:

def using_np_savez():    
    save_sparse_matrix('/tmp/my_array',x)
    return load_sparse_matrix('/tmp/my_array.npz').tolil()

def using_mm():
    scipy.io.mmwrite('/tmp/my_array',x)
    return scipy.io.mmread('/tmp/my_array').tolil()    

if __name__=='__main__':
    for func in (using_np_savez,using_mm):
        y=func()
        print(repr(y))
        assert(x.shape==y.shape)
        assert(x.dtype==y.dtype)
        assert(x.__class__==y.__class__)    
        assert(np.allclose(x.todense(),y.todense()))
Run Code Online (Sandbox Code Playgroud)

产量

% python -mtimeit -s'import test' 'test.using_mm()'
10 loops, best of 3: 380 msec per loop

% python -mtimeit -s'import test' 'test.using_np_savez()'
10 loops, best of 3: 116 msec per loop
Run Code Online (Sandbox Code Playgroud)


小智 5

可以使用()作为索引提取隐藏在0d数组中的对象:

A = sp.load("my_array")[()]
Run Code Online (Sandbox Code Playgroud)

这看起来很奇怪,但无论如何它似乎都有效,而且这是一个非常短的解决方法.