Hello everyone, in this tutorial we’ll cover how to detect and remove outliers in a dataset using the Matplotlib and Pandas Machine Learning libraries.
Here is the GitHub repo containing the full version of code shown here along with the train.csv data file(from which the data for the below tutorial is used): House Prices Advanced GitHub Repo
First, let’s display all the datapoints in the scatterplot where the two variable being plotted are OverallQual and SalePrice. This can be done using the following code snippet:
plot, ax = plt.subplots(1, 1, figsize = (12, 5))
sns.scatterplot(data = train_data, x = "OverallQual", y = "SalePrice", c = ["blue"], ax = ax)
Here is the resulting graph that is created:
Figure 1: Initial Scatter Plot Graph
Next, let’s work towards identifying the outliers in the above graph.
plot, ax = plt.subplots(1, 2, figsize = (12, 5))
outliers = (train_data["OverallQual"] == 10) & (train_data["SalePrice"] <= 250000)
sns.scatterplot(data = train_data, x = "OverallQual", y = "SalePrice", c = ["red" if is_outlier else "blue" for is_outlier in outliers], ax = ax[0])
train_data.drop(train_data[(train_data["OverallQual"] == 10) & (train_data["SalePrice"] <= 250000)].index, inplace = True)
sns.scatterplot(data = train_data, x = "OverallQual", y = "SalePrice", ax = ax[1], c = ["blue"])
plt.show()
Here’s an explanation for the code above:
- plt.subplots(1, 2, figsize = (12, 5)): We're creating two subplots inside one Matplotlib graph. The first parameter is the number of rows and the second defines the number of columns
- outliers = ...: Here we are identifying the outliers in the dataset using some specified criteria, namely: (train_data["OverallQual"] == 10) and (train_data["SalePrice"] <= 250000)
- First sns.scatterplot: Here we are defining the first scatterplot that will display the original data with the outliers highlighted in red. The red highlighting is caused by the conditional statement in the c option("red" if is_outlier else "blue" for is_outlier in outliers)
- Next we remove the outliers from our dataset via the pandas.DataFrame.drop() method(more on this method in the official docs)
- Second sns.scatterplot: Here we are defining the second scatterplot that will display the data with the outliers removed
- plt.show(): This method displays the graph that we just specified
And here’s the resulting graph from running the code snippet above:
Figure 2: First Graph showing the outlier in red, Second Graph shows the data with the outlier removed
And that’s one way to remove outliers from a dataset using Matplotlib, Pandas and Python! Thanks for following along.
Conclusion
Well that's it for this post! Thanks for following along in this article and if you have any questions or concerns please feel free to post a comment in this post and I will get back to you when I find the time.
If you found this article helpful please share it and make sure to follow me on Twitter and GitHub, connect with me on LinkedIn, subscribe to my YouTube channel.
If you found this article helpful please share it and make sure to follow me on Twitter and GitHub, connect with me on LinkedIn, subscribe to my YouTube channel.
Top comments (0)