DEV Community

loading...
Cover image for Algorithm explained: K-means clustering📈 with PHP🐘

Algorithm explained: K-means clustering📈 with PHP🐘

thormeier profile image Pascal Thormeier ・4 min read

Part 5 of Algorithms explained! Every few weeks I write about an algorithm and explain and implement it!
Is there an algorithm you always wanted to know about? Leave a comment!

K-means clustering is amazing: This machine learning algorithm finds structures in unstructured data by trying to find clusters of data points. In this post I will explain how the algorithm works and implement it from scratch in PHP.

First up: Something to have a look at

Let's generate some data first, so we know what we'll be talking about.

<?php

declare(strict_types=1);

/**
 * @param float $lowerBoundX
 * @param float $lowerBoundY
 * @param float $upperBoundX
 * @param float $upperBoundY
 * @param int $numberOfPoints
 * @return array
 */
function generateDataPoints(
    float $lowerBoundX,
    float $lowerBoundY,
    float $upperBoundX,
    float $upperBoundY,
    int $numberOfPoints
): array {
    $precision = 1000;

    $lowerBoundX = (int) round($lowerBoundX * $precision);
    $lowerBoundY = (int) round($lowerBoundY * $precision);
    $upperBoundX = (int) round($upperBoundX * $precision);
    $upperBoundY = (int) round($upperBoundY * $precision);

    $points = [];

    for ($i = 0; $i < $numberOfPoints; $i++) {
        $points[] = [
            mt_rand($lowerBoundX, $upperBoundX) / $precision,
            mt_rand($lowerBoundY, $upperBoundY) / $precision,
        ];
    }

    return $points;
}

$dataPoints = [
    ...generateDataPoints(0.0, 0.0, 1.0, 1.0, 10),
    ...generateDataPoints(2.0, 2.0, 3.0, 3.0, 10),
];
Enter fullscreen mode Exit fullscreen mode

I'm going to roll with these points from now on, to make results more predictable:

$dataPoints = [
    [0.567, 0.7], [0.259, 0.58], [0.89, 0.785], [0.447, 0.498], [0.254, 0.311],
    [0.741, 0.138], [0.088, 0.371], [0.146, 0.12], [0.022, 0.202], [0.111, 0.284],
    [2.45, 2.829], [2.101, 2.728], [2.018, 2.813], [2.498, 2.929], [2.613, 2.799],
    [2.663, 2.435], [2.757, 2.314], [2.571, 2.457], [2.086, 2.804], [2.636, 2.785],
];
Enter fullscreen mode Exit fullscreen mode

They look like this:

Desmos showing the above points

There we go: Two nicely separated clusters. This visualization already gives away where the clusters are, sure, but the machine doesn't know that yet. So I'm going to teach it.

How the algorithm works - step by step overview

The k-means algorithm tries to find the centers of potential clusters

So, there's few steps:

  1. Initialize some amount of cluster centers, called centroids
  2. Assign each point to the closest centroid
  3. Move centroids to average of assigned points
  4. Repeat 2 and 3 until there's hardly any or no movement anymore

The first step is pretty straight forward. In the case of my generated data above, I'd introduce two cluster centroids.

But are there really two clusters? The top right corner of the data suggests even more clusters.

Wait a sec, how many clusters are there, actually?

Well, that depends on what you're doing. Usually you more or less know how many clusters you expect in your data. There's ways to find the number of clusters dynamically, but they usually have some kind of trade off for that. For learning purposes, we'll assume we know the number of clusters.

Ok, back to actually finding the clusters.

Let's initialize some random centroids somewhere:

$centroids = generateDataPoints(0.0, 0.0, 3.0, 3.0, 1);

var_dump($centroids);

/*
array(2) {
  [0] =>
  array(2) {
    [0] =>
    double(2.4393)
    [1] =>
    double(2.6893)
  }
  [1] =>
  array(2) {
    [0] =>
    double(0.3525)
    [1] =>
    double(0.3989)
  }
}
*/
Enter fullscreen mode Exit fullscreen mode

I'll again use those mentioned above to actually make things more traceable. Let's see, what they look like with Demos:

Demos showing the data points, but this time with the cluster centers/centroids

They're in the top/top left of the whole data set, marked as little crosses. Next up, I need a function to measure the distance between two points. The Euclidean distance will do:

/**
 * @param array $p1
 * @param array $p2
 * @return float
 */
function getDistance(array $p1, array $p2): float {
    return sqrt(($p2[0] - $p1[0]) ** 2 + ($p2[1] - $p1[1]) ** 2);
}
Enter fullscreen mode Exit fullscreen mode

Awesome! I can now use this function to create another one which finds the nearest cluster centroids:

/**
 * @param array $p
 * @param array $centroids
 * @return int
 */
function getNearestCentroidIndex(array $p, array $centroids): int {
    $centroids = array_map(function(array $centroid) use ($p) {
        return getDistance($p, $centroid);
    }, $centroids);

    return array_search(min($centroids), $centroids);
}
Enter fullscreen mode Exit fullscreen mode

Next, I need a function that returns an average point over a set of points:

/**
 * @param array $points
 * @return array
 */
function getAveragePoint(array $points): array {
    $pointsCount = count($points);
    if ($pointsCount === 0) {
        return [0, 0];
    }

    return [
        array_sum(array_column($points, 0)) / $pointsCount,
        array_sum(array_column($points, 1)) / $pointsCount,
    ];
}
Enter fullscreen mode Exit fullscreen mode

Almost there. I can use these functions now to implement the movement of the centroids:

/**
 * @param array $centroids
 * @param array $dataPoints
 * @return array
 */
function moveCentroids(array $centroids, array $dataPoints): array {
    $nearestCentroidsMap = array_map(function (array $point) use ($centroids): array {
        return [
            ...$point,
            getNearestCentroidIndex($point, $centroids)
        ];
    }, $dataPoints);

    $newCentroids = [];

    foreach ($centroids as $key => $value) {
        $newCentroids[$key] = getAveragePoint(array_filter($nearestCentroidsMap, function (array $point) use ($key) {
            return $point[2] === $key;
        }));
    }

    return $newCentroids;
}
Enter fullscreen mode Exit fullscreen mode

And that's it. Now for the looping. I first create a new set of centroids. Then I measure the average moving distance between the previous and the new centroids. If the movement was still too large, I do another iteration, until there's next to no movement anymore.

do {
    $newCentroids = moveCentroids($centroids, $dataPoints);
    $movedDistances = array_map(function ($a, $b) {
        return getDistance($a, $b);
    }, $centroids, $newCentroids);

    $averageDistanceTravelled = array_sum($movedDistances) / count($movedDistances);

    $centroids = $newCentroids;
} while ($averageDistanceTravelled > 0.0001);
Enter fullscreen mode Exit fullscreen mode

Ok, time to run that thing and see if it works:

var_dump($newCentroids);

/*
array(2) {
  [0] =>
  array(2) {
    [0] =>
    double(2.4393)
    [1] =>
    double(2.6893)
  }
  [1] =>
  array(2) {
    [0] =>
    double(0.3525)
    [1] =>
    double(0.3989)
  }
}
*/
Enter fullscreen mode Exit fullscreen mode

The numbers look good so far. Let's plot them to see if they really fit:

The centroids, now showing in the middle of the clusters

Awesome, works! This only took two iterations, because the data is pretty straight forward, but more data that is a lot less structured might need a few more iterations.

That's it, thanks for reading!


I hope you enjoyed reading this article as much as I enjoyed writing it! If so, leave a ❤️ or a 🦄! I write tech articles in my free time and like to drink coffee every once in a while.

If you want to support my efforts, buy me a coffee or follow me on Twitter 🐦!

Buy me a coffee button

Discussion (2)

pic
Editor guide
Collapse
kshamaa profile image
Collapse
thormeier profile image
Pascal Thormeier Author

I didn't know this toolkit, thank you for linking it!