如何在自定义模型上使用 Allen NLP 解释

Aja*_*ld 1 heatmap allennlp

我希望使用 Allen NLP Interpret 进行集成可视化和显着性映射。关于自定义变压器模型,您能告诉我该怎么做吗?

小智 5

这可以通过在您的自定义模型周围使用 AllenNLP 包装器来完成。解释模块需要一个 Predictor 对象,因此您可以编写自己的对象,或使用现有的对象。

这是分类模型的示例:

from allennlp.data.vocabulary import Vocabulary

from allennlp.predictors.text_classifier import TextClassifierPredictor
from allennlp.data.dataset_readers import TextClassificationJsonReader

import torch
  
class ModelWrapper(Model):
    def __init__(self, vocab, your_model):
        super().__init__(vocab)
        self.your_model = your_model
        self.logits_to_probs = torch.nn.Softmax()
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, tokens, label=None):
        if label is not None:
            outputs = self.your_model(tokens, label=label)
        else:
            outputs = self.your_model(tokens)
        probs = self.logits_to_probs(outputs["logits"])
        if label is not None:
            loss = self.loss(outputs["logits"], label)
            outputs["loss"] = loss
        outputs["probs"] = probs
        return outputs
Run Code Online (Sandbox Code Playgroud)

您的自定义变压器模型可能没有可识别的TextFieldEmbedder. 这是模型的初始嵌入层,根据该层计算显着性解释器的梯度。这些可以通过覆盖预测器中的以下方法来指定。

class PredictorWrapper(TextClassifierPredictor):
    def get_interpretable_layer(self):
        return self._model.model.bert.embeddings.word_embeddings # This is the initial layer for huggingface's `bert-base-uncased`; change according to your custom model.

    def get_interpretable_text_field_embedder(self):
        return self._model.model.bert.embeddings.word_embeddings
    
predictor = PredictorWrapper(model=ModelWrapper(vocab, your_model),
                             dataset_reader=TextClassificationJsonReader())
Run Code Online (Sandbox Code Playgroud)

现在您有了一个 AllenNLP 预测器,它可以与解释模块一起使用,如下所示:

from allennlp.interpret.saliency_interpreters import SimpleGradient
interpreter = SimpleGradient(predictor)
interpreter.saliency_interpret_from_json({"sentence": "This is a good movie."})
Run Code Online (Sandbox Code Playgroud)

这应该为您提供每个输入标记的梯度。