Building a simple neural-network for digit recognition with Keras

Original author: Xavier Snelgrove
Additions and modifications: Samy Zafrany

This is originally a simple quick-start in performing digit recognition in a neural network in Keras, for a short tutorial at the University of Toronto. Some more elaborations and modifications were added later by Samy Zafrany. It is largely based on the mnist_mlp.py example from the Keras source.
Official site: http://yann.lecun.com/exdb/mnist/

In [1]:
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.layers import Convolution2D, MaxPooling2D, Flatten
from keras.utils import np_utils
from keras.layers.advanced_activations import SReLU
from kerutils import *   # private library
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (7,7) # Make the figures a bit larger
%matplotlib inline
Using Theano backend.
DEBUG: nvcc STDOUT mod.cu
   Creating library C:/Users/samy/AppData/Local/Theano/compiledir_Windows-10-10.0.14393-Intel64_Family_6_Model_94_Stepping_3_GenuineIntel-2.7.11-64/tmppgymhe/265abc51f7c376c224983485238ff1a5.lib and object C:/Users/samy/AppData/Local/Theano/compiledir_Windows-10-10.0.14393-Intel64_Family_6_Model_94_Stepping_3_GenuineIntel-2.7.11-64/tmppgymhe/265abc51f7c376c224983485238ff1a5.exp

Using gpu device 0: GeForce GTX 950 (CNMeM is enabled with initial size: 80.0% of memory, cuDNN 5103)
c:\anaconda2\lib\site-packages\theano\sandbox\cuda\__init__.py:600: UserWarning: Your cuDNN version is more recent than the one Theano officially supports. If you see any problems, try updating Theano or downgrading cuDNN to version 5.
  warnings.warn(warn)

Download Prerequisites

To run the code in this notebook, you'll need to download the following course modules which we use in all study units of this course:

  1. http://www.samyzaf.com/cgi-bin/view_file.py?file=ML/lib/kerutils.py
  2. http://www.samyzaf.com/cgi-bin/view_file.py?file=ML/lib/dlutils.py
  3. http://www.samyzaf.com/ML/style-notebook.css

Or you can download everything in one zip file from the github repository: https://github.com/samyzaf/kerutils

First let's source a good looking notebok style css:

In [1]:
# These are css/html styles for good looking ipython notebooks
from IPython.core.display import HTML
css = read_file('style-notebook.css')
HTML('<style>%s</style>' % css)
Out[1]:

Installing and configuring your graphics card for Theano can be a bit daunting, but is not essential for running this notebook (it will simply run slowly on you CPU)

Load training and validation data

In [3]:
nb_classes = 10

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print("x_train original shape", x_train.shape)
print("y_train original shape", y_train.shape)
('x_train original shape', (60000L, 28L, 28L))
('y_train original shape', (60000L,))
In [4]:
# What is the type of x_train ?

type(x_train)
Out[4]:
numpy.ndarray

OK, this is a standard Numpy matrix, which is good for Keras

In [5]:
# How big is X_train ?

len(x_train)
Out[5]:
60000

So our training data set consists of 60 thousands images of hand written digits (each image in a 28x28 pixmap form). The y_train vector maps each numeral instance to its proper class: 0, 1, 2, ..., 9

In [6]:
# View of the first 20 classes in y_train

print(y_train[0:20])
[5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9]
In [7]:
print(len(x_test), len(y_test))
10000 10000
In [8]:
print("test0 = %s\n class = %s" % (x_test[0], y_test[0]))
test0 = [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  84 185 159 151  60  36   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0 222 254 254 254 254 241 198 198 198 198 198 198
  198 198 170  52   0   0   0   0   0   0]
 [  0   0   0   0   0   0  67 114  72 114 163 227 254 225 254 254 254 250
  229 254 254 140   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  17  66  14  67  67  67  59
   21 236 254 106   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   83 253 209  18   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  22
  233 255  83   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 129
  254 238  44   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  59 249
  254  62   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 133 254
  187   5   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   9 205 248
   58   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 126 254 182
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  75 251 240  57
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0  19 221 254 166   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3 203 254 219  35   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  38 254 254  77   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  31 224 254 115   1   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0 133 254 254  52   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  61 242 254 254  52   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0 121 254 254 219  40   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0 121 254 207  18   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]]
 class = 7

Looking at a 28x28 matrix is not a pleasant experience, but if we draw it as an image, it might be more appealing to the eye. Here is a matplotlib code for drawing this matrix as a 28x28 image.

In [9]:
plt.imshow(x_test[0], cmap='gray', interpolation='none')
plt.title("Class {}".format(y_test[0]))
Out[9]:
<matplotlib.text.Text at 0x644626d8>

Let's look at some more examples of the training data. We will use the same code in a loop for displaying 15 images in a 3x5 grid:

In [10]:
n = 835 # just a random pick
for i in range(0, 15):
    plt.subplot(3,5,i+1)
    plt.imshow(x_train[n+i], cmap='gray', interpolation='none')
    plt.title("Class {}".format(y_train[n+i]), fontsize=11)
    plt.tick_params(axis='both', which='major', labelsize=6)
    plt.subplots_adjust(hspace=0.65, wspace=0.5)

Format the data for training

Our neural-network is going to take a single vector for each training example, so we need to reshape the input so that each 28x28 image becomes a single 784 dimensional vector. We'll also scale the inputs to be in the range [0-1] rather than [0-255]

In [6]:
X_train = x_train.reshape(60000, 784)
X_test = x_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print("Training matrix shape", X_train.shape)
print("Testing matrix shape", X_test.shape)
Training matrix shape (60000L, 784L)
Testing matrix shape (10000L, 784L)
In [12]:
#Example of one training vector
X_train[0]
Out[12]:
array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.01176471,  0.07058824,  0.07058824,
        0.07058824,  0.49411765,  0.53333336,  0.68627453,  0.10196079,
        0.65098041,  1.        ,  0.96862745,  0.49803922,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.11764706,  0.14117648,  0.36862746,  0.60392159,
        0.66666669,  0.99215686,  0.99215686,  0.99215686,  0.99215686,
        0.99215686,  0.88235295,  0.67450982,  0.99215686,  0.94901961,
        0.7647059 ,  0.25098041,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.19215687,  0.93333334,
        0.99215686,  0.99215686,  0.99215686,  0.99215686,  0.99215686,
        0.99215686,  0.99215686,  0.99215686,  0.98431373,  0.36470589,
        0.32156864,  0.32156864,  0.21960784,  0.15294118,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.07058824,  0.85882354,  0.99215686,  0.99215686,
        0.99215686,  0.99215686,  0.99215686,  0.7764706 ,  0.71372551,
        0.96862745,  0.94509804,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.3137255 ,  0.61176473,  0.41960785,  0.99215686,  0.99215686,
        0.80392158,  0.04313726,  0.        ,  0.16862746,  0.60392159,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.05490196,
        0.00392157,  0.60392159,  0.99215686,  0.35294119,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.54509807,
        0.99215686,  0.74509805,  0.00784314,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.04313726,  0.74509805,  0.99215686,
        0.27450982,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.13725491,  0.94509804,  0.88235295,  0.627451  ,
        0.42352942,  0.00392157,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.31764707,  0.94117647,  0.99215686,  0.99215686,  0.46666667,
        0.09803922,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.17647059,
        0.72941178,  0.99215686,  0.99215686,  0.58823532,  0.10588235,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.0627451 ,  0.36470589,
        0.98823529,  0.99215686,  0.73333335,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.97647059,  0.99215686,
        0.97647059,  0.25098041,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.18039216,  0.50980395,
        0.71764708,  0.99215686,  0.99215686,  0.81176472,  0.00784314,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.15294118,
        0.58039218,  0.89803922,  0.99215686,  0.99215686,  0.99215686,
        0.98039216,  0.71372551,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.09411765,  0.44705883,  0.86666667,  0.99215686,  0.99215686,
        0.99215686,  0.99215686,  0.78823531,  0.30588236,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.09019608,  0.25882354,  0.83529413,  0.99215686,
        0.99215686,  0.99215686,  0.99215686,  0.7764706 ,  0.31764707,
        0.00784314,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.07058824,  0.67058825,  0.85882354,
        0.99215686,  0.99215686,  0.99215686,  0.99215686,  0.7647059 ,
        0.3137255 ,  0.03529412,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.21568628,  0.67450982,
        0.88627452,  0.99215686,  0.99215686,  0.99215686,  0.99215686,
        0.95686275,  0.52156866,  0.04313726,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.53333336,  0.99215686,  0.99215686,  0.99215686,
        0.83137256,  0.52941179,  0.51764709,  0.0627451 ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ], dtype=float32)

Modify the target matrices to be in the one-hot format, i.e.

0 → [1, 0, 0, 0, 0, 0, 0, 0, 0]
1 → [0, 1, 0, 0, 0, 0, 0, 0, 0]
2 → [0, 0, 1, 0, 0, 0, 0, 0, 0]
3 → [0, 0, 0, 1, 0, 0, 0, 0, 0]
etc...
In [7]:
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
In [8]:
# Let's look at the first 20 elements of Y_train

print(Y_train[0:19])
[[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]
 [ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  1.  0.]
 [ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.]]

Build the neural network

Build the neural-network. Here we'll do a simple 3 layer fully connected network ("Dense"). We will also apply a Dropout technique to combat the overfitting phenomenon during training. This technique has been suggested in a 2014 paper and has been widely adopted since then.
http://jmlr.org/papers/v15/srivastava14a.html
Aditional source on Dropout:
http://machinelearningmastery.com/dropout-regularization-deep-learning-models-keras

In our first model, Dropout is applied between the input layer to the hidden layer and then between the hidden layer and the output layer. The dropout rate of 20% to 50% seems to fit in many practical cases.

The activation parameter is discussed in our the Pima Indian Diabetes Database unit.

In [15]:
model1 = Sequential()
model1.add(Dense(512, input_shape=(784,), activation='relu'))
model1.add(Dropout(0.25))   # Dropout helps protect the model from memorizing or "overfitting" the training data
model1.add(Dense(512, activation='relu'))
model1.add(Dropout(0.25))
model1.add(Dense(10, activation='softmax'))
# This special "softmax" activation among other things, ensures the output is
# a valid probaility distribution, that is that its values are all
# non-negative and sum to 1.

Compile the model

Keras is built on top of Theano (and now TensorFlow as well), both packages that allow you to define a computation graph in Python, which they then compile and run efficiently on the CPU or GPU without the overhead of the Python interpreter.

When compiling a model, Keras asks you to specify your loss function and your optimizer. These functions are discussed in more detail in the Pima Indian Diabetes Database unit. The loss function that we will use here is called categorical crossentropy, and is a loss function well-suited for comparing two probability distributions.

The neural network output vector represents probability distributions across the ten different digits (e.g. "we're 78% confident this image is a 3, 12% sure it's an 8, 5% it's a 2, etc."), and the target is a probability distribution with 100% for the correct category, and 0 for everything else. The cross-entropy is a measure of how different your predicted distribution is from the target distribution. More detail at Wikipedia

The optimizer helps determine how quickly the model learns, how resistant it is to getting "stuck" or "blowing up". We will not discuss this in too much detail, but "adam" is often a good choice (developed at Univ. of Toronto).

In [16]:
model1.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])

Train the model!

This is the fun part: you can feed the training data loaded in earlier into this model and it will learn to classify digits. Note that the fit method returns a history object which logs the various loss/accuracy values across the training flow. See the following example on how to use it for inspecting the model.

In [17]:
h1 = model1.fit(
    X_train,
    Y_train,
    batch_size=32,
    nb_epoch=4,
    verbose=1,
    validation_data=(X_test, Y_test)
)
Train on 60000 samples, validate on 10000 samples
Epoch 1/4
60000/60000 [==============================] - 6s - loss: 0.2223 - acc: 0.9323 - val_loss: 0.1062 - val_acc: 0.9674
Epoch 2/4
60000/60000 [==============================] - 6s - loss: 0.1129 - acc: 0.9657 - val_loss: 0.0823 - val_acc: 0.9739
Epoch 3/4
60000/60000 [==============================] - 6s - loss: 0.0894 - acc: 0.9717 - val_loss: 0.0810 - val_acc: 0.9765
Epoch 4/4
60000/60000 [==============================] - 6s - loss: 0.0763 - acc: 0.9765 - val_loss: 0.0758 - val_acc: 0.9778

Finally, evaluate its performance

In [18]:
loss1, accuracy1 = model1.evaluate(X_test, Y_test, verbose=0)
print("Validation: accuracy1 = %f  ;  loss1 = %f" % (accuracy1, loss1))
Validation: accuracy1 = 0.977800  ;  loss1 = 0.075823
In [19]:
loss1, accuracy1 = model1.evaluate(X_train, Y_train, verbose=0)
print("Training: accuracy1 = %f  ;  loss1 = %f" % (accuracy1, loss1))
Training: accuracy1 = 0.987583  ;  loss1 = 0.038519

Inspecting the output

Not bad: 98.76% training accuracy, and 97.78% training accuracy on our first simple model is a good start! It's always a good idea to inspect the model output and make sure it looks sane. It is also a good practice to look at some examples it gets right, and some examples it gets wrong. A more thorough inspection is obtained by looking at the progress graphs of training and validation accuracy values. The view_acc function ("view accuracy") is currently part of the kerutils module (look above for download links).

In [20]:
# Training accuracy and validation accuracy graphs

view_acc(h1)

A similar function for viewing the model loss history is view_loss

In [21]:
view_loss(h1)
In [22]:
# The predict_classes function outputs the highest probability class
# according to the trained classifier for each input example.

y_pred = model1.predict_classes(X_test)
 9760/10000 [============================>.] - ETA: 0s
In [23]:
# Check which items we got right / wrong

true_indices = np.nonzero(y_pred == y_test)[0]
false_indices = np.nonzero(y_pred != y_test)[0]
print("Number of false predictions = %d (out of %d samples)" % (len(false_indices), len(y_test)))
Number of false predictions = 222 (out of 10000 samples)

Take a look at the following 32 samples as processed by model1. The first 16 samples were correctly identified by our model1, while the other 16 samples where missed by it. But look how close it was to identify them (almost like a "human" error).

In [21]:
# This function draws 9 correctly predicted cases vs. 9 incorrect cases

def view_false_samples(X, y, y_pred, n):
    false_indices = np.nonzero(y_pred != y)[0]
    print("Total number of false items = %d out of %d" % (len(false_indices), len(y)))
    plt.figure(figsize=(7,7))
    for i, incorrect in enumerate(false_indices[n:n+16]):
        plt.subplot(4,4,i+1)
        plt.imshow(X[incorrect].reshape(28,28), cmap='gray', interpolation='none')
        plt.title("Pred={}, Class={}".format(y_pred[incorrect], y[incorrect]), fontsize=10)
        plt.tick_params(axis='both', which='major', labelsize=7)
        plt.subplots_adjust(hspace=0.5, wspace=0.5)
    plt.show()

view_false_samples(X_test, y_test, y_pred, 0)

As you can see, the errors that model1 made are almost "human" in most of the cases. Let's see if we can build a better model.

Second Keras Model (based on Convolution)

It turns out that in most practcal cases, Keras Dense layers (fully connected layers) are a bad choice for image data sets. They may work OK for small 28x28 images but are impractical for real size images like a 1024x1024 image which requires 1 milion neurons input layer and $10^{12}$ synapses for a Dense layer on top of it (with the same size). In a convolutional layer, each neuron is connected to a very small subset of neurons, that represent pixels in its local area (usuall 3x3 or 5x5 windows in its neighborhood). The rational is that an image pixel is more likely to be impacted by the 3x3 or 5x5 pixels around it and not by any pixel which is far away from it.

This is not the place to go into more details, but a decent quick introduction to convolutional layers can be picked from this course: </b> http://cs231n.github.io/convolutional-networks

In the mnist case we will not gain much precision since our images are small and convolutional layers in such case hardly matter. But it's good to get acquainted with the technology.

Let's start by reloading our mnist data and define the required parameters for our convolutional layers.

In [9]:
# Number of classes (digits)
nb_classes = 10

# input image dimensions
img_rows, img_cols = 28, 28

# number of convolutional filters to use
nb_filters = 128

# size of pooling area for max pooling
pool_size = (2, 2)

# convolution kernel size
kernel_size = (3, 3)

# Loading the data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print("x_train original shape", x_train.shape) 
print("y_train original shape", y_train.shape)

X_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
X_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

X_train = X_train.astype('float32')  # use less memory
X_test = X_test.astype('float32')    # use less memory
X_train /= 255
X_test /= 255
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
print("Training matrix shape", X_train.shape)
print("Testing matrix shape", X_test.shape)
x_train original shape (60000L, 28L, 28L)
y_train original shape (60000L,)
Training matrix shape (60000L, 28L, 28L, 1L)
Testing matrix shape (10000L, 28L, 28L, 1L)

Here is the second model definition:

In [13]:
model2 = Sequential()
model2.add(Convolution2D(nb_filters, 3, 3, border_mode='valid', input_shape=input_shape))
model2.add(Activation(SReLU()))
model2.add(Convolution2D(nb_filters, 3, 3))
model2.add(Activation(SReLU()))
model2.add(MaxPooling2D(pool_size=pool_size))
model2.add(Dropout(0.25))

model2.add(Flatten())
model2.add(Dense(256))
model2.add(Activation(SReLU()))
model2.add(Dropout(0.5))
model2.add(Dense(nb_classes))
model2.add(Activation('softmax'))

from keras.optimizers import SGD
#sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model2.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

h = model2.fit(
    X_train,
    Y_train,
    batch_size=64,
    nb_epoch=30,
    verbose=1,
    validation_data=(X_test, Y_test),
)

#model2.save("model_2.h5")
show_scores(model2, h, X_train, Y_train, X_test, Y_test)
Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 94s - loss: 0.1571 - acc: 0.9516 - val_loss: 0.0395 - val_acc: 0.9866
Epoch 2/30
60000/60000 [==============================] - 95s - loss: 0.0643 - acc: 0.9804 - val_loss: 0.0415 - val_acc: 0.9858
Epoch 3/30
60000/60000 [==============================] - 95s - loss: 0.0503 - acc: 0.9840 - val_loss: 0.0434 - val_acc: 0.9860
Epoch 4/30
60000/60000 [==============================] - 95s - loss: 0.0425 - acc: 0.9863 - val_loss: 0.0310 - val_acc: 0.9896
Epoch 5/30
60000/60000 [==============================] - 95s - loss: 0.0344 - acc: 0.9892 - val_loss: 0.0299 - val_acc: 0.9911
Epoch 6/30
60000/60000 [==============================] - 95s - loss: 0.0314 - acc: 0.9895 - val_loss: 0.0301 - val_acc: 0.9915
Epoch 7/30
60000/60000 [==============================] - 95s - loss: 0.0278 - acc: 0.9911 - val_loss: 0.0253 - val_acc: 0.9921
Epoch 8/30
60000/60000 [==============================] - 95s - loss: 0.0253 - acc: 0.9920 - val_loss: 0.0303 - val_acc: 0.9919
Epoch 9/30
60000/60000 [==============================] - 95s - loss: 0.0232 - acc: 0.9924 - val_loss: 0.0323 - val_acc: 0.9914
Epoch 10/30
60000/60000 [==============================] - 95s - loss: 0.0217 - acc: 0.9930 - val_loss: 0.0424 - val_acc: 0.9896
Epoch 11/30
60000/60000 [==============================] - 95s - loss: 0.0194 - acc: 0.9937 - val_loss: 0.0336 - val_acc: 0.9915
Epoch 12/30
60000/60000 [==============================] - 95s - loss: 0.0181 - acc: 0.9940 - val_loss: 0.0352 - val_acc: 0.9921
Epoch 13/30
60000/60000 [==============================] - 95s - loss: 0.0189 - acc: 0.9942 - val_loss: 0.0324 - val_acc: 0.9922
Epoch 14/30
60000/60000 [==============================] - 95s - loss: 0.0165 - acc: 0.9951 - val_loss: 0.0339 - val_acc: 0.9918
Epoch 15/30
60000/60000 [==============================] - 95s - loss: 0.0163 - acc: 0.9950 - val_loss: 0.0363 - val_acc: 0.9910
Epoch 16/30
60000/60000 [==============================] - 95s - loss: 0.0145 - acc: 0.9955 - val_loss: 0.0355 - val_acc: 0.9922
Epoch 17/30
60000/60000 [==============================] - 95s - loss: 0.0163 - acc: 0.9950 - val_loss: 0.0372 - val_acc: 0.9914
Epoch 18/30
60000/60000 [==============================] - 96s - loss: 0.0152 - acc: 0.9953 - val_loss: 0.0453 - val_acc: 0.9911
Epoch 19/30
60000/60000 [==============================] - 97s - loss: 0.0140 - acc: 0.9957 - val_loss: 0.0426 - val_acc: 0.9913
Epoch 20/30
60000/60000 [==============================] - 97s - loss: 0.0132 - acc: 0.9962 - val_loss: 0.0365 - val_acc: 0.9929
Epoch 21/30
60000/60000 [==============================] - 97s - loss: 0.0154 - acc: 0.9957 - val_loss: 0.0381 - val_acc: 0.9914
Epoch 22/30
60000/60000 [==============================] - 97s - loss: 0.0158 - acc: 0.9956 - val_loss: 0.0527 - val_acc: 0.9907
Epoch 23/30
60000/60000 [==============================] - 97s - loss: 0.0166 - acc: 0.9953 - val_loss: 0.0464 - val_acc: 0.9901
Epoch 24/30
60000/60000 [==============================] - 97s - loss: 0.0121 - acc: 0.9964 - val_loss: 0.0612 - val_acc: 0.9896
Epoch 25/30
60000/60000 [==============================] - 97s - loss: 0.0110 - acc: 0.9968 - val_loss: 0.0395 - val_acc: 0.9921
Epoch 26/30
60000/60000 [==============================] - 97s - loss: 0.0141 - acc: 0.9961 - val_loss: 0.0434 - val_acc: 0.9922
Epoch 27/30
60000/60000 [==============================] - 97s - loss: 0.0145 - acc: 0.9960 - val_loss: 0.0461 - val_acc: 0.9912
Epoch 28/30
60000/60000 [==============================] - 97s - loss: 0.0150 - acc: 0.9964 - val_loss: 0.0407 - val_acc: 0.9933
Epoch 29/30
60000/60000 [==============================] - 97s - loss: 0.0118 - acc: 0.9968 - val_loss: 0.0487 - val_acc: 0.9912
Epoch 30/30
60000/60000 [==============================] - 97s - loss: 0.0128 - acc: 0.9965 - val_loss: 0.0458 - val_acc: 0.9926
Training: accuracy   = 0.999883 loss = 0.000639
Validation: accuracy = 0.992600 loss = 0.045782
Over fitting score   = 0.003950
Under fitting score  = 0.004489
Params count: 4870282
stop epoch = 29
nb_epoch = 30
batch_size = 64
nb_sample = 60000
In [14]:
loss, accuracy = model2.evaluate(X_train, Y_train, verbose=0)
print("Training: accuracy = %f  ;  loss = %f" % (accuracy, loss))
Training: accuracy = 0.999883  ;  loss = 0.000639
In [15]:
loss, accuracy = model2.evaluate(X_test, Y_test, verbose=0)
print("Validation: accuracy = %f  ;  loss = %f" % (accuracy, loss))
Validation: accuracy = 0.992600  ;  loss = 0.045782

Indeed, using convolutional layers has yielded better accuracy levels. The training accuracy 99.998% is almost close to 100%. The validation accuracy 99.26% is almost practical, at least for police camera control and toll roads detection cameras (in which vehicle registration plate digits are expected to behave better than hand written digits ...).

In [16]:
# The predict_classes function outputs the highest probability class
# according to the trained classifier for each input example.

y_pred = model2.predict_classes(X_test)
 9984/10000 [============================>.] - ETA: 0s
In [18]:
# Check which items we got right / wrong

true_indices = np.nonzero(y_pred == y_test)[0]
false_indices = np.nonzero(y_pred != y_test)[0]
In [19]:
print("success samples: %d, failed samples: %d" % (len(true_indices), len(false_indices)))
success samples: 9926, failed samples: 74

Out of 10000 training images, only 74 of them where missed by our model. Let's take a look at some of them to try to understand why they were so hard to catch?

In [22]:
view_false_samples(X_test, y_test, y_pred, 20)
Total number of false items = 74 out of 10000

Just by looking at the last 16 false results, it looks like as if our model is almost "human"-like in missing them and identify them with closely looking numerals. In some of these cases look too vague to be recognized even by a real human eye. Improving the 99.61% precision score will probably require more training instances, or maybe more refined network architectures (like the Keras graph models). We may later formulate some challenges of this sort as course projects.

For further improvements it's recommended to look at Keras ImageDataGenerator which helps in generating new training samples from the old samples by using image processing techniques. Here is a good tutorial on how to do it on the mnist data set:
http://machinelearningmastery.com/image-augmentation-deep-learning-keras/

This technique is demonstrated in the CIFAR10 image data set at a later study unit in this course.

There are lots of other great examples at the Keras homepage at http://keras.io and in the source code at https://github.com/fchollet/keras