Day 33 - Nearest-neighbor classification

Today we will look at a classification method which is more flexible than trees and logistic regression. In a sense, it is the most flexible classifier possible. To classify an observation, all you do is find the most similar example in the training set and return the class of that example. This is called 1-nearest-neighbor classification, or 1-nn. More generally, we can take the k most similar examples and return the majority vote. This is k-nearest-neighbor classification. The relative frequency of classes among the neighbors can be used to get a crude measure of class probability.

Simple example

Consider the following dataset:
        x1         x2   y
 3.2283839  3.7587273 Yes
-1.2221922 -2.5341930  No
 2.6448352  2.0397598 Yes
-2.3814279 -0.9399438  No
-0.2223163 -1.1950309  No
...
The function knn.model will make a knn classifier (k=1 by default):
nn <- knn.model(y~.,frame)
This function actually doesn't do anything but put the training set into an object. All of the work is done at prediction time by the function predict:
> predict(nn,data.frame(x1=0,x2=-1))
[1] No
To classify a future (x1,x2) observation, it finds the training example (z1,z2) with smallest distance. The distance measure is the familiar Euclidean distance:
d(x,z) = sqrt((x1 - z1)^2 + (x2 - z2)^2)
To see the model visually, use cplot(nn):

The boundary between the two classes is shown in green.

The 1-nn boundary is always a polygon. To understand this, note that each example has a small region around it where it is the nearest example. This region is a polygon, and the classification decision, whatever it is, is constant over the region. Some regions are labeled "Yes" and others "No". The class boundary is the boundary between the "Yes" regions and "No" regions, so it must be a polygon, or collection of polygons. This is more general than a tree classifier, which uses rectangles.

When we use k > 1, the boundary is smoother:

nn <- knn.model(y~.,frame,k=3)
cplot(nn)

To choose the best k, use cross-validation via the function best.k.knn. This will be shown in the diabetes example below.

Handwritten digit recognition

A challenging classification problem, which is crucial to the operation of pen-based handheld computers, is recognizing handwritten letters. As the pen moves over the tablet, a sequence of (x,y) points is recorded. To classify the sequence, it helps to first reduce it to a small number of landmarks. In the following dataset, each pen stroke was reduced to 8 landmark points which equally divide the distance that the pen traveled. Note that this is not the same as an equal division in time, since the pen may change speed during a stroke. Each stroke is thus reduced to 16 numbers, which form a row of the data frame. Here is the result for handwritten digits:
x1  y1 x2 y2 x3 y3  x4  y4 x5 y5  x6 y6  x7  y7 x8 y8 digit
47 100 27 81 57 37  26   0  0 23  56 53 100  90 40 98     8
48  96 62 65 88 27  21   0 21 33  79 67 100 100  0 85     8
 0  57 31 68 72 90 100 100 76 75  50 51  28  25 16  0     1
 0 100  7 92  5 68  19  45 86 34 100 45  74  23 67  0     4
...
Here is what they look like in 2D:

For the full dataset and a description of how it was gathered, see the pendigits directory at the UCI data repository.

To keep the example simple, we will only consider the two-class problem of classifying a digit as "8" or "not 8". We break the dataset into train and test (we're going to use a small 5% training set this time):

x <- rsplit(x8,0.05); names(x) <- c("train","test")
Let's start with logistic regression:
> fit <- glm(digit8~.,x$train,family=binomial)
> misclass(fit)/nrow(x$train)
[1] 0.02181818
> misclass(fit,x$test)/nrow(x$test)
[1] 0.03208198
Can we do better with a tree?
> tr <- tree(digit8~.,x$train)
> misclass(tr)/nrow(x$train)
[1] 0.009090909
> misclass(tr,x$test)/nrow(x$test)
[1] 0.02777246
The tree can fit the training set well, but doesn't do much better on the test set. Now let's try nearest-neighbor. Intuitively, this problem seems well-suited to nearest-neighbor because Here is the result:
> nn <- knn.model(digit8~.,x$train)
> misclass(nn)/nrow(x$train)
[1] 0
> misclass(nn,x$test)/nrow(x$test)
[1] 0.002298410
On the training set, 1-nn is of course perfect, because it has memorized the training set. But this doesn't mean 1-nn has overfit; quite the contrary. Nearest neighbor makes one-tenth as many errors on the test set as the other two classifiers. The distance measure that 1-nn used was
d(a,b) = sqrt((a.x1 - b.x1)^2 + (a.y1 - b.y1)^2 + (a.x2 - b.x2)^2 + ...)
This is Euclidean distance in the 16-dimensional space defined by (x1,y1,x2,...,y8).

Computation

In the above example, the training set had 550 examples and the test set had 10442. To classify a single stroke, 550 comparisons need to be made. To classify the entire test set, 550*10442 = 5.7 million comparisons are needed. Nearest-neighbor classification clearly has a high computational cost. Is there any way to reduce it?

One approach is to use geometry to skip irrelevant computations. For example, suppose we are looking for the nearest neighbor to x. We measure the distance to example z and find it is 2. This tells us that the nearest neighbor, whatever it is, cannot be farther than 2. So there is no need to compute the distance to any examples which are outside of a sphere of radius 2 centered at x. With certain data structures, it is possible to find out which points are outside of this sphere, without computing the distance to each one of them. After the next comparison, the sphere gets smaller, and so on. This technique can substantially reduce the number of comparisons, and it is widely used.

Another approach is to prune away some of the examples. As discussed above, nearest-neighbor is largely unaffected by removing examples. Only the boundary examples are important. Examples can be pruned as follows: for each point, see if its nearest neighbor has the same class. If so, prune one of them. Repeat. For datasets with simple boundaries, this can substantially reduce the number of examples; sometimes to two or three. Note that you can combine this with the geometric method for additional speedup. With these tricks, nearest-neighbor can be quite practical.

Diabetes classification

Now consider the Pima Indians diabetes dataset from day31. We previously identified glu and ped as the most important predictors. Let's try to classify using a 1-nn model:
nn <- knn.model(type~glu+ped,x)
cplot(nn)

The classes overlap, so we get lots of isolated regions where the predicted class changes. To smooth away these regions, use k > 1. The function best.k.knn uses cross-validation to find the k which minimizes the expected future error rate:
> nn2 <- best.k.knn(nn,1:20)
best k is 13 

The second argument to best.k.knn is the range of k values to consider. By default it is 1:20. The cross-validation used is leave-one-out, which means the result should always be the same. The cross-validation curve suggests a fairly high value of k, which means that there is a lot of overlap between the classes. The returned object nn2 has k=13:
cplot(nn2)

The boundary is substantially smoother. It also reveals a subtle nonlinearity that was not captured by the logistic regression model from day31. Here are the results on test data:
> misclass(nn,Pima.te)
[1] 93
> misclass(nn2,Pima.te)
[1] 81
> misclass(fit,Pima.te)
[1] 68
Using k=13 helped, but logistic regression still wins. Why? From looking at the test data, the boundary is not quite linear. But by assuming linearity, logistic regression was able to get an estimate of the boundary which was closer to the truth than knn could, for the same amount of data. Once again, the more constrained classifier finds an advantage.

Standardizing predictors

The distance measure used above requires some explanation. Notice that glu ranges 60-200 while ped ranges 0-2.5. If we used Euclidean distance as above, glu would dominate the distances and the classification would only be based on glu. When predictors are measured on different scales, such as these, it is a good idea to standardize them first so that they have standard deviation 1. To do this, divide each predictor value by its standard deviation on the training set. In practice, the standard deviations are stored and used to rescale the distances as follows:
d(x,z) = sqrt((x1 - z1)^2/sd1^2 + (x2 - z2)^2/sd2^2)
This is done automatically by knn.model.

Properties of nearest-neighbor

Nearest-neighbor is unique compared to the other classification methods discussed in this course. First, nearest-neighbor doesn't simplify the dataset at all; it provides no concise model of the relationship between the predictors and the response. Consequently, it isn't as useful for visualization and knowledge discovery as the other methods. It is simply a black box which gives predictions.

Second, nearest-neighbor is very robust to noise in the dataset; even the removal of random points has little effect on its performance. Compare this to logistic regression and classification trees, which can change substantially from the alteration of a single training example. On the other hand, nearest-neighbor is especially sensitive to irrelevant predictors, and doesn't perform well when some predictors are more important than others. This is the opposite of logistic regression and classification trees, which deliberately exclude irrelevant predictors. This problem can only be fixed by designing your own distance measure which downweights certain predictors.

Nearest-neighbor is especially useful for domains where a distance measure between examples is straightfoward to define but a model relating the predictors to the response is not.

Code

To use knn.model, you will need the latest version of regress.r or regress.s. In R, you will have to also download the package VR via the Packages menu (see day19). This package is loaded by knn.model.

Functions introduced in this lecture:

References

Using a data structure to facilitate nearest-neighbor searches is sometimes called indexing. Check out this paper and its references: "Two Algorithms for Nearest-Neighbor Search in High Dimensions".


Tom Minka
Last modified: Wed Nov 28 13:49:41 Eastern Standard Time 2001