DEV Community

Cover image for A Guide to Connect Machine Learning Model Backend to Frontend Using Flask
Ayub Ali Emon
Ayub Ali Emon

Posted on

A Guide to Connect Machine Learning Model Backend to Frontend Using Flask

Introduction

Many of us work with machine learning models for various purposes, often using Jupyter Notebook, Google Collab or other environments. Now we may think about how can we save our trained model and make predictions based on user input directly from our website or app. Many different methods are available out there for various purposes. However, in this article, I will show a simple approach to saving a trained model as a file and using this model file to predict actual input data from users and show prediction results.
Here, I have trained a basic Linear Regression model from the scikit-learn library using the mtcars dataset to predict Miles Per Gallon (mpg) based on Horsepower (hp) and Weight (wt). Then, I also created a simple web app using Flask, JQuery, and Bootstrap 5 to collect, process user input and generate predictions. You don't need to worry about the technical details, you can easily try it out by cloning/downloading my Github repository.

All the resource files can be found on this GitHub Repository:
github.com/alfa-echo-niner-ait/ml-predict-app

I have also hosted our sample web app online. Feel free to try it out at this link:
Simple Flask Web App to Predict From User Input

Simple Flask Web App to Predict From User Input

Requirements

To try out this tutorial you need to have Python installed on your local machine and a text editor or IDE that supports Python and editing notebook_(.ipynb)_ files. Here, I will be using my all-time favourite Visual Studio Code as I can just install certain extensions to continue my work.
Now you can either clone my repository or directly download the source code as a zip file from GitHub and extract it on your file storage. To clone the repository, create a new folder, then open the terminal from that folder and run the git clone command.

git clone https://github.com/alfa-echo-niner-ait/ml-predict-app.git

After cloning/downloading the source codes, you need to install the necessary packages. Here mainly we will need the following

  • packages:
  • flask
  • pandas
  • scikit-learn
  • joblib

Or you can just run the following command in your resource folder where you have cloned/downloaded the source files:

pip install -r requirements.txt

However, it is a good practice to create a virtual environment first when we are working on big projects. Our project is small enough, we are good to go now.

Processing

Train Model

For this tutorial, I've used a simple car dataset. Here is the overview of the dataset:

Dataset Overview

We will be predicting mpg based on hp and wt. So, we need these three columns from our dataset. So, I will assign them accordingly and split them for training and testing data.

X = data[['hp', 'wt']]
Y = data['mpg']

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=11)
Enter fullscreen mode Exit fullscreen mode

If you are willing to use your models to predict user input data, pay attention to the training process, especially picking the right columns. Picking an unnecessary or extensive amount of columns affects the model's performance. Make sure which columns you really need and exclude the remaining. Also, choose the correct model according to your dataset.

Here I have used LinearRegression model from the scikit-learn library to train and test.

from sklearn.linear_model import LinearRegression
lr_model = LinearRegression()
Enter fullscreen mode Exit fullscreen mode

You can check out the full model training and testing process that has been included with the resource files in the /Model/Model_train.ipynb

Save Model

After we finish training, testing, and evaluating the accuracy and other necessary metrics, and are satisfied with our model performance, we are ready to save our trained model. My model has achieved an accuracy rate of 88.48%, which is good enough.

Model Accuracy

Here, we have used the joblib library to save our model, by joblib.dump() function. To save the model, you have to pass your training model (in our case lr_model) and filename with an extension. Then, our model will be saved in the root directory of our notebook file with the mentioned filename(model_filename).

In model_filename if the file extension is not mentioned, filename extensions ('.z', '.gz', '.bz2', '.xz' or '.lzma') will be used automatically.

import joblib

# Define the save model name
model_filename = 'car_model.joblib'
# Save the model
joblib.dump(lr_model, model_filename)
Enter fullscreen mode Exit fullscreen mode

The joblib.dump() function can receive the following attributes:

joblib.dump(value, filename, compress=0, protocol=None, cache_size=None)
Enter fullscreen mode Exit fullscreen mode

protocol and cache_size are optional. For compress, it supports int from 0 to 9 or bool or 2-tuple, or you can also leave it as optional.

Load Model

Now we have saved our model, it's time to load the model and do the prediction. We can load our saved model simply by using joblib.load() function.

loaded_model = joblib.load(model_filename)
Enter fullscreen mode Exit fullscreen mode

This function has two attributes, filename(str, pathlib.Path, or file object) and mmap_mode({None, 'r+', 'r', 'w+', 'c'}) or you can leave mmap_mode as optional.

WARNING: joblib.load relies on the pickle module and can therefore execute arbitrary Python code. It should therefore never be used to load files from untrusted sources.

joblib.load(filename, mmap_mode=None)
Enter fullscreen mode Exit fullscreen mode

Predict With Loaded Model

To predict using your loaded model you have to pass the data with the exact format of your training data. You can check the training and testing data format by printing their type or check your training process to find your format.

Check Data Type

For example, here we have hp and wt data separately, then we convert them to pandas.core.frame.DataFrame, and then use for prediction.

# Input data
hp = 110
wt = 3.2

# Convert accordingly
data = {
    "hp": [hp],
    "wt": [wt]
}
test_data = pd.DataFrame(data)

# Predict
predicted_mpg = loaded_model.predict(test_data)

# Print Result
print(f"Result: {predicted_mpg[0]:.2f} mpg")
Enter fullscreen mode Exit fullscreen mode

Working with Flask

Load Model in the App

If you have already cloned/downloaded the source code from my GitHub repository, you might've noticed that we load our model file when we initialize our app in /source/init.py, it will prevent reloading the model again and again. In the MODEL_NAME mention your saved model filename. Here, I have placed my model file inside /source/static/model folder. If you place your model in a different folder, remember to modify the model_path.

import os
import joblib
from flask import Flask

MODEL_NAME = "car_model.joblib"

# Initialize the app
app = Flask(__name__)

# Load the model when the application starts
model_path = os.path.join(os.path.dirname(__file__), 'static/model/' + MODEL_NAME)
model = joblib.load(model_path)
Enter fullscreen mode Exit fullscreen mode

User Input

In our /source/templates/index.html, we have a simple form to take hp and wt input from the user.

<!-- User Input -->
<input type="text" id="hp" class="form-control-lg" placeholder="Enter HP, e.g. 112" required>
<input type="text" id="wt" class="form-control-lg" placeholder="Enter WT, e.g. 3.2" required>
<!-- Process input/output via main.js file -->
 <button class="btn btn-lg btn-info" type="button" id="submitData">Get MPG</button>
Enter fullscreen mode Exit fullscreen mode

Send Data to Flask Server

When the user submits the data we read the form data, send it to our flask server and get a response using AJAX request. The code can be found inside our /source/static/js/main.js file. However, you can use any other method you like to exchange request/response data with your server. Here's how I did it:

$(document).ready(function () {
    // Send data to flask server upon submit
    $("#submitData").click(function () {
        $("#resText").show()
        $("#resText").html("Loading....");
        hp = $("#hp").val();
        wt = $("#wt").val();
        $("#hp").val("");
        $("#wt").val("");

        // Data request
        $.ajax({
            type: "POST",
            url: predict_url,
            data: {
                hp: hp,
                wt: wt
            },
            // Print result
            success: function (response) {
                $("#resText").html(response);
            }
        });
    });
});
Enter fullscreen mode Exit fullscreen mode

Process, Predict and Return Result

We have defined a /predict route in /source/routes.py flask app to read the data from the AJAX request, then process the data, predict and return the result.

# Process user input and return response data
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST' and model:
        # Receive & store the form input data by user
        hp = request.form.get('hp')
        wt = request.form.get('wt')

        # Prepare received data for test data according to your trained model
        # Note: Prepare your user input as per your trained model prediction data
        data = {
            "hp": [hp],
            "wt": [wt]
        }
        test_data = pd.DataFrame(data)
        try:
            # Calculate the prediction result as per your trained model
            predict_mpg = model.predict(test_data)
            result = f"{predict_mpg[0]:.2f}"
            return f"Predicted : <strong>{result} MPG</strong>"
        except ValueError:
            return "<i class='text-danger'>Please Enter Correct Data!</i>"

    return "Failed to load model!"
Enter fullscreen mode Exit fullscreen mode

Here we process the user data the same way we showed in the "Predict With Loaded Model" section.

Conclusion

With this tutorial, I have tried to demonstrate some basic ideas for implementing a machine-learning model to interact with users. However, these ideas are not limited, you can try out more with different models and different frameworks like Django, FastAPI etc. or more complex applications to not just predict user input but also concurrently feed the model live user data and predict various aspects to improve user experience and provide better services.

And you're welcome to explore the internet for more ideas on this topic and build a strong skillset. I have attached all the links to the Reference section of the technology I've discussed through my tutorial.

Reference Links

Top comments (0)