DEV Community

ashwins-code
ashwins-code

Posted on

Machine Learning Algorithms: KNN Classifier

Hello! Welcome to this post where we talk about a popular machine learning algorithm: KNN Classifier. We'll go through what it is and write it ourselves using python.

You can find the code in this post here: https://github.com/ashwins-code/machine-learning-algorithms/blob/main/knn-classifier.py

What is a KNN Classifier and how does it work??

KNN stands for K-Nearest-Neighbours. A KNN Classifier is a common machine learning algorithm that classifies pieces of data.

Classifying data means putting that data into certain categories. An example could be classifying text data as happy, sad or neutral.

The way a KNN Classifier works is by having some initial training data which contains several data points and their respective classifications. When it is asked to classify some new data, it looks at the k closest data points in the training data (this is commonly calculated using the euclidean distance which calculates the straight line distance between 2 points).

After it has got the k closest data points, it finds the most common classification of these data points and returns that as the classification of the new data it is trying to classify.

Alt Text

Our classification problem

Here's the classification problem we'll want to solve using the KNN Classifier we're about to create.

Given an RGB value, we want to classify it as red, blue or green.

If you don't know what an RGB value is, it shows how much red, green and blue is in a colour. The max a number can be in an RGB value is 255. Here's an example:

(255, 200, 140)
Enter fullscreen mode Exit fullscreen mode

The first number is how red the colour is, the second number is how green it is and the third number is how blue the colour is. Now the following should make sense

(255, 0, 0) #Red
(0, 255, 0) #Green
(0, 0, 255) #Blue
Enter fullscreen mode Exit fullscreen mode

In the context of our problem, we will consider a value such as (195, 50, 0) to be red as it is the closest to (255, 0, 0)

The Code!

First, let's import numpy for mathematical computations

import numpy as np
Enter fullscreen mode Exit fullscreen mode

Now let's create a function which returns the mode of a given list

def get_mode(l):
    mode = ""
    max_count = 0
    count = {}

    for i in l:
        if i not in count:
            count[i] = 0
        count[i] += 1

        if count[i] > max_count:
            max_count = count[i]
            mode = i

    return mode
Enter fullscreen mode Exit fullscreen mode

Let's create our KNN Classifier class

class knn_classifier:
    def __init__(self):
        self.data_points = []
        self.classifications = []

    def add_example(self, data_point, classification):
        #Adding training data points

        #self.data_points contain the data points themseleves, self.classification contain their respective classifications
        self.data_points.append(data_point)
        self.classifications.append(classification)

    def classify(self, input, k = 3):
        #Classifies new data
        classification = sorted(self.classifications, key = lambda x: np.linalg.norm(np.subtract(input, self.data_points[self.classifications.index(x)])))[:k]
        #The above line may seem confusing. It sorts self.classification by the euclidean distance between each classification's respective data point and the input data point
        #"classification" is ultimately sliced to contain the classifications of the k closest data points

        # Returns the final classification
        return get_mode(classification)
Enter fullscreen mode Exit fullscreen mode

Finally, let's setup a new classifier and classify some values!

classifier = knn_classifier()
training_data_points = [
    [[255, 0, 0], "red"], 
    [[0, 255, 0], "green"], 
    [[0, 0, 255], "blue"],
    [[250, 5, 5], "red"],
    [[5, 250, 5], "green"],
    [[5, 5, 250], "blue"],
    [[245, 10, 10], "red"],
    [[10, 245, 10], "green"],
    [[10, 10, 245], "blue"],
]

for point in training_data_points:
    classifier.add_example(point[0], point[1])

print (classifier.classify([250, 0, 0], k = 3))
print (classifier.classify([100, 180, 50], k = 3))
print (classifier.classify([50, 50, 190], k = 3))
Enter fullscreen mode Exit fullscreen mode

Our output

red
green
blue
Enter fullscreen mode Exit fullscreen mode

It works! Here's all the code together

import numpy as np


def get_mode(l):
    mode = ""
    max_count = 0
    count = {}

    for i in l:
        if i not in count:
            count[i] = 0
        count[i] += 1

        if count[i] > max_count:
            max_count = count[i]
            mode = i

    return mode

class knn_classifier:
    def __init__(self):
        self.data_points = []
        self.classifications = []

    def add_example(self, data_point, classification):
        #Adding training data points

        #self.data_points contain the data points themseleves, self.classification contain their respective classifications
        self.data_points.append(data_point)
        self.classifications.append(classification)

    def classify(self, input, k = 3):
        #Classifies new data
        classification = sorted(self.classifications, key = lambda x: np.linalg.norm(np.subtract(input, self.data_points[self.classifications.index(x)])))[:k]
        #The above line may seem confusing. It sorts self.classification by the euclidean distance between each classification's respective data point and the input data point
        #"classification" is ultimately sliced to contain the classifications of the k closest data points

        # Returning the final classification
        return get_mode(classification)


classifier = knn_classifier()
training_data_points = [
    [[255, 0, 0], "red"], 
    [[0, 255, 0], "green"], 
    [[0, 0, 255], "blue"],
    [[250, 5, 5], "red"],
    [[5, 250, 5], "green"],
    [[5, 5, 250], "blue"],
    [[245, 10, 10], "red"],
    [[10, 245, 10], "green"],
    [[10, 10, 245], "blue"],
]

for point in training_data_points:
    classifier.add_example(point[0], point[1])

print (classifier.classify([250, 0, 0], k = 3))
print (classifier.classify([100, 180, 50], k = 3))
print (classifier.classify([50, 50, 190], k = 3))
Enter fullscreen mode Exit fullscreen mode

Latest comments (0)