PyTorch RNN 使用“batch_first=False”更高效?

Rav*_*euk 8 python nlp pytorch

在机器翻译中,我们总是需要在注释和预测中切出第一个时间步(SOS 标记)。

使用 时batch_first=False,切掉第一个时间步仍然保持张量连续。

import torch
batch_size = 128
seq_len = 12
embedding = 50

# Making a dummy output that is `batch_first=False`
batch_not_first = torch.randn((seq_len,batch_size,embedding))
batch_not_first = batch_first[1:].view(-1, embedding) # slicing out the first time step
Run Code Online (Sandbox Code Playgroud)

但是,如果我们batch_first=True在切片后使用 , ,张量就不再连续。我们需要先使其连续,然后才能执行不同的操作,例如view.

batch_first = torch.randn((batch_size,seq_len,embedding))
batch_first[:,1:].view(-1, embedding) # slicing out the first time step

output>>>
"""
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-a9bd590a1679> in <module>
----> 1 batch_first[:,1:].view(-1, embedding) # slicing out the first time step

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
"""
Run Code Online (Sandbox Code Playgroud)

这是否意味着batch_first=False至少在机器翻译方面更好?因为它使我们免于执行该contiguous()步骤。有没有batch_first=True效果比较好的案例?

Szy*_*zke 5

表现

batch_first=True和之间似乎没有太大区别batch_first=False。请看下面的脚本:

import time

import torch


def time_measure(batch_first: bool):
    torch.cuda.synchronize()
    layer = torch.nn.RNN(10, 20, batch_first=batch_first).cuda()
    if batch_first:
        inputs = torch.randn(100000, 7, 10).cuda()
    else:
        inputs = torch.randn(7, 100000, 10).cuda()


    torch.cuda.synchronize()
    start = time.perf_counter()

    for chunk in torch.chunk(inputs, 100000 // 64, dim=0 if batch_first else 1):
        _, last = layer(chunk)

    torch.cuda.synchronize()
    return time.perf_counter() - start


print(f"Time taken for batch_first=False: {time_measure(False)}")
print(f"Time taken for batch_first=True: {time_measure(True)}")
Run Code Online (Sandbox Code Playgroud)

在我的设备 (GTX 1050 Ti)、PyTorch1.6.0和 CUDA 11.0 上,结果如下:

Time taken for batch_first=False: 0.3275816479999776
Time taken for batch_first=True: 0.3159054920001836
Run Code Online (Sandbox Code Playgroud)

(无论哪种方式都会有所不同,所以没有结论)。

代码可读性

batch_first=True当您想要使用需要th 维batch的其他 PyTorch 层时更简单(几乎所有层0都是如此,例如)。torch.nntorch.nn.Linear

permute在这种情况下,如果batch_first=False指定了,您无论如何都必须返回张量。

机器翻译

它应该更好,因为它tensor始终是连续的并且不需要复制数据。[1:]使用而不是切片看起来也更干净[:,1:]