jnt*_*rcs 2 python pytorch tensor
从我在网上看到的所有内容来看,FloatTensorsPytorch 对所有内容都是默认设置,当我创建一个张量以传递给我的生成器模块时,它是FloatTensor一个DoubleTensor.
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False)
def forward(self, zvec):
print(zvec.size())
fc = self.fully_connected(zvec)
return(fc.size())
gen = Generator();
gen(torch.from_numpy(np.random.normal(size=100)))
Run Code Online (Sandbox Code Playgroud)
其中产生
RuntimeError: Expected object of type torch.DoubleTensor but found type torch.FloatTensor for argument #2 'mat2'
Run Code Online (Sandbox Code Playgroud)
这里的问题是您的 numpy 输入double用作数据类型,同样的数据类型也应用于结果张量。
该weights图层的self.fully_connected,另一方面是float。当通过该层馈送数据时,会应用矩阵乘法,并且该乘法要求两个矩阵具有相同的数据类型。
所以你有两个解决方案:
通过改变:
gen(torch.from_numpy(np.random.normal(size=100)))
Run Code Online (Sandbox Code Playgroud)
到:
gen(torch.from_numpy(np.random.normal(size=100)).float())
Run Code Online (Sandbox Code Playgroud)
您输入的输入gen将转换为floatthen。
转换输入的完整工作代码:
from torch import nn
import torch
import numpy as np
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False)
def forward(self, zvec):
print(zvec.size())
fc = self.fully_connected(zvec)
return(fc.size())
gen = Generator();
gen(torch.from_numpy(np.random.normal(size=100)).float()) # converting network input to float
Run Code Online (Sandbox Code Playgroud)
如果您需要双精度,您也可以将您weights的double.
改变这一行:
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False)
Run Code Online (Sandbox Code Playgroud)
只为了:
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False).double()
Run Code Online (Sandbox Code Playgroud)
用于转换权重的完整工作代码:
from torch import nn
import torch
import numpy as np
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False).double() # converting layer weights to double()
def forward(self, zvec):
print(zvec.size())
fc = self.fully_connected(zvec)
return(fc.size())
gen = Generator();
gen(torch.from_numpy(np.random.normal(size=100)))
Run Code Online (Sandbox Code Playgroud)
所以,这两种方式应该为你工作,但如果你不需要的额外的精度double,你应该去float为double需要更多的计算能力。
| 归档时间: |
|
| 查看次数: |
6948 次 |
| 最近记录: |