The Yhat Blog


machine learning, data science, engineering


Classifying handwritten digits using TensorFlow

by Greg |


Way back when in 2013, I wrote a post about digit recognition in Python. Since then a lot has changed in the Python data ecosystem. Google has given us Tensorflow, scikit-learn has matured, and "AI" is now the latest craze.

Well I decided that it's about time to freshen up the digit recognizer. It's still a really simple, yet fascinating concept: taking someone's handwriting and trying to predict what that person wrote. Good news is that I've picked up a few new tricks over the years and I'm going to share some of them with you here.

What a difference a couple years makes :)

Psst. Want to go straight to the new app? Check it out here: https://www.yhat.com/ops-demos/handwriting!

What we're doing

One of the things I wanted to change about the handwriting recognizer was to make it more "real-time." Instead of making the user draw something and then click "guess", wouldn't it be way more interesting if the app proactively tried to figure out what you were drawing?

To do this we need a way to invoke our model as the user is drawing the number. This might seem simple but it implies that we can (1) make predictions on an image fast enough to keep up with someone's handwriting and (2) we can send predictions back and forth from our application (presumably using something like Javascript).

Luckily since my first post was written, Yhat has made some big performance improvements: it can now do a few thousand predictions each second as opposed to 500-700 a few years ago.

The Data

One of the consequences of making an app that's constantly making predictions is that we need to train our model on incomplete data. Let's say you've written half of the number 5. That might look sort of like a 6, 9, or even a 1 (depending your caligraphic style). In order to handle these sorts of partial data situations we need to collect data (or in our case images) on partially completed digits.

To do this, I made a little web application that does the following:

  • Presents the visitor with a blank canvas and asks them to draw a random number
  • As the user is drawing the number, saves the image to an S3 bucket (categorized and labeled of course)
  • Once the user is finished, show repeat steps #1 and #2 with a new number

After persuading my co-workers, I wound up with just over 80K images:

14066
211005
39279
48061
59909
68451
75141
811058
99187
total84905

Phew! We've got over 80,000 numbers in various stages of being drawn. To get these images into Python I used the skimage library which provides a consistent, easy to use API on top of OpenCV, scipy, and other python-image tools.

You can see below that I'm looping through each image file and turning it into a numpy array using the get_image_data function. get_image_data takes an image file and does the following:

  • opens it and reads it into python as an array
  • converts it from color to grayscale (saves on size)
  • downsamples the image by 25x (this might seem like a lot but it doesn't make a significant difference to our classifier)
  • "flattens" it from a 2-dimensional array (20x20) into a 1-dimensional array (1x400)

And ze code...

def get_image_data(filename):
    img = io.imread(f, as_grey=True)
    return transform.downscale_local_mean(img, (25, 25)).flatten()

data = []
labels = []
for i, (label, filename) in enumerate(files):
    image = get_image_data(filename)
    data.append(image)

    classes = np.zeros(10)
    classes[label] = 1.0
    labels.append(classes)

    if i%100==0:
        print "    %d of %d (%2f)" % (i, len(files), float(i) / len(files))

Next step here is to build a classifier that can interpret them!

Classifying our images

Classifying handwritten digits is a fairly common tutorial/textbook problem for machine learning libraries. The MNIST dataset is commonly referenced and you can find it in the documentation for libraries such as scikit-learn, Tensorflow, and Keras. So lucky for us there are a lot of great starting points. Our problem really becomes adapting one of these examples to our use-case.

For this post I decided to use Tensorflow. They actually have 2 examples of building a classifier using MNIST dataset which is very convenient for us!

I'll use the basic example for now because it's less code and more straightforward.

Model Code

Tensorflow requires a lot of "low-level" understanding and configuration of the underlying graph/network. To be perfectly honest, I wouldn't recommend using it unless you're confident you know what you're doing.

Setting up our model we'll create "placeholder" variables (like x below). These placeholders aren't executed when they're instantiated. Instead evaluation is delayed until you're actually operating on your data.

We're also going to create 2 variables, W and b, which represent the "weights" and "bias" of our model. Initially we'll set them all to 0, but once we start training our model these variables will be updated to reflect the training data.

Lastly we'll define our model with tf.nn.softmax. In this case we're using the softmax algorithm but there are plenty of other options available.

import tensorflow as tf

x = tf.placeholder(tf.float32, [None, len(data[0])])
W = tf.Variable(tf.zeros([len(data[0]), 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

Apart from the model we also need to define optimizers for how training will be done.

y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

Ok almost there. Just initialize our variables session and start our tensorflow session!

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

Training

Now that everything is set :coup we can actually start training our model. Tensorflow is nice because it supports iterative training--this means you don't have to do 100% of your training in a single step or a single line (think scikit-learn). So as you acquire more data you can update your model and fine-tune your weights.

Below you can see the actual training code. We're doing 10,000 batches of training on 1000 data points each time. I've written a helper function called next_batch for randomly feeding my model data.

def next_batch(n):
    idx = np.random.randint(0, len(data), n)
    return data[idx], labels[idx]

for i in range(10000):
    batch_xs, batch_ys = next_batch(1000)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

    if i%100==0:
        print "    epoch %d" % i

Making predictions

Ok great now that we've got our model we can start making predictions! I found this very unintuitive to do in tensorflow, but alas I didn't write the API:

print sess.run(y, feed_dict={x: mnist.test.images })

I went ahead and deployed this model using ScienceOps(shameless plug) and hooked it up to the web app discusssed above. I can now generate predictions in real-time as the user is drawing a particular number. Pretty slick! My model could use a little work but it's really cool to see how it's prediction gets better as you give it more information (i.e. as you get closer to finishing your drawing).

Check out the fully functional webapp here.

Final Thoughts

For more resources on tensorflow, MNIST, or image recognition check out the resources below:



Our Products


Rodeo: a native Python editor built for doing data science on your desktop.

Download it now!

ScienceOps: deploy predictive models in production applications without IT.

Learn More

Yhat (pronounced Y-hat) provides data science solutions that let data scientists deploy and integrate predictive models into applications without IT or custom coding.