ger*_*abs 4 tensorflow pytorch torchscript huggingface-transformers
我正在关注本教程:https : //huggingface.co/transformers/torchscript.html
来创建我的自定义 BERT 模型的跟踪,但是在运行完全相同的模型时,dummy_input我收到一个错误:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
We cant record the data flow of Python values, so this value will be treated as a constant in the future.
Run Code Online (Sandbox Code Playgroud)
在我的模型和标记器中加载后,创建跟踪的代码如下:
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
traced_model = torch.jit.trace(model, dummy_input)
Run Code Online (Sandbox Code Playgroud)
这dummy_input是张量列表,所以我不确定Boolean类型在这里发挥作用。有谁明白为什么会发生这个错误以及布尔转换是否正在发生?
非常感谢
当您尝试对具有数据相关控制流的模型进行建模时,会出现此警告。torch.jit.trace
这个简单的例子应该更清楚:
import torch
class Foo(torch.nn.Module):
def forward(self, tensor):
# It is data dependent
# Trace will only work with one path
if tensor.max() > 0.5:
return tensor ** 2
return tensor
model = Foo()
traced = torch.jit.script(model) # No warnings
traced = torch.jit.trace(model, torch.randn(10)) # Warning
Run Code Online (Sandbox Code Playgroud)
从本质上说,BERT模型具有一定的控制流(如if,for循环)依赖于数据,因此你收到警告。
你可以在这里看到 BERTforward代码。
你没问题,如果:
None传递给 的值forward)并且它会在之后保持这种状态script(例如在推理调用期间)__init__(如配置),因为这不会改变例如:
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
Run Code Online (Sandbox Code Playgroud)
只会作为一个分支运行torch.jit.trace,因为它只是跟踪张量上的操作,并且不知道这样的控制流。
HuggingFace 团队可能已经意识到这一点,并且此警告不是问题(尽管您可能会仔细检查您的用例或尝试使用torch.jit.script)
torch.jit.script这会很困难,因为整个模型必须torchscript兼容(torchscript有一个 Python 子集可用,而且很可能无法与 BERT 一起开箱即用)。
仅在必要时才这样做(可能不是)。
| 归档时间: |
|
| 查看次数: |
1724 次 |
| 最近记录: |