The code below is a Livebook.
Mix.install(
[
{:nx, "~> 0.9.1"},
{:exla, "~> 0.9.1"},
{:kino, "~> 0.14.2"},
{:zigler, "~> 0.13.3"},
],
config: [nx: [default_backend: EXLA.Backend]]
)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)
Warning
If you run into an error above, this means that you don't have Zig. Comment the package Zigler
above.
This means you can't run the very last module where we use inline Zig
code.
Introduction
We want to produce an image that represents the beautiful Mandelbrot set
Source: https://en.wikipedia.org/wiki/Mandelbrot_set
We use the Nx
library with the EXLA
backend to speed up the computations.
We also propose to run the equivalent code in Zig
in this Livebook
if you want extra speed. This happens thanks to the Zigler library. The Zig
code returns a binary that Nx
is able to consume and Kino
to display.
What is a Mandelbrot set?
In a "Mandlebrot image", each pixel has a colour repesenting how fast the underlying point "escapes" when calculating its iterates under a certain function.
What is an underlying point?
A pixel has some coordinates [i,j]
. For example, in a 1024 × 768 image (WIDTH x HEIGHT), the row number varies from from 0 to 1023 and column number from 0 to 767.
We transform these couples of integers (i,j)
into a point into a 2D plane. This map "quantitizes" the 2D plane.
Here, the 2D "real" plan is defined by the upper left corner, say (-2,1)
, and bottom right corner, say (1,-1)
.
How? We have a linear mapping between the couple (i,j)
and a point (x,y)
in the defined zone. For example, the pixel (0,0)
becomes (-2,1)
and the pixel (767, 1023)
becomes (1,-1)
.
What is iterating?
We will iterate the function: x -› x*x +c
where c
is a given number and x
the variable.
We start with:
z0 = f(0) = c
z1 = f(z0) = z0 * z0 + c
z2 = f(z1) = z1 * z1 + c
...
Let's take an example. The module below calculates the iterations x(n) = f(x(n-1))
by a simple recursion.
The sets of these iterates of c
is called its orbit .
defmodule Simple do
def p(x,c), do: x**2 + c
# initial value
def iterate(1,c), do: c
# the n-th step
def iterate(n,c), do: p(iterate(n-1, c), c)
end
We calculate the first elements of its orbit and evaluate how does the point c=1
behaves. It looks like it will diverge to infinity.
c = 1
{ c,
Simple.iterate(1,c), Simple.iterate(2,c), Simple.iterate(3,c), Simple.iterate(4,c), Simple.iterate(5,c), Simple.iterate(6,c),
}
gives:
{1, 1, 2, 5, 26, 677, 458330}
On the other side, the point c=-1
seems well bahaved: the orbit has only two values, 0 and - 1, and is periodic.
c = -1
{ c,
Simple.iterate(1,c), Simple.iterate(2,c), Simple.iterate(3,c), Simple.iterate(4,c),Simple.iterate(5,c),Simple.iterate(6,c),
}
gives:
{-1, -1, 0, -1, 0, -1, 0}
In the examples above, we took a simple "real" number.
For the Mandelbrot set, we use the complex representation of a point: (x,y) -> x + y*i
where i
is the imaginary number (i * i = -1
).
So, each pixel (i,j)
is mapped to a complex number c = projection(i,j)
, and we want to evaluate how do the iterates of c
behave under this iteration starting at z0 = 0
.
Iteration number?
We are interested by assigning an iteration number to each c
.
The number of iterations needs to be bounded (think of a periodic orbit). Let max_iter
be the maximum number of iterations, for example 100.
If the orbit of c
remains bounded, we assign an iteration number to max_iter
.
If it escapes, meaning one iterate has a norm greater than 2, then we calculate the first index such that the iterate norm is greater than 2 (in absolute value as a complex, or its norm as a point).
Complex calculus interface
We will use two types of functions:
-
Elixir
functions usingdef
-
Nx
functions usingdefn
; these use a special backend (EXLA with CPU or GPU if any)
The points of the 2D plane will be represented as complex numbers as the Mandelbrot map works with complex numbers.
The function z(n+1) = z(n) * z(n) + c
takes a complex number and returns a complex number.
Below is a helper module to work with complex number in numerical Elixir.
We use numerical functions, declared with
defn
. All the arguments are treated as tensors .
defmodule Ncx do
import Nx.Defn
defn i() do
Nx.Constants.i()
end
# primitive to build a complex scalar tensor
defn new(x,y) do
x + i() * y
end
# square norm
defn sq_norm(z) do
Nx.conjugate(z) |> Nx.dot(z) |> Nx.real()
end
end
Algorithm
Source: https://en.wikipedia.org/wiki/Plotting_algorithms_for_the_Mandelbrot_set
Input: image dimensions (eg w x h of 1500 x 1000), max iteration (eg 100)
Iterate over each pixel (i,j)
:
- map it into the 2D plane: compute its "complex coordinates"
- compute the iteration number
- compute a colour
- Sum-up and draw from the final tensor with
Kino
.
Pixel to complex plan mapping
This module transforms a couple (i,j)
into a complex number.
› Notice that once you are in a numerical function, the arguments becomes "tensors", and a tensor can be of type complex "c64". It natively understands complex numbers.
defmodule Pixel do
import Nx.Defn
defn map(index, {h,w}, {top_left_x, top_left_y, bottom_right_x,bottom_right_y}) do
scale_x = Nx.divide(bottom_right_x-top_left_x, w-1)
scale_y = Nx.divide(bottom_right_y-top_left_y, h-1)
# building a complex typed tensor
Ncx.new(
top_left_x + Nx.dot(index[1],scale_x),
top_left_y + Nx.dot(index[0], scale_y)
)
end
end
Orbit and iteration number
This module computes the iteration number for a given input c
.
If |c|>2
, then this point is unstable. Otherwise, we have to compute for each point whether it stays bounded or not.
If it is bounded, we get max_iter
, otherwise a lower value.
It is also using numerical functions via defn
.
We cannot use the recursion form we did earlier because numerical functions don't accept several headers as plain Elixir. Instead we run a specil while
loop. Note how we use the Nx
versions of cond
, and also the double condition managed byNx.logical_and
, and also the Nx
version of cond
. Also, true
is 1
.
defmodule Orbit do
import Nx.Defn
defn poly(z,c) do
z*z + c
end
defn number(c,max_iter) do
condition = (Nx.real(c) +1) ** 2 + (Nx.imag(c)**2)
cond do
# points in first cardioid are all stable. Save on iterations
Nx.less(condition, 0.0625) ->
max_iter
# these points are unbounded whenever the norm is > 2
Nx.greater(Ncx.sq_norm(c), 4) ->
0
# we have to evaluate each point as it can be or not bounded in the disk 2
1 ->
{_, _, j} =
while {z=c, c, i=max_iter}, Nx.logical_and(Nx.greater(i,1), Nx.less(Ncx.sq_norm(z), 4)) do
{poly(z,c), c,i-1}
end
max_iter - j
end
end
end
Examples:
st = Ncx.new(0.2, 0.2)
dv1 = Ncx.new(0.4, 0.4)
dv2 = Ncx.new(0.3, 0.6)
dv3 = Ncx.new(2,2)
iter_max = 100
iter_dv1 = Orbit.number(dv1, iter_max) #<- we should find 8 iterations before z_n escapes from the disk 2
iter_dv2 = Orbit.number(dv2, iter_max) #<- we should find 14 iterations before z_n escapes from the disk 2
iter_dv3 = Orbit.number(dv3, iter_max)
iter_st = Orbit.number(st, iter_max) #<- this point is stable and the loop reaches n interations.
%{
"unstable/2: #{Nx.to_number(dv2)}" => iter_dv2 |> Nx.to_number(),
"unstable/1: #{Nx.to_number(dv1)}" => iter_dv1 |> Nx.to_number(),
"out_of_disk2: #{Nx.to_number(dv3)}" => iter_dv3 |> Nx.to_number(),
"stable: #{Nx.to_number(st)}" => iter_st |> Nx.to_number(),
}
%{
"out_of_disk2: 2.0+2.0i" => 0,
"stable: 0.20000000298023224+0.20000000298023224i" => 99,
"unstable/1: 0.4000000059604645+0.4000000059604645i" => 8,
"unstable/2: 0.30000001192092896+0.6000000238418579i" => 14
}
A Colour palette
Each iteration number is an integer n
. We want to associate a colour [r(n),g(n),b(n)].
This will help us to visualise which point of the complex plane is stable, and if not how fast it escapes.
The choice below is just an example. Other choices can be made.
defmodule Colour do
import Nx.Defn
defn normalize(n, max_iter) do
n / max_iter
end
defn rgb(n) do
cond do
Nx.equal(n, 0) ->
Nx.stack([255, 255, 0]) |> Nx.as_type(:u8)
Nx.less(n, 0.5) ->
scaled = n * 2
r = 255 * (1 - scaled)
g = 255 * (1 - scaled/2)
b = 127 * scaled
Nx.stack([r, g, b]) |> Nx.as_type(:u8)
true ->
scaled = (n - 0.5) * 2;
r = 255*(1+scaled/2)
g = 128 * (1+scaled/2)
b = 255 * (1 - scaled)
Nx.stack([r, g, b]) |> Nx.as_type(:u8)
end
end
end
Computing the Mandelbrot set
We will know reassemble our modules.
Firstly, an example.
dim = {500,500}; iter_max = 100
p = Nx.tensor([30,40])
c_i_j = Pixel.map(p,dim, defining_points)
n_i_j = Orbit.number(c_i_j, iter_max)
nm_i_j = Colour.normalize(n_i_j, iter_max)
{Nx.to_number(n_i_j), Colour.rgb(nm_i_j)} |> dbg()
p = Nx.tensor([40,70])
c_i_j = Pixel.map(p,dim, defining_points)
n_i_j = Orbit.number(c_i_j, iter_max)
nm_i_j = Colour.normalize(n_i_j, iter_max)
{Nx.to_number(n_i_j), Colour.rgb(nm_i_j)} |> dbg()
We found that this pixel reached a point in the complex plan that escapes rather quickly from the disk 2. It get stamped with some colour.
{4,
#Nx.Tensor<
u8[3]
EXLA.Backend<host:0, 0.3807096825.1655832596.50891>
[234, 244, 10]
>}
The final module
We then reassemble the tensor into the desired format for Kino
to consume it and display.
Note that you want to pass arguments into a
defn
function that you don't want to be treated as tensors, you need to use a keyword list or a map.
defmodule Mandelbrot do
import Nx.Defn
defn compute(opts) do
top_left_x = -2; top_left_y = 1.2; bottom_right_x = 0.6; bottom_right_y = - 1.2;
defining_points = {top_left_x, top_left_y, bottom_right_x, bottom_right_y}
h = opts[:h]
w = opts[:w]
max_iter = opts[:max_iter]
# build the tensor [[0,0],, ...[0,m], [1,1]...[n,m]]. Thks to PValente
iota_rows = Nx.iota({h}, type: :u16) |> Nx.vectorize(:rows)
iota_cols = Nx.iota({w}, type: :u16) |> Nx.vectorize(:cols)
cross_product = Nx.stack([iota_rows, iota_cols])
Pixel.map(cross_product,{h,w}, defining_points)
|> Orbit.number(max_iter)
|> Colour.normalize(max_iter)
|> Colour.rgb()
|> Nx.devectorize()
|> Nx.reshape({h, w, 3})
|> Nx.as_type(:u8)
end
end
Depending on your machine, the computation below can be lengthy.
If you want to simply evaluate, set h = w = 400
.
h = w = 400;
Mandelbrot.compute(h: h, w: w, max_iter: 100)
|> Kino.Image.new()
Parallelise it with async_stream
When the resolution of the image increases, it is interesting to parallelise the computations.
We divide the image in horizontal bands, as many as the number of CPU cores on the machine.
When you use async_stream
, the BEAM - the VM that runs this code - parallelises the running code on the cores.
This is worth only if the size of the image is large enough as this comes with non negligible overhead.
We also set ordered: true
as we need to sum-up the results in an ordered manner.
Another possible optimisation is to remark that the image is symmetric. You can compute half of the image (redefine
h
to beh-rem(h, cpus*2)
but you would need to be able to reverse a tensor.
defmodule StreamMandelbrot do
import Nx.Defn
@doc"""
Example: 42 rows, 8 cpus
42 rows = 8cpus * 5 + 2
We run 8 threads consuming 5 rows each
We just ignore the last 2 rows.
"""
def run(%{h: h, w: w} = opts) do
cpus = :erlang.system_info(:logical_processors_available)
# we eliminate a few rows from the final image, 8 at most.
h = h - rem(h,cpus)
rows_per_cpu = div(h, cpus)
Task.async_stream(0..cpus-1, fn cpu_count ->
# we shift the start index by the number of rows already consummed
iota_rows = Nx.iota({rows_per_cpu}, type: :u16) |> Nx.add(cpu_count * rows_per_cpu)|> Nx.vectorize(:rows)
# full width
iota_cols = Nx.iota({w}, type: :u16) |> Nx.vectorize(:cols)
cross_product = Nx.stack([iota_rows, iota_cols])
Nx.Defn.jit_apply(fn t ->
compute(t, opts) end, [cross_product])
end,
timeout: :infinity, ordered: true
)
|> Enum.map(fn {:ok, t} -> t end) #&elem(&1, 1)
|> Nx.stack()
|> Nx.reshape({h,w,3})
end
defn compute(cross_product, %{h: h, w: w, max_iter: max_iter}) do
top_left_x = -2; top_left_y = 1.2; bottom_right_x = 0.6; bottom_right_y = -1.2;
defining_points = {top_left_x, top_left_y, bottom_right_x, bottom_right_y}
Pixel.map(cross_product,{h,w}, defining_points)
|> Orbit.number(max_iter)
|> Colour.normalize(max_iter)
|> Colour.rgb()
|> Nx.devectorize()
|> Nx.as_type(:u8)
end
end
When we run the code, we have much faster results. On my machine, it took 44s to draw a 1M pixels image. We get the expected performance boost.
h= w = 400;
StreamMandelbrot.run( %{h: h, w: w, max_iter: 200})
|> Kino.Image.new()
Run embedded Zig code
If we still need or want extra speed, we can also embed Zig
code in Elixir
within a Livebook.
Zigler
offers a remarkable documentation.
❗ You may to have Zig
installed on your machine.
In the Livebook
, we add the dependencies (in the first cell):
Mix.install([{:zigler, "~> 0.13.3"},{:zig_get, "~> 0.13.1"},])
With the Zigler
, we can even inline Zig code.
The code below runs the same algorithm and runs OS threads for concurrency.
we use the
beam
memory allocator from the library.the slice is returned as a binary - typed as
beam.term
- to be easily consumed byNx
and thenKino
.
defmodule Zigit do
use Zig, otp_app: :zigler,
nifs: [..., generate_mandelbrot: [:threaded]]
# release_mode: :fast
~Z"""
const beam = @import("beam");
const std = @import("std");
const Cx = std.math.Complex(f64);
const topLeft = Cx{ .re = -2.1, .im = 1.2 };
const bottomRight = Cx{ .re = 0.6, .im = -1.2 };
const w = bottomRight.re - topLeft.re;
const h = bottomRight.im - topLeft.im;
const Context = struct {res_x: usize, res_y: usize, imax: usize};
/// nif: generate_mandelbrot/3 Threaded
pub fn generate_mandelbrot(res_x: usize, res_y: usize, max_iter: usize) !beam.term {
const pixels = try beam.allocator.alloc(u8, res_x * res_y * 3);
defer beam.allocator.free(pixels);
const resolution = Context{ .res_x = res_x, .res_y = res_y, .imax = max_iter };
const res = try createBands(pixels, resolution);
return beam.make(res, .{ .as = .binary });
}
// <--- threaded version
fn createBands(pixels: []u8, ctx: Context) ![]u8 {
const cpus = try std.Thread.getCpuCount();
var threads = try beam.allocator.alloc(std.Thread, cpus);
defer beam.allocator.free(threads);
// half of the total rows
const rows_to_process = ctx.res_y / 2 + ctx.res_y % 2;
// one band is one count of cpus
// const nb_rows_per_band = rows_to_process / cpus + rows_to_process % cpus;
const rows_per_band = (rows_to_process + cpus - 1) / cpus;
for (0..cpus) |cpu_count| {
const start_row = cpu_count * rows_per_band;
// Stop if there are no rows to process
if (start_row >= rows_to_process) break;
const end_row = @min(start_row + rows_per_band, rows_to_process);
const args = .{ ctx, pixels, start_row, end_row };
threads[cpu_count] = try std.Thread.spawn(.{}, processRows, args);
}
for (threads[0..cpus]) |thread| {
thread.join();
}
return pixels;
}
fn processRows(ctx: Context, pixels: []u8, start_row: usize, end_row: usize) void {
for (start_row..end_row) |current_row| {
processRow(ctx, pixels, current_row);
}
}
fn processRow(ctx: Context, pixels: []u8, row_id: usize) void {
// Calculate the symmetric row
const sym_row_id = ctx.res_y - 1 - row_id;
if (row_id <= sym_row_id) {
// loop over columns
for (0..ctx.res_x) |col_id| {
const c = mapPixel(.{ @as(usize, @intCast(row_id)), @as(usize, @intCast(col_id)) }, ctx);
const iter = iterationNumber(c, ctx.imax);
const colour = createRgb(iter, ctx.imax);
const p_idx = (row_id * ctx.res_x + col_id) * 3;
pixels[p_idx + 0] = colour[0];
pixels[p_idx + 1] = colour[1];
pixels[p_idx + 2] = colour[2];
// Process the symmetric row (if it's different from current row)
if (row_id != sym_row_id) {
const sym_p_idx = (sym_row_id * ctx.res_x + col_id) * 3;
pixels[sym_p_idx + 0] = colour[0];
pixels[sym_p_idx + 1] = colour[1];
pixels[sym_p_idx + 2] = colour[2];
}
}
}
}
fn mapPixel(pixel: [2]usize, ctx: Context) Cx {
const px_width = ctx.res_x - 1;
const px_height = ctx.res_y - 1;
const scale_x = w / @as(f64, @floatFromInt(px_width));
const scale_y = h / @as(f64, @floatFromInt(px_height));
const re = topLeft.re + scale_x * @as(f64, @floatFromInt(pixel[1]));
const im = topLeft.im + scale_y * @as(f64, @floatFromInt(pixel[0]));
return Cx{ .re = re, .im = im };
}
fn iterationNumber(c: Cx, imax: usize) ?usize {
if (c.re > 0.6 or c.re < -2.1) return 0;
if (c.im > 1.2 or c.im < -1.2) return 0;
// first cardiod
if ((c.re + 1) * (c.re + 1) + c.im * c.im < 0.0625) return null;
var z = Cx{ .re = 0.0, .im = 0.0 };
for (0..imax) |j| {
if (sqnorm(z) > 4) return j;
z = Cx.mul(z, z).add(c);
}
return null;
}
fn sqnorm(z: Cx) f64 {
return z.re * z.re + z.im * z.im;
}
fn createRgb(iter: ?usize, imax: usize) [3]u8 {
// If it didn't escape, return black
if (iter == null) return [_]u8{ 0, 0, 0 };
// Normalize time to [0,1[ now that we know it isn't "null"
const normalized = @as(f64, @floatFromInt(iter.?)) / @as(f64, @floatFromInt(imax));
if (normalized < 0.5) {
const scaled = normalized * 2;
return [_]u8{ @as(u8, @intFromFloat(255 * (1 - scaled))), @as(u8, @intFromFloat(255.0 * (1 - scaled / 2))), @as(u8, @intFromFloat(127 + 128 * scaled)) };
} else {
const scaled = (normalized - 0.5) * 2.0;
return [_]u8{ 0, @as(u8, @intFromFloat(127 * (1 - scaled / 2))), @as(u8, @intFromFloat(255 * (1 - scaled))) };
}
}
"""
end
We run the Zig code. It returns a binary that we are able to consume with Nx
and display the image.
To draw an image of 1M pixels, it takes a few milliseconds. Feels like magic.
h = w = 5_000
max_iter = 300;
Zigit.generate_mandelbrot(h, w, max_iter)
|> Nx.from_binary(:u8)
|> Nx.reshape({h, w, 3})
|> Kino.Image.new()
Top comments (0)