Scikit Learn (Sklearn) is a Python machine learning library that provides accessible tools for data science. That also means you can do some pretty interesting stuff and slap "machine learning pro" on your resume in only a few lines of code. In this tutorial, I'll show how to create a simple image classifier that you can deploy with a cool web app.
If you want to see what we'll be building, check out the Github Repository.
Installing Libraries
Assuming you have Python already, you can simply use pip
to install Sklearn. Ideally you'd want to use virtual environments to ensure your dependencies don't have any conflicts. My favorite virtual environment manager is pipenv
because it mirrors pip
so closely, but you may also want to install anaconda
if you plan on continuing to make projects on machine learning and datascience.
For now, we'll just go ahead and install Sklearn and a few other dependencies globally.
>>> pip install scikit-learn pillow numpy joblib flask
- PIL: for image manipulation
- Numpy: converting images to arrays
- Joblib: saving the entire model to deploy it to an API (it's really overkill for this project, but I love how easy it is to use)
- Flask: simple web server to host the frontend
Downloading the Dataset
I used the Four Shapes Dataset by smeschke on Kaggle, which contains a bunch of black and white images of circles, squares, stars, and triangles.
Shapes Classifier
Create a new Python file or Jupyter notebook where we can build and save the machine learning model. I called mine shapes_classifier.ipynb
.
Import Libraries
from PIL import Image
import numpy as np
import os
Import Dataset
After you unzip the dataset and import it into your project directory, it should look something like this:
shapes/
circle/
square/
star/
triangle/
shapes_classifier.ipynb
NOTE: Make sure you include the shapes directory in .gitignore so you don't commit the entire dataset with git!
Images by themselves cannot be directly processed by Sklearn, so we have to create a function that looks in each shape directory and converts each image into a Numpy array representing pixel intensities (255 for white, 0 for black)
shapes_classifier.ipynb
def import_dataset():
shapes = ["circle", "square", "triangle", "star"]
X = []
y = []
# loop through each shape
for shape in shapes:
# get all shape images
shape_files = os.listdir(f"shapes/{shape}")
# convert each image to a numpy array
for image in shape_files:
X.append(np.asarray(Image.open(f"shapes/{shape}/{image}")).flatten())
# add a corresponding label
y.append(shape)
return X, y
X, y = import_dataset()
Like images, the shape labels (like "circle" or "triangle") in y
must be encoded so Sklearn can process them.
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y = le.fit_transform(y)
Split Dataset
After we import the dataset, we split the dataset into 2 parts: the training set and the test set. The training set will be processed by Sklearn to create a model that makes predictions, whereas the test set is used estimate the actual accuracy of the model.
test_size
: refers to how big the test set should be, I usually just set it to 1/4 of the entire dataset
random_state
: Sklearn splits datasets randomly but you can set a random_state
to ensure reproducible results (referred to as "deterministic behavior")
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/4, random_state=42)
Train the Model
For this project we will use the Random Forest Classifier, which is known for is flexibility and high levels of accuracy without extensive tuning. Random Forest Classifiers are basically the average of many Decision Tree Classifiers, which continually split the dataset into similar-looking groups. We can then use the final groups to classify new data. Quite frankly I don't fully understand the math to explain how the model works other than my basic intuition, so we'll just sweep the process under the rug along with the Sklearn library.
n_estimators
: how many decision tree classifiers we want to use (it technically already defaults to 100)
from sklearn.ensemble import RandomForestClassifier
# create a new classifier and train it on the data
classifier = RandomForestClassifier(n_estimators=100, random_state=42)
classifier.fit(X_train, y_train)
Making Predictions
Let's use the classifier to make predictions on the testing dataset. We can compare the results to the expected results in y_test
to get an understanding of the model's accuracy.
y_pred = classifier.predict(X_test)
from sklearn.metrics import accuracy_score
print(accuracy_score(y_test, y_pred)) # => 0.9997328346246327
The model achieves a 99.97% accuracy! However, extremely high accuracies are indicators of overfitting, so the model might not be able to effectively classify new & unseen data. Try tuning the hyperparameters to reduce overfitting (I was too lazy to make any fixes).
You can also make a single prediction by loading an image with PIL & converting it to a Numpy array
# open an image -> convert to black and white -> resize to (200, 200)
pil_image = Image.open("circle.png").convert('1').resize((200, 200))
# convert image into 1D array of pixels
test_image = [ 255 if pixel else 0 for pixel in np.asarray(pil_image).flatten() ]
# predict what the image is
le.inverse_transform(classifier.predict([test_image])) # => circle (hopefully)
Congratulations on creating a maching learning model! You can actually stop here if you don't want to deploy your model to a Flask app.
Saving the Model
Joblib makes it easy to save models and load them back into other Python scripts. Although there are potentially lighter or more practical libraries, I found Joblib very approachable and simple. Remember to save the label encoder as well so we can convert the model's output into a string we can understand.
from joblib import dump
dump(le, "model/shapes_label_encoder.joblib", compress=3)
dump(classifier, "model/shapes_classifier.joblib", compress=3)
Model API
Let's make a web app that allows users to submit drawings of shapes to our model. First, create some new directories & files.
shapes/
circle/
square/
star/
triangle/
model/
shapes_classifier.joblib
shapes_label_encoder.joblib
shapes_classifier.ipynb
# what you need to add:
web/
static/
app.js
styles.css
templates/
index.html
server.py
Import Libraries
For the backend, we'll be making a very simple Flask webserver. Start by importing all of the libraries. BytesIO
and base64
will be used to process requests containing an image encoded in base64 format. You don't need to install these because they should be built into python.
server.py
# web server
from flask import Flask, request, render_template
# image processing
from PIL import Image
import numpy as np
from io import BytesIO
import base64
# load the model
import joblib
Load the Model
Joblib makes everything easy! We can directly call on the classifier and label encoder without importing Sklearn into server.py
.
label_encoder = joblib.load("model/shapes_label_encoder.joblib")
classifier = joblib.load("model/shapes_classifier.joblib")
Create Flask App
Here, we create a simple Flask app with a single route to render the web/templates/index.html
app = Flask("app", static_url_path='/', static_folder="web/static", template_folder="web/templates")
@app.route('/')
def hello_world():
return render_template("index.html")
Adding an upload route is only slightly more complicated. First we get the request body, which will contain a "image" field with a base64 encoded image. After loading the image with PIL, the rest is almost identical to the single prediction code we wrote earlier.
@app.route("/api/upload", methods=["POST"])
def upload():
body = request.get_json()
pil_image = Image.open(BytesIO(base64.b64decode(body["image"]))).convert('1').resize((200, 200))
test_image = [ 255 if pixel else 0 for pixel in np.asarray(pil_image).flatten() ]
prediction = label_encoder.inverse_transform(classifier.predict([ test_image ]))
return { "prediction": prediction[0] }
Now we can start the Flask app:
app.run(host="localhost", port=5000)
If you are using Replit, make sure to change "localhost" to "0.0.0.0" to open a new webserver.
Frontend App
Although I think this tutorial is mainly focused on Python, I'll try to add some comments about each step. For the frontend, we'll create a simple paint-like program that has the ability to set individual pixels, fill in a convex shape (a bucket tool), and upload the drawn image to the server.
Page Structure
Feel free to copy and paste the HTML into your own editor. Be sure to import the JavaScript and CSS files so we can add functionality and styling to our app.
web/templates/index.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Shapes Classifier</title>
<!-- add styles & js (stored in web/static) -->
<link rel="stylesheet" href="styles.css" />
<script src="app.js" type="module"></script>
</head>
<body>
<div class="app">
<!-- we'll use this canvas element to draw pixels -->
<canvas class="app__canvas"></canvas>
<!-- all of the app options with inline SVGs as icons -->
<div class="app__buttons">
<button id="prediction" title="Prediction">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"></circle></svg>
</button>
<button title="Erase Canvas">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> <path d="M2.5 2v6h6M21.5 22v-6h-6"/><path d="M22 11.5A10 10 0 0 0 3.2 7.2M2 12.5a10 10 0 0 0 18.8 4.2"/></svg>
</button>
<button title="Pen Tool">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#000000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polygon points="16 3 21 8 8 21 3 21 3 16 16 3"></polygon></svg>
</button>
<button title="Bucket Tool">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"></circle></svg>
</button>
<button title="Upload Image">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.2 15c.7-1.2 1-2.5.7-3.9-.6-2-2.4-3.5-4.4-3.5h-1.2c-.7-3-3.2-5.2-6.2-5.6-3-.3-5.9 1.3-7.3 4-1.2 2.5-1 6.5.5 8.8m8.7-1.6V21"/><path d="M16 16l-4-4-4 4"/></svg>
</button>
</div>
</div>
</body>
</html>
I got a bit lazy with the user interface, so I just used the title
attribute to label options.
Page Styles
I always start out by resetting the box-sizing to border-box to make CSS a bit more consistent.
web/static/styles.css
html, body {
height: 100%;
width: 100%;
padding: 0;
margin: 0;
box-sizing: border-box;
}
*, ::before, ::after {
box-sizing: inherit;
}
After that, I centered the drawing canvas & the buttons on the side.
.app {
/* used to ensure the .app__canvas & .app__buttons are the same height */
--canvas-size: 30rem;
height: 100vh;
width: 100vw;
/* the centering magic */
display: flex;
justify-content: center;
align-items: center;
overflow: hidden;
gap: 2rem;
}
.app__canvas {
width: var(--canvas-size);
height: var(--canvas-size);
border-radius: 0.5rem;
box-shadow: 0 0.5rem 1rem #eee;
}
.app__buttons {
height: var(--canvas-size);
display: flex;
flex-direction: column;
gap: 1rem;
}
The buttons still have their default styles, so let's change that by adding some rounded edges and gray border.
.app__buttons button {
align-items: flex-start;
border-radius: 0.5rem;
background: none;
outline: none;
border: 1px solid #eee;
padding: 1rem;
cursor: pointer;
color: black;
transition: border 0.2s ease;
}
.app__buttons button:hover {
border: 1px solid dodgerblue;
}
Predictions will be stored in the top right corner containing an icon of the result To make the prediction stand out, we can make the background a nice shade of blue. You can also make a new element that explicitly says the result rather than showing an icon, but I wanted to keep the minimal design.
#prediction {
background: dodgerblue;
border: 1px solid dodgerblue;
color: #fff;
cursor: default;
}
We're also going to add a rotating loading animation with keyframes when you're uploading an image. You won't be able to see this animation yet because we'll need to use JS to add the .loading
class to an element first.
.loading {
animation: loading 1.5s linear infinite;
}
@keyframes loading {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
The UI isn't terrible, but it doesn't make a lot of sense either. Try adding in your own icons, styles, and HTML to make the app more understandable.
Page Functionality
At the moment you can't actually interact with the page. After we add JS though, the user should be able to draw on & clear the canvas with the "clear", "pen", and "bucket" tool. Again, you can (and should) definitely expand the tools I added to make things more user friendly.
At the top of the JS file I created an object containing some SVG icons , which will be used to show a user the predicted result.
/web/static/app.js
const icons = {
upload: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.2 15c.7-1.2 1-2.5.7-3.9-.6-2-2.4-3.5-4.4-3.5h-1.2c-.7-3-3.2-5.2-6.2-5.6-3-.3-5.9 1.3-7.3 4-1.2 2.5-1 6.5.5 8.8m8.7-1.6V21"/><path d="M16 16l-4-4-4 4"/></svg>`,
// we add the loading class to this one so it spins when we add it to the DOM
loading: `<svg class="loading" xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><line x1="12" y1="2" x2="12" y2="6"></line><line x1="12" y1="18" x2="12" y2="22"></line><line x1="4.93" y1="4.93" x2="7.76" y2="7.76"></line><line x1="16.24" y1="16.24" x2="19.07" y2="19.07"></line><line x1="2" y1="12" x2="6" y2="12"></line><line x1="18" y1="12" x2="22" y2="12"></line><line x1="4.93" y1="19.07" x2="7.76" y2="16.24"></line><line x1="16.24" y1="7.76" x2="19.07" y2="4.93"></line></svg>`,
square: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect></svg>`,
triangle: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M3 20h18L12 4z"/></svg>`,
star: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polygon points="12 2 15.09 8.26 22 9.27 17 14.14 18.18 21.02 12 17.77 5.82 21.02 7 14.14 2 9.27 8.91 8.26 12 2"></polygon></svg>`,
circle: `<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"></circle></svg>`
}
Next, I selected the canvas and all of the buttons.
const canvas = document.querySelector(".app__canvas")
// used to draw stuff on the canvas
const ctx = canvas.getContext("2d")
// resize canvas to canvas dimensions specified in CSS
canvas.width = canvas.offsetWidth
canvas.height = canvas.offsetHeight
const [ prediction, refreshButton, penButton, fillButton, uploadButton ] = document.querySelectorAll(".app__buttons button")
We also need a way to store the canvas's state and configuration, such as how large each drawn pixel should be, mouse position, whether or not we are uploading and image, and an array of all of the canvas's pixels.
const pixelSize = 10
const mouse = { x: 0, y: 0, down: false, tool: "pen" }
let loading = false
let image = []
You can get the mouse position relative to the actual canvas by using the absolute mouse position (obtained through e.clientX
and e.clientY
) and subtracting it by the canvas's relative position to the viewport. We can also use Math.floor
to ensure that each x & y coordinate is fixed into a grid, like pixel art.
const updateMousePosition = (e, mouseDown) => {
const rect = canvas.getBoundingClientRect()
mouse.x = Math.floor((e.clientX - rect.left) / pixelSize) * pixelSize
mouse.y = Math.floor((e.clientY - rect.top) / pixelSize) * pixelSize
}
Adding event listeners will allow us to update the mouse's position every time we move the mouse or click on the canvas. I also included more event listeners to update the mouse's tool between "fill" and "pen".
// update mouse position
canvas.addEventListener("mousemove", updateMousePosition)
canvas.addEventListener("mouseup", () => { mouse.down = false })
// update mouse tool
penButton.addEventListener("click", () => { mouse.tool = "pen" })
fillButton.addEventListener("click", () => { mouse.tool = "fill" })
// reset array of pixels (we haven't actually drawn anything to the screen yet, hang tight!)
refreshButton.addEventListener("click", () => { image = [] })
In order to render pixels to the canvas, we need to create a render loop that will continually refresh to reflect new changes. requestAnimationFrame
is commonly used to recursively call a render function that updates the canvas.
We can add pixels to the image
array (if they don't already exist) if the user's mouse tool is equal to "pen" and loop through the array every frame to render the pixels to the screen.
const render = () => {
// clear canvas
ctx.fillStyle = "#fff"
ctx.fillRect(0, 0, canvas.width, canvas.height)
// if mouse is down and the tool is a pen then append mouse coords to image array
if(mouse.down && mouse.tool == "pen") {
const pixel = { x: mouse.x, y: mouse.y }
if(!image.find(p => p.x === pixel.x && p.y === pixel.y)) {
image.push(pixel)
}
}
// loop through image array to draw pixels
for(const pixel of image) {
ctx.fillStyle = "#000"
ctx.fillRect(pixel.x, pixel.y, pixelSize, pixelSize)
}
window.requestAnimationFrame(render)
}
To kick off the animation loop, just call render
once.
render()
I didn't add the fill tool inside of the render loop because I thought it would computationally expensive to use the tool every frame while the mouse is down. We only need to trigger the tool once when the mouse is clicked!
canvas.addEventListener("mousedown", e => {
updateMousePosition(e)
mouse.down = true
if(mouse.tool == "fill") {
// fill tool code here
}
})
Filling the interior of a closed shape and implementing the bucket tool is more complicated, but intuitively it sort of makes sense. Here's the idea:
1) make a queue (an array) that stores the "jobs" that we need to complete (adding each pixel within a closed shape to the image)
2) loop through the queue
1) if the pixel from the queue already exists in the image, don't do anything - this means we've hit an edge.
2) otherwise, add the current pixel in the queue to the image array and then add all of the surrounding pixels to the queue
I also add a variable called maxLoop
to prevent the loop from going through the queue forever, which could happen if the shape we decided to fill in wasn't actually closed. There's other ways to restrict this behavior (ie: checking if the pixel has hit the edge of the canvas), but this seemed the easiest to me at the time.
// fill tool code here
const queue = [ { x: mouse.x, y: mouse.y } ]
let maxLoop = 7000
while(queue.length > 0 && maxLoop > 0) {
const pixel = queue.shift()
if(!image.find(p => p.x === pixel.x && p.y === pixel.y)) {
image.push(pixel)
queue.push({ x: pixel.x - pixelSize, y: pixel.y })
queue.push({ x: pixel.x + pixelSize, y: pixel.y })
queue.push({ x: pixel.x, y: pixel.y - pixelSize })
queue.push({ x: pixel.x, y: pixel.y + pixelSize })
}
maxLoop--
}
I've saved the best for last: uploading the canvas to the Flask webserver in order to make predictions. Using the fetch
API, we can send over a POST request with a base64 encoded image to the server for processing.
uploadButton.addEventListener("click", async () => {
// only upload image if we aren't already uploading one
if(!loading) {
// set the upload option to a loading indicator (it'll will spin around yay)
uploadButton.innerHTML = icons.loading
loading = true
// POST canvas to server
const json = await fetch("/api/upload", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ image: canvas.toDataURL("image/png").split(',').pop() })
}).then(res => res.json())
// display prediction to the user
prediction.innerHTML = icons[json.prediction]
// reset loading & upload button
uploadButton.innerHTML = icons.upload
loading = false
}
})
Closing
A huge thank you if you've made it to the end. 🥳 This was a massive tutorial, so if you find a typo or something that doesn't work, make sure to tell me so I can fix it!
Hopefully this gave you a brief introduction to machine learning with Sklearn. I look forward to what you do with this project. If you want a challenge, I recommend making the frontend user interface more friendly and using Sklearn on more datasets on Kaggle.
Top comments (0)