I am sure many of you have experienced that when you try to do machine learning with Java or Scala, there is no cool graphing tool unlike that Python has Matplotlib. Did you wonder if you could draw Matplotlib chart with Java..?
Matplotlib4j is a library which gives you the power!
sh0nk / matplotlib4j
Matplotlib for java: A simple graph plot library for java, scala and kotlin with powerful python matplotlib
How to Use
Add Library
From here I will introduce Java examples. Of course, it can also be used from other JVM languages ββsuch as Scala and Kotlin. The examples will be described later.
First, add Matplotlib4j to the Java project where you want to use Matplotlib.
For Maven, add the following dependency.
<dependency>
<groupId>com.github.sh0nk</groupId>
<artifactId>matplotlib4j</artifactId>
<version>0.5.0</version>
</dependency>
Similarly, for Gradle:
compile 'com.github.sh0nk:matplotlib4j:0.5.0'
That's all. Let's begin drawing!
Draw Graphs
The usage is similar to Matplotlib's API, so we can write it intuitively. First, create a Plot
object, call the pyplot
method on it to add an arbitrary graph, and finally call the show()
method; since it is a Builder pattern, we will add options behind it using IDE completion.
Scatter Plot
As a starting point, let's draw a scatter plot.
List<Double> x = NumpyUtils.linspace(-3, 3, 100);
List<Double> y = x.stream().map(xi -> Math.sin(xi) +γMath.random()).collect(Collectors.toList());
Plot plt = Plot.create();
plt.plot().add(x, y, "o").label("sin");
plt.legend().loc("upper right");
plt.title("scatter");
plt.show();
With the above Java code, we can draw the following graph.
Some Numpy methods, such as linspace
and meshgrid
, have been prepared as NumpyUtils
classes to help with graph drawing. The first block generates the x and y data for plotting. Here we give a random value to the sin curve. After that, we create a plot object, add the generated x and y data to the plot()
method, and call show()
at the end to draw the graph.
This is almost equivalent to the Python implementation below (almost, because the data generation part of numpy is strictly different). The method calls are similar, making it easy to use in case you are a Pythonista.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-3, 3, 100)
y = np.sin(x) + np.random.rand(100)
plt.plot(x, y, "o", label="sin")
plt.legend(loc="upper right")
plt.title("scatter")
plt.show()
Contour Plot
Next, let's draw a contour plot (contour line).
List<Double> x = NumpyUtils.linspace(-1, 1, 100);
List<Double> y = NumpyUtils.linspace(-1, 1, 100);
NumpyUtils.Grid<Double> grid = NumpyUtils.meshgrid(x, y);
List<List<Double>> zCalced = grid.calcZ((xi, yj) -> Math.sqrt(xi * xi + yj * yj));
Plot plt = Plot.create();
ContourBuilder contour = plt.contour().add(x, y, zCalced);
plt.clabel(contour)
.inline(true)
.fontsize(10);
plt.title("contour");
plt.show();
Histogram
Histograms can be drawn in the same way.
Random rand = new Random();
List<Double> x1 = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian())
.collect(Collectors.toList());
List<Double> x2 = IntStream.range(0, 1000).mapToObj(i -> 4.0 + rand.nextGaussian())
.collect(Collectors.toList());
Plot plt = Plot.create();
plt.hist()
.add(x1).add(x2)
.bins(20)
.stacked(true)
.color("#66DD66", "#6688FF");
plt.xlim(-6, 10);
plt.title("histogram");
plt.show();
Save Image to File
Matplotlib4j also supports saving to a file. Saving images to a file would be convenient for use cases that do not have a GUI, such as batch processing of machine learning on a server.
Similar to the original Matplotlib, by using the .savefig()
method instead of .show()
, the image is saved to a file without popping up a plot window. The only difference is that plt.executeSilently()
needs to be called after .savefig()
, which is necessary as a termination process since the savefig command can also be a part of a method chain.
Random rand = new Random();
List<Double> x = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian())
.collect(Collectors.toList());
Plot plt = Plot.create();
plt.hist().add(x).orientation(HistBuilder.Orientation.horizontal);
plt.ylim(-5, 5);
plt.title("histogram");
plt.savefig("/tmp/histogram.png").dpi(200);
// Necessary to output the file
plt.executeSilently();
This will output an image like the one below.
Switch Python with pyenv, pyenv-virtualenv
To use Matplotlib4j, you need to install Matplotlib with Python environment; by default, Matplotlib4j will use the Python that is in your environment path, but in many cases you may not have Matplotlib installed in the system default Python.
In that case, you can switch to a Python environment with Matplotlib installed, such as Anaconda, using pyenv or pyenv-virtualenv.
To use Python according to the Pyenv environment, specify PythonConfig
when creating the Plot
object as follows.
Plot plot = Plot.create(PythonConfig.pyenvConfig("Arbitrary pyenv name"));
Similarly, you can specify the environment name of pyenv-virtualenv.
Plot plot = Plot.create(PythonConfig.pyenvVirtualenvConfig("Arbitrary pyenv name", "Arbitrary virtualenv name"));
Scala
When used from Scala, the aforementioned scatter plot example can be written as follows, just by paying attention to the difference of the Boxing/Unboxing numbers and List classes.
import scala.collection.JavaConverters._
val x = NumpyUtils.linspace(-3, 3, 100).asScala.toList
val y = x.map(xi => Math.sin(xi) + Math.random()).map(Double.box)
val plt = Plot.create()
plt.plot().add(x.asJava, y.asJava, "o")
plt.title("scatter")
plt.show()
Need More Details?
In the Tutorial page, you can find more cases step by step in Java, Scala and Kotlin.
Extra
How it All Started
I recently started reading a book of Deep Learning and decided to try to implement it in Scala which I've often touched lately, since it was not interesting to copy the code on the book in Python as it is. I was happy to be able to write it in a functional way in Scala, but when I got to the backpropagation using the steepest descent method, I encountered a situation where the loss was not dropping at all, and I thought, "What's wrong?"
Of course, the common practice to tackle this is to thicken the tests, but I'd like to see what's going on first quickly by displaying a graph like in the book. But found that there are no good graphing tools in Scala... However, implementing the graphing tool in Scala from scratch is too hard... So I decided to use Matplotlib, which is a familiar Python library, as the reason to create the library.
Design
Matplotlib4j calls Matplotlib in a way that generates Python code without using JNI or Jython. Initially, I wanted to implement it using Jython, but it only supported the Python version up to 2.7, and since numpy wasn't supported, the Matplotlib which depends on it wouldn't work either, so I decided to abandon this path.
There is a library in the world that allow you to use CPython from Java code, and this one was a candidate because we can use both Python3 and numpy. However, we had to install a separate environment-dependent library to use JNI, and we also had to install the library from pip on the Python side, which was too much work for something as simple as drawing graphs. So in the end I have decided to implement it independently of these libraries at all.
Of course, since it is executed via a file, I had to do some tricks in how I pass variables and use return values. Fortunately, since the purpose is only to draw graphs, the basic functions can be satisfied by one-way output to a file, and I think the performance is within the acceptable range with some latency.
sh0nk / matplotlib4j
Matplotlib for java: A simple graph plot library for java, scala and kotlin with powerful python matplotlib
matplotlib4j
A simplest interface library to enable your java project to use matplotlib.
Of course it is able to be imported to scala project as below. The API is designed as similar to the original matplotlib's.
Now tutorial is under preparation to walkthrough the features If you want to skim only the idea of Matplotlib4j, skip that and go to the next section: How to use
How to use
Here is an example. Find more examples on MainTest.java
Plot plt = Plot.create();
plt.plot()
.add(Arrays.asList(1.3, 2))
.label("label")
.linestyle("--");
plt.xlabel("xlabel");
plt.ylabel("ylabel");
plt.text(0.5, 0.2, "text");
plt.title("Title!");
plt.legend();
plt.show();
Another example toβ¦
Top comments (0)