oar*_*ish 3 numpy matplotlib python-3.x subsampling
我有一个问题,我想绘制数据分布,其中某些值经常出现,而其他值则非常罕见。总点数约为30.000。渲染像 png 或(上帝禁止)pdf 这样的图需要很长时间,而且 pdf 太大而无法显示。
所以我想对数据进行二次采样以绘制绘图。我想要实现的是删除很多重叠的点(密度高的点),但保留密度低的点,概率几乎为 1。
现在,numpy.random.choice允许指定一个概率向量,我根据数据直方图计算出该概率向量,并进行了一些调整。但我似乎无法得到我的选择,以便真正保留稀有点。
我附上了数据的图像;分布的右尾部的点少了几个数量级,所以我想保留这些点。数据是 3d 的,但密度仅来自一维,因此我可以将其用作给定位置中有多少个点的度量
考虑以下函数。它将沿轴将数据分入相同的分箱中,并且
这允许将原始数据保留在低密度区域中,但显着减少在高密度区域中绘制的数据量。同时,所有特征都通过足够密集的分箱得以保留。
import numpy as np; np.random.seed(42)
def filt(x,y, bins):
d = np.digitize(x, bins)
xfilt = []
yfilt = []
for i in np.unique(d):
xi = x[d == i]
yi = y[d == i]
if len(xi) <= 2:
xfilt.extend(list(xi))
yfilt.extend(list(yi))
else:
xfilt.extend([xi[np.argmax(yi)], xi[np.argmin(yi)]])
yfilt.extend([yi.max(), yi.min()])
# prepend/append first/last point if necessary
if x[0] != xfilt[0]:
xfilt = [x[0]] + xfilt
yfilt = [y[0]] + yfilt
if x[-1] != xfilt[-1]:
xfilt.append(x[-1])
yfilt.append(y[-1])
sort = np.argsort(xfilt)
return np.array(xfilt)[sort], np.array(yfilt)[sort]
Run Code Online (Sandbox Code Playgroud)
为了说明这个概念,让我们使用一些玩具数据
x = np.array([1,2,3,4, 6,7,8,9, 11,14, 17, 26,28,29])
y = np.array([4,2,5,3, 7,3,5,5, 2, 4, 5, 2,5,3])
bins = np.linspace(0,30,7)
Run Code Online (Sandbox Code Playgroud)
然后调用xf, yf = filt(x,y,bins)并绘制原始数据和过滤后的数据得出:
具有约 30000 个数据点的问题的用例如下所示。使用所提出的技术可以将绘制点的数量从 30000 个减少到大约 500 个。这个数字当然取决于所使用的分箱 - 这里是 300 个分箱。在这种情况下,函数的计算时间约为 10 毫秒。这不是超级快,但与绘制所有点相比仍然是一个很大的进步。
import matplotlib.pyplot as plt
# Generate some data
x = np.sort(np.random.rayleigh(3, size=30000))
y = np.cumsum(np.random.randn(len(x)))+250
# Decide for a number of bins
bins = np.linspace(x.min(),x.max(),301)
# Filter data
xf, yf = filt(x,y,bins)
# Plot results
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(7,8),
gridspec_kw=dict(height_ratios=[1,2,2]))
ax1.hist(x, bins=bins)
ax1.set_yscale("log")
ax1.set_yticks([1,10,100,1000])
ax2.plot(x,y, linewidth=1, label="original data, {} points".format(len(x)))
ax3.plot(xf, yf, linewidth=1, label="binned min/max, {} points".format(len(xf)))
for ax in [ax2, ax3]:
ax.legend()
plt.show()
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2538 次 |
| 最近记录: |