使用plotly创建图像的子图

Fel*_* B. 8 python plotly jupyter-notebook

我想用mnist绘图显示数据集中的前 10 张图像。事实证明这比我想象的要复杂。这不起作用:

import numpy as np
np.random.seed(123)

import plotly.express as px
from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) =  mnist.load_data()

fig = subplots.make_subplots(rows=1, cols=10)
fig.add_trace(px.imshow(X_train[0]), row=1, col=1)
Run Code Online (Sandbox Code Playgroud)

因为它导致

ValueError: 
    Invalid element(s) received for the 'data' property of 
        Invalid elements include: [Figure({
    'data': [{'coloraxis': 'coloraxis',
              'type': 'heatmap',
              'z': array([[0, 0, 0, ..., 0, 0, 0],
                          [0, 0, 0, ..., 0, 0, 0],
                          [0, 0, 0, ..., 0, 0, 0],
                          ...,
                          [0, 0, 0, ..., 0, 0, 0],
                          [0, 0, 0, ..., 0, 0, 0],
                          [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}],
    'layout': {'coloraxis': {'colorscale': [[0.0, '#0d0887'], [0.1111111111111111,
                                            '#46039f'], [0.2222222222222222,
                                            '#7201a8'], [0.3333333333333333,
                                            '#9c179e'], [0.4444444444444444,
                                            '#bd3786'], [0.5555555555555556,
                                            '#d8576b'], [0.6666666666666666,
                                            '#ed7953'], [0.7777777777777778,
                                            '#fb9f3a'], [0.8888888888888888,
                                            '#fdca26'], [1.0, '#f0f921']]},
               'margin': {'t': 60},
               'template': '...',
               'xaxis': {'constrain': 'domain', 'scaleanchor': 'y'},
               'yaxis': {'autorange': 'reversed', 'constrain': 'domain'}}
})]

    The 'data' property is a tuple of trace instances
    that may be specified as:
      - A list or tuple of trace instances
        (e.g. [Scatter(...), Bar(...)])
      - A single trace instance
        (e.g. Scatter(...), Bar(...), etc.)
      - A list or tuple of dicts of string/value properties where:
        - The 'type' property specifies the trace type
            One of: ['area', 'bar', 'barpolar', 'box',
                     'candlestick', 'carpet', 'choropleth',
                     'choroplethmapbox', 'cone', 'contour',
                     'contourcarpet', 'densitymapbox', 'funnel',
                     'funnelarea', 'heatmap', 'heatmapgl',
                     'histogram', 'histogram2d',
                     'histogram2dcontour', 'image', 'indicator',
                     'isosurface', 'mesh3d', 'ohlc', 'parcats',
                     'parcoords', 'pie', 'pointcloud', 'sankey',
                     'scatter', 'scatter3d', 'scattercarpet',
                     'scattergeo', 'scattergl', 'scattermapbox',
                     'scatterpolar', 'scatterpolargl',
                     'scatterternary', 'splom', 'streamtube',
                     'sunburst', 'surface', 'table', 'treemap',
                     'violin', 'volume', 'waterfall']

        - All remaining properties are passed to the constructor of
          the specified trace type

        (e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])
Run Code Online (Sandbox Code Playgroud)

也不

fig.add_trace(go.Image(X_train[0]), row=1, col=1)
Run Code Online (Sandbox Code Playgroud)

或者

fig.add_trace(go.Figure(go.Heatmap(z=X_train[0])), 1,1)
Run Code Online (Sandbox Code Playgroud)

我开始没有想法了。应该可以将一行图像作为标题。

Fel*_* B. 9

这有效 - 我希望这不是最终答案:

fig = subplots.make_subplots(rows=2, cols=5)

for n, image in enumerate(X_train[:10]):
  fig.add_trace(px.imshow(255-image).data[0], row=int(n/5)+1, col=n%5+1)

# the layout gets lost, so we have to carry it over - but we cannot simply do
# fig.layout = layout since the layout has to be slightly different for subplots
# fig.layout.yaxis in a subplot refers only to the first axis for example
# update_yaxes updates *all* axis on the other hand
layout = px.imshow(X_train[0], color_continuous_scale='gray').layout
fig.layout.coloraxis = layout.coloraxis
fig.update_xaxes(**layout.xaxis.to_plotly_json())
fig.update_yaxes(**layout.yaxis.to_plotly_json())
fig.show()
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

有趣的是,如果您单击“plotly 下载绘图为 png”图标,生成的图片并不相同,看起来像这样(参见Github 问题):

在此输入图像描述


Jor*_*ier 7

使用构面

请考虑以下答案,该答案要简单得多:

import plotly.express as px
from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()

fig = px.imshow(X_train[:10, :, :], binary_string=True, facet_col=0, facet_col_wrap=5)
Run Code Online (Sandbox Code Playgroud)

它产生以下输出:

输出

使用动画

出于探索目的,尽管这并不完全是您所要求的,但您也可以制作动画并一次可视化一位数字:

fig = px.imshow(X_train[:10, :, :], binary_string=True, animation_frame=0)
Run Code Online (Sandbox Code Playgroud)

动画输出

来源

Imshow 文档