DEV Community

Cover image for What is kmeans?
petercour
petercour

Posted on

What is kmeans?

kmeans is a clustering algorithm for Machine Learning purposes. A cluster is another word for group.

So what do I mean with kmeans?

Starting situation

The goal of the algorithm is to find groups in datapoints. In the context of machine learning, you measure data.

For instance, you could measure height and salary. Then you would have 2d vectors (height,salary) that you could plot on a 2d plot.

There are K groups, or formally, K clusters. There must be at least 2 groups, because 1 group doesn't need the algorithm to run.

End goal

Based on the previous input (the data points). Then the algorithm will find K groups, in this case K=3.

So now if you have a new data point, the group that's closest is the type. Which K is best? that's a bit of a searching quest.

The code

The program below uses Machine Learning and the modules sklearn, matplotlib to create the plots and run the kmeans algorithm. It creates the images shown in this article.

#!/usr/bin/python3
from sklearn.datasets import make_blobs
from matplotlib import pyplot as plt
from sklearn.cluster import KMeans

X, y = make_blobs(n_samples=200,
                  n_features=2,
                  centers=4,
                  cluster_std=1,
                  center_box=(-10.0, 10.0),
                  shuffle=True,
                  random_state=1)

plt.figure(figsize=(6, 4), dpi=144)
plt.xticks(())
plt.yticks(())
plt.scatter(X[:, 0],X[:, 1], s=20, marker='o')
plt.show()

n_cluster = 3
k_means = KMeans(n_clusters=n_cluster)
k_means.fit(X)
print("kmean: k={}, cost={}".format(n_cluster, int(k_means.score(X))))

labels = k_means.labels_
centers = k_means.cluster_centers_

markers = ['o', '^', '*']
colors = ['r', 'b', 'y']

plt.figure(figsize=(6, 4), dpi=144)
plt.xticks(())
plt.yticks(())

for c in range(n_cluster):
    cluster = X[labels == c]
    plt.scatter(cluster[:, 0], cluster[:, 1], marker=markers[c], s=20, c=colors[c])

plt.scatter(centers[:, 0], centers[:, 1], marker='o', c='white', alpha=0.9, s=300)
for i, c in enumerate(centers):
    plt.scatter(c[0], c[1], marker='$%d$' % i, s=50, c=colors[i])

plt.show()

Related links:

Top comments (0)