如何将 TF Dense 层转换为 PyTorch?

Joh*_*tud 5 python torch tensorflow

我想知道是否有人可以帮助我了解如何将简短的 TF 模型转换为 Torch。

考虑这个 TF 设置:

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis = -1)
end = tf.squeeze(end, axis = -1)
model = Model(inputs = inp, outputs = [start, end])
Run Code Online (Sandbox Code Playgroud)

具体来说,我不确定 Torch 命令会将我的数据从什么转变386, 1024, 1386, 1024, 2,我也不明白它的作用:Model(inputs = inp, outputs = [start, end])

是:

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
Run Code Online (Sandbox Code Playgroud)

相当于:

X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]
Run Code Online (Sandbox Code Playgroud)

Mr.*_*ple 4

构建模型时的TF -> Torch基本上很简单,您通常可以在PyTorch 文档中找到与 TF 函数等效的 Torch 函数,以下是转换 TF 代码的示例:

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis=-1)
end = tf.squeeze(end, axis=-1)
model = models.Model(inputs = inp, outputs = [start, end])

X = np.random.randn(3, 386, 1024, 1)
output = model(X)
print(output[0].shape, output[1].shape)

# Outputs: (3, 386, 1024) (3, 386, 1024)
Run Code Online (Sandbox Code Playgroud)

火炬代码:

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.fc = nn.Linear(1, 2)

    def forward(self, x):
      x = self.fc(x)
      start, end = torch.split(x, 1, dim=-1)
      start = torch.squeeze(start, dim=-1)
      end = torch.squeeze(end, dim=-1)
      return [start, end]

net = Net()

X = torch.randn(3, 386, 1024, 1)
output = net(X)
print(output[0].size(), output[1].size())

# Outputs: torch.Size([3, 386, 1024]) torch.Size([3, 386, 1024])
Run Code Online (Sandbox Code Playgroud)

以及以下TF代码:

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
Run Code Online (Sandbox Code Playgroud)

不等同于以下 Torch 代码:

X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]
Run Code Online (Sandbox Code Playgroud)

由于layers.DenseTF中的相当于nn.LinearTorch中的