Skip to main content

K-nearest neighbor

kNN, or the k-nearest neighbor algorithm, is a machine learning algorithm that uses proximity to compare one data point with a set of data it was trained on and has memorized to make predictions. This instance-based learning affords kNN the 'lazy learning' denomination and enables the algorithm to perform classification or regression problems. kNN works off the assumption that similar points can be found near one another — birds of a feather flock together.

As a classification algorithm, kNN assigns a new data point to the majority set within its neighbors. As a regression algorithm, kNN makes a prediction based on the average of the values closest to the query point.

kNN is a supervised learning algorithm in which 'k' represents the number of nearest neighbors considered in the classification or regression problem, and 'NN' stands for the nearest neighbors to the number chosen for k.

How does kNN work?

The kNN algorithm works as a supervised learning algorithm, meaning it is fed training datasets it memorizes. It relies on this labeled input data to learn a function that produces an appropriate output when given new unlabeled data.

This enables the algorithm to solve classification or regression problems. While kNN's computation occurs during a query and not during a training phase, it has important data storage requirements and is therefore heavily reliant on memory.

For classification problems, the KNN algorithm will assign a class label based on a majority, meaning that it will use the label that is most frequently present around a given data point. In other words, the output of a classification problem is the mode of the nearest neighbors.

Regression problems use the mean of the nearest neighbors to predict a classification. A regression problem will produce real numbers as the query output.

4 types of computing kNN distance metrics

The key to the kNN algorithm is determining the distance between the query point and the other data points. Determining distance metrics enables decision boundaries. These boundaries create different data point regions. There are different methods used to calculate distance:

Euclidean distance is the most common distance measure, which measures a straight line between the query point and the other point being measured.

Manhattan distance is also a popular distance measure, which measures the absolute value between two points. It is represented on a grid, and often referred to as taxicab geometry — how do you travel from point A (your query point) to point B (the point being measured)?

Minkowski distance is a generalization of Euclidean and Manhattan distance metrics, which enables the creation of other distance metrics. It is calculated in a normed vector space. In the Minkowski distance, p is the parameter that defines the type of distance used in the calculation. If p=1, then the Manhattan distance is used. If p=2, then the Euclidean distance is used.

Hamming distance, also referred to as the overlap metric, is a technique used with Boolean or string vectors to identify where vectors do not match. In other words, it measures the distance between two strings of equal length. It is especially useful for error detection and error correction codes.

How to choose the best k value

To choose the best k value — the number of nearest neighbors considered — you must experiment with a few values to find the k value that generates the most accurate predictions with the fewest number of errors. Determining the best value is a balancing act:

Low k values make predictions unstable. Take this example: a query point is surrounded by 2 green dots and one red triangle. If k=1 and it happens that the point closest to the query point is one of the green dots, the algorithm will incorrectly predict a green dot as the outcome of the query. Low k values are high variance (the model fits too closely to the training data), high complexity, and low bias (the model is complex enough to fit the training data well).

High k values are noisy. A higher k value will increase the accuracy of predictions because there are more numbers of which to calculate the modes or means. However, if the k value is too high, it will likely result in low variance, low complexity, and high bias (the model is NOT complex enough to fit the training data well).

Ideally, you want to find a k value that is between high variance and high bias. It is also recommended to choose an odd number for k to avoid ties in classification analysis.

Some commonly used approaches to choose ( k ):


1. Cross-Validation

  • Perform k-fold cross-validation on your dataset.
  • Test your KNN model with different ( k ) values and measure the performance (e.g., accuracy, F1-score).
  • Choose the ( k ) value that provides the best performance on validation data.

2. Rule of Thumb

  • A commonly suggested starting point is to set ( k ) as: k=nk = \sqrt{n} where ( n ) is the total number of data points in your training dataset.
  • This value is just a heuristic and should be fine-tuned.

3. Odd ( k ) for Binary Classification

  • Use an odd ( k ) value to avoid ties when there are two classes (binary classification).

4. Avoid Overfitting and Underfitting

  • Small ( k ) (e.g., ( k = 1 )) can lead to overfitting because the model becomes too sensitive to noise.
  • Large ( k ) can lead to underfitting because it overly smoothens the decision boundary.
  • Experiment with different ( k ) values to strike a balance.

6. Domain Knowledge

  • Consider the problem you're solving and the structure of your dataset. In some cases, domain-specific knowledge might suggest a reasonable range for ( k ).

The right k value is also relative to your data set. To choose that value, you might try to find the square root of N, where N is the number of data points in the training dataset. Cross-validation tactics can also help you choose the k value best suited to your dataset.

Advantages of the kNN algorithm

The kNN algorithm is often described as the “simplest” supervised learning algorithm, which leads to its several advantages:

Simple: kNN is easy to implement because of how simple and accurate it is. As such, it is often one of the first classifiers that a data scientist will learn. Adaptable: As soon as new training samples are added to its dataset, the kNN algorithm adjusts its predictions to include the new training data.

Easily programmable: kNN requires only a few hyperparameters — a k value and a distance metric. This makes it a fairly uncomplicated algorithm.

Limitations of kNN

While the kNN algorithm is simple, it also has a set of challenges and limitations, due in part to its simplicity:

Difficult to scale: Because kNN takes up a lot of memory and data storage, it brings up the expenses associated with storage. This reliance on memory also means that the algorithm is computationally intensive, which is in turn resource-intensive.

Curse of dimensionality: This refers to a phenomenon that occurs in computer science, wherein a fixed set of training examples is challenged by an increasing number of dimensions and the inherent increase of feature values in these dimensions. In other words, the model’s training data can’t keep up with the evolving dimensionality of the hyperspace. This means that predictions become less accurate because the distance between the query point and similar points grows wider — on other dimensions.

Overfitting: The value of k, as shown earlier, will impact the algorithm’s behavior. This can happen especially when the value of k is too low. Lower values of k can overfit the data, while higher values of k will ‘smooth’ the prediction values because the algorithm averages values over a greater area.