Geo*_*Liu 15 python numpy matplotlib
我正在开展一个项目,我需要将10行3列的绘图网格组合在一起.虽然我已经能够绘制并绘制子图,但是我无法生成一个没有空格的漂亮图,例如下面的gridspec文档.
.
我尝试了以下帖子,但仍然无法像示例图像中那样完全删除空白区域.有人可以给我一些指导吗?谢谢!
以下是我的代码.完整的脚本在GitHub上.注意:images_2和images_fool都是具有形状(1032,10)的扁平图像的numpy阵列,而delta是形状的图像阵列(28,28).
def plot_im(array=None, ind=0):
"""A function to plot the image given a images matrix, type of the matrix: \
either original or fool, and the order of images in the matrix"""
img_reshaped = array[ind, :].reshape((28, 28))
imgplot = plt.imshow(img_reshaped)
# Output as a grid of 10 rows and 3 cols with first column being original, second being
# delta and third column being adversaril
nrow = 10
ncol = 3
n = 0
from matplotlib import gridspec
fig = plt.figure(figsize=(30, 30))
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1])
for row in range(nrow):
for col in range(ncol):
plt.subplot(gs[n])
if col == 0:
#plt.subplot(nrow, ncol, n)
plot_im(array=images_2, ind=row)
elif col == 1:
#plt.subplot(nrow, ncol, n)
plt.imshow(w_delta)
else:
#plt.subplot(nrow, ncol, n)
plot_im(array=images_fool, ind=row)
n += 1
plt.tight_layout()
#plt.show()
plt.savefig('grid_figure.pdf')
Run Code Online (Sandbox Code Playgroud)
Imp*_*est 21
开头的注释:如果你想完全控制间距,请避免使用,plt.tight_layout()因为它会尝试将图中的绘图排列为同样且分布均匀.这大部分都很好并且产生令人愉快的结果,但是会随意调整间距.
您从Matplotlib示例库中引用的GridSpec示例运行良好的原因是因为子图的方面未预定义.也就是说,子图将简单地在网格上展开并保持设置间距(在这种情况下wspace=0.0, hspace=0.0)与图形大小无关.
与您绘制imshow图像的方式相反,默认情况下图像的方面设置相等(相当于ax.set_aspect("equal")).也就是说,您当然可以将set_aspect("auto")每个绘图放入(并且另外添加 wspace=0.0, hspace=0.0为GridSpec中的参数,如在库示例中),这将生成没有间距的绘图.
然而,当使用图像时,保持相等的纵横比使得每个像素都宽到高并且正方形阵列被显示为正方形图像是很有意义的.
您需要做的是使用图像大小和图形边距来获得预期结果.图中的figsize参数是以英寸为单位的图形(宽度,高度),这里可以使用两个数字的比率.并且wspace, hspace, top, bottom, left可以手动调整子图参数以给出期望的结果.以下是一个例子:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
nrow = 10
ncol = 3
fig = plt.figure(figsize=(4, 10))
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1],
wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845)
for i in range(10):
for j in range(3):
im = np.random.rand(28,28)
ax= plt.subplot(gs[i,j])
ax.imshow(im)
ax.set_xticklabels([])
ax.set_yticklabels([])
#plt.tight_layout() # do not use this!!
plt.show()
Run Code Online (Sandbox Code Playgroud)
编辑:
当然希望不必手动调整参数.因此,可以根据行数和列数计算一些最优的.
nrow = 7
ncol = 7
fig = plt.figure(figsize=(ncol+1, nrow+1))
gs = gridspec.GridSpec(nrow, ncol,
wspace=0.0, hspace=0.0,
top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1),
left=0.5/(ncol+1), right=1-0.5/(ncol+1))
for i in range(nrow):
for j in range(ncol):
im = np.random.rand(28,28)
ax= plt.subplot(gs[i,j])
ax.imshow(im)
ax.set_xticklabels([])
ax.set_yticklabels([])
plt.show()
Run Code Online (Sandbox Code Playgroud)
尝试将此行添加到您的代码中:
fig.subplots_adjust(wspace=0, hspace=0)
Run Code Online (Sandbox Code Playgroud)
并且对于每个轴对象集:
ax.set_xticklabels([])
ax.set_yticklabels([])
Run Code Online (Sandbox Code Playgroud)
小智 5
遵循 ImportanceOfBeingErnest 的答案,但如果您想使用plt.subplots及其功能:
fig, axes = plt.subplots(
nrow, ncol,
gridspec_kw=dict(wspace=0.0, hspace=0.0,
top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
figsize=(ncol + 1, nrow + 1),
sharey='row', sharex='col', # optionally
)
Run Code Online (Sandbox Code Playgroud)