pytorch 中的 tf.cast 等价物?

Sar*_*han 6 python tensorflow pytorch

我是 PyTorch 的新手。TensorFlow 有一个 API tf.cast()tf.shape()。该tf.cast在TensorFlow特定的目的,有什么等值的火炬?我有张量 x= tensor(shape(128,64,32,32)): tf.shape(x) create tensor of Dimension 1 x.shape create the true dimension。我需要在火炬中使用tf.shape(x)

tf.cast已经不仅仅是改变张着不同的角色D型的火炬。

有没有人在 torch/PyTorch 中有等效的 API?

小智 7

查看PyTorch 文档

正如他们所提到的:

print(x.dtype) # Prints "torch.int64", currently 64-bit integer type
x = x.type(torch.FloatTensor)
print(x.dtype) # Prints "torch.float32", now 32-bit float
print(x.float()) # Still "torch.float32"
print(x.type(torch.DoubleTensor)) # Prints "tensor([0., 1., 2., 3.], dtype=torch.float64)"
print(x.type(torch.LongTensor)) # Cast back to int-64, prints "tensor([0, 1, 2, 3])"
Run Code Online (Sandbox Code Playgroud)