是否可以找到 spaCy POS 标签的不确定性?

Joe*_*e50 8 python nlp spell-checking spacy

我正在尝试构建一个非英语拼写检查器,它依赖于 spaCy 对句子的分类,这允许我的算法使用 POS 标签和各个标记的语法依赖性来确定不正确的拼写(在我的情况下,更具体地说:不正确的拼写)荷兰语复合词中的拆分)。

然而,如果句子包含语法错误,spaCy 似乎会错误地对句子进行分类,例如将名词分类为动词,即使分类后的单词看起来根本不像动词。

因此,我想知道是否有可能获得 spaCy 分类的不确定性,以便能够判断 spaCy 是否在处理一个句子。毕竟,如果 spaCy 正在努力进行分类,这将使我的拼写检查器更有信心该句子包含错误。

有没有办法知道 spaCy 认为一个句子在语法上是否正确(无需指定我的语言中所有正确句子结构的模式),或者获得分类确定性?


根据@Sergey Bushmanov 评论中的建议进行编辑:

我发现https://spacy.io/api/tagger#predict,这对于获取标签的概率可能有用。但是,我不太确定我在看什么,并且我并没有真正理解文档对输出的含义。我正在使用以下代码:

import spacy

nlp = spacy.load('en_core_web_sm')
text = "This is an example sentence for the Spacy tagger."
doc = nlp(text)

docs = nlp(text, disable=['tagger'])
scores, tensors = nlp.tagger.predict([docs])

print(scores)
probs = tensors[0]
for p in probs:
    print(p, max(p), p.tolist().index(max(p)))
Run Code Online (Sandbox Code Playgroud)

我猜测这会打印预测的一些整数表示(考虑到“整数”和“表示”获得相同的分数),然后为句子中的每个单词打印一个由 96 个浮点组成的数组。它还列出了最高分和最高分的位置,但似乎对于大多数单词来说,数组中有多个项目p获得相似的值。现在我想知道这些数组的含义是什么,以及如何从中提取每个分类的概率。


问题是:如何解释此输出以获得 spaCy 标记器找到的特定标签的特定概率?或者用另一种方式提出同样的问题是:上面代码生成的输出是什么意思?

小智 4

>>> nlp = spacy.load("en_core_web_sm")
>>> tagger = nlp.get_pipe("tagger")
>>> doc = nlp("Turn left")
>>> tagger.model.predict([doc])[0][1]
array([2.4706091e-07, 9.5889463e-06, 7.8214543e-07, 1.0063847e-06,
       1.4711081e-07, 8.9995199e-05, 1.3229882e-05, 1.7524673e-07,
       1.8464769e-05, 2.4248957e-06, 1.2176755e-06, 3.3774859e-07,
       1.3199920e-06, 1.2011193e-06, 9.4455345e-06, 2.1991875e-05,
       1.6732251e-02, 1.3964747e-07, 2.0764594e-07, 7.0467541e-07,
       1.4303426e-07, 3.7962508e-07, 1.2130551e-03, 3.1479198e-07,
       4.8646534e-08, 6.1310317e-07, 1.0607551e-05, 3.7493783e-06,
       2.7809198e-08, 1.2118652e-05, 9.9081490e-03, 1.8219554e-06,
       4.7322575e-07, 1.8754436e-05, 6.2416703e-08, 9.5453437e-08,
       1.8937490e-05, 6.3916352e-03, 3.7999314e-01, 1.5741379e-03,
       5.8360571e-01, 9.6441705e-05, 1.7456010e-04, 5.1820080e-06,
       1.2672864e-06, 9.7453121e-06, 2.4000105e-05, 5.1192428e-06,
       2.4821245e-05], dtype=float32)
>>> r = [*enumerate(tagger.model.predict([doc])[0][1])]
>>> r.sort(key=lambda x: x[1])
>>> r
[(28, 2.7809198e-08), (24, 4.8646534e-08), (34, 6.24167e-08), (35, 9.545344e-08), (17, 1.3964747e-07), (20, 1.4303426e-07), (4, 1.4711081e-07), (7, 1.7524673e-07), (18, 2.0764594e-07), (0, 2.470609e-07), (23, 3.1479198e-07), (11, 3.377486e-07), (21, 3.7962508e-07), (32, 4.7322575e-07), (25, 6.1310317e-07), (19, 7.046754e-07), (2, 7.8214543e-07), (3, 1.0063847e-06), (13, 1.2011193e-06), (10, 1.2176755e-06), (44, 1.2672864e-06), (12, 1.319992e-06), (31, 1.8219554e-06), (9, 2.4248957e-06), (27, 3.7493783e-06), (47, 5.119243e-06), (43, 5.182008e-06), (14, 9.4455345e-06), (1, 9.588946e-06), (45, 9.745312e-06), (26, 1.0607551e-05), (29, 1.2118652e-05), (6, 1.3229882e-05), (8, 1.8464769e-05), (33, 1.8754436e-05), (36, 1.893749e-05), (15, 2.1991875e-05), (46, 2.4000105e-05), (48, 2.4821245e-05), (5, 8.99952e-05), (41, 9.6441705e-05), (42, 0.0001745601), (22, 0.0012130551), (39, 0.001574138), (37, 0.006391635), (30, 0.009908149), (16, 0.016732251), (38, 0.37999314), (40, 0.5836057)]
Run Code Online (Sandbox Code Playgroud)

您可以在此处看到前 2 个匹配项(在列表末尾)(38, 0.37999314)、(40, 0.5836057) 的置信度不高 (~50%),因此您可以看到一些模糊性的迹象。

>>> tagger.labels
('$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'ADD', 'AFX', 'CC', 'CD', 'DT', 'EX', 'FW', 'HYPH', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NFP', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', 'XX', '``')
>>> tagger.labels[40]
'VBN'
>>> tagger.labels[38]
'VBD'

Run Code Online (Sandbox Code Playgroud)

看起来有一些特定于语言的标签,并且需要一些映射才能获得通用 POS 标签。