Dav*_* S. 7 python neural-network keras tensorflow
是否有可能将向量传递到经过训练的神经网络,以便它仅从经过训练的识别类的子集中进行选择。例如,我的网络经过训练可以识别数字和字母,但是我知道接下来要运行的图像将不包含小写字母(例如序列号的图像)。然后我给它传递一个向量,告诉它不要猜测任何小写字母。由于类是互斥的,因此网络以softmax函数结尾。以下只是我想尝试的示例,但没有一个真正起作用。
import numpy as np
def softmax(arr):
return np.exp(arr)/np.exp(arr).sum()
#Stand ins for previous layer/NN output and vector of allowed answers.
output = np.array([ 0.15885351,0.94527385,0.33977026,-0.27237907,0.32012873,
0.44839673,-0.52375875,-0.99423903,-0.06391236,0.82529586])
restrictions = np.array([1,1,0,0,1,1,1,0,1,1])
#Ideas -----
'''First: Multilpy restricted before sending it through softmax.
I stupidly tried this one.'''
results = softmax(output*restrictions)
'''Second: Multiply the results of the softmax by the restrictions.'''
results = softmax(output)
results = results*restrictions
'''Third: Remove invalid entries before calculating the softmax.'''
result = output*restrictions
result[result != 0] = softmax(result[result != 0])
Run Code Online (Sandbox Code Playgroud)
所有这些都有问题。第一个导致无效选择默认为:
1/np.exp(arr).sum()
Run Code Online (Sandbox Code Playgroud)
由于softmax的输入可能为负,因此会增加赋予无效选择的可能性,并使答案更糟。(在尝试之前,应该已经调查过。)
第二个和第三个都有类似的问题,因为它们一直等到给出应用限制的答案之前。例如,如果网络正在查看字母l,但它开始确定它是数字1,则直到使用这些方法结束时才对其进行更正。因此,如果采用这种方式,则以.80的概率给出1的输出,但是随后删除了此选项,则似乎剩余的选项将重新分配,并且最高的有效答案不会像80%那样可信。剩下的选项最终要统一得多。我想说的一个例子:
output
Out[75]: array([ 5.39413513, 3.81445419, 3.75369546, 1.02716988, 0.39189373])
softmax(output)
Out[76]: array([ 0.70454877, 0.14516581, 0.13660832, 0.00894051, 0.00473658])
softmax(output[1:])
Out[77]: array([ 0.49133596, 0.46237183, 0.03026052, 0.01603169])
Run Code Online (Sandbox Code Playgroud)
(对数组进行排序使其更容易。)在原始输出中,softmax给出.70的答案是[1,0,0,0,0],但是如果这是无效的答案,则删除了重新分配如何分配剩余的4个值可能性低于50%的期权,由于太低而无法使用,因此很容易被忽略。
我已经考虑过将向量作为另一种输入传递到网络中,但是我不确定如何在不要求它学习向量告诉它做什么的情况下进行此操作,我认为这会增加训练所需的时间。
编辑:我在评论中写的太多了,所以我只在这里发布更新。最终,我确实尝试将这些限制作为网络的输入。我选择了一个热编码答案,并随机添加了额外的启用类以模拟答案键,并确保正确答案始终在键中。当密钥只有很少的启用类别时,网络会严重依赖它,并且会干扰从图像中学习功能。当密钥具有很多启用的类别时,它似乎完全忽略了密钥。这可能是一个需要优化的问题,或者是我的网络体系结构出现了问题,或者只是需要进行一些调整才能进行培训,但我始终没有解决该问题的方法。
我确实发现,当我最终减去np.inf
而不是乘以0 时,删除答案和清零几乎是相同的。我知道合奏,但是正如在对第一个响应的评论中提到的那样,我的网络正在处理CJK字符(字母只是为了使示例更容易),并拥有3000多个课程。网络已经非常庞大,这就是为什么我想研究这种方法的原因。我没想到要为每个单独的类别使用二进制网络,但是3000多个网络也存在问题(如果我理解正确的话),尽管我稍后会进行研究。
首先,我将大致浏览您列出的可用选项,并添加一些可行的替代方案及其优缺点。构建这个答案有点困难,但我希望你能明白我想要表达的内容:
显然,正如您所写的那样,可能会为清零条目提供更高的机会,一开始似乎是一种错误的方法。
替代方案:用 logit 值替换不可能的值smallest
。这与 类似softmax(output[1:])
,尽管网络对结果更加不确定。实施示例pytorch
:
import torch
logits = torch.Tensor([5.39413513, 3.81445419, 3.75369546, 1.02716988, 0.39189373])
minimum, _ = torch.min(logits, dim=0)
logits[0] = minimum
print(torch.nn.functional.softmax(logits))
Run Code Online (Sandbox Code Playgroud)
产生:
tensor([0.0158, 0.4836, 0.4551, 0.0298, 0.0158])
Run Code Online (Sandbox Code Playgroud)
是的,当你这样做时,你是对的。更重要的是,此类的实际概率实际上要低得多,约为14%
( tensor([0.7045, 0.1452, 0.1366, 0.0089, 0.0047])
)。通过手动更改输出,您实际上破坏了该神经网络已学习的属性(及其输出分布),从而使某些部分的计算变得毫无意义。这指向了这次赏金中提到的另一个问题:
我可以想象这个问题可以通过多种方式解决:
argmax
创建多个神经网络,并通过对最后的logits 求和(或者softmax
然后是“argmax”)来将它们集成起来。具有不同预测的3 种不同模型的假设情况:
import torch
predicted_logits_1 = torch.Tensor([5.39413513, 3.81419, 3.7546, 1.02716988, 0.39189373])
predicted_logits_2 = torch.Tensor([3.357895, 4.0165, 4.569546, 0.02716988, -0.189373])
predicted_logits_3 = torch.Tensor([2.989513, 5.814459, 3.55369546, 3.06988, -5.89473])
combined_logits = predicted_logits_1 + predicted_logits_2 + predicted_logits_3
print(combined_logits)
print(torch.nn.functional.softmax(combined_logits))
Run Code Online (Sandbox Code Playgroud)
这将为我们提供以下概率softmax
:
[0.11291057 0.7576356 0.1293983 0.00005554 0.]
(注意第一类现在是最有可能的)
您可以使用引导聚合和其他集成技术来改进预测。这种方法使分类决策表面更加平滑,并修复了分类器之间的相互错误(假设它们的预测差异很大)。需要很多帖子才能更详细地描述(或者需要针对特定问题的单独问题),这里或这里有一些可能会帮助您开始。
不过,我不会将这种方法与手动选择输出混合起来。
这种方法可能会产生更好的推理时间,如果您可以将其分布在多个 GPU 上,甚至可能会获得更好的训练时间。
基本上,您的每个班级都可以存在 ( 1
) 或缺席 ( 0
)。原则上,您可以N
为类训练神经网络N
,每个类输出一个无界数字 (logit)。这个数字告诉网络是否认为这个例子应该被分类为它的类别。
如果您确定某个类别不会成为结果,请确保您不运行负责此类检测的网络。从所有网络(或网络子集)获得预测后,您选择最高值(如果使用sigmoid
激活,则选择最高概率,尽管这在计算上会造成浪费)。
额外的好处是所述网络的简单性(更容易训练和微调)和简单的switch-like
行为(如果需要)。
如果我是你,我会采用2.2中概述的方法,因为你可以轻松地节省一些推理时间,并允许你以合理的方式“选择输出”。
如果这种方法还不够,您可以考虑N
网络集成,因此混合使用2.2和2.1、一些引导程序或其他集成技术。这也会提高你的准确性。
归档时间: |
|
查看次数: |
480 次 |
最近记录: |