DEV Community


Posted on

Code a Neural Network from scratch to solve the binary MNIST problem


This article provides the development of a 3-layer Neural Network (NN) from scratch (i.e., only using Numpy) for solving the binary MNIST dataset. This project offers a practical guide to the foundational aspects of deep learning and the architecture of neural networks. It primarily concentrates on building the network from the ground up (i.e., the mathematics running under the hood of NNs). It is noted that this project is an extension of a project titled "Code a 2-layer Neural Network from Scratch," where I explained in detail the mathematics behind the senses of a neural network (see this article for more details). In other words, solving the binary MNIST can be considered a from-scratch neural network use case.

Load MNIST dataset

Once the helper files are available in AWS SageMaker, we use pre-defined functions to load the MNIST dataset.

from utils_data import *
Enter fullscreen mode Exit fullscreen mode

The purpose of this experiment is to handle the binary MNIST only. Therefore, we need a function to load the binary MNIST of 0 and 1 only (i.e., other MNIST digits from 2 to 9 are out of scope in this example).

We applied those functions to load the binary MNIST dataset.

X_train_org, Y_train_org, X_test_org, Y_test_org = load_mnist()
X_train_org, Y_train_org, X_test_org, Y_test_org = load_binary_mnist(X_train_org, Y_train_org, X_test_org, Y_test_org)
Enter fullscreen mode Exit fullscreen mode

Data visualization

visualize_multi_images(X_train_org, Y_train_org, layout=(3, 3), figsize=(10, 10), fontsize=12)
Enter fullscreen mode Exit fullscreen mode


I store data in an AWS S3 bucket for later use if needed.

key = "data/mnist.npz"
bucket_url = "s3://{}/{}".format(BUCKET_NAME, key)
Enter fullscreen mode Exit fullscreen mode

Data preparation for model training

We applied a helper function to prepare the binary MNIST for training the Neural Network.

X_train, X_test, Y_train, Y_test = make_inputs(X_train_org, X_test_org, Y_train_org, Y_test_org)
Enter fullscreen mode Exit fullscreen mode

Build a Neural Network for solving binary MNIST

To build a Neural Network, we must define helper functions as the building blocks for constructing the architecture. I will not list those functions here because they will make this writing unnecessarily long. I only present the construction of the Neural Network. For more details regarding helper functions and components of nn_Llayers_binary(), please see in this repository or refer to this article for more details).

Setup the hyperparameters

layer_dims = [784, 128, 64, 1]
learning_rate = 0.01
number_iterations = 250
Enter fullscreen mode Exit fullscreen mode

Train the Neural Network

from utils_binary import *

parameters, costs, time = nn_Llayers_binary(X_train, Y_train, layer_dims, learning_rate, number_iterations, print_cost=False)
Enter fullscreen mode Exit fullscreen mode

Compute accuracy on train and test datasets

Yhat_train = predict_binary(X_train, Y_train, parameters)
train_accuracy = compute_accuracy(Yhat_train, Y_train)

Yhat_test = predict_binary(X_test, Y_test, parameters)
test_accuracy = compute_accuracy(Yhat_test, Y_test)

print(f"Train accuracy: {train_accuracy} %")
print(f"Test accuracy: {test_accuracy} %")
Enter fullscreen mode Exit fullscreen mode

The accuracy output

Train accuracy: 99.65 %
Test accuracy: 99.81 %
Enter fullscreen mode Exit fullscreen mode

Given that the MNIST dataset is not difficult, using only binary MNIST to distinguish between 0 and 1 makes this task even simpler. Therefore, it is no surprise to see such high accuracy on both the train and test datasets, even though the solution presented in this experiment is not an advanced neural network.

There are some visualizations of the misclassified cases.



This repository could make a great introductory project for those new to artificial intelligence, machine learning, and deep learning. Experimenting with this simple neural network taught me many basic principles that operate behind the scenes.

Top comments (2)

jonrandy profile image
Jon Randy 🎖️

I remember doing something like this back in the late '80s using BASIC on my 8-bit ZX Spectrum. No libraries, slow as anything, but it worked. Just very simple number recognition

hoangng profile image

Very interesting to hear; doing simple things like this without the libraries is great for understanding how things work.
BTW, I used to work in Bangkok in 2014, so how are things in Bangkok?