DEV Community

eblocha
eblocha

Posted on

Parallel Matrix Multiplication in Rust

Much of linear algebra revolves around matrix multiplications. Since this is an O(n^3) operation, other techniques are needed to improve its performance. Today, I will show how I parallelized matrix multiplication, improving speed by 4x on my 12-thread machine.

In my examples, I am using the nalgebra crate to handle matrix storage (which stores dynamically sized matrices as a Vec in column-major order). This isn't too important, the only thing that really matters is the ability to index the matrix by row and column index.

Single-Threaded Implementation

In order to multiply matrices, we need to iterate over the columns of the rhs, then for each column, iterate over the rows of the lhs. We then zip the lhs row with the rhs column, taking the dot product of these two arrays.

let l_shape = lhs.shape();
let r_shape = rhs.shape();

// check for shape compatibility here...

// the multiplication
let result: Vec<f64> = (0..r_shape.1).flat_map(move |rj| {
    (0..l_shape.0).flat_map(move |li| {
        (0..r_shape.0)
            .zip(0..l_shape.1)
            .map(move |(ri, lj)| {
                lhs.index((li, lj)) * rhs.index((ri, rj))
            })
            .sum::<f64>()
    })
})
.collect();

// result is a vec in column-major order
Enter fullscreen mode Exit fullscreen mode

Parallelizing

I will now use rayon to parallelize this operation. It's way simpler than you may think!

let result: Vec<f64> = (0..r_shape.1).into_par_iter().flat_map(move |rj| {
    (0..l_shape.0).into_par_iter().flat_map(move |li| {
        (0..r_shape.0)
            .zip(0..l_shape.1)
            .map(move |(ri, lj)| {
                lhs.index((li, lj)) * rhs.index((ri, rj))
            })
            .sum::<f64>()
    })
})
.collect();

Enter fullscreen mode Exit fullscreen mode

That's it! Just adding into_par_iter after the ranges is all that is needed. I benchmarked this on an Intel i5-10600K (12) @ 4.800GHz, multiplying a 1000x1000 matrix with a 1000x1 vector, and the average execution time went from 8 ms down to 2 ms. That is a 75% improvement in speed.

Latest comments (0)