she*_*rry 4 python filter mnist deep-learning keras
我目前正在使用Keras在MNIST数据集上训练前馈神经网络。我正在使用以下格式加载数据集
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
但是我只想用数字0和4训练我的模型,而不是全部。如何只选择2位数字?我对python相当陌生,可以弄清楚如何过滤mnist数据集...
Y_train
并Y_test
为您提供图像标签,您可以将其与numpy.where
一起使用,以滤除带有0和4的部分标签。您所有的变量都是numpy数组,因此您可以轻松完成;
import numpy as np
train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))
Run Code Online (Sandbox Code Playgroud)
您可以使用这些过滤器按索引获取数组的子集。
X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]
Run Code Online (Sandbox Code Playgroud)
如果您对两个以上的标签感兴趣,则语法可能会因为where和or而变得冗长。因此,您还可以numpy.isin
用来创建蒙版。
train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])
Run Code Online (Sandbox Code Playgroud)
您可以像以前一样使用这些掩码进行布尔索引。
您有标签文件以及训练和测试:
train_images = mnist.train_images()
train_labels = mnist.train_labels()
test_images = mnist.test_images()
test_labels = mnist.test_labels()
Run Code Online (Sandbox Code Playgroud)
您可以将它们与简单的列表理解一起使用来过滤数据集
zero_four_test = [test_images[key] for (key, label) in enumerate(test_labels) if int(label) == 0 or int(label) == 4]
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
3201 次 |
最近记录: |