MNIST digit prediction

The data

Today we will start looking at the MNIST data set. This is a set of images of handwritten digits. The learning goal is to predict what digit the number represents (0-9). This is a canonical dataset for basic image processing and was probably the first dataset to which a large community of researchers used as a universal benchmark for computer vision. Today I am using just a subset (and slightly downsampled) of the full dataset. It is the example data given in the Elements of Statistical Learning. We will explore the larger dataset over the next two weeks as we start introducing neural networks.

I’ll read in the training and testing datasets; the split into train and test should be the same as that used by most other sources:

set.seed(1)
train <- read.csv("data/mnist_train.psv", sep="|", as.is=TRUE, header=FALSE)
test <- read.csv("data/mnist_test.psv", sep="|", as.is=TRUE, header=FALSE)

Looking at the data, we see that it has 257 columns: a first column giving the true digit class and the others giving the pixel intensity (in a scale from -1 to 1) of the 16x16 pixel image.

dim(train)
## [1] 7291  257
train[1:10,1:10]
##    V1 V2 V3 V4     V5     V6     V7     V8     V9    V10
## 1   6 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.631  0.862
## 2   5 -1 -1 -1 -0.813 -0.671 -0.809 -0.887 -0.671 -0.853
## 3   4 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -1.000 -1.000
## 4   7 -1 -1 -1 -1.000 -1.000 -0.273  0.684  0.960  0.450
## 5   3 -1 -1 -1 -1.000 -1.000 -0.928 -0.204  0.751  0.466
## 6   6 -1 -1 -1 -1.000 -1.000 -0.397  0.983 -0.535 -1.000
## 7   3 -1 -1 -1 -0.830  0.442  1.000  1.000  0.479 -0.328
## 8   1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000  0.510 -0.213
## 9   0 -1 -1 -1 -1.000 -1.000 -0.454  0.879 -0.745 -1.000
## 10  1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.909  0.801

We can plot what the image actually looks like in R using the rasterImage function:

y <- matrix(as.matrix(train[3400,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5

plot(0,0)
rasterImage(y,-1,-1,1,1)

With a minimal amount of work, we can build a much better visualization of what these digits actually look like. Here is a grid of 35 observations. How hard to you think it will be to predict the correct classes?

iset <- sample(1:nrow(train),5*7)
par(mar=c(0,0,0,0))
par(mfrow=c(5,7))
for (j in iset) {
  y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE)
  y <- 1 - (y + 1)*0.5

  plot(0,0,xlab="",ylab="",axes=FALSE)
  rasterImage(y,-1,-1,1,1)
  box()
  text(-0.8,-0.7, train[j,1], cex=3, col="red")
}