Wednesday, June 22, 2016

The MNIST handwritten digits dataset

MNIST


Machine Learning and OCR

Can a computer read handwritten characters? We all know by now that the answer is yes. We have handwriting recognition on our phones and Optical Character Recognition (OCR) capabilities on our PDF scanners. But have you wondered how it is done?

Even before deep learning ML algorithms came into prominence against character recognition, the document processing industry has been thriving for years and has developed specialized techniques to handle character recognition. These techniques typically involved hand-crafted features to uniquely identify characters.

Machine learning (specifically through the use of convolutional neural networks) came to address character recognition to very high accuracy rates in the mid-90s. One of the first applications then for ML-based character recognition was reading handwritten numbers in bank checks and zipcodes in letters. (Strictly speaking, OCR refers to recognizing typeset characters, not handwritten characters, which is a much harder problem.)

Number recognition is now an easy problem for machine learning enthusiasts. Google reads home numbers through images captured by its fleet of camera cars. Reading digits in real-world situations, such as with Google's street address problem, is still non-trivial. However, if confined to a 'clean' and controlled environment, recognizing digits is now a generic coding assignment for graduate students. It wasn’t so easy.

Introducing another benchmark dataset: MNIST

We previously introduced the Iris dataset as a simple test for classification algorithms. We now introduce a slightly more complex dataset for testing these algorithms. This dataset is called the MNIST dataset (here). It is a collection of handwritten numbers from "0" through "9" written by random Census Bureau employees and high school students. It stands for modified NIST and NIST of course is the National Institute for Standards and Technology, the measurements standards agency of the US government.

With 70k entries, it is a huge dataset of handwritten digits (0-9). Just like the Iris dataset, this has been extensively used by researchers. It therefore offers a good platform to test different algorithms and observe performance relative to more well-known benchmarks. The benchmark hovers around ~99.7%. The top techniques are statistically equal and are practically perfect. The 20-odd errors from a 10k test set are on examples that could also confuse adult humans! On this basis, it is considered a 'solved' problem.

Techniques that are successful on MNIST are also successful in other digit recognition datasets, but to a slightly lesser accuracy, for example on a dataset called USPS (zipcode digits; a dataset we could explore in the future). USPS has a less clean dataset and is considered slightly harder than MNIST. This drop in performance indicates the top-level algorithms are still affected by the quality of their training set, a problem that does not seem to faze a 1st grader. There are also performance disparities between classifiers trained with one dataset and used against a different dataset (e.g., MNIST-trained but deployed on USPS, or vice versa), further suggesting that the generalization capacity of the learned digit features are somewhat imprecise (relative to a human's capacity to infer similarities in stroke patterns).


MNIST testing as proxy for computer vision testing

Researchers study difficult problems and develop algorithms to solve it by breaking it down into smaller problems, or solving only a simplified version. For instance, while recognizing objects is a trivial task for humans, for computers, vision is a very difficult problem involving multiple stages of image processing, e.g., edge detection, segmentation, object recognition, classification. A very simplified version of this problem is to recognize simple shapes, without background clutter.

Recognizing digits is one such simplified problem. Reading digits is actually not a trivial problem either, even if it is already simplified to ‘just’ reading digits. It is hard because people write numbers in many different ways. Written digits might also overlap. Identifying and correctly classifying numbers that overlap with another number might sound an afterthought for a 1st grader, but it makes recognition more complicated.

MNIST simplifies this by presenting a dataset of well-defined and consistently processed images. The handwritten digits are centered (i.e., no need to train a classifer where to look), are individually separated (no need for segmentation, nor resolving occlussion and overlaps), and on a grayscale (i.e., consistent color scheme) against a plain background (i.e., no background clutter). The centering process is done via a center-of-moment over the pixel weights. Unlike a centered bounding box approach, this causes some digits, for example bottom heavy digits like "6", to appear higher than the rest. Segmentation, in this case against other digits or against background clutter, is considered an integral part of the object recognition process, i.e., recognizing partially seen objects inform segmentation, and segmentation refines object recognition.

In character recognition terminology, MNIST is a set of offline, handprinted digits. Offline means the character is complete when presented, and we do not have information on the sequence of strokes used to write it. Many handwriting translators on smartphones and tablets are online character recognizers. They have the very useful additional information of stroke sequence. (Try this experiment: write letters and words in reverse stroke sequence in your smartphone). Handprinted means characters are written individually, not in cursive, removing the need to segment characters, which in itself is a very difficult problem. Generally however, there is no distinction made between handprinted and handwritten for MNIST since the context is clearly well separated digits.

Extracting the MNIST data

We can extract the original MNIST dataset from Lecun’s page (here), which we can then re-write to a format of our preference (e.g., CSV). Alternatively, we could look around the internet for a prepared CSV file. :) This led me to the Kaggle site (here). Kaggle is well-known by data scientists from all walks of life for its data analytics competitions. Kaggle provides a 42k training data and 28k testing data for the MNIST digits. The original MNIST dataset is actually 60k training data and 10k testing data. Kaggle splits them differently but the two datasets are the same otherwise. Kaggle, however, randomly changed the sequence of the original MNIST dataset.

To make processing faster, we will only use the first 10k digits from the Kaggle training dataset. In our future experiments, a smaller number of training data will lower our accuracy score vs standard benchmark runs that use the entire training set. We can always run the full dataset given more time, but we only want to see if our algorithms are pointing in the right direction. Note also that the benchmark state-of-the-art runs actually use the original dataset grouping of 60k training data and 10k testing data, not the 42k-28k split done by Kaggle.

In [1]:
# Jupyter magic so we do not have to manually attach PNG output files
%matplotlib inline

# import libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# let's limit our dataset
train_start=1
train_end=10000

print('Running MNIST, stats:')
print('... train_start:',train_start)
print('... train_end:',train_end)

print('Copying csv to matrix...')
train_file='C:/Users/Philip/Documents/Home/Dad/kaggle_mnist_train.csv'
#train_file='/data/kaggle_mnist_train.csv'
temp_train=np.array(pd.read_csv(train_file).ix[train_start-1:train_end-1].as_matrix(),dtype='uint8')

# remove label column
print('Deleting labels...')
train=np.delete(temp_train,0,1)

print('... done.')
Running MNIST, stats:
... train_start: 1
... train_end: 10000
Copying csv to matrix...
Deleting labels...
... done.


Displaying the MNIST digits

Each row in the MNIST dataset represents a 28x28 pixel grid. Each point is an integer between 0 (black) and 255 (white). Intermediate values indicate the intensity of the stroke at that pixel. This is a mechanical problem that can be easily coded by reading each value and reconstructing a pixel array.

However, with some creativity, we can do better. We can use a 28x28 matplotlib heatmap instead, forcing the heatmap to use a grayscale instead of the usual red-orange 'hot' color legend. We scale the intensity between 0 and 255 to reflect the original color intensity. This is implied if we want to create an 'average' intensity, but we will explore changing this color scale later.

The code below displays the first 100 digits on the Kaggle dataset.

In [2]:
# display the first 100 digits
fig=plt.figure(figsize=(10,10))
for i in range(100):
    ax=fig.add_subplot(10,10,i+1)
    ax.set_axis_off()
    a=np.copy(train[i])
    a=np.reshape(a,(28,28))
    ax.imshow(a, cmap='gray', interpolation='nearest')
plt.show()


Handwriting peculiarities

We notice several types of distortion, each potentially affecting a sample fully or partially: translation, rotation, shearing, thinning/thickening, compression, and loss/change of pixel data due to pre-processing.

  1. Some people write digits off-center (3rd row, 2nd column: "6"). Translation is partly addressed with a centered image (via pre-processing), but parts of the image can still appear translated versus other similar digits.
  2. Some people write digits diagonally in different orientations, that is, the entire digit is rotated, not leaning, which is shearin (1st row, 4th column: "4"; 5th row, 4th column: "4").
  3. Some people write with an obvious lean, usually to the right, but of different amounts (varioius examples).
  4. Some digits are also written with too much ink, or too little (1st row, 5th column: "0" vs the next "0").
  5. Some digits are written thinly or widely (same examples as #4; also, 9th row, 6th column: "9" and the next "9").
  6. Some digits are also not perfectly represented, e.g., missing or incorrect pixels, as part of the dataset preparation (8th row, 4th column: possibly a "2").
  7. Some digits are also written with extra style, e.g., elongated ends and other personalized effects (different examples of "7", "3", "2" and so on).
  8. Some digits are simply barely legible, pure noise, or even mislabeled (5th row, 9th column: "9" or "7"?; 7th row, 1st column: ???; 8th row, 4th column: "2"?).

A combination of these distortions tends to cause different digits to look similar (which would also confuse human readers). Clearly, this is not a simple problem. :)

How do we solve this?

If we were to ‘read’ a digit from a known set of fonts, this would be easy. We can just use a template matching algorithm and select the best fit. The less varied the fonts, the easier the process. Since MNIST is handwritten, we have an infinite variety of digits, although it is possible to create an ‘average’ template for each digit.

Test 1: 'Average' pixels

Let us explore what an average digit would look. We note however that an average template will miss many of the distorted digits described above. To determine the 'average', we can simply get the mean of the pixel values for each of the 9 digits.

In [3]:
# display 'average' digits
fig=plt.figure(figsize=(10,3))
for i in range(10):
    train_ave=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
    ax=fig.add_subplot(2,5,i+1)
    ax.set_axis_off()
    a=np.copy(train_ave)
    a=np.reshape(a,(28,28))
    ax.imshow(a, cmap='gray', interpolation='nearest', clim=(0,255))
plt.show()

The 'average' representation above is not exactly a good approximation of an average number template. It is dominated by the manner of writing prevalent in the training set. As an average, it works, and maybe it is good enough to capture the most common writing style. But this can miss a perfectly distinct representation of a digit.

For instance, if few people write "1" vertically, or leaning to the left, the 'average' will suppress those forms because most people write right-leaning "1"s. A left-leaning "1" is practically wiped out (very close to black) since its average is muted down (too far below the maximum 255 white level).

One could argue that a vertical or left-leaning "1" is clearly a "1" and is even more distinguishable as a "1" than a right-leaning "1". That is, a right-leaning "1" could be confused with a normal "7" whereas a left-leaning "1" has no other possible interpretation. An average template the way we calculated above would miss this.

We should try another method that picks up these other forms.

NOTE: A "hot" colormap is a better visualization to show contrast (white-orange-red) instead of white-gray-black), as shown below. So while a grayscale image is a better approximation of how "hard" it would be for an algorithm to differentiate between similar pixels, we will revert to the "hot" color range for subsequent plots to better highlight differences. As the original text was written with grayscale images, any mention of "gray" translates to a shade between red and orange, etc.

In [4]:
# display 'average' digits
fig=plt.figure(figsize=(10,3))
for i in range(10):
    train_ave=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
    ax=fig.add_subplot(2,5,i+1)
    ax.set_axis_off()
    a=np.copy(train_ave)
    a=np.reshape(a,(28,28))
    ax.imshow(a, cmap='hot', interpolation='nearest', clim=(0,255))
plt.show()


Test 2: 'Max' pixels

To avoid the averaging effect, we can highlight all pixels that each digit triggers at least once. That is, we cover all possible examples. We can then use the maximum value assigned to that pixel from all the training examples. We could mentally simulate the output. But since we already have the base code, we might as well just show it.

In [5]:
# display 'max' digits
fig=plt.figure(figsize=(10,3))
for i in range(10):
    train_max=np.max(train[np.where(temp_train[:,0]==i)],axis=0)
    ax=fig.add_subplot(2,5,i+1)
    ax.set_axis_off()
    a=np.copy(train_max)
    a=np.reshape(a,(28,28))
    ax.imshow(a, cmap='hot', interpolation='nearest', clim=(0,255))
plt.show()

This looks worse. It did capture what we wanted, which was to cover all possible examples presented in the training data. But it also blotted similar pixels across digits, leaving us with very little discriminatory power, except for "1".

This is in a way expected. The shapes and orientations of the different digits are so varied that if we superimpose all training examples representingt the same digits, we would end up with a blotted figure. Only "1" would end up with a less blotty super-template, because "1"s are either oriented left, upright, or right, so the 'waist' of the "1" is not as spread out.

Maybe we could go back to the 'average' and find a better refinement....

Test 3: 'Threshold mean' pixels

To avoid the biased average pixels, we could loosen the 'average' pixels to display. We could display pixels whose mean is greater than a threshold, for instance half of the maximum 255 value.

Note that we have no guarantee that the mean of a pixel that appears 'fully' activated in all training examples will be 255 (the maximum, a white pixel). This happens because the encoding (scanning from handwritten black ink images to the MNIST images) might have captured what we might see as 'white' as a very light gray pixel instead (thus, slightly less than 255).

In [6]:
# display 'threshold mean' digits
pixel_threshold=0.50
fig=plt.figure(figsize=(10,3))
for i in range(10):
    pixel_max=np.max(train[np.where(temp_train[:,0]==i)])
    print('digit:', i, ' pixel_max:',pixel_max)
    train_mean=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
    train_threshold=(train_mean>=pixel_threshold*pixel_max)*train_mean
    ax=fig.add_subplot(2,5,i+1)
    ax.set_axis_off()
    a=np.copy(train_threshold)
    a=np.reshape(a,(28,28))
    ax.imshow(a, cmap='hot', interpolation='nearest', clim=(0,pixel_max))
plt.show()
digit: 0  pixel_max: 255
digit: 1  pixel_max: 255
digit: 2  pixel_max: 255
digit: 3  pixel_max: 255
digit: 4  pixel_max: 255
digit: 5  pixel_max: 255
digit: 6  pixel_max: 255
digit: 7  pixel_max: 255
digit: 8  pixel_max: 255
digit: 9  pixel_max: 255

Notice the predominantly gray color of the remaining pixels. This is due to the color range used by the heatmap, which is from 0 to 255. The mean of each pixel would tend to be lower than 255, so the pixels tend to appear grayish. This makes it a little hard to detect contrast between pixels. Notice further that the more intense pixels on "1" vs other digits. This maybe informative, or misleading, depending on how we interpret the pixel intensities.

We could change the heatmap to increase contrast either by normalizing the pixel values over the range 0-255, which would stretch the gaps between pixel values. Alternatively, we could simply change the heatmap color range per digit without changing the original pixel values. This will create a normalized pixels effect.

In [7]:
# display 'threshold mean' digits
pixel_threshold=0.50
fig=plt.figure(figsize=(10,3))
for i in range(10):
    pixel_max=np.max(train[np.where(temp_train[:,0]==i)])
    print('digit:', i, ' pixel_max:',pixel_max)
    train_mean=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
    train_threshold=(train_mean>=pixel_threshold*pixel_max)*train_mean
    ax=fig.add_subplot(2,5,i+1)
    ax.set_axis_off()
    a=np.copy(train_threshold)
    a=np.reshape(a,(28,28))
    ax.imshow(a, cmap='hot', interpolation='nearest', clim=(min(train_threshold),max(train_threshold)))
plt.show()
digit: 0  pixel_max: 255
digit: 1  pixel_max: 255
digit: 2  pixel_max: 255
digit: 3  pixel_max: 255
digit: 4  pixel_max: 255
digit: 5  pixel_max: 255
digit: 6  pixel_max: 255
digit: 7  pixel_max: 255
digit: 8  pixel_max: 255
digit: 9  pixel_max: 255


Test 4: 'Threshold-mean, maximized' pixels

Since handwriting varies across persons, what if we presume that the gray pixels (which are above some threshold pixel intensity) are integral part of a digit? They are after all classified as the same digit as the pixels that are fully white. Perhaps they were gray because they are not used as often as the more common pixels (in our sample set), but are nonetheless indicative of the same digit.

If we follow this thinking, we could simply force the above-threshold gray pixels to be 255 (white). There is obviously no need to artificially constrain the colormap range (pixels will be either 0 or 255), but let us be consistent in our plotting code.

In [8]:
# display 'threshold mean' digits
pixel_threshold=0.50
fig=plt.figure(figsize=(10,3))
for i in range(10):
    pixel_max=np.max(train[np.where(temp_train[:,0]==i)])
    train_mean=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
    train_threshold=(train_mean>=pixel_threshold*pixel_max)*pixel_max
    ax=fig.add_subplot(2,5,i+1)
    ax.set_axis_off()
    a=np.copy(train_threshold)
    a=np.reshape(a,(28,28))
    ax.imshow(a, cmap='hot', interpolation='nearest', clim=(min(train_threshold),max(train_threshold)))
plt.show()

That is sort of an improvement over the original whitewash in Test #2! We can see a larger spread of maximum pixel values. But maybe the 50% threshold is too high since we appear to have truncated many pixels. If we think about it, this is likely. Maybe some common variations barely make 20% of the entire training set. Let us then test a few thresholds from 2.5% to 50%.

In [9]:
# display 'threshold mean' digits
pixel_threshold=(0.025, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50)
fig=plt.figure(figsize=(10,15))
for i in range(10):
    for j,row in enumerate(pixel_threshold):
        pixel_max=np.max(train[np.where(temp_train[:,0]==i)])
        train_mean=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
        train_threshold=(train_mean>=pixel_threshold[j]*pixel_max)*pixel_max        
        ax=fig.add_subplot(10,len(pixel_threshold),len(pixel_threshold)*i+j+1)
        ax.set_axis_off()
        a=np.copy(train_threshold)
        a=np.reshape(a,(28,28))
        ax.imshow(a, cmap='hot', interpolation='nearest', clim=(min(train_threshold),max(train_threshold)))
plt.show()

It looks good, but not much better. This assumes pixels that are not 'core' white pixels are not relevant. This might be a very harsh judgment. Perhaps we could re-consider the pixels below the threshold.

Test 4: 'Threshold-mean, maximized plus below-threshold noise' pixels

Putting back the discarded pixels (those below threshold) might even be more informative. But we should still assume that any pixel above a threshold should be considered an 'integral' part of the form of the digit (so it will be assigned max value like in Test #3), while we can add back the pixels below the threshold. Let us explore this below.

In [10]:
# display 'threshold mean' digits
pixel_threshold=(0.025, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50)
fig=plt.figure(figsize=(10,15))
for i in range(10):
    for j,row in enumerate(pixel_threshold):
        pixel_max=np.max(train[np.where(temp_train[:,0]==i)])
        train_mean=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
        train_threshold=(train_mean<pixel_threshold[j]*pixel_max)*train_mean
        train_threshold[np.where(train_mean>=pixel_threshold[j]*pixel_max)]=pixel_max
        ax=fig.add_subplot(10,len(pixel_threshold),len(pixel_threshold)*i+j+1)
        ax.set_axis_off()
        a=np.copy(train_threshold)
        a=np.reshape(a,(28,28))
        ax.imshow(a, cmap='hot', interpolation='nearest', clim=(min(train_threshold),max(train_threshold)))
plt.show()

This looks a lot better! The series of pictures clearly show the 'important' pixels that characterize each digit, and where the variance occurs as the threshold is raised. When the threshold is low, there is not many below-threshold pixels, so the pixels are mostly all white. On higher thresholds, the gray pixels start to appear.

Each column of the above is exactly the same as the entries from Test #1, except the above-threshold pixels are highlighted as white (255) pixels. This similarity is most apparent in the last column, where the threshold used is 0.50, creating the least number of white pixels. A threshold of 1.0 will cause all pixels to be the mean values, as happens in Test #1.

Test 5: Softened version of Test 4

The sharp edges and sudden drop-off in pixel value is due to our method of raising the pixel values that met the cutoff to 255. We could try the opposite: bring down the pixels above the threshold to the value of the threshold, without adjusting the values below the threshold.

In [ ]:
# display 'threshold mean' digits
pixel_threshold=(0.025, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50)
fig=plt.figure(figsize=(10,15))
for i in range(10):
    for j,row in enumerate(pixel_threshold):
        pixel_max=np.max(train[np.where(temp_train[:,0]==i)])
        train_mean=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
        train_threshold=(train_mean<pixel_threshold[j]*pixel_max)*train_mean
        train_threshold[np.where(train_mean>=pixel_threshold[j]*pixel_max)]=pixel_threshold[j]*pixel_max
        ax=fig.add_subplot(10,len(pixel_threshold),len(pixel_threshold)*i+j+1)
        ax.set_axis_off()
        a=np.copy(train_threshold)
        a=np.reshape(a,(28,28))
        ax.imshow(a, cmap='hot', interpolation='nearest', clim=(min(train_threshold),max(train_threshold)))
plt.show()

This leads to a softening of the edges around the above-threshold pixels. Since we lowered the maximum threshold values per pixel, the heatmap colorscale range also shrinks, causing more low-valued pixels to appear as lighter gray instead of dark gray.

Test 6: Same as Test 3 but with different thresholds

Finally, let us expand Test 3 using different threshold cutoffs. In other words, we retain the threshold filter, but retain the mean values of the pixels instead of max, leaving untouched the below-threshold noise.

In [ ]:
# display 'threshold mean' digits
pixel_threshold=(0.025, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50)
fig=plt.figure(figsize=(10,15))
for i in range(10):
    for j,row in enumerate(pixel_threshold):
        pixel_max=np.max(train[np.where(temp_train[:,0]==i)])
        train_mean=np.mean(train[np.where(temp_train[:,0]==i)],axis=0)
        train_threshold=(train_mean>=pixel_threshold[j]*pixel_max)*train_mean
        ax=fig.add_subplot(10,len(pixel_threshold),len(pixel_threshold)*i+j+1)
        ax.set_axis_off()
        a=np.copy(train_threshold)
        a=np.reshape(a,(28,28))
        ax.imshow(a, cmap='hot', interpolation='nearest', clim=(min(train_threshold),max(train_threshold)))
plt.show()


Closing thoughts

The MNIST dataset is a popular benchmark against which new and variants of existing algorithms are run. It has enough training samples to allow repetitive and data-hungry training often required by newer ML techniques (e.g., deep learning convnets). It is also possible to train on a smaller subset, and see how algorithms perform with limited training (this falls into the ML field of one-shot learning). We will explore some of these algorithms, including my own twists of these algorithms, via the MNIST dataset.

The output of these tests suggest it might even be possible to create a good algorithm with a generic template matching approach, as long as we can make the prototypes flexible enough to capture atypical digit orientations. This also suggests that a hand-crafted template approach (instead of a statistically learned template) might also work well.

No comments:

Post a Comment