DEV Community

Cover image for Shapes Classifier: Sklearn tutorial but scuffed
Nathan
Nathan

Posted on • Updated on

Shapes Classifier: Sklearn tutorial but scuffed

Scikit Learn (Sklearn) is a Python machine learning library that provides accessible tools for data science. That also means you can do some pretty interesting stuff and slap "machine learning pro" on your resume in only a few lines of code. In this tutorial, I'll show how to create a simple image classifier that you can deploy with a cool web app.

If you want to see what we'll be building, check out the Github Repository.

Installing Libraries

Assuming you have Python already, you can simply use pip to install Sklearn. Ideally you'd want to use virtual environments to ensure your dependencies don't have any conflicts. My favorite virtual environment manager is pipenv because it mirrors pip so closely, but you may also want to install anaconda if you plan on continuing to make projects on machine learning and datascience.

For now, we'll just go ahead and install Sklearn and a few other dependencies globally.

>>> pip install scikit-learn pillow numpy joblib flask
Enter fullscreen mode Exit fullscreen mode
  • PIL: for image manipulation
  • Numpy: converting images to arrays
  • Joblib: saving the entire model to deploy it to an API (it's really overkill for this project, but I love how easy it is to use)
  • Flask: simple web server to host the frontend

Downloading the Dataset

I used the Four Shapes Dataset by smeschke on Kaggle, which contains a bunch of black and white images of circles, squares, stars, and triangles.

Shapes Classifier

Create a new Python file or Jupyter notebook where we can build and save the machine learning model. I called mine shapes_classifier.ipynb.

Import Libraries

from PIL import Image
import numpy as np
import os
Enter fullscreen mode Exit fullscreen mode

Import Dataset

After you unzip the dataset and import it into your project directory, it should look something like this:

shapes/
    circle/
    square/
    star/
    triangle/

shapes_classifier.ipynb
Enter fullscreen mode Exit fullscreen mode

NOTE: Make sure you include the shapes directory in .gitignore so you don't commit the entire dataset with git!

Images by themselves cannot be directly processed by Sklearn, so we have to create a function that looks in each shape directory and converts each image into a Numpy array representing pixel intensities (255 for white, 0 for black)

shapes_classifier.ipynb

def import_dataset():
    shapes = ["circle", "square", "triangle", "star"]
    X = []
    y = []

    # loop through each shape
    for shape in shapes:

        # get all shape images
        shape_files = os.listdir(f"shapes/{shape}")

        # convert each image to a numpy array
        for image in shape_files:
            X.append(np.asarray(Image.open(f"shapes/{shape}/{image}")).flatten())

            # add a corresponding label
            y.append(shape)

    return X, y

X, y = import_dataset()
Enter fullscreen mode Exit fullscreen mode

Like images, the shape labels (like "circle" or "triangle") in y must be encoded so Sklearn can process them.

from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
y = le.fit_transform(y)
Enter fullscreen mode Exit fullscreen mode

Split Dataset

After we import the dataset, we split the dataset into 2 parts: the training set and the test set. The training set will be processed by Sklearn to create a model that makes predictions, whereas the test set is used estimate the actual accuracy of the model.

test_size: refers to how big the test set should be, I usually just set it to 1/4 of the entire dataset

random_state: Sklearn splits datasets randomly but you can set a random_state to ensure reproducible results (referred to as "deterministic behavior")

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/4, random_state=42)
Enter fullscreen mode Exit fullscreen mode

Train the Model

For this project we will use the Random Forest Classifier, which is known for is flexibility and high levels of accuracy without extensive tuning. Random Forest Classifiers are basically the average of many Decision Tree Classifiers, which continually split the dataset into similar-looking groups. We can then use the final groups to classify new data. Quite frankly I don't fully understand the math to explain how the model works other than my basic intuition, so we'll just sweep the process under the rug along with the Sklearn library.

n_estimators: how many decision tree classifiers we want to use (it technically already defaults to 100)

from sklearn.ensemble import RandomForestClassifier

# create a new classifier and train it on the data
classifier = RandomForestClassifier(n_estimators=100, random_state=42)
classifier.fit(X_train, y_train)
Enter fullscreen mode Exit fullscreen mode

Making Predictions

Let's use the classifier to make predictions on the testing dataset. We can compare the results to the expected results in y_test to get an understanding of the model's accuracy.

y_pred = classifier.predict(X_test)
Enter fullscreen mode Exit fullscreen mode
from sklearn.metrics import accuracy_score

print(accuracy_score(y_test, y_pred)) # => 0.9997328346246327
Enter fullscreen mode Exit fullscreen mode

The model achieves a 99.97% accuracy! However, extremely high accuracies are indicators of overfitting, so the model might not be able to effectively classify new & unseen data. Try tuning the hyperparameters to reduce overfitting (I was too lazy to make any fixes).

You can also make a single prediction by loading an image with PIL & converting it to a Numpy array

# open an image -> convert to black and white -> resize to (200, 200)
pil_image = Image.open("circle.png").convert('1').resize((200, 200))

# convert image into 1D array of pixels
test_image = [ 255 if pixel else 0 for pixel in np.asarray(pil_image).flatten() ]

# predict what the image is
le.inverse_transform(classifier.predict([test_image])) # => circle (hopefully)
Enter fullscreen mode Exit fullscreen mode

Congratulations on creating a maching learning model! You can actually stop here if you don't want to deploy your model to a Flask app.

Saving the Model

Joblib makes it easy to save models and load them back into other Python scripts. Although there are potentially lighter or more practical libraries, I found Joblib very approachable and simple. Remember to save the label encoder as well so we can convert the model's output into a string we can understand.

from joblib import dump

dump(le, "model/shapes_label_encoder.joblib", compress=3)
dump(classifier, "model/shapes_classifier.joblib", compress=3)
Enter fullscreen mode Exit fullscreen mode

Model API

Let's make a web app that allows users to submit drawings of shapes to our model. First, create some new directories & files.

shapes/
    circle/
    square/
    star/
    triangle/

model/
    shapes_classifier.joblib
    shapes_label_encoder.joblib

shapes_classifier.ipynb

# what you need to add:
web/
    static/
        app.js
        styles.css
    templates/
        index.html

server.py
Enter fullscreen mode Exit fullscreen mode

Import Libraries

For the backend, we'll be making a very simple Flask webserver. Start by importing all of the libraries. BytesIO and base64 will be used to process requests containing an image encoded in base64 format. You don't need to install these because they should be built into python.

server.py

# web server
from flask import Flask, request, render_template

# image processing
from PIL import Image
import numpy as np
from io import BytesIO
import base64

# load the model
import joblib
Enter fullscreen mode Exit fullscreen mode

Load the Model

Joblib makes everything easy! We can directly call on the classifier and label encoder without importing Sklearn into server.py.

label_encoder = joblib.load("model/shapes_label_encoder.joblib")
classifier = joblib.load("model/shapes_classifier.joblib")
Enter fullscreen mode Exit fullscreen mode

Create Flask App

Here, we create a simple Flask app with a single route to render the web/templates/index.html

app = Flask("app", static_url_path='/', static_folder="web/static", template_folder="web/templates")

@app.route('/')
def hello_world():
    return render_template("index.html")
Enter fullscreen mode Exit fullscreen mode

Adding an upload route is only slightly more complicated. First we get the request body, which will contain a "image" field with a base64 encoded image. After loading the image with PIL, the rest is almost identical to the single prediction code we wrote earlier.

@app.route("/api/upload", methods=["POST"])
def upload():
    body = request.get_json()
    pil_image = Image.open(BytesIO(base64.b64decode(body["image"]))).convert('1').resize((200, 200))
    test_image = [ 255 if pixel else 0 for pixel in np.asarray(pil_image).flatten() ]
    prediction = label_encoder.inverse_transform(classifier.predict([ test_image ]))

    return { "prediction": prediction[0] }
Enter fullscreen mode Exit fullscreen mode

Now we can start the Flask app:

app.run(host="localhost", port=5000)
Enter fullscreen mode Exit fullscreen mode

If you are using Replit, make sure to change "localhost" to "0.0.0.0" to open a new webserver.

Frontend App

Although I think this tutorial is mainly focused on Python, I'll try to add some comments about each step. For the frontend, we'll create a simple paint-like program that has the ability to set individual pixels, fill in a convex shape (a bucket tool), and upload the drawn image to the server.

Page Structure

Feel free to copy and paste the HTML into your own editor. Be sure to import the JavaScript and CSS files so we can add functionality and styling to our app.

web/templates/index.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="utf-8" />
    <meta http-equiv="X-UA-Compatible" content="IE=edge" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />

    <title>Shapes Classifier</title>

    <!-- add styles & js (stored in web/static) -->
    <link rel="stylesheet" href="styles.css" />
    <script src="app.js" type="module"></script>
</head>
<body>
    <div class="app">
        <!-- we'll use this canvas element to draw pixels -->
        <canvas class="app__canvas"></canvas>

        <!-- all of the app options with inline SVGs as icons -->
        <div class="app__buttons">
            <button id="prediction" title="Prediction">
                <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"></circle></svg>
            </button>
            <button title="Erase Canvas">
                <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> <path d="M2.5 2v6h6M21.5 22v-6h-6"/><path d="M22 11.5A10 10 0 0 0 3.2 7.2M2 12.5a10 10 0 0 0 18.8 4.2"/></svg>
            </button>
            <button title="Pen Tool">
                <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#000000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polygon points="16 3 21 8 8 21 3 21 3 16 16 3"></polygon></svg>
            </button>
            <button title="Bucket Tool">
                <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"></circle></svg>
            </button>
            <button title="Upload Image">
                <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.2 15c.7-1.2 1-2.5.7-3.9-.6-2-2.4-3.5-4.4-3.5h-1.2c-.7-3-3.2-5.2-6.2-5.6-3-.3-5.9 1.3-7.3 4-1.2 2.5-1 6.5.5 8.8m8.7-1.6V21"/><path d="M16 16l-4-4-4 4"/></svg>
            </button>
        </div>
    </div>
</body>
</html>
Enter fullscreen mode Exit fullscreen mode

I got a bit lazy with the user interface, so I just used the title attribute to label options.

No Styles, No Script

Page Styles

I always start out by resetting the box-sizing to border-box to make CSS a bit more consistent.

web/static/styles.css

html, body {
    height: 100%;
    width: 100%;
    padding: 0;
    margin: 0;
    box-sizing: border-box;
}

*, ::before, ::after {
    box-sizing: inherit;
}
Enter fullscreen mode Exit fullscreen mode

After that, I centered the drawing canvas & the buttons on the side.

.app {
    /* used to ensure the .app__canvas & .app__buttons are the same height */
    --canvas-size: 30rem;

    height: 100vh;
    width: 100vw;

    /* the centering magic */
    display: flex;
    justify-content: center;
    align-items: center;
    overflow: hidden;
    gap: 2rem;
}

.app__canvas {
    width: var(--canvas-size);
    height: var(--canvas-size);
    border-radius: 0.5rem;
    box-shadow: 0 0.5rem 1rem #eee;
}

.app__buttons {
    height: var(--canvas-size);
    display: flex;
    flex-direction: column;
    gap: 1rem;
}
Enter fullscreen mode Exit fullscreen mode

The buttons still have their default styles, so let's change that by adding some rounded edges and gray border.

.app__buttons button {
    align-items: flex-start;
    border-radius: 0.5rem;
    background: none;
    outline: none;
    border: 1px solid #eee;
    padding: 1rem;
    cursor: pointer;
    color: black;
    transition: border 0.2s ease;
}

.app__buttons button:hover {
    border: 1px solid dodgerblue;
}
Enter fullscreen mode Exit fullscreen mode

Predictions will be stored in the top right corner containing an icon of the result To make the prediction stand out, we can make the background a nice shade of blue. You can also make a new element that explicitly says the result rather than showing an icon, but I wanted to keep the minimal design.

#prediction {
    background: dodgerblue;
    border: 1px solid dodgerblue;
    color: #fff;
    cursor: default;
}
Enter fullscreen mode Exit fullscreen mode

We're also going to add a rotating loading animation with keyframes when you're uploading an image. You won't be able to see this animation yet because we'll need to use JS to add the .loading class to an element first.

.loading {
    animation: loading 1.5s linear infinite;
}

@keyframes loading {
    from {
        transform: rotate(0deg);
    }
    to {
        transform: rotate(360deg);
    }
}
Enter fullscreen mode Exit fullscreen mode

The UI isn't terrible, but it doesn't make a lot of sense either. Try adding in your own icons, styles, and HTML to make the app more understandable.

With Styling

Page Functionality

At the moment you can't actually interact with the page. After we add JS though, the user should be able to draw on & clear the canvas with the "clear", "pen", and "bucket" tool. Again, you can (and should) definitely expand the tools I added to make things more user friendly.

At the top of the JS file I created an object containing some SVG icons , which will be used to show a user the predicted result.

/web/static/app.js

const icons = {
    upload: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.2 15c.7-1.2 1-2.5.7-3.9-.6-2-2.4-3.5-4.4-3.5h-1.2c-.7-3-3.2-5.2-6.2-5.6-3-.3-5.9 1.3-7.3 4-1.2 2.5-1 6.5.5 8.8m8.7-1.6V21"/><path d="M16 16l-4-4-4 4"/></svg>`,

    // we add the loading class to this one so it spins when we add it to the DOM
    loading: `<svg class="loading" xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><line x1="12" y1="2" x2="12" y2="6"></line><line x1="12" y1="18" x2="12" y2="22"></line><line x1="4.93" y1="4.93" x2="7.76" y2="7.76"></line><line x1="16.24" y1="16.24" x2="19.07" y2="19.07"></line><line x1="2" y1="12" x2="6" y2="12"></line><line x1="18" y1="12" x2="22" y2="12"></line><line x1="4.93" y1="19.07" x2="7.76" y2="16.24"></line><line x1="16.24" y1="7.76" x2="19.07" y2="4.93"></line></svg>`,
    square: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect></svg>`,
    triangle: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M3 20h18L12 4z"/></svg>`,
    star: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polygon points="12 2 15.09 8.26 22 9.27 17 14.14 18.18 21.02 12 17.77 5.82 21.02 7 14.14 2 9.27 8.91 8.26 12 2"></polygon></svg>`,
    circle: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"></circle></svg>`
}
Enter fullscreen mode Exit fullscreen mode

Next, I selected the canvas and all of the buttons.

const canvas = document.querySelector(".app__canvas")

// used to draw stuff on the canvas
const ctx = canvas.getContext("2d") 

// resize canvas to canvas dimensions specified in CSS
canvas.width = canvas.offsetWidth
canvas.height = canvas.offsetHeight

const [ prediction, refreshButton, penButton, fillButton, uploadButton ] = document.querySelectorAll(".app__buttons button")
Enter fullscreen mode Exit fullscreen mode

We also need a way to store the canvas's state and configuration, such as how large each drawn pixel should be, mouse position, whether or not we are uploading and image, and an array of all of the canvas's pixels.

const pixelSize = 10

const mouse = { x: 0, y: 0, down: false, tool: "pen" }
let loading = false
let image = []
Enter fullscreen mode Exit fullscreen mode

You can get the mouse position relative to the actual canvas by using the absolute mouse position (obtained through e.clientX and e.clientY) and subtracting it by the canvas's relative position to the viewport. We can also use Math.floor to ensure that each x & y coordinate is fixed into a grid, like pixel art.

const updateMousePosition = (e, mouseDown) => {
    const rect = canvas.getBoundingClientRect()
    mouse.x = Math.floor((e.clientX - rect.left) / pixelSize) * pixelSize
    mouse.y = Math.floor((e.clientY - rect.top) / pixelSize) * pixelSize
}
Enter fullscreen mode Exit fullscreen mode

Adding event listeners will allow us to update the mouse's position every time we move the mouse or click on the canvas. I also included more event listeners to update the mouse's tool between "fill" and "pen".

// update mouse position
canvas.addEventListener("mousemove", updateMousePosition)
canvas.addEventListener("mouseup", () => { mouse.down = false })

// update mouse tool
penButton.addEventListener("click", () => { mouse.tool = "pen" })
fillButton.addEventListener("click", () => { mouse.tool = "fill" })

// reset array of pixels (we haven't actually drawn anything to the screen yet, hang tight!)
refreshButton.addEventListener("click", () => { image = [] })
Enter fullscreen mode Exit fullscreen mode

In order to render pixels to the canvas, we need to create a render loop that will continually refresh to reflect new changes. requestAnimationFrame is commonly used to recursively call a render function that updates the canvas.

We can add pixels to the image array (if they don't already exist) if the user's mouse tool is equal to "pen" and loop through the array every frame to render the pixels to the screen.

const render = () => {

    // clear canvas
    ctx.fillStyle = "#fff"
    ctx.fillRect(0, 0, canvas.width, canvas.height)

    // if mouse is down and the tool is a pen then append mouse coords to image array
    if(mouse.down && mouse.tool == "pen") {
        const pixel = { x: mouse.x, y: mouse.y }
        if(!image.find(p => p.x === pixel.x && p.y === pixel.y)) { 
            image.push(pixel) 
        }
    }

    // loop through image array to draw pixels
    for(const pixel of image) {
        ctx.fillStyle = "#000"
        ctx.fillRect(pixel.x, pixel.y, pixelSize, pixelSize)
    }

    window.requestAnimationFrame(render)
}
Enter fullscreen mode Exit fullscreen mode

To kick off the animation loop, just call render once.

render()
Enter fullscreen mode Exit fullscreen mode

I didn't add the fill tool inside of the render loop because I thought it would computationally expensive to use the tool every frame while the mouse is down. We only need to trigger the tool once when the mouse is clicked!

canvas.addEventListener("mousedown", e => { 
    updateMousePosition(e)
    mouse.down = true

    if(mouse.tool == "fill") {
        // fill tool code here
    }
})
Enter fullscreen mode Exit fullscreen mode

Filling the interior of a closed shape and implementing the bucket tool is more complicated, but intuitively it sort of makes sense. Here's the idea:
1) make a queue (an array) that stores the "jobs" that we need to complete (adding each pixel within a closed shape to the image)
2) loop through the queue
1) if the pixel from the queue already exists in the image, don't do anything - this means we've hit an edge.
2) otherwise, add the current pixel in the queue to the image array and then add all of the surrounding pixels to the queue

Queue

I also add a variable called maxLoop to prevent the loop from going through the queue forever, which could happen if the shape we decided to fill in wasn't actually closed. There's other ways to restrict this behavior (ie: checking if the pixel has hit the edge of the canvas), but this seemed the easiest to me at the time.

        // fill tool code here
        const queue = [ { x: mouse.x, y: mouse.y } ]
        let maxLoop = 7000

        while(queue.length > 0 && maxLoop > 0) {
            const pixel = queue.shift()
            if(!image.find(p => p.x === pixel.x && p.y === pixel.y)) {
                image.push(pixel)
                queue.push({ x: pixel.x - pixelSize, y: pixel.y })
                queue.push({ x: pixel.x + pixelSize, y: pixel.y })
                queue.push({ x: pixel.x, y: pixel.y - pixelSize })
                queue.push({ x: pixel.x, y: pixel.y + pixelSize })
            }   
            maxLoop--
        }

Enter fullscreen mode Exit fullscreen mode

I've saved the best for last: uploading the canvas to the Flask webserver in order to make predictions. Using the fetch API, we can send over a POST request with a base64 encoded image to the server for processing.

uploadButton.addEventListener("click", async () => {

    // only upload image if we aren't already uploading one
    if(!loading) {

        // set the upload option to a loading indicator (it'll will spin around yay)
        uploadButton.innerHTML = icons.loading
        loading = true

        // POST canvas to server
        const json = await fetch("/api/upload", {
            method: "POST",
            headers: { "Content-Type": "application/json" },
            body: JSON.stringify({ image: canvas.toDataURL("image/png").split(',').pop() })
        }).then(res => res.json())

        // display prediction to the user
        prediction.innerHTML = icons[json.prediction]

        // reset loading & upload button
        uploadButton.innerHTML = icons.upload
        loading = false
    }
})
Enter fullscreen mode Exit fullscreen mode

Closing

A huge thank you if you've made it to the end. 🥳 This was a massive tutorial, so if you find a typo or something that doesn't work, make sure to tell me so I can fix it!

Hopefully this gave you a brief introduction to machine learning with Sklearn. I look forward to what you do with this project. If you want a challenge, I recommend making the frontend user interface more friendly and using Sklearn on more datasets on Kaggle.

Top comments (0)