DEV Community is a community of 620,183 amazing developers

We're a place where coders share, stay up-to-date and grow their careers.

K-MEANS CLUSTERING

Ruthvik Raja M.V ・4 min read

The dataset for this algorithm can be found from the following link: https://github.com/ruthvikraja/MNIST
(The size of each image in the above dataset is 28 X 28)

The goal is to cluster the input images and find the top 10 images which are close to each cluster from its centroid......

K-Means clustering is a technique that is used to partition N observations into K clusters(K<=N) in which each observation belongs to the cluster with the nearest mean. This is one of the simplest unsupervised Algorithm and uses a distance metric for finding the closest centroid.

Algorithm WorkFlow:
Step 1: Randomly choose K points as the cluster centres.
Step 2: Compute the distances and group the closest ones.
Step 3: Compute the new mean and repeat Step 2.
Step 4: If change is negligible (or) if there is no reassignment of observations to other clusters (or) if any stopping criteria is met then the process terminates.

Thereby, the following is the python code to implement K-Means Clustering:-

``````import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import cv2

## Reading a numpy file:
# This can be done using load function that is present in numpy library

# Since k means clustering is an unsupervised learning the x_train and x_test i.e the independent variables can be concatenated to fit the model
# and also y_train and y_test can be concatenated. This can be done using the concatenate function that is present in numpy library by sending input arrays as a tuple:

x=np.concatenate((x_train,x_test))
y=np.concatenate((y_train,y_test))

# Now let us define K-Means clustering where the number of clusters is 10 and the number of iterations is 20

# The below K-Means class accepts the number of clusters that has to be formed, several other optional parameters and the input array values for fitting into the model:

kmeans=KMeans(n_clusters=10, max_iter=20, n_jobs=-1) # random_state=0 means everytime the same set of random centroid values are taken, n_jobs=-1 means using all the processors that are present in the local machine
# Here the stopping criteria would be the maximum number of iterations for a single run

kmeans.fit(x) # Fitting the input images to predict the clusters

x_predict=kmeans.predict(x)
# (or) fit_predict() can also be used
# x_predict=kmeans.fit_predict(x)

kmeans.labels_ # This will prints all the predicted labels

centers=kmeans.cluster_centers_ # Centroid of each cluster is stored

# Computing 10 nearest images for each cluster:
for c in range(0,10):
l=[]
for i in x_predict:
if i==c:
l.append(True)
else:
l.append(False)

cluster6= x[l] # Separating the data points with cluster label c

# Now let us see some random images the cluster with label c has clustered:

#img=np.reshape(x[12],(28,28))
#plt.imshow(img,cmap="gray")

# Now let us compute the distances between the centroid of each cluster and its corresponding data points to find the top 10 data points that is closer to the centroid.

l6=[]
for i in cluster6:
l6.append(np.linalg.norm(centers[c] - i)) # This will compute the distance between each data point and the centroid of its corresponding cluster

l6=np.array(l6)
l6_index=np.argsort(l6) # This will compute the indices of each distance value in an ascending order

l6_top10=l6_index[0:10] # This will gives the top 10 indices values which are closer to the centroid

l_images=[]
for j in l6_top10:
l_images.append(cluster6[j])

l_images=np.array(l_images) # These are the final top 10 images that are closer to the centroid

#img=np.reshape(l_images[0],(28,28))
#plt.imshow(img,cmap="gray")

# Let us display all the 10 images:

fig=plt.figure(figsize=(10, 7)) # Initially we have to create a figure object, then we can add subplots in it

rows=5 # Defining 5 rows and 2 columns to accomodate 10 images
columns=2

fig.add_subplot(rows, columns, 1) # Here the number 1 denotes the position of our image in the figure
plt.imshow(np.reshape(l_images[0],(28,28)),cmap="gray")
plt.axis('off') #This will set axis values = off

plt.imshow(np.reshape(l_images[1],(28,28)),cmap="gray")
plt.axis('off')

plt.imshow(np.reshape(l_images[2],(28,28)),cmap="gray")
plt.axis('off')

plt.imshow(np.reshape(l_images[3],(28,28)),cmap="gray")
plt.axis('off')

plt.imshow(np.reshape(l_images[4],(28,28)),cmap="gray")
plt.axis('off')

plt.imshow(np.reshape(l_images[5],(28,28)),cmap="gray")
plt.axis('off')

plt.imshow(np.reshape(l_images[6],(28,28)),cmap="gray")
plt.axis('off')

plt.imshow(np.reshape(l_images[7],(28,28)),cmap="gray")
plt.axis('off')