Jos*_*hua 10 python machine-learning mnist deep-learning tensorflow
我从LeCun网站下载了MNIST数据集.我想要的是编写Python代码以提取gzip并直接从目录中读取数据集,这意味着我不再需要下载或访问MNIST站点.
欲望过程: 访问文件夹/目录 - >提取gzip - >读取数据集(一个热编码)
怎么做?由于几乎所有教程都必须访问LeCun或Tensoflow站点才能下载和读取数据集.提前致谢!
这个tensorflow调用
from tensorflow.examples.tutorials.mnist import input_data
input_data.read_data_sets('my/directory')
Run Code Online (Sandbox Code Playgroud)
... 如果您已经有文件,将不会下载任何内容.
但如果由于某种原因你希望自己解压缩,那么你就是这样做的:
from tensorflow.contrib.learn.python.learn.datasets.mnist import extract_images, extract_labels
with open('my/directory/train-images-idx3-ubyte.gz', 'rb') as f:
train_images = extract_images(f)
with open('my/directory/train-labels-idx1-ubyte.gz', 'rb') as f:
train_labels = extract_labels(f)
with open('my/directory/t10k-images-idx3-ubyte.gz', 'rb') as f:
test_images = extract_images(f)
with open('my/directory/t10k-labels-idx1-ubyte.gz', 'rb') as f:
test_labels = extract_labels(f)
Run Code Online (Sandbox Code Playgroud)
如果您提取了MNIST数据,则可以使用NumPy直接将其低级加载:
def loadMNIST( prefix, folder ):
intType = np.dtype( 'int32' ).newbyteorder( '>' )
nMetaDataBytes = 4 * intType.itemsize
data = np.fromfile( folder + "/" + prefix + '-images-idx3-ubyte', dtype = 'ubyte' )
magicBytes, nImages, width, height = np.frombuffer( data[:nMetaDataBytes].tobytes(), intType )
data = data[nMetaDataBytes:].astype( dtype = 'float32' ).reshape( [ nImages, width, height ] )
labels = np.fromfile( folder + "/" + prefix + '-labels-idx1-ubyte',
dtype = 'ubyte' )[2 * intType.itemsize:]
return data, labels
trainingImages, trainingLabels = loadMNIST( "train", "../datasets/mnist/" )
testImages, testLabels = loadMNIST( "t10k", "../datasets/mnist/" )
Run Code Online (Sandbox Code Playgroud)
并转换为热编码:
def toHotEncoding( classification ):
# emulates the functionality of tf.keras.utils.to_categorical( y )
hotEncoding = np.zeros( [ len( classification ),
np.max( classification ) + 1 ] )
hotEncoding[ np.arange( len( hotEncoding ) ), classification ] = 1
return hotEncoding
trainingLabels = toHotEncoding( trainingLabels )
testLabels = toHotEncoding( testLabels )
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
22553 次 |
| 最近记录: |