DEV Community

Cover image for Build a Machine Learning API With Django Rest Framework:
paulsaul621
paulsaul621

Posted on

Build a Machine Learning API With Django Rest Framework:

Introduction

Finding the proper treatment, especially drug treatment, can be a difficult task. Whether you're looking for drugs to treat seasonal allergies, or you're treating a serious illness, collecting data on various aspects of patients and what treatments they best respond to can be vital for saving lives and increasing patient convenience. Here, we'll act as a data analyst compiling medical data on a set of patients all suffering from the same illness. During their treatment, each patient responded to one of five drug treatments

In this article, we'll:

  • Build a simple Decision Tree model to predict Drug Treatment with relatively high accuracy.
  • Create an API with Django Rest Framework that will be used to consume the Decision Tree Model created.

Build a Model

Let us start by building and saving a machine learning model that will be later used to make predictions for our API.

Dataset

Find the dataset for this model on my Github repo. The file is a CSV with data from different patients who received treatment for the same disease. The different drugs can be drugA, drugB, drugC, drugX, or drugY.

Let’s get started with the exploration of the dataset.

The first thing that should be done is to import dependencies. If you do not have the libraries installed, kindly do so before proceeding.

import pandas
from sklearn import tree
import pydotplus
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
import matplotlib.image as pltimg
from sklearn.metrics import classification_report, plot_confusion_matrix, plot_roc_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import joblib
Enter fullscreen mode Exit fullscreen mode

Next, we load the dataset using pandas:

df = pandas.read_csv("https://raw.githubusercontent.com/paulwababu/datasets/main/drug200.csv")
print(df.head())
Enter fullscreen mode Exit fullscreen mode
   Age Sex      BP Cholesterol  Na_to_K   Drug
0   23   F    HIGH        HIGH   25.355  drugY
1   47   M     LOW        HIGH   13.093  drugC
2   47   M     LOW        HIGH   10.114  drugC
3   28   F  NORMAL        HIGH    7.798  drugX
4   61   F     LOW        HIGH   18.043  drugY
Enter fullscreen mode Exit fullscreen mode

To make a decision tree, all data has to be numerical.

We have to convert the non-numerical columns 'Gender', 'Cholesterol', and 'BP' into numerical values.

Pandas have a map() method that takes a dictionary with information on how to convert the values.

#Change string values into numerical values:
d = {'M': 1, 'F': 0}
df['Sex'] = df['Sex'].map(d)

d = {'HIGH': 0, 'LOW': 1, 'NORMAL': 2}
df['BP'] = df['BP'].map(d)

d = {'HIGH': 0, 'LOW': 1, 'NORMAL': 2}
df['Cholesterol'] = df['Cholesterol'].map(d)
print(df.head())
Enter fullscreen mode Exit fullscreen mode
   Age  Sex  BP  Cholesterol  Na_to_K   Drug
0   23    0   0            0   25.355  drugY
1   47    1   1            0   13.093  drugC
2   47    1   1            0   10.114  drugC
3   28    0   2            0    7.798  drugX
4   61    0   1            0   18.043  drugY
Enter fullscreen mode Exit fullscreen mode

Then we have to separate the feature columns(independent variables) from the target column(dependent variable).

The feature columns are the columns that we try to predict from, and the target column is the column with the values we try to predict.

#independent variables
features = ['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']
X = df[features]

#dependent variable
y = df['Drug']
Enter fullscreen mode Exit fullscreen mode
     Age  Sex  BP  Cholesterol  Na_to_K
0     23    0   0            0   25.355
1     47    1   1            0   13.093
2     47    1   1            0   10.114
3     28    0   2            0    7.798
4     61    0   1            0   18.043
..   ...  ...  ..          ...      ...
195   56    0   1            0   11.567
196   16    1   1            0   12.006
197   52    1   2            0    9.894
198   23    1   2            2   14.020
199   40    0   1            2   11.349

[200 rows x 5 columns]
0      drugY
1      drugC
2      drugC
3      drugX
4      drugY
       ...  
195    drugC
196    drugC
197    drugX
198    drugX
199    drugX
Name: Drug, Length: 200, dtype: object
Enter fullscreen mode Exit fullscreen mode

Now we can create the actual decision tree, fit it with our details, and save a .png file on the computer:

dtree = DecisionTreeClassifier()
dtree = dtree.fit(X, y)
data = tree.export_graphviz(dtree, out_file=None, feature_names=features)
graph = pydotplus.graph_from_dot_data(data)
graph.write_png('mydecisiontree2.png')

img=pltimg.imread('mydecisiontree2.png')
imgplot = plt.imshow(img)
plt.show()
Enter fullscreen mode Exit fullscreen mode

Results Explained

The decision tree uses your earlier decisions. Let us read the different aspects of the decision tree:

mydecisiontree2.png

Salt to Potassium ratio (Na_to_K)

Na_to_K <= 14.829 means that every patient with a sodium-potassium of 14.829 or lower will follow the True arrow (to the left), and the rest will follow the False arrow (to the right).

gini = 0.694 refers to the quality of the split, where 0.0 would mean all of the samples got the same result, and 0.5 would mean that the split is done exactly in the middle.

samples = 200 means that there are 200 patients left at this point in the decision, which is all of them since this is the first step.

value = [23, 16, 16, 54, 91] means that of these 200 patients, 23 will get "drugY", 16 will get "drugB", 16 will get "drugC", 54 will get "drugX" and 91 will get "drugY",

Note, to know these values, we ran:

print(df['Drug'].value_counts())
Enter fullscreen mode Exit fullscreen mode
drugY    91
drugX    54
drugA    23
drugB    16
drugC    16
Name: Drug, dtype: int64
Enter fullscreen mode Exit fullscreen mode

Predict Values

We can use the Decision Tree to predict new values.

Example: What prescription(among the five drugs: Y X A B and C) should a 68-year-old female with normal levels of cholesterol and BP and has salt to potassium level of 27.10 take?

print(dtree.predict([[68, 0, 2, 2, 27.10]]))
# ['drugY'] will be the output
Enter fullscreen mode Exit fullscreen mode

The system predicts drugY. If you follow the decision tree implemented above, you will arrive at drugY to! Sorry I cannot explain every aspect of the decision tree, its quite lengthy and is not within the scope of this tutorial. Let us now proceed to saving the model built

# Save the model
joblib_file = "DecisionTreeModel.joblib"
joblib.dump(dtree, joblib_file) 
Enter fullscreen mode Exit fullscreen mode

Now we can start the Django section.

Create an API With Django Rest Framework

Following Python best practices, we will create a virtual environment for our project, and install the required packages.

First, create the project directory.

$ mkdir djangoapp
$ cd djangoapp

Enter fullscreen mode Exit fullscreen mode

Now, create a virtual environment and install the required packages.

For macOS and Unix systems:

$ python3 -m venv myenv
$ source myenv/bin/activate
(myenv) $ pip install django requests djangorestframework numpy joblib scikit-learn

Enter fullscreen mode Exit fullscreen mode

For Windows:

$ python3 -m venv myenv
$ myenv\Scripts\activate
(myenv) $ pip install django requests djangorestframework numpy joblib scikit-learn
Enter fullscreen mode Exit fullscreen mode

Setting Up Your Django Application

First, navigate to the directory djangoapp we created and establish a Django project.

(myenv) $ django-admin startproject mainapp

Enter fullscreen mode Exit fullscreen mode

This will auto-generate some files for your project skeleton:

mainapp/
    manage.py
    mainapp/
        __init__.py
        settings.py
        urls.py
        asgi.py
        wsgi.py
Enter fullscreen mode Exit fullscreen mode

Now, navigate to the directory you just created (make sure you are in the same directory as manage.py) and create your app directory.

(myenv) $ python manage.py startapp monitor
Enter fullscreen mode Exit fullscreen mode

This will create the following:

monitor/
    __init__.py
    admin.py
    apps.py
    migrations/
        __init__.py
    models.py
    tests.py
    views.py
Enter fullscreen mode Exit fullscreen mode

On the mainapp/settings.py file, look for the following line and add the app we just created above.

INSTALLED_APPS = [
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',
    'rest_framework',#new line
    'monitor', #new line
]

Enter fullscreen mode Exit fullscreen mode

Ensure you are in the monitor directory then create a new directory called templates and a new file called urls.py. Your directory structure of monitor application should look like this

monitor/
    __init__.py
    admin.py
    apps.py
    migrations/
    templates/
        __init__.py
    models.py
    tests.py
    urls.py
    views.py
Enter fullscreen mode Exit fullscreen mode

Ensure your mainapp/urls.py file, add our monitor app URL to include the URLs we shall create next on the monitor app:

from django.contrib import admin
from django.urls import path, include

urlpatterns = [
    #path('admin/', admin.site.urls),
    path('', include('monitor.urls')),#monitor app url
]

Enter fullscreen mode Exit fullscreen mode

Now, on the monitor/urls.py file, add our website there:

from django.urls import path
from .views import *

urlpatterns = [
    path('', Prediction.as_view(), name = 'prediction'),
]
Enter fullscreen mode Exit fullscreen mode

Let’s create another directory to store our machine learning model. I’ll also add the dataset to the project for those who want to achieve the whole dataset. (It is not compulsory to create a data folder.)

(venv)$ mkdir ml
(venv)$ mkdir ml/models
(venv)$ mkdir ml/data
Enter fullscreen mode Exit fullscreen mode

We also need to tell Django where our machine learning model is located. Add these lines to settings.py file:

import os
MODELS = os.path.join(BASE_DIR, 'ml/models')
Enter fullscreen mode Exit fullscreen mode

Load Model through apps.py

Load your machine learning models in apps.py so that when the application starts, the trained model is loaded only once. Otherwise, the trained model is loaded each time an endpoint is called, and then the response time will be slower.

Let’s update apps.py

import os
import joblib
from django.apps import AppConfig
from django.conf import settings


class ApiConfig(AppConfig):
    name = 'api'
    MODEL_FILE = os.path.join(settings.MODELS, "DecisionTreeModel.joblib")
    model = joblib.load(MODEL_FILE)
Enter fullscreen mode Exit fullscreen mode

Edit views.py

The last step is to update views.py. The views will be mainly responsible for two tasks:

  • Process incoming POST requests.
  • Make a prediction with the incoming data and give the result as a Response.
import numpy as np
import pandas as pd
from .apps import *
from rest_framework.views import APIView
from rest_framework.response import Response


class Prediction(APIView):
    def post(self, request):
        #data = request.data
        age= request.GET.get('age')
        gender = request.GET.get('gender')
        bp = request.GET.get('bp')
        cholesterol = request.GET.get('cholesterol')
        salt = request.GET.get('salt')
        dtree = ApiConfig.model
        #predict using independent variables
        PredictionMade = dtree.predict([[age, gender, cholesterol, bp, salt]])
        response_dict = {"Predicted drug": PredictionMade}
        print(response_dict)
        return Response(response_dict, status=200)

Enter fullscreen mode Exit fullscreen mode

Test the API

To test our API, I’ll use requests library, which is the de facto standard for making HTTP requests in Python. You can also use Postman or Insomnia for making the requests.

Let’s predict what drug prescription(among the five drugs: Y X A B and C) should a 68-year-old female with normal levels of cholesterol and BP and has salt to potassium level of 27 take?

import requests
predict = [print(x) for x in requests.post("http://127.0.0.1:8000/?age=68&gender=0&bp=2&cholesterol=2&salt=27")]
Enter fullscreen mode Exit fullscreen mode
{'Predicted drug': ['drugY']}
Enter fullscreen mode Exit fullscreen mode

Thanks for staying tuned! In case of any questions, feel free to ask in the comment section below.

Top comments (7)

Collapse
 
devsvkflbs profile image
dev-svk-flbs

Thanks for this great article...

Typically, if the inference model is light, the endpoint will return fast results. But for my case there is no way to reduce the inference time less than 1 minute. In this case, the whole endpoint will be locked, and wont be able to accept any further inference request. Also I am guessing the request may be even times out ...This is a very common situation with ML inferencing so I am wondering there must by something easy to handle this kind of situation.

I did some reading on DRF async but not sure if thats the right approach.

Can you please provide some directions?

Collapse
 
paulwababu profile image
paulsaul621

Thank you! Glad you enjoyed the post. There are a number of approaches which you can take, though i have not implemented them myself, apart from the first one.
1.)Implement async views with task queue like Celery
2.)Another approach is to use a service such as AWS Lambda, which allows you to run code without provisioning or managing servers. This can make it easier to scale your application and handle high levels of concurrency.
3.)Another option is to use a service like Kubernetes, which can help you to scale your application based on the number of requests with its autoscaling capabilities.

Collapse
 
rukshar_alam profile image
Rukshar Alam

Would be great if you please make similar posts on deploying ML models using Celery and Kubernetes.

Collapse
 
orthymarjan profile image
Marjan Ferdousi

Very well written! Love it.

Collapse
 
paulwababu profile image
paulsaul621

Thank you! Glad you liked it.

Collapse
 
ashkanmohammadi profile image
Amohammadi2

You can also create a web interface on top of it as the front-end

Collapse
 
paulwababu profile image
paulsaul621

Definitely! I'm making a post on how to do that next!