import struct
import numpy as np
from array import array
import matplotlib.pyplot as plt
def load_nmnist(dataset="training", selecteddigits=range(10), path=r'Directory\containing\your\downloaded\files'):
#Check training/testing specification. Must be "training" (default) or "testing"
if dataset == "training":
fname_digits = path + '\\' + 'train-images.idx3-ubyte'
fname_labels = path + '\\' + 'train-labels.idx1-ubyte'
elif dataset == "testing":
fname_digits = path + '\\' + 't10k-images.idx3-ubyte'
fname_labels = path + '\\' + 't10k-labels.idx1-ubyte'
else:
raise ValueError("dataset must be 'testing' or 'training'")
#Import digits data
digitsfileobject = open(fname_digits, 'rb')
magic_nr, size, rows, cols = struct.unpack(">IIII", digitsfileobject.read(16))
digitsdata = array("B", digitsfileobject.read())
digitsfileobject.close()
#Import label data
labelsfileobject = open(fname_labels, 'rb')
magic_nr, size = struct.unpack(">II", labelsfileobject.read(8))
labelsdata=array("B",labelsfileobject.read())
labelsfileobject.close()
#Find indices of selected digits
indices=[k for k in range(size) if labelsdata[k] in selecteddigits]
N=len(indices)
#Create empty arrays for X and T
X = np.zeros((N, rows*cols), dtype=np.uint8)
T = np.zeros(N, dtype=np.uint8)
#Fill X from digitsdata
#Fill T from labelsdata
for i in range(N):
X[i] = digitsdata[indices[i]*rows*cols:(indices[i]+1)*rows*cols]
T[i] = labelsdata[indices[i]]
return X,T
def vectortoimg(v,show=True):
plt.imshow(v.reshape(28, 28),interpolation='None', cmap='gray')
plt.axis('off')
if show:
plt.show()
X, T = load_nmnist(dataset="training",selecteddigits=range(10))
# OTHER WAYS OF CALLING load_nmnist
# X, T = load_nmnist(dataset="training",selecteddigits=[5,6])
# X, T = load_nmnist() #Loads ALL digits of training data
# X, T = load_nmnist(dataset="testing",selecteddigits=[1,2,7])
print("Checking shape of matrix:", X.shape)
print("Checking min/max values:",(np.amin(X),np.amax(X)))
print("Checking unique labels in T:",list(np.unique(T)))
# Showing a selected row of X as an image
vectortoimg(X[20000])
# Label corresponding to the above image
T[20000]
# Show multiple randomly selected rows of X as images. Be patient
plt.close('all')
fig = plt.figure(figsize=(15,15))
nrows=10
ncols=10
for row in range(nrows):
for col in range(ncols):
plt.subplot(nrows, ncols, row*ncols+col + 1)
vectortoimg(X[np.random.randint(len(T))],show=False)
plt.show()