Joh*_*tud 5 python torch tensorflow
我想知道是否有人可以帮助我了解如何将简短的 TF 模型转换为 Torch。
考虑这个 TF 设置:
inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp) # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis = -1)
end = tf.squeeze(end, axis = -1)
model = Model(inputs = inp, outputs = [start, end])
Run Code Online (Sandbox Code Playgroud)
具体来说,我不确定 Torch 命令会将我的数据从什么转变386, 1024, 1为386, 1024, 2,我也不明白它的作用:Model(inputs = inp, outputs = [start, end])
是:
inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp) # [None, 386, 1024, 2]
Run Code Online (Sandbox Code Playgroud)
相当于:
X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]
Run Code Online (Sandbox Code Playgroud)
构建模型时的TF -> Torch基本上很简单,您通常可以在PyTorch 文档中找到与 TF 函数等效的 Torch 函数,以下是转换 TF 代码的示例:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp) # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis=-1)
end = tf.squeeze(end, axis=-1)
model = models.Model(inputs = inp, outputs = [start, end])
X = np.random.randn(3, 386, 1024, 1)
output = model(X)
print(output[0].shape, output[1].shape)
# Outputs: (3, 386, 1024) (3, 386, 1024)
Run Code Online (Sandbox Code Playgroud)
火炬代码:
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(1, 2)
def forward(self, x):
x = self.fc(x)
start, end = torch.split(x, 1, dim=-1)
start = torch.squeeze(start, dim=-1)
end = torch.squeeze(end, dim=-1)
return [start, end]
net = Net()
X = torch.randn(3, 386, 1024, 1)
output = net(X)
print(output[0].size(), output[1].size())
# Outputs: torch.Size([3, 386, 1024]) torch.Size([3, 386, 1024])
Run Code Online (Sandbox Code Playgroud)
以及以下TF代码:
inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp) # [None, 386, 1024, 2]
Run Code Online (Sandbox Code Playgroud)
不等同于以下 Torch 代码:
X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]
Run Code Online (Sandbox Code Playgroud)
由于layers.DenseTF中的相当于nn.LinearTorch中的
| 归档时间: |
|
| 查看次数: |
3036 次 |
| 最近记录: |