ndarray 的条件过滤

ari*_*wan 5 python arrays numpy python-3.x numpy-ndarray

假设我有以下数组数组:

Input = np.array([[[[17.63,  0.  , -0.71, 29.03],
         [17.63, -0.09,  0.71, 56.12],
         [ 0.17,  1.24, -2.04, 18.49],
         [ 1.41, -0.8 ,  0.51, 11.85],
         [ 0.61, -0.29,  0.15, 36.75]]],


       [[[ 0.32, -0.14,  0.39, 24.52],
         [ 0.18,  0.25, -0.38, 18.08],
         [ 0.  ,  0.  ,  0.  ,  0.  ],
         [ 0.  ,  0.  ,  0.  ,  0.  ],
         [ 0.43,  0.  ,  0.3 ,  0.  ]]],


       [[[ 0.75, -0.38,  0.65, 19.51],
         [ 0.37,  0.27,  0.52, 24.27],
         [ 0.  ,  0.  ,  0.  ,  0.  ],
         [ 0.  ,  0.  ,  0.  ,  0.  ],
         [ 0.  ,  0.  ,  0.  ,  0.  ]]]])

Input.shape
(3, 1, 5, 4)
Run Code Online (Sandbox Code Playgroud)

与此Input数组一起是Label 所有输入的对应数组,因此:

Label = np.array([0, 1, 2])

Label.shape
(3,)
Run Code Online (Sandbox Code Playgroud)

我需要某种方法来检查 的所有嵌套数组Input,以仅选择具有足够数据点的数组。

我的意思是我想要一种方法来消除(或者我应该说删除)所有最后 3 行的条目都是零的数组。在这样做的同时,消除该Label数组的对应项。

预期输出:

Input_filtered
array([[[[17.63,  0.  , -0.71, 29.03],
         [17.63, -0.09,  0.71, 56.12],
         [ 0.17,  1.24, -2.04, 18.49],
         [ 1.41, -0.8 ,  0.51, 11.85],
         [ 0.61, -0.29,  0.15, 36.75]]],


       [[[ 0.32, -0.14,  0.39, 24.52],
         [ 0.18,  0.25, -0.38, 18.08],
         [ 0.  ,  0.  ,  0.  ,  0.  ],
         [ 0.  ,  0.  ,  0.  ,  0.  ],
         [ 0.43,  0.  ,  0.3 ,  0.  ]]]])

Label_filtered
array([0, 1])
Run Code Online (Sandbox Code Playgroud)

我需要什么技巧?