Introduction and Basics
With the prevalence of Artificial Intelligence, data science engineers and developers working in various domains are widely using machine learning (ML) algorithms to make their tasks simpler and life easier. For example, economists are using AI to predict future market prices to make a profit, doctors use AI to classify whether a tumor is malignant or benign, certain machine learning algorithms enable Google Maps to find the fastest route to our destinations, AcuWeather to get the weather forecast of 3.5 million locations weeks in advance, Facebook to automatically detect faces and suggest tags, etcetera. The impetus behind such ubiquitous use of AI is machine learning algorithms. For anyone who wants to learn ML algorithms but hasn’t gotten their feet wet yet, you are at the right place. The preliminary algorithm that every Machine Learning enthusiast starts with is a linear regression algorithm. In this article, we will look into how linear regression algorithm works and how it can be efficiently used in your machine learning projects to build better models.
Any machine learning problem can be assigned to one of the two broad classifications: Supervised learning and Unsupervised learning
In Supervised learning, we are given a data set (input) and already have a vague idea of what our output should look like. All our algorithm does is, figure out a relationship between the input and the output and using that relationship predict even more correct answers for test inputs. It is called supervised learning because the process of an algorithm learning from the training dataset can be thought of as a teacher supervising the learning process. We know the correct answers, the algorithm iteratively makes predictions on the training data and is corrected by the teacher. Learning stops when the algorithm achieves an acceptable level of performance.
Supervised learning problems are further categorized into
- Regression: A regression problem is when the output variable is a real value, such as “dollars” or “weight”.
- Classification: : A classification problem is when the output variable is a category, such as “malignant” or “benign” or “disease” and “no disease”.
In Regression problems (which is our focus for today), we map our input variables to some continuous function. For example: Given the data about the size of a houses, predict their prices. So the price of a house is the function of its size. And clearly, it’s a continuous function.
In Classification problems, we map our input variables to a discrete function. For example: Given the data about favorite colors of different people, predict whether it’s a girl or a boy. Now, it is a discrete function, you are mapping all inputs to either a boy or a girl.
Linear Regression is a Regression algorithm.
What is Linear Regression?
From fig1 below, let’s say we have a dataset which contains information about the relationship between ‘housing prices’ from a fantasy country somewhere in the world and their ‘sizes’. This will be our training data (or dataset). Our goal is to design a model that can predict the price of a house if its size is provided. Using the training data, a regression line is obtained which will give minimum error.
fig1: House prices given their size
Linear Regression Algorithm
The idea is to feed our learning algorithm with some dataset, which then outputs a function called Hypothesis.
Hypothesis approximates a target function for mapping inputs to outputs.
The line seen in fig1 is the Hypothesis function. This line is the best fit that passes through most of the points and reduces the error, which is the average squared vertical distances of the scattered points from the line as shown in image below.
The diagram below conveniently gives a brief overview of the Linear Regression Algorithm.
The training set of housing prices is fed into the learning algorithm. Its main job is to produce a function, which by convention is called h (for hypothesis). You then use that hypothesis function to output the estimate house price y, by giving it the size of a house in input x.
For linear regression with one variable, our hypothesis function is of the form
h(x) = θ0 + θ1x
where θ0 and θ1 are the parameters. In our humble hypothesis function there is only one variable, that is x. For this reason our task is often called linear regression with one variable. Experts also call it univariate linear regression, where univariate means "one variable".
Hold on, we can’t tell the hypothesis function by merely looking at the graph! We need some systematic way to figure it out. This is where Cost function comes to our rescue.
Cost Function
At this point we know that there's a hypothesis function to be deduced. For our hypothesis function to fit our training data perfectly, we first need to find the parameters θ0 and θ1. If you recall how the equation of a line works, those two parameters control the slope and the height of the line. By tweaking θ0 and θ1, we want to find a line that represents at best our data. Picture 3. below shows what I mean:
We want something like the first example in the picture above. To achieve our goal we need to find the perfect values of θ0 and θ1. Our dataset has several examples where we know the size of the house x, and the actual price of the house y.
So the idea in a nutshell: let's try to choose θ0 and θ1 so that at least in the existing dataset, given the x as input parameter to the hypothesis function we can make accurate predictions for the y values. Once we are satisfied, we can use the hypothesis function with its pretty parameters to make predictions on new input data.
Cost function helps us to figure out which hypothesis best fits our data. Or more formally, It helps us to measure the accuracy of our hypothesis.
Mathematically, Cost function —
From a mathematical point of view, I want that for each i-th point in my data set, the difference hθ(x(i))−y(i)
is very small. Here hθ(x(i))
is the prediction of the hypothesis when we input the size of house number i, while y(i) is the actual price of the house number i. If that difference is small, it means our hypothesis has made an accurate prediction, because it is similar to the actual data.
Note the 1/2m and the summation part: we are properly computing a mean. That 2 at the denominator will ease some calculations in future steps. Also, the squaring is done so negative values do not cancel positive values.
So now we have our hypothesis function and a way of measuring how well it fits into the data(with help of cost function). All we need to do is estimate the value of the parameters in the hypothesis function for which the cost function is minimum. This is where gradient descent comes into the picture.
Gradient Descent
The next important concept needed to understand linear regression is gradient descent. Gradient descent is a method of updating θ0 and θ1 to reduce the cost function. The idea is to start with random θ1 and θ2 values and then iteratively updating the values, reaching minimum cost.Gradient Descent is an algorithm to minimize a function and it turns out that in the case of Linear regression, the cost function is a convex function i.e it has only one global minimum as shown in the following figure.
We put θ0 on the x axis and θ1 on the y axis, with the cost function on the vertical z axis. The points on our graph will be the result of the cost function using our hypothesis with those specific theta parameters. The graph above depicts such a setup. We will know that we have succeeded when our cost function is at the very bottom of the pits in our graph, i.e. when its value is the minimum. in the image below, the red circle show the minimum points in the graph. So the algorithm to minimize the function is to compute the direction of steepest descent and then take a small step downhill and just repeat that over and over again until you find the function’s minima.
Mathematically,
j=0,1 represents the feature index number.
i) := means assigning RHS to LHS
ii) α is the learning rate i.e the length of the steps taken to reach downhill which is always positive.
iii) The derivative term gives us the direction of steepest descent.
It looks like a cup and the optimization problem consists in finding the lowest point on the bottom edge. Sometimes, when the picture becomes too messy, it is common to switch to another representation: the contour plot. A contour plot is a graphical technique for representing a 3-dimensional surface by plotting constant horizontal slices, called contours, on a 2-dimensional format.Figure below shows what I'm talking about.
At each instance of tweaking the parameters, one should simultaneously update the parameters θ1, θ2,...,θn. Updating a specific parameter prior to calculating another one on the i(th) iteration would yield to a wrong implementation.
Let θ0 = 0, it’ll help us in understanding Gradient descent function better.
Gradient Descent for a single parameter —
From the above equation, the following things can be inferred —
Derivative term
- When the derivative term is positive i.e when the slope is increasing, θ1 will decrease.
- When the derivative term is negative i.e when the slope is decreasing, θ1 will increase. So it’s clear that the derivative term gives us the direction to move towards. ###### Learning Rate The size of each step is determined by α, known as the Learning rate. The value of α should be selected with care to ensure that the gradient descent algorithm converges in a reasonable time.
- If α is too small, gradient descent can be slow.
- If α is too large, gradient descent can overshoot the minimum and may even diverge. Another thing to notice is that, Gradient Descent can converge to a local minimum even with fixed α. This is because as we approach local minima, the slope automatically starts decreasing, thus gradient descent will automatically take smaller steps. So no need to change α over time. With the help of the Gradient Descent Algorithm, we can find the appropriate parameters for our hypothesis.
Conclusion
Linear Regression algorithm is a concept every Machine Learning engineer should know and it is also the right place to start for people who want to learn Machine Learning as well. Drop your comments below or say hello on Twitter. So this is all from my side, hope everything makes sense.
Discussion (0)