如何在Pytorch中可视化网络?

raa*_*aaj 11 python pytorch

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)
Run Code Online (Sandbox Code Playgroud)

我想resnet从pytorch模型可视化。我该怎么做?我尝试使用,torchviz但出现错误:

'ResNet' object has no attribute 'grad_fn'
Run Code Online (Sandbox Code Playgroud)

sta*_*010 37

以下是使用不同工具的三种不同图形可视化。

为了生成示例可视化,我将使用一个简单的 RNN 来执行来自在线教程的情绪分析:

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))
Run Code Online (Sandbox Code Playgroud)

这是print()模型的输出。

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)
Run Code Online (Sandbox Code Playgroud)

以下是来自三种不同可视化工具的结果。

对于所有这些,您需要有可以通过模型forward()方法的虚拟输入。获取此输入的一种简单方法是从您的 Dataloader 中检索一个批次,如下所示:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().
Run Code Online (Sandbox Code Playgroud)

火炬视点

https://github.com/szagoruyko/pytorchviz

我相信这个工具使用向后传递生成它的图,所以所有的框都使用 PyTorch 组件进行反向传播。

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")
Run Code Online (Sandbox Code Playgroud)

此工具生成以下输出文件:

火炬视输出

这是唯一一个明确提到我的模型中三层的输出embeddingrnn、 和fc。运算符名称取自向后传递,因此其中一些难以理解。

隐藏层

https://github.com/waleedka/hiddenlayer

我相信这个工具使用前向传递。

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')
Run Code Online (Sandbox Code Playgroud)

这是输出。我喜欢蓝色的阴影。

隐藏层输出

我发现输出有太多细节并且混淆了我的架构。例如,为什么unsqueeze提到了这么多次?

耐创

https://github.com/lutzroeder/netron

此工具是适用于 Mac、Windows 和 Linux 的桌面应用程序。它依赖于首先导出为ONNX 格式的模型。然后应用程序读取 ONNX 文件并呈现它。然后有一个选项可以将模型导出到图像文件。

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)
Run Code Online (Sandbox Code Playgroud)

下面是模型在应用程序中的样子。我认为这个工具非常灵活:您可以缩放和平移,还可以钻取图层和操作符。我发现的唯一缺点是它只做垂直布局。

网创截图

  • 抱歉,什么是“batch.text”? (4认同)
  • Netron 还支持水平布局(参见菜单) (3认同)

Sha*_*hai 27

make_dot期望一个变量(即带有 的张量grad_fn),而不是模型本身。
尝试:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module
Run Code Online (Sandbox Code Playgroud)

  • 如何将图像保存为文件? (2认同)
  • 这显示了当我们反向传播时会发生什么。但我可以知道如何查看前支柱吗? (2认同)

Mer*_*tan 25

这可能是一个迟来的答案。但是,特别是在__torch_function__开发后,可以获得更好的可视化效果。你可以在这里尝试我的项目,torchview

对于 resnet50 的示例,您可以查看 colab 笔记本,在这里 我演示了 resnet18 模型的可视化。resnet18的图像由以下代码生成

import torchvision
from torchview import draw_graph

model_graph = draw_graph(resnet18(), input_size=(1,3,224,224), expand_nested=True)
model_graph.visual_graph
Run Code Online (Sandbox Code Playgroud)

Torchview 的 Resnet

它还接受多种输出/输入类型(例如列表、字典)

  • 它在 Google Colab 上运行良好。如果你想让它在 VSCode 上工作并将其保存为 PNG 或 SVG,请使用 ```model_graph.resize_graph(scale=5.0) # 按视图缩放 model_graph.visual_graph.render(format='svg') ``` (2认同)

Dav*_* J. 15

你可以看看 PyTorchViz ( https://github.com/szagoruyko/pytorchviz ),“一个创建 PyTorch 执行图和跟踪可视化的小包。”

PyTorchViz 可视化示例

  • `from graphviz import Source;` `model_arch = make_dot(...);` `Source(model_arch).render(filepath);` (3认同)
  • 如何将图形保存为图像? (2认同)

Cha*_*ker 8

这里是你如何与做torchviz,如果你想保存的图片:

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=x**2
z=x**3
r=(y+z).sum()

make_dot(r).render("attached", format="png")
Run Code Online (Sandbox Code Playgroud)

你得到的图像截图:

在此处输入图片说明

来源:http : //www.bnikolic.co.uk/blog/pytorch-detach.html