小编Vla*_*sin的帖子

pytorch 或 Huggingface/transformer 标签的代码中的何处被“重命名”为标签?

我的问题涉及这个例子,可以在伟大的huggingface/transformers库中找到。

我正在使用库创建者提供的笔记本作为我的管道的起点。它提出了一个在 Glue 数据集上微调 BERT 进行句子分类的流程。

当进入代码时,我注意到一个非常奇怪的事情,我无法解释。

InputFeatures在示例中,输入数据作为类的实例从此处引入模型:

该类有4个属性,包括label属性:

class InputFeatures:
    ...
    input_ids: List[int]
    attention_mask: Optional[List[int]] = None
    token_type_ids: Optional[List[int]] = None
    label: Optional[Union[int, float]] = None
Run Code Online (Sandbox Code Playgroud)

随后将其作为输入字典传递给forward()模型方法。这是由Trainer类完成的,例如这里的第 573-576 行:

    def _training_step(
        self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
    ) -> float:
        model.train()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        outputs = model(**inputs)  
Run Code Online (Sandbox Code Playgroud)

但是,该forward()方法需要标签(注意复数形式)输入参数(取自此处):

    def forward(
        self,
        input_ids=None,
        attention_mask=None, …
Run Code Online (Sandbox Code Playgroud)

python pytorch huggingface-transformers

2
推荐指数
1
解决办法
1665
查看次数

标签 统计

huggingface-transformers ×1

python ×1

pytorch ×1