DEV Community

Cover image for How To Build An Artificial Neural Network in Java
Rahul Raj
Rahul Raj

Posted on

How To Build An Artificial Neural Network in Java

Let me give you a brief note on the purpose of this article and then will discuss the technical part. I see plenty of Python based tutorials on how to start coding deep learning models, train them and deploy into production. Some way or the other my research always ended up considering Python for any machine learning practices. Most of the time, it was due to simplicity and the fact that majority of machine learning libraries are released in Python. Python is interesting, but it was never my primary choice. I was born and brought up in Java and it's always painful to see that there's technology constraint to learn something new. So, should I be transforming to a "Python Developer" to learn machine learning at its best? This question was all over my mind. Well, some folks would argue that you're a software engineer and you're expected to learn any technology stack without hesitance. Who said Python is ugly? It is indeed beautiful, it does the job with less coding effort.

Mathematical computations that involve plenty of Java code can be done in few lines in Python. Most of the organizations look for all-rounders who have taste in multiple technologies. But it is also fair to think on how to leverage your existing technology stack to implement what you dream about. Searching for Java-based deep learning frameworks ended up in DL4J. There may be other Java frameworks/libraries that does the job, but DL4J is the only promising commercial-grade deep learning framework in Java till date. The founders have released DL4J for the same concern, for the people of Java. Having said that, do not restrict yourself to particular machine learning library. Solution varies upon your model and you will have to research on what works best in your case. Now let's get our hands dirty! I will be using DeepLearning4j framework throughout this article as my attempt is to provide Java oriented solution.

I'm using the below example to demonstrate the implementation of neural network. An International Banking company wants to increase the customer retention rate by predicting the customers that are most likely to leave the bank.

The customer data set (CSV format) look like this: Click Here

1) Data Pre-Processing

Data in human readable format may doesn't make sense for machine learning purpose. We will be only dealing with digits while we train the neural network. Also, we need to take note of categorical inputs and transform them properly before feeding to neural network. As you could observe, there are 14 fields on the customer data set. The last field "Exited" tells whether customer left the bank or not. '1' indicates that customer has left the bank. So it is going to be your output label. Check the data set and inspect the possible dependent labels on deciding the output. The first three labels: RowNumber, CustomerId & Surname can surely be neglected since those are not deciding factors. Now we have 10 fields for consideration apart from the output label. If you inspect, you will see there are two labels: Geography & Gender in which values are not digits. We need to transform them into digits in a meaningful manner before passing onto neural network. 'Gender' label should be mapped to binary values (0 or 1) depends on male/female. 'Geography' label on the other hand, have multiple values. We can use one hot encoding to encode this label to values(digits).

In DL4J, we can define a schema for the data set and then feed this schema into a transform process. We can then apply all the encoding and transformation.

Refer to the schema implementation code here

After this encoding, 'Geography' label will be converted to multiple columns with binary values. Let's say if we have 3 countries in data set, it will be mapped to three columns, each represents a country value. We also have to take care of dummy variable trap by removing one categorical variable. The removed variable becomes the base category against other categories. For example, we removed "France" and kept it as the base for indicating other 'country' values.

Refer to this block of code

Now it's time to split the data set to training/test sets. We will be using training set for training the neural network and test set will be used to measure how well your neural network has been trained. Mean error and success rate will be calculated at the end of each epoch. Use CSVRecordReader to read the csv file and pass it to TransformProcessRecordReader to apply the transformation we have defined above. Both record readers are implementations of RecordReader interface.

Refer to this block of code

Now let's define what are input labels and what are output labels. The resultant data set after applying transformation would have 13 columns. So the index values are 0 to 12. Last column represent expected output and all other columns are input labels.

Input training labels will look like this:

t

And the output data will look like this:

image

Remember to define a batch size for your data set. Batch size defines the quantity on which you want to transfer the data from data set to neural network. We have 10000 entries in our data set. We could have a batch size of 8000 (training set) so that the whole data set can be transferred in single data chunk. But remember, there's big difference while you chose larger batch size. There will be less number of updates performed if you choose large batch size. Now define a lower batch size and use DataSetIterator to retrieve data sets from the file.

The second and the best approach would be making use of DataSetIteratorSplitter.

We can now get training/test set iterators to pass into neural network model once it's ready to be trained.

Now, can we go ahead and feed this data to neural network? Absolutely not! Because when you inspect the data, you will see that data is not scaled properly. The data we're feeding to neural network should be comparable each other. The magnitude of 'Balance' and 'Estimated Salary' is way higher than most other labels. So, if we process them as such, there could be high dependency on these labels on computation. It would potentially hide the effect of other dependent labels for predicting the output model. So, we need to do feature scaling here.

Refer here

Remember that, data pre-processing is very crucial for avoiding incorrect outputs and errors and it is entirely dependent on the data we possess. Finally, we have the data that can be fed to neural network. Let's see how we can design the neural network model.

2) Define Neural Network Shape
First, we will start defining the neural network configuration. So, we specify how many neurons should be present in the input layer, the hidden layer structure & it's connections, the output layer, activation functions for each of the layers, the loss function for the output layer and the optimizer function.

Refer this block of code

As you could see we have added dropOuts between input layer and output layer. This is to optimize the neural network by avoiding the over-fitting. Also note that we didn't drop large portion of neurons (only 10% ) to avoid under-fitting at the same time. There are 11 input labels and one output category. So, we configured the same. How we decide the number of neurons in the hidden layer? A recommended model would be an average count of both input and output neurons, so it would be (11+1) / 2 = 6 . Note that our expected output model would indicate the whether the customer would leave the bank or not. It is going to be a probabilistic calculation of customers who leaves the bank. So, this scenario will look like logistic regression and hence use sigmoid activation function in the output layer. We also need to specify the loss function using which error rate will be calculated. In our case, it is the sum of squares of the difference between actual output and expected output. Corresponding loss function is then binary cross-entropy.

3) Train the model and predict results

Compile and initialize the model as shown here

Then we can start the neural network training by calling the fit() method.

Refer here

Now that we can go ahead and evaluate our results using the Evaluation class in DL4J.

Refer here

The evaluation metrics will be displayed in the form a confusion matrix as shown below:

confmatrx

Now let's see the neural network in action. Let us predict if the following customer will leave the bank:

We need to write a new schema for the test dataset since test data doesnt have the label unlike train data. So, the only change is to remove the label from the previously discussed schema creation. Here is the new schema implementation for accepting input from user: Click here

We may design an API function that return the prediction results in the form of INDArray.

Now we can create a maven project and build the project. Then refer the JAR file in your application (For example, Spring Boot) as a maven dependency and use the API to predict the results. Since the results are in INDArray format, you need to write logic to display custom results as per your requirement. Here, our neural network predicts the probability in which a customer leaves the organisation. So, if probability is greater than 0.5 (50%), then you may handle the results in UI to indicate an unhappy customer.

Congrats! We have just developed a standard neural network with about 85% prediction rate. Feel free to explore DL4J concrete examples here. Feel free to message me on Linkedin for any queries or clarifications.

Top comments (0)