Import modules

In [1]:
import struct
import numpy as np
from array import array
import matplotlib.pyplot as plt

Function definitions

In [2]:
def load_nmnist(dataset="training", selecteddigits=range(10), path=r'Directory\containing\your\downloaded\files'):

    #Check training/testing specification. Must be "training" (default) or "testing"
    if dataset == "training":
        fname_digits = path + '\\' + 'train-images.idx3-ubyte'
        fname_labels = path + '\\' + 'train-labels.idx1-ubyte'
    elif dataset == "testing":
        fname_digits = path + '\\' + 't10k-images.idx3-ubyte'
        fname_labels = path + '\\' + 't10k-labels.idx1-ubyte'
    else:
        raise ValueError("dataset must be 'testing' or 'training'")
        
        
    #Import digits data
    digitsfileobject = open(fname_digits, 'rb')
    magic_nr, size, rows, cols = struct.unpack(">IIII", digitsfileobject.read(16))
    digitsdata = array("B", digitsfileobject.read())
    digitsfileobject.close()

    #Import label data
    labelsfileobject = open(fname_labels, 'rb')
    magic_nr, size = struct.unpack(">II", labelsfileobject.read(8))
    labelsdata=array("B",labelsfileobject.read())
    labelsfileobject.close()
    
    #Find indices of selected digits
    indices=[k for k in range(size) if labelsdata[k] in selecteddigits]
    N=len(indices)
    
    #Create empty arrays for X and T
    X = np.zeros((N, rows*cols), dtype=np.uint8)
    T = np.zeros(N, dtype=np.uint8)
    
    #Fill X from digitsdata
    #Fill T from labelsdata
    for i in range(N):
        X[i] = digitsdata[indices[i]*rows*cols:(indices[i]+1)*rows*cols]
        T[i] = labelsdata[indices[i]]
    
    return X,T
In [3]:
def vectortoimg(v,show=True):
    plt.imshow(v.reshape(28, 28),interpolation='None', cmap='gray')
    plt.axis('off')
    if show:
        plt.show()

Load data

In [4]:
X, T = load_nmnist(dataset="training",selecteddigits=range(10))
In [5]:
# OTHER WAYS OF CALLING load_nmnist
# X, T = load_nmnist(dataset="training",selecteddigits=[5,6])
# X, T = load_nmnist() #Loads ALL digits of training data
# X, T = load_nmnist(dataset="testing",selecteddigits=[1,2,7])
In [6]:
print("Checking shape of matrix:", X.shape)
print("Checking min/max values:",(np.amin(X),np.amax(X)))
print("Checking unique labels in T:",list(np.unique(T)))
Checking shape of matrix: (60000, 784)
Checking min/max values: (0, 255)
Checking unique labels in T: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Display images

In [7]:
# Showing a selected row of X as an image
vectortoimg(X[20000])
In [8]:
# Label corresponding to the above image
T[20000]
Out[8]:
3
In [9]:
# Show multiple randomly selected rows of X as images. Be patient
plt.close('all')
fig = plt.figure(figsize=(15,15))
nrows=10
ncols=10
for row in range(nrows):
    for col in range(ncols):
        plt.subplot(nrows, ncols, row*ncols+col + 1)
        vectortoimg(X[np.random.randint(len(T))],show=False)
plt.show()