使用 px.imshow 和 go.Scatter 绘制动画子图

Gil*_*een 7 python animation plotly

我正在尝试创建一个图形,显示图像“重建”作为 PC 数量的函数。我想对其进行动画处理以显示原始图像、累积图像(在 PC 1、...、i 上)以及仍需“重建”的部分。除此之外,我想将原始图像和重建图像之间的距离显示为 PC 数量的函数。

我设法创建了下图,它使底部的散点图和顶部的图像具有动画效果。

在此输入图像描述

问题是,一旦动画开始,右侧的两个图像就会“消失”,我认为它们出现在“原始图像”下

在此输入图像描述

这是我的代码(使用所有 3 个图像和散点图创建动画帧,然后形成图形):

import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
from sklearn.decomposition import PCA

pio.templates["custom"] = go.layout.Template(
    layout=go.Layout(
        margin=dict(l=20, r=20, t=40, b=0)
    )
)
pio.templates.default = "simple_white+custom"


class AnimationButtons():
    def play_scatter(frame_duration = 500, transition_duration = 300):
        return dict(label="Play", method="animate", args=
                    [None, {"frame": {"duration": frame_duration, "redraw": False},
                            "fromcurrent": True, "transition": {"duration": transition_duration, "easing": "quadratic-in-out"}}])
    
    def play(frame_duration = 1000, transition_duration = 0):
        return dict(label="Play", method="animate", args=
                    [None, {"frame": {"duration": frame_duration, "redraw": True},
                            "mode":"immediate",
                            "fromcurrent": True, "transition": {"duration": transition_duration, "easing": "linear"}}])
    
    def pause():
        return dict(label="Pause", method="animate", args=
                    [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}])

pca = PCA(n_components=15).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T


reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        frames.append(go.Frame(
            data = [px.imshow(img, binary_string=True).data[0],
                    px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
                    px.imshow(reconstructed.copy(), binary_string=True).data[0],
                    go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))


fig = make_subplots(rows=2, cols=3, 
                    subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                    specs=[[{}, {}, {}], [{"colspan": 3}, None, None]], row_heights=[500, 200],)
fig.add_traces(data=frames[0]["data"], rows = [1,1,1,2], cols = [1,2,3,1])
fig.update(frames=frames)

fig.update_layout(title=frames[0]["layout"]["title"],
                  xaxis4=dict(range=[0, 50], autorange=False),
                  yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                  margin = dict(t = 100),
                  width=800,
                  updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
fig.show()
Run Code Online (Sandbox Code Playgroud)

我尝试寻找类似的问题,但无法找到任何适用于显示两者px.imshow以及go.Scatter子情节和动画的内容。

数据X为居中后的MNIST数字图像。这是一个包含这样一个图像的 numpy 数组:(X.shape=(16,5,5)- 16 张 5x5 的图像 - 仅在第一张图像上有动画)

X=np.array( [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]]] )
Run Code Online (Sandbox Code Playgroud)

将以上代码放在GitHub 上的 Jupyter Notebook中

Gil*_*een 3

与jayvessea 的建议类似,我最终尝试了px.imshow. 我首先创建了px.imshow面和动画,然后添加了散点图和所需的布局

pca = PCA(n_components=50).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[150], pca.transform(X[150].reshape(-1, 1)).T

reconstructed, distortion, images, scatters, titles = np.zeros_like(X[0]), [], [], [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every other reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        images.append([img.copy(), reconstructed.copy(), (img - reconstructed).copy()])
        scatters.append(go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion, name=3, xaxis="x4", yaxis="y4", marker_color="black"))
        titles.append(rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")


        
# Create figure on the basis of the animated facetted imshow figure
fig = px.imshow(np.array(images), facet_col=1, animation_frame=0, binary_string=True)
for i, (scatter, title) in enumerate(zip(*[scatters, titles])):
    fig["frames"][i]["data"] += (scatter, )
    fig["frames"][i]["traces"] = [0,1,2,3]
    fig["frames"][i]["layout"]["title"] = title 
fig.add_traces(data=fig["frames"][0]["data"][-1])

# Create "template" figure to transfer layout onto the `fig` figure
layout = make_subplots(rows=2, cols=3, 
                       subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                       specs=[[{"type":"Image"}, {"type":"Image"}, {"type":"Image"}], [{"type":"xy","colspan": 3}, None, None]], row_heights=[500, 200],)

layout.update_layout(title=titles[0],
                     xaxis4=dict(range=[0, 50], autorange=False),
                     yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                     margin = dict(t = 100), width=800,
                     updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])

fig["layout"] = layout["layout"]
fig
Run Code Online (Sandbox Code Playgroud)

这不是一个非常优雅的解决方案,但它是一个足够的解决方法。

在此输入图像描述