This post describes how I added automatic differentiation to Tensorken. Tensorken is my attempt to build a fully featured yet easy-to-understand and hackable implementation of a deep learning library in Rust. It takes inspiration from the likes of PyTorch, Tinygrad, and JAX.

Tensorken's approach to automatic differentiation (or AD) is heavily inspired by JAX. Like JAX, Tensorken supports higher-order derivatives - besides the first derivative, it can calculate the second, third, and so on. Tensorken supports both forward and reverse-mode AD, and can arbitrarily compose the two. Finally, thanks to good fundamentals explained in the previous two posts (part 1 and part 2), Tensorken can compute derivatives on the CPU or GPU.

All code for this post is in the Tensorken repository, tagged v0.3.

*There's a "it's turtles all the way down" reference somewhere in this post, and only then will this image make sense. Generated by Lexica's Aperture model.*

##
**Previously in Tensors From Scratch: neural networks, matrix multiplication, and GPU acceleration**

Modern neural networks, for example, large language models (LLMs) like OpenAI's ChatGPT and GPT-4, Microsoft's Bing, Google's Bard, and Anthropic's Claude, are powered by tensors. Tensors are multi-dimensional arrays augmented with operations that execute efficiently on modern hardware, most notably GPUs.

To understand all that, I am building a neural network library like PyTorch or JAX, from the ground up in Rust. These libraries consist of:

A tensor library, to provide efficient operations to slice, dice, and apply bulk operations to tensors.

Accelerators, to accelerate tensor operations on the GPU.

Automatic differentiation, to train neural networks via gradient descent.

Neural network building blocks, to simplify using common activation functions and layers.

In the first post, I focussed on the tensor library. I described almost twenty fundamental tensor operations and abstracted them in a Rust trait called `RawTensor`

. `RawTensor`

had a single implementation, `CpuRawTensor`

which executes tensor computations on the CPU. In the second post, I implemented `RawTensor`

again in `WgpuRawTensor`

to execute on the GPU using wgpu, Rust's implementation of WebGPU. We dove into the nitty-gritty of GPU programming in general and wgpu in particular.

This third part of the series describes how to add automatic differentiation to Tensorken. Automatic differentiation (AD) is a technique to compute derivatives of tensor computations, without programmer intervention. AD is crucial because neural networks are trained via gradient descent, which relies on the efficient calculation of derivatives.

## How to train your neural network

Let's sketch how to train a neural network to emphasize how important AD is for deep learning.

First, gather training data, and lots of it. Training data are lots of input-expected output pairs. The input examples are encoded as numbers and aggregated in a tensor π. The outputs go in a tensor π. Think of π as the correct predictions for the inputs π. For a language model, each example in π could be a sequence of words, and π the next word, encoded as numbers. (How to encode text as numbers is an interesting problem that's not relevant to this story.)

Second, decide on the architecture of your neural network. A neural network consists of tensors πα΅’ that contain the parameters of the network. That's what you download when you get a model's weights. The architecture determines how many parameters we have and how we combine the input π with parameters Wα΅’ to obtain an output π'. Whatever the architecture is, we can execute it to predict π':

π' = π(π, πα΅’).

I'm simplifying - researchers distinguish weights π and biases b, but in the end, they're both part of the trainable parameters so I'm just lumping them together in the πs.

Third, using the expected π and the prediction π', calculate the loss π». The loss is a single number that is high when the prediction is bad, and low when it is good. The loss is calculated by comparing the network's prediction π' with the expected output π:

π» = π(π, π').

Fourth, calculate the gradient πΆ of the loss. The scalar loss value π» is a function of π, πα΅’, and π. Imagine the loss function as describing a (highly dimensional!) landscape. Training the network to improve its predictions means changing the parameters Wα΅’ to make the loss small. We'd like to know *how* we should change the parameters to achieve that.

Now is when the gradient comes in. Going back to the landscape analogy, to make the loss smaller we'd like to know the best direction to "move" in to go "down" - that is, from the current value of the parameters, find the direction with the highest slope. If you remember some calculus, the derivative of a function at a point is that slope. So, to calculate the gradient, we calculate the loss function's derivative with respect to each parameter πα΅’. In other words, we'll have a number for each parameter that tells us how to change that parameter to make the loss smaller.

Fifth, update the parameters using the gradient. There are many ways of doing this. The simplest is to multiply the gradient with a small number Ο΅ and subtract it from the parameters:

Wα΅’ <- Wα΅’ β Ο΅πΆα΅’.

That's one training step done! Your neural network just got a tiny bit better. Now repeat from step 2 until you've had enough. You can stop when the loss becomes small enough, when it stops changing for some number of iterations, or when your AWS bill exceeds the budget.

Tensorken can already do almost all of those steps. Running a network, calculating a loss, and updating the parameters amounts to applying tensor operations. What's missing is calculating the gradient via the loss function's derivative. In the olden days, people would calculate the derivative of the network by hand, symbolically, and then implement it manually. Clearly tedious and error-prone, not to mention limiting the complexity and size of the networks. Modern neural network libraries calculate a function's output and its derivative without programmer intervention using the miracle of automatic differentiation.

AD is a vast and intricate topic. For a (much) longer primer on the basics, see my earlier post. If you are unfamiliar with AD I encourage you to read it or any of the AD primers in the links.

The following section demonstrates Tensorken's AD capabilities and interface via small examples. Then we'll dive into implementation details, but I'll stay away from the detailed mechanics of AD since that's already covered elsewhere. Instead, I'll focus on how Tensorken implements higher-order, mixed-mode, JAX-style AD as an elegant and minimal Rust library.

## The Autodiff Cookbook in Tensorken

To demonstrate Tensorken's AD capabilities, I translated a significant part of JAX's Autodiff Cookbook to Tensorken. I reproduced and edited part of the original text here. JAX's license is Apache 2.0, so I hope this does not incur the wrath of Google. The titles in this section are similar to the ones in JAX's cookbook, in case you want to compare. The full example code is in jax_autodiff_cookbook.rs.

Before we begin - Tensorken runs on the CPU if you create tensors via the `Cpu32`

type alias and on the GPU via `Wgpu32`

. (The `32`

is because they work with 32-bit floating point numbers.) To make it easy to switch, I'll use the `Tr`

type alias throughout:

```
type Tr = Cpu32; // or Wgpu32
```

### Gradients

You can differentiate a function using `grad1`

. The `1`

indicates the number of arguments of the function - a poor man's variadic arguments. In the text, I'll sometimes refer to the family of `grad1`

, `grad2`

functions as `grad`

. In the code, I'll use the function with the correct number of arguments.

To start with, we'll use a simple scalar function - a function that takes a single number and returns a single number:

```
let p = Tr::scalar(2.0);
let df = grad1(|x| x.tanh(), &p);
> df: [0.07065082]
```

In Tensorken, all arguments must be a tensor `Tr`

- it doesn't support mixed tensors and scalar numbers. To turn a number into a tensor we first use `Tr::scalar`

. It makes a tensor with shape `[1]`

.

`grad1`

takes a function of one argument π and evaluates βπ(π), the derivative of π at a given point π. You can think of β as a higher-order function that takes a differentiable function π and produces a function that evaluates the derivative.

*Pronouncing β: I say "grad", I've heard people say "del", and the symbol's Unicode name is "nabla".*

Similarly, if you have a Rust function `f`

that evaluates the mathematical function π, then `grad(f, p)`

computes the value βπ(π).

Unlike JAX, Tensorken does not directly expose βπ as a first-class function, mostly because I had a hard time accomplishing that in Rust and staying sane! It required returning a closure from `grad(f)`

so you can write `grad(f)(p)`

, but satisfying the compiler proved difficult. So far this hasn't been a constraint in practice.

Like JAX, Tensorken does support applying `grad`

to functions that themselves call `grad`

to calculate higher-order derivatives:

```
let ddf = grad1(|x| grad1(|x| x.tanh(), x), &p);
let dddf = grad1(|x| grad1(|x| grad1(|x| x.tanh(), x), x), &p);
> ddf: [-0.13621868]
> dddf: [0.25265408]
```

Letβs try computing gradients with `grad`

in a linear logistic regression model. In other words, a simple neural network with one neuron. First, the setup:

```
// Outputs probability of a label being true.
fn predict<'t, T>(w: &'t T, b: &'t T, inputs: &T) -> T
where
T: TensorLike<'t>,
{
(inputs.dot(w) + b).sigmoid()
}
```

The function `predict`

encodes the architecture of our toy model. Its parameters are a vector `w`

and a scalar `b`

, for weights and bias. As you can see, we're multiplying the weights with the inputs and adding the bias. Then we use `sigmoid`

to squish the output values in the [0, 1] interval. This model predicts the probability of an outcome based on some input measurements.

Why is this equivalent to a neural network with a single neuron? Say the vector `w`

has three elements - three weights. We thus have three inputs as well, in `inputs`

. The function `dot`

multiplies each input with its corresponding weight, and then adds them up. The bias `b`

in the neuron analogy is typically a negative number, which represents a threshold that `inputs.dot(w)`

must exceed to "activate" the neuron.

All arguments are tensors, but are represented by a generic argument `T`

. The type needs to be generic so automatic differentiation can work. We'll see later why. `T: TensorLike`

is a handy constraint to make tensor operations like `dot`

, `+`

, and `sigmoid`

available on `T`

. You'll see the `TensorLike`

constraint often when using Tensorken's AD: to make functions differentiable, replace concrete `Tensor`

types with `T: TensorLike`

.

Let's run the model.

```
// Build a toy dataset.
// These are four measurements of some unspecified variable, one in each row.
let inputs = Tr::new(
&[4, 3],
&[
0.52, 1.12, 0.77, //
0.88, -1.08, 0.15, //
0.52, 0.06, -1.30, //
0.74, -2.49, 1.39,
],
);
// These are four observed outcomes, one for each row in the input.
let targets = Tr::new(&[4], &[1.0, 1.0, 0.0, 1.0]);
// Initialize the parameters w and b randomly
let key = 0;
let mut rng = StdRng::seed_from_u64(key);
let w = Tr::randn(&[3], &mut rng);
let b = Tr::randn(&[1], &mut rng);
let prediction = predict(&w, &b, &inputs);
> prediction: [0.4059896 0.37711427 0.9770815 0.007901279]
```

The inputs could be "changes in temperature observed on three consecutive days" and the targets could be "temperature went up or down on the next day". We're then training a model that predicts the probability of the temperature going up given three days' changes in temperature.

Since we initialized the model randomly, its prediction is random. We got unlucky: if you compare `targets`

(what we want) with `prediction`

(what we have) there is a big difference. The 3rd and 4th predictions are especially bad, almost the exact opposite of the training data.

To improve our model, we first need to quantify how crap the model is via a loss function.

```
// Training loss is the negative log-likelihood of the training examples.
fn loss<'t, T>(w: &'t T, b: &'t T, inputs: &T, targets: &'t T) -> T
where
T: TensorLike<'t>,
for<'s> &'s T: TensorLikeRef<T>,
{
let prediction = predict(w, b, inputs);
// ones_like makes a tensor of the same shape with all values equal to 1.
let label_probs = &prediction * targets
+ (&prediction.ones_like() - &prediction) * (targets.ones_like() - targets);
-label_probs.log().sum(&[0])
}
let l = loss(&w, &b, &inputs, &targets);
> loss: [10.4931755]
```

This loss function is *negative log-likelihood*. You can intuit why it works: `prediction`

is "compared" with `targets`

in `label_probs`

. It contains a high value for predictions that are close to the target. We then take the `log`

of each, which exaggerates its value: the logarithm is -infinity when `label_probs`

is zero. Since the logarithm is negative, we negate it to get a positive number. Then we take the `sum`

of the vector so we have a single positive loss number that is high when the model is doing badly, and low when it's making good predictions.

Now we can improve the model by adjusting its weights and biases. We use `grad`

to differentiate the `loss`

function with respect to the parameters `w`

and `b`

:

```
// Differentiate loss wrt weights
let w_grad = grad1(
|w| {
loss(
w,
&Reverse::lift(&b),
&Reverse::lift(&inputs),
&Reverse::lift(&targets),
)
},
&w,
);
print!("w_grad: {w_grad}");
// Differentiate loss wrt bias
let b_grad = grad1(
|b| {
loss(
&Reverse::lift(&w),
b,
&Reverse::lift(&inputs),
&Reverse::lift(&targets),
)
},
&b,
);
> w_grad: [-1.0830948 2.5363755 -3.2000453]
> b_grad: [-1.2319121]
```

To make the types work out, we need to `Reverse::lift`

all the arguments to `loss`

we do NOT want to differentiate. They are treated as constants. The type is called `Reverse`

because Tensorken uses reverse mode AD in this case. The `Reverse`

type reveals why we need to make the arguments to `loss`

and `predict`

generic: the `grad`

function, while taking a plain `Tr`

type as the second argument, passes `Reverse<Tr>`

to the closure. So the function `f`

can be called with `Tr`

, `Reverse<Tr>`

, or other types we'll see later.

Here's the simplified signature for `grad1`

.We'll get to the full signature later:

```
pub fn grad1<F>(f: F, at: &Tr) -> Tr where F: Fn(&Reverse<Tr>) -> Reverse<Tr>
```

Briefly, `Reverse`

is a wrapper to interpret tensor operations so they calculate the derivative along with the main result. In this example, it'll run a different `dot`

, `+`

, and `sigmoid`

compared to calling `loss`

with plain tensors of type `Tr`

.

Calling the `loss`

function twice is not ideal - we're doing twice the work. We can also calculate the gradients with respect to both `w`

and `b`

at the same time, using `grad2`

.

```
let (w_grad, b_grad) = grad2(
|w, b| loss(w, b, &Reverse::lift(&inputs), &Reverse::lift(&targets)),
&w,
&b,
);
> w_grad: [-1.0830948 2.5363755 -3.2000453]
> b_grad: [-1.2319121]
```

Finally, let's do a single training iteration and check if that improves our model.

```
// Update parameters
let new_w = &w - &w_grad;
let new_b = &b - &b_grad;
// Predict
let new_prediction = predict(&new_w, &new_b, &inputs);
let new_loss = loss(&new_w, &new_b, &inputs, &targets);
> new_prediction: [0.7384342 0.99262685 0.7747804 0.9996524]
> new_loss: [1.8016509]
```

A massive improvement - we're now only 1.8 crap, down from 10.5!

###
Evaluate a function and its gradient using `value_and_grad`

In a real training run, we'd do the above in a loop while keeping an eye on the loss to see when to stop. Again `loss`

is called twice: once inside `grad`

and once outside. Luckily, we don't have to. Another convenient family of functions is `value_and_grad`

to efficiently compute a function and its gradient.

```
let (loss_value, (w_grad, b_grad)) = value_and_grad2(
|w, b| loss(w, b, &Reverse::lift(&inputs), &Reverse::lift(&targets)),
&w,
&b,
);
> loss: [10.4931755]
> w_grad: [-1.0830948 2.5363755 -3.2000453]
> b_grad: [-1.2319121]
```

### Checking against numerical differences

Our loss improved, which is a good indication that things work. To gain confidence we can compare Tensorken's derivatives with finite differences.

```
// step size for finite difference
let eps = Tr::scalar(1e-4);
let half_eps = &eps / Tr::scalar(2.);
let b_grad_numerical = (loss(&w, &(&b + &half_eps), &inputs, &targets)
- loss(&w, &(&b - &half_eps), &inputs, &targets))
/ &eps;
> b_grad_numerical [-1.2207031]
> b_grad_autodiff [-1.2319121]
```

Close enough.

###
Jacobians using `jacfwd`

and `jacrev`

Ignoring bias `b`

for now, the `loss`

function is a function of three parameters, represented as a single tensor `w`

with three elements. It has a single scalar output, represented as a tensor with a single element. Taking the gradient of this function results in a vector of three elements, the sensitivity of the loss to each parameter. This picture becomes more complicated if there is more than one output parameter. `grad`

still gives an answer, but what does it mean?

```
let deriv = grad1(
|w| predict(w, &Reverse::lift(&b), &Reverse::lift(&inputs)),
&w,
);
> deriv: [0.34956074 -0.0017646346 0.20271438]
```

Remember that `predict`

returns a vector with four elements, and the input `w`

is a vector with three elements. We get a vector with three sensitivities - one for each input. But the sensitivity of which output? There are four. As we'll check below, `grad`

returns the sum of the sensitivity of all outputs. That's typically not what we want: we'd like to disaggregate the sensitivities.

The usual approach is to represent the sensitivity of each output with respect to each input as a matrix, called the Jacobian. In this case, a 4 by 3 matrix - number of outputs by number of inputs. Tensorken can compute Jacobians, in forward and reverse mode using `jacfwd`

and `jacrev`

:

```
let J = jacfwd(
|w| predict(w, &Forward::lift(&b), &Forward::lift(&inputs)),
&w,
);
> jacfwd result, with shape [4, 3]
β β
β 0.12540425 0.2701015 0.18569478 β
β 0.20671119 -0.25369102 0.03523486 β
β 0.01164451 0.0013435973 -0.029111274 β
β 0.0058007482 -0.019518733 0.010895999 β
β β
let J = jacrev(
|w| predict(w, &Reverse::lift(&b), &Reverse::lift(&inputs)),
&w,
);
> jacrev result, with shape [4, 3]
β β
β 0.12540427 0.27010152 0.18569478 β
β 0.20671119 -0.25369102 0.03523486 β
β 0.01164451 0.0013435973 -0.029111274 β
β 0.005800748 -0.019518731 0.010895998 β
β β
```

These two functions compute the same values (up to machine precision), but differ in their implementation: `jacfwd`

uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while `jacrev`

uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, `jacfwd`

probably has an edge over `jacrev`

.

We can now check that `grad`

computed the sum of the sensitivity of all four outputs:

```
&J.sum(&[0])
> [0.34956074 -0.0017646346 0.20271438]
```

Using a composition of `javfwd`

and `jacrev`

gives us a way to compute dense Hessian matrices. Hessian matrices contain all the second derivatives.

```
let hessian = jacfwd(
|w| {
jacrev(
|w| {
predict(
w,
&Reverse::lift(&Forward::lift(&b)),
&Reverse::lift(&Forward::lift(&inputs)),
)
},
w,
)
},
&w,
);
println!("hessian with shape {:?}", hessian.shape());
> hessian shape [4, 3, 3]
```

Why this shape? We start with a function f:π½βπΌ. Traditionally, we'd write π:ββΏββα΅, but that there are π½ inputs and πΌ outputs is more important than that we're talking about real numbers, so I'll omit the β from now on.

At a point π‘ β π½ we expect to get the shapes

π(π‘) β πΌ, the value of π at π‘,

βπ(π‘) β πΌ Γ π½, the Jacobian matrix at π‘,

βΒ²π(π‘) β πΌ Γ π½ Γ π½, the Hessian at π‘,

and so on.

To implement `hessian`

we could have used `jacfwd(jacrev(f))`

or `jacrev(jacfwd(f))`

or any other composition of the two. But forward-over-reverse is typically the most efficient. Thatβs because, in the inner Jacobian computation, weβre often differentiating a function with a wide Jacobian (maybe like a loss function π:π½β1), while in the outer Jacobian computation, weβre differentiating a function with a square Jacobian (since βπ:π½Γπ½), which is where forward-mode wins out.

Note we now need to `lift`

the inputs twice to make Rust's type checker happy.

That concludes the tour of Tensorken's AD capabilities. It packs a lot of punch - now let's see how to fit it in a small package.

## A tale of two functions

All AD functions like `jacfwd`

, `jacrev`

, and `grad`

are implemented in terms of two function-type pairs: `jvp`

with `Forward`

, and `vjp`

with `Reverse`

. JVP stands for Jacobian-vector product, and VJP stands for Vector-Jacobian product. These functions are directly inspired by JAX. To explain their names, we need some math background that deserves a standalone post. If you can't wait, refer to this section in JAX's Autodiff Cookbook.

I'll now introduce `jvp`

and `vjp`

, and the beginnings of how AD works in Tensorken. I assume some background knowledge about AD, in particular AD for scalar functions. See my earlier post for a primer.

### From scalars to tensors

Forward AD on scalar functions works by replacing operators and functions on numbers with versions that operate on a *dual number* - a `(f32, f32)`

tuple. The first element is the *primal*, which the function computes without AD. The second is the derivative, or *tangent*. Operations on dual numbers are straightforward:

apply the operation to the primal(s), and

apply differentiation rules to the tangent(s).

For example, multiplication on dual numbers is:

(pβ, tβ) . (pβ, tβ) = (pβ.pβ, pβ.tβ + pβ.tβ)

Reverse mode is more involved. The primal computation is identical, but instead of calculating the tangent alongside the primal, we collect a trace - essentially a stack of operations. A reverse pass through the trace calculates the derivatives.

Exactly how these operations are replaced is a concern for the implementation. Common methods are code transformation in the compiler, code generation, and operator overloading. Tensorken uses trait-based overloading.

Forward mode has little extra memory requirements beyond bringing the tangent along for the ride, while reverse mode needs to keep a trace that's as long as the computation is deep. As a result, for scalar-to-scalar functions, forward mode is more efficient.

That situation changes if we consider functions from many scalars to one, or vice versa. One extreme is a function that takes a single input and computes n outputs. That's great for forward mode: in one execution of the function on dual numbers, we'll have both the primal result and the derivative - or in other words the sensitivity of each output to a small change in the single input.

However, a function that takes many inputs and has a single output is efficient only in reverse mode. In forward mode, we'd need as many executions of the function as there are inputs - we'd have to pass 1 as tangent for each input separately. In reverse mode, we still need the extra memory for the trace, but one forward pass for the primal and one backward pass for the partial derivatives is all we need.

The good news is that if you understand this, nothing much changes if we allow tensors instead of scalars. After all, a tensor is a container of scalars, and operations on tensors can be broken down into operations on scalars. That's not how we want to implement them though! Bulk operations are where the performance is at.

For forward mode, we'll overload tensor operations to propagate a "dual tensor", a tuple of a primal and a tangent tensor. For reverse mode, we'll build up a trace of tensor operations in the forward pass and get the tangent tensors from a backward pass.

One difference with scalar AD is that we need to take the shape of tensors into account. Besides arithmetic operations like addition and multiplication, we also need to figure out differentiation rules for `sum`

, `reshape`

, and others, which affect the shape of both primal and derivative tensors.

### JVP for forward-mode AD

Here's the signature of `jvp`

:

```
pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
for<'a> F: Fn(&'a Forward<T>) -> Forward<T>,
```

As the `1`

suffix indicates, this version takes a single primal tensor `at`

and a single tangent tensor `tangent`

. It evaluates the primal and tangent of the function `f`

and returns them as a tuple.

To understand why this signature makes sense, think of AD as a program transformation. Without AD we'd write programs that boil down to:

```
let p1 = f1(x);
let p2 = f2(p1);
...
```

With forward AD and `jvp`

we can rewrite them as:

```
let (p1,t1) = jvp(f1, x, x.ones_like());
let (p2,t2) = jvp(f2, p1, t1);
...
```

That illustrates how programs that compose functions can be transformed into programs that compose `jvp`

-wrapped functions.

Importantly, `at`

and `tangent`

must have the same shape, and the two tensors in the output tuple have the same shape. `f`

computes `out.0`

from `at`

, and `jvp`

additionally computes `out.1`

from `at`

and `tangent`

.

`jvp`

works for any tensor-like type that implements `Diffable`

. `Diffable`

is the foundational trait that defines Tensorken's primitive, differentiable tensor operations. Higher-level operations like `matmul`

are built on these. Keeping the tensor generic in `jvp`

allows it to work with any `Diffable`

implementation - something we'll use when doing higher-order AD. We'll come back to the details soon.

### VJP for backward AD

Here is the signature for `vjp`

:

```
pub fn vjp1<'b, 't, T: Diffable + Clone + 't, F>(f: F, at: &T) -> (T, PullBack<'t, T>)
where
for<'a> F: Fn(&'a Reverse<'a, 't, T>) -> Reverse<'a, 't, T>,
```

It looks a bit different because reverse mode has a backward pass. What's the same are the differentiable function `f: F`

and the primal input `at`

. `vjp`

calls `f`

with `Reverse`

wrapping the input `T`

. Since reverse mode needs two passes, `vjp`

only returns the primal directly.

`PullBack`

(a term from differential geometry) is a named struct that executes the backward pass. It takes a *cotangent*, a tensor in the shape of the *output* of `f`

, and calculates the tangent, a tensor in the shape of the input of `f`

.

```
impl<T: Diffable + Clone> PullBack<'_, T> {
pub fn call(&self, cotangent: &T) -> T
}
```

It's all backward! But that's why reverse mode AD is more efficient if you have the right tensor shape.

A short note on why `jvp`

and `vjp`

have different signatures. On the one hand, we could re-write `jvp`

to return a `PushForward`

struct with a `call`

function that works similarly to `jvp`

's `PullBack`

. However, that would require keeping a trace of the operations around so users can `call`

multiple times with different tangents. That jeopardizes the memory efficiency of forward mode. The ability to re-execute the differentiating pass with different tangents does not offset the added memory usage.

We could also write `vjp`

with a signature like `jvp`

by making the `PullBack`

internal and `call`

ing at the end. In reverse mode, we have to expend the memory anyway, so we might as well make it available to the user for potential reuse.

## Interpreters all the way down

We're now at the point where we can dive into the code, and it's interpreters all the way down.

Before AD, Tensorken's core was the `RawTensor`

trait, with implementations for the CPU and the GPU. It's useful to think of this trait as the definition of a language for primitive tensor operations, and implementations of the trait as interpreters of that language. Interpreters don't necessarily have to produce a tensor - for debugging and testing a pretty-printing interpreter for `RawTensor`

is useful:

```
impl RawTensor for String {
type Elem = f32;
fn exp(&self) -> Self {
format!("{self}.exp()")
}
fn add(&self, other: &Self) -> Self {
format!("({self} + {other})")
}
// etc
}
```

We can use it as follows:

```
let t1: String = RawTensor::new(&[2, 2], &[1., 2., 3., 4.]);
let t2: String = RawTensor::new(&[2, 2], &[5., 6., 7., 8.]);
let r = t1.exp().add(&t2.log());
> r: "(new([2, 2], [1.0, 2.0, 3.0, 4.0]).exp() + new([2, 2], [5.0, 6.0, 7.0, 8.0]).log())"
```

Or even:

```
let t1: String = "A".to_string();
let t2: String = "B".to_string();
let r = t1.exp().add(&t2.log());
> r: "(A.exp() + B.log())"
```

We could generate source code or an abstract syntax tree this way, turning the interpreter into a compiler of sorts. That is the essence of the final tagless approach I described in depth in an earlier post. It has many extensibility advantages, which we'll take advantage of soon.

What does this have to do with automatic differentiation? AD is achieved by hard-coding how to differentiate primitive operations like addition and multiplication, and composing those primitive rules via the chain rule. The primitive operations define a language of tensor operations which we can interpret in a few ways - in particular, as straightforward tensor operations without differentiation via `Tensor`

, as a forward mode differentiated program via `Forward`

, or as a reverse mode differentiated program via `Reverse`

. As for `RawTensor`

we represent the primitive operations of the differentiable language as a trait, `Diffable`

, and then implement this trait for each interpreter.

Let's start with the trait definition:

```
pub trait Diffable {
type Elem: Num;
fn log(&self) -> Self;
fn exp(&self) -> Self;
fn elementwise_add(&self, other: &Self) -> Self;
fn elementwise_sub(&self, other: &Self) -> Self;
fn elementwise_mul(&self, other: &Self) -> Self;
fn elementwise_div(&self, other: &Self) -> Self;
fn elementwise_pow(&self, exp: &Self) -> Self;
fn elementwise_eq(&self, other: &Self) -> Self;
fn sum(&self, axes: &[usize]) -> Self;
fn max(&self, axes: &[usize]) -> Self;
fn reshape(&self, shape: &[usize]) -> Self;
fn permute(&self, dims: &[usize]) -> Self;
fn expand(&self, shape: &[usize]) -> Self;
fn pad(&self, padding: &[(usize, usize)]) -> Self;
fn crop(&self, limits: &[(usize, usize)]) -> Self;
fn new(shape: &[usize], data: &[Self::Elem]) -> Self;
fn shape(&self) -> &[usize];
}
```

`Diffable`

's operations are similar to `RawTensor`

's, and we can categorize them in much the same way - unary operations, binary operations, reduce-like operations, and shape-changing operations. Missing is the optimized fused multiply-add in `RawTensor`

, which illustrates the difference in intent between `RawTensor`

and `Diffable`

. While we could make `RawTensor`

differentiable, I'll now try to convince you we don't want to.

Fused multiply-add is an optimized operation that we need on the lowest level to have some hope of efficiency. It is likely that to make Tensorken more efficient, we'll need to add more special-purpose operations to better exploit hardware primitives, reduce memory usage, and so on.

We don't (necessarily) want to figure out how to differentiate those special-purpose operations - we'd like a small set of primitive operations, define their derivatives, and then compose those into higher-level operations. We then get derivatives of those higher-level operations for free, because differentiation is so beautifully composable. Separating `Diffable`

from `RawTensor`

allows us to add efficient, special-purpose operations to `RawTensor`

without figuring out their derivatives. Vice versa, we can add operations to `Diffable`

without having to change `RawTensor`

and its implementations.

Before `Diffable`

, we translated user-facing operations like matrix multiplication to `RawTensor`

operations, which were interpreted by a concrete `RawTensor`

like `CpuRawTensor`

. Now we add another interpreter, `Diffable`

, between the user-facing operations and `RawTensor`

, which not only calculates the primal results but also derivatives. `Diffable`

interpreters execute both primal and derivative calculations as `RawTensor`

operations. That means we can combine all implementations of `Diffable`

with all implementations of `RawTensor`

. So we can do forward AD on the GPU, reverse AD on the CPU, or any other combination.

Let's make our way down the interpreter layers to see how this works in practice. We'll start with matrix multiplication and end up at `CpuRawTensor`

.

Each of the sections that follow is one layer of the interpreter lasagne:

High-level tensor operations like

`matmul`

are translated to`Diffable`

operations like`sum`

and`elementwise_mul`

.A

`Diffable`

interpreter like`Forward`

and`Reverse`

translates primitive operations like`sum`

and`elementwise_mul`

to`RawTensor`

operations, adding calculation of derivatives.A

`RawTensor`

interpreter like`CpuRawTensor`

executes the operations on a particular device.

*A nicely layered design. Is it lunchtime yet? Photo by Parnis Azimi on Unsplash*

###
User-facing layer: matrix multiplication in terms of `Diffable`

Here is a sketch of `matmul`

, omitting everything that is not an operation on `Diffable`

:

```
pub trait DiffableExt: Diffable
{
fn matmul(&self, other: &Self) -> Self {
// preconditions, shape manipulation omitted
// special cases omitted
let l = self.reshape(&l_shape);
// shape manipulation omitted
let r = other
.reshape(&r_shape)
.transpose(r_shape.ndims() - 1, r_shape.ndims() - 2);
// after multiply: [..., m, o, n]
l.mul(&r)
// after sum: [..., m, o, 1]
let sum = prod.sum(&[prod.shape().ndims() - 1]);
// after reshape: [..., m, o]
let s = sum.shape();
sum.reshape(&s[..s.ndims() - 1])
}
}
```

Tensorken has three implementations of `Diffable`

: `Tensor`

, `Forward`

, and `Reverse`

. `Tensor`

doesn't do any differentiation at all - it translates `Diffable`

to `RawTensor`

operations. `Forward`

and `Reverse`

augment the operations with their respective mode of AD. We'll come back to these later - first, we need to find a Rust vehicle to put the user-facing operations that are not in `Diffable`

. We could re-implement them on each implementation of `Diffable`

, but that is redundant. Instead, I've defined `DiffableExt`

, a sub-trait of `Diffable`

with a blanket implementation:

```
pub trait DiffableExt: Diffable
{
// all the fns we want, like matmul, go here.
// They'll need to be defined in terms of Diffable,
// because that's all that's available.
fn matmul(&self, other: &Self) -> Self { ... }
}
impl<T: Diffable> DiffableExt for T {}
```

The advantage is we only have to implement `Diffable`

on a concrete type, then anything defined on `DiffableExt`

is available too (as long as `DiffableExt`

is in scope.)

###
**The first** `Diffable`

**implementation: Tensor**

We now need a concrete type to present to users. `Tensor`

is that type. Its definition is mysteriously simple:

```
pub struct Tensor<T>(T);
```

The idea is that the generic type argument `T`

is a `Diffable`

. Why not add the type constraint here? Because it's unnecessary - for all interesting implementations, `T`

is `Diffable`

. Constraining `T`

here adds nothing new.

We can now make `Tensor<T>`

implement `Diffable`

for any `T`

that's `Diffable`

:

```
impl<T: Diffable> Diffable for Tensor<T> {
type Elem = T::Elem;
fn log(&self) -> Self {
Tensor(self.0.log())
}
// etc
}
```

All operations delegate to `T`

. Full implementation here.

###
From `Diffable`

to `RawTensor`

That gets us nowhere - we can have a differentiable `Tensor<T>`

if we have a differentiable `T`

. To execute tensor operations we need to get to a `RawTensor`

. We can do that by interpreting `Diffable`

operations as `RawTensor`

operations. In Rust, this means creating a blanket implementation of `Diffable`

for any `RawTensor`

:

```
impl<T: Num, TTensor: RawTensor<Elem = T>> Diffable for TTensor {
type Elem = T;
fn log(&self) -> Self {
self.log()
}
// etc
}
```

Since `Diffable`

is a subset of `RawTensor`

, the implementation is again straightforward. A type like `Tensor<CpuRawTensor>`

now works, and we can apply all operations in `Diffable`

and `DiffableExt`

to it.

It seems like we went around in a big circle. After Tensorken parts 1 and 2, we had a `Tensor<T: RawTensor>`

with high-level operations like `matmul`

and primitive operations on `RawTensor`

. Now we have `Tensor<T: Diffable>`

with high-level operations like `matmul`

moved to `DiffableExt`

, differentiable primitive operations on `Diffable`

, and primitive executable operations still on `RawTensor`

.

What we gained is the ability to have other `Diffable`

implementations. We're going to use that ability now.

###
Forward-mode AD with `Forward`

The `Forward`

type wraps `T`

with extra stuff so we can transform and trace the computation to calculate the derivative. In interpreter terms, `Forward`

is an interpreter for the `Diffable`

language that calculates the derivative alongside the primal result using forward-mode AD. It does that by applying all tensor operations on a dual tensor.

The `Forward`

type:

```
pub enum Forward<T> {
Lift(T),
Forward(T, T),
}
```

Like for `Tensor<T>`

, the `T`

here is a `Diffable`

tensor. The `Forward`

case should make sense - it's the primal and the tangent tensors. We use the `Lift`

case if we're not interested in computing the derivative of a tensor. Lifted tensors are treated as constants for the derivative computation. Another way of saying this is that their derivative is zero. We avoid many multiplications with zero by having a dedicated case instead of using the functionally equivalent `Forward(t, zero)`

.

We can understand `jvp1`

's implementation now:

```
pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
for<'a> F: Fn(&'a Forward<T>) -> Forward<T>,
{
let forward = Forward::Forward(at.clone(), tangent.clone());
let result = f(&forward);
match result {
Forward::Lift(p) => (p.clone(), p.zeros_like()),
Forward::Forward(p, t) => (p, t),
}
}
```

We wrap the `at`

and `tangent`

arguments in `Forward`

, then call `f`

with them and unwrap the `Forward`

from the result.

`Forward`

must implement `Diffable`

for this to work. Finally, we come to the implementation of differentiation rules for the primitive operations:

```
impl<T: Clone + Diffable> Diffable for Forward<T> {
type Elem = T::Elem;
fn elementwise_mul(&self, rhs: &Self) -> Self {
self.binary::<MulOp<T>>(rhs)
}
fn sum(&self, axes: &[usize]) -> Self {
self.unary::<SumOp, _>(axes)
}
// etc
}
```

Full implementation here.

Calculating the primal and derivatives are encapsulated in `Op`

structs. The `unary`

and `binary`

functions deal with handling `Lift`

or `Forward`

enum cases in one place, and delegate to a given `Op`

struct for the calculation:

```
impl<T: Diffable> Forward<T> {
fn unary<Op: UnaryOp<T, Args = TArgs> + UnaryDiffOp<T>, TArgs: ?Sized>(
&self,
args: &TArgs,
) -> Self {
let (primal, op) = Op::f(self.primal(), args);
match self {
Forward::Lift(_) => Forward::Lift(primal),
Forward::Forward(_, tan) => Self::Forward(primal, op.dfda(tan)),
}
}
}
```

`binary`

is similar but more involved because it has 4 combinations of `Lift`

and `Forward`

.

Here's `MulOp`

:

```
pub(crate) struct MulOp<TTensor>(TTensor, TTensor);
impl<TTensor: Clone + Diffable> BinaryOp<TTensor> for MulOp<TTensor> {
fn f(a: &TTensor, b: &TTensor) -> (TTensor, Self) {
(a.elementwise_mul(b), MulOp(a.clone(), b.clone()))
}
}
impl<TTensor: Diffable> BinaryDiffOp<TTensor> for MulOp<TTensor> {
fn dfda(&self, d: &TTensor) -> TTensor {
d.elementwise_mul(&self.1) // da * b
}
fn dfdb(&self, d: &TTensor) -> TTensor {
d.elementwise_mul(&self.0) // db * a
}
}
```

Differentiation rules often capture intermediate results or arguments of the primal computation. So `f`

returns not only the result of the primal computation but also a struct to store whatever data is needed for the derivative computation. For `MulOp`

, it captures the input tensors `a`

and `b`

.

`dfda`

and `dfdb`

define how to compute the derivative with respect to the first and second argument, given `d`

, the derivative of downstream functions. The differentiation rule for elementwise tensor multiplication is essentially the same as for scalar multiplication.

Unary operations are similar but don't define `dfdb`

:

```
pub(crate) struct SumOp(Vec<usize>);
impl<TTensor: Diffable> UnaryOp<TTensor> for SumOp {
type Args = [usize];
fn f(a: &TTensor, axes: &Self::Args) -> (TTensor, Self) {
let r = a.sum(axes);
(r, SumOp(axes.to_vec()))
}
}
impl<TTensor: Diffable> UnaryDiffOp<TTensor> for SumOp {
fn dfda(&self, d: &TTensor) -> TTensor {
d.sum(&self.0)
}
}
```

`SumOp`

only needs the reduced axes from the primal computation to calculate `dfda`

The derivative of the sum is the sum of the derivatives, so we can apply the same `sum`

to primal and tangent.

You can find all the ops here and here.

###
`Forward<Forward<T>>`

for higher order derivatives

Reiterating this signature:

```
pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
for<'a> F: Fn(&'a Forward<T>) -> Forward<T>
```

Since the only requirement on `T`

is that it's `Diffable`

and `Forward`

is `Diffable`

, besides a `Tensor<T>`

we can pass a `Forward<T>`

to `jvp1`

to calculate higher-order derivatives.

```
let p: Tensor<CpuRawTensor<f32>> = Tr::scalar(2.0);
let ddf = diff1(|x: &Forward<Tensor<_>>|
diff1(|x: &Forward<Forward<Tensor<_>>>| x.tanh(), x),
&p
);
```

As we'll see soon, this same design will allow us to combine forward and reverse modes up to arbitrary depth, by building up types like `Forward<Reverse<Tensor<..>>>`

.

This scheme works because the `Diffable`

operations are implemented in terms of `Diffable`

. That looks circular, but it's not: it's a stack of interpreters with a `Tensor`

at the bottom, which translates `Diffable`

operations to `RawTensor`

operations:

`Forward<Forward<...>>: Diffable`

->... -> `Tensor: Diffable`

-> `RawTensor`

Somewhat surprisingly, differentiating a differentiated program gets us the second derivative. One way to make sense of that is that if you compute second or third derivatives symbolically, that's exactly what you do: you apply the differentiation rules multiple times. If you want to do a "fun" exercise, you can work out that stacking `Forward`

types amounts to operating on duals-of-duals up to the desired order. If you work out the differentiation rules by hand, you'll find that it yields the correct higher-order derivative.

###
Reverse-mode AD with `Reverse`

The implementation of `Diffable`

for `Reverse`

follows the same pattern as `Forward`

but is more involved. Because reverse mode accumulates the derivative in a separate backward pass, we can no longer compute everything on the fly when we compute the primal. Instead, `Reverse`

builds a trace of operations in a forward pass while calculating the primal result, then accumulates derivatives in the backward pass.

The difference with forward mode is visible in the signature of `vjp`

:

```
pub fn vjp1<'b, 't, T: Diffable + Clone + 't, F>(f: F, at: &T) -> (T, PullBack<'t, T>)
where
for<'a> F: Fn(&'a Reverse<'a, 't, T>) -> Reverse<'a, 't, T>
```

Like `jvp`

, it returns the primal result. Unlike `jvp`

, it doesn't return the tangent, but instead a `PullBack`

struct. The only available operation on that is `call`

:

```
pub fn call(&self, cotangent: &T) -> Vec<T>
where
T: Diffable + Clone,
```

This takes a cotangent tensor - in other words, a tensor with the same shape as the result of `f`

, and returns the tangents of all the arguments of `f`

. Here's how `vjp`

is used:

```
pub fn value_and_gradn<'t, T: Diffable + Clone + 't, F>(f: F, at: &[&T]) -> (T, Vec<T>)
where
for<'a> F: Fn(&'a [Reverse<'a, 't, T>]) -> Reverse<'a, 't, T>,
{
// one forward pass, tracing
let (primal, pullback) = vjpn(f, at);
// one backward pass, accumulating derivatives
let tangents = pullback.call(&primal.ones_like());
// but we get multiple tangents in one go
(primal, tangents)
}
```

Other implementations for `grad`

functions follow a similar pattern.

The details of how this is implemented (via a `Trace`

type) are explained in my post on AD, so I won't repeat them here. It is not substantially different from the scalar case. Briefly, here is the `Reverse`

type:

```
pub enum Reverse<'a, 't, T> {
Lift(T),
Reverse(&'a Trace<'t, T>, T, usize),
}
```

Like `Forward`

, it has a `Lift`

case for tensors we don't want to differentiate. The `Reverse`

case contains the primal `T`

, and some administrative data to record the trace and do the backward pass.

The implementation of `Diffable`

looks similar to `Forward`

:

```
impl<T: Clone + Diffable> Diffable for Reverse<'_, '_, T> {
type Elem = T::Elem;
fn elementwise_mul(&self, rhs: &Self) -> Self {
self.binary::<MulOp<T>>(rhs)
}
fn sum(&self, axes: &[usize]) -> Self {
self.unary::<SumOp, _>(axes)
}
// other omitted
}
```

Again we have `unary`

and `binary`

helper methods to deal with `Lift`

and call the appropriate functions on the `Op`

structs.

Interestingly, even though reverse mode calculates derivatives backward, from the output to the input, `MulOp`

is identical for forward and reverse mode. This is true for all elementwise operations.

`sum`

however is different from forward mode. In the backward pass, we get a `d`

in the shape of the result of the `sum`

(i.e. with fewer elements) and we need to produce a tensor in the shape of the input of `sum`

. To do that, we need `expand`

:

```
pub(crate) struct SumOp(Vec<usize>);
impl<TTensor: Diffable> UnaryOp<TTensor> for SumOp {
type Args = [usize];
fn f(a: &TTensor, axes: &Self::Args) -> (TTensor, Self) {
let r = a.sum(axes);
(r, SumOp(a.shape().to_vec()))
}
}
impl<TTensor: Diffable> UnaryDiffOp<TTensor> for SumOp {
fn dfda(&self, d: &TTensor) -> TTensor {
d.expand(&self.0)
}
}
```

Full implementation for reverse mode is in ad_reverse.rs and the reverse operations are in ad_ops_reverse.rs.

After all that, we can run all the examples in the demo section. However, there is one remaining issue.

## Un-blowing up matmul, again

The problem is serious. Repeating the (pseudo-code) implementation of `matmul`

in `DiffableExt`

:

```
pub trait DiffableExt: Diffable
{
fn matmul(&self, other: &Self) -> Self {
// preconditions, shape manipulation omitted
// special cases omitted
let l = self.reshape(&l_shape);
// shape manipulation omitted
let r = other
.reshape(&r_shape)
.transpose(r_shape.ndims() - 1, r_shape.ndims() - 2);
// TROUBLE BEGINS HERE
// after multiply: [..., m, o, n]
l.mul(&r)
// after sum: [..., m, o, 1]
let sum = prod.sum(&[prod.shape().ndims() - 1]);
// after reshape: [..., m, o]
let s = sum.shape();
sum.reshape(&s[..s.ndims() - 1])
}
}
```

See that `mul`

followed by `sum`

? In the second part of Tensors from Scratch, I explained that this blows up memory, to the point where this approach is utterly unscalable. The fused multiply-add function in `RawTensor`

came to the rescue - we rewrote the separate `sum`

and `mul`

calls into one `l.fused_multiply_add(&r, dims)`

, which made it efficient. Now we've regressed to the previous bad situation. What gives?

First, `Diffable`

doesn't have `fused_multiply_add`

, so we can't write the optimized version directly. We could add `fused_multiply_add`

to `Diffable`

as a primitive operation, but then we have to define a differentiation rule for it in the various modes. One of the main reasons for `Diffable`

's existence is to avoid that.

Second, while manually fusing `mul`

and `sum`

worked for this particular case, users may inadvertently write a `mul`

followed by a `sum`

, and fall into this trap themselves. Worse, while we're calculating derivatives by composing operations in forward or reverse mode, Tensorken itself may introduce a `mul`

followed by a `sum`

. Manually fusing all cases is not going to work. We need a better solution.

If we were writing a compiler, it'd be straightforward to go through the abstract syntax tree of tensor operations and transform any `l.mul(r).sum(axes)`

into `l.fused_multiply_add(r, axes)`

. Can we do a similar optimization here?

Let's think about what's happening from the perspective of interpreters. We have defined a language for writing differentiable programs using the trait `Diffable`

. Everything we do with tensors - `matmul`

, `crop`

, `max`

, `sigmoid`

as well as getting derivatives, is eventually a program in terms of the operations on `Diffable`

. We have three interpreters for `Diffable`

- one that translates the differentiable program to `RawTensor`

operations, and two that augment the differentiable program with forward or reverse mode AD. No matter how many times we stack `Diffable`

on top of `Diffable`

, eventually the program gets run via a `RawTensor`

interpreter.

We only have concrete `RawTensor`

interpreters so far - that calculate the results on CPU or GPU, or that print a string representing the result. But we can also write an interpreter that spits out a new, optimized `RawTensor`

interpreter, with all `mul`

+ `sum`

fused into `fused_multiply_add`

.

This technique - which I didn't invent at all, to be clear - is introduced more gradually and gracefully in my post on typed tagless final interpreters. Here I'll give a whirlwind tour of the implementation.

We'll use a type called `Fuse<T>`

. `T`

is the target optimized `RawTensor`

. Whenever `mul`

followed by `sum`

is detected in the unoptimized, original `RawTensor`

, `Fuse`

rewrites the two operations to a fused equivalent.

```
enum FuseCtx {
Sum(Vec<usize>),
NotSum,
}
pub struct Fuse<T>(Rc<dyn Fn(&FuseCtx) -> T>);
```

The function from `FuseCtx`

to the fused `T: RawTensor`

is a factory function we'll build up while interpreting the original `RawTensor`

as `Fuse<T>`

. In other words, `Fuse<T>`

interprets `RawTensor`

as a function that given a `FuseCtx`

produces an optimized `RawTensor`

. It works in two passes. A first pass builds up the factory function, then a second pass to run the function and get a new `RawTensor`

.

Since `Fuse`

only needs to fuse multiply and sum operations, it delays the application of `sum`

, and instead passes `Sum(axes)`

to the continuation via `FuseCtx`

. The continuation calls the delayed `sum`

if it can't fuse or `fused_multiply_add`

if it can. Here's the implementation of `mul`

where fusing happens:

```
impl<TRaw: RawTensor + Clone + 'static> RawTensor for Fuse<TRaw> {
type Elem = TRaw::Elem;
fn mul(&self, other: &Self) -> Self {
let f_lhs = Rc::clone(&self.0);
let f_rhs = Rc::clone(&other.0);
let nextctx = FuseCtx::NotSum;
Fuse::new(move |ctx| match ctx {
FuseCtx::Sum(axes) => f_lhs(&nextctx).fused_multiply_add(&f_rhs(&nextctx), axes),
FuseCtx::NotSum => f_lhs(&nextctx).mul(&f_rhs(&nextctx)),
})
}
}
```

The context passed in the closure represents what the next operation is, from the perspective of the current operation. If it's `sum`

, the `Sum`

enum case, we fuse. If it's anything else, represented by `NotSum`

, we know the operation has already been applied and we can't fuse. Since `mul`

is not a `sum`

, we pass `NotSum`

as the next context.

Here is the implementation of `sum`

:

```
fn sum(&self, axes: &[usize]) -> Self {
let f = Rc::clone(&self.0);
let my_axes = axes.to_vec();
Fuse::new(move |ctx| match ctx {
FuseCtx::Sum(sum_axes) => f(&FuseCtx::Sum(combine_axes(&my_axes, sum_axes))),
FuseCtx::NotSum => f(&FuseCtx::Sum(my_axes.clone())),
})
}
```

We do not apply `sum`

straight away to the resulting interpreter. Instead, we pass `Sum`

through to the next operation, so it gets a chance to fuse it. Any operations that don't fuse, need to apply the delayed `sum`

if they get the `Sum`

enum. We might as well fuse consecutive `sum`

calls into one by combining axes, hence the first match arm.

Fusing happens in two passes: the first pass builds the `FuseCtx -> RawTensor`

function. The second pass creates the optimized `RawTensor`

by calling the function:

```
impl<T> Fuse<T> {
fn run(&self) -> T {
(self.0)(&FuseCtx::NotSum)
}
}
```

Link to full implementation of fusing.

Now I can finally reveal the full `Cpu32`

and `Wgpu32`

types:

```
pub type Cpu32 = Tensor<ShapeTracker<Fuse<CpuRawTensor<f32>>>>;
pub type Wgpu32<'d> = Tensor<ShapeTracker<Fuse<WgpuRawTensor<'d, f32>>>>;
```

The remaining unknown there is `ShapeTracker`

. `ShapeTracker`

is a `RawTensor`

implementation that abstractly interprets the operations by only tracking tensor shapes. It delegates all operations to the `RawTensor`

it wraps, except `shape`

:

```
pub struct ShapeTracker<T>(ShapeStrider, T);
/// This implementation passes every operation through
/// to self.1, except for shape.
impl<TRaw: RawTensor> RawTensor for ShapeTracker<TRaw> {
type Elem = TRaw::Elem;
fn exp(&self) -> Self {
Self(self.0.clone(), self.1.exp())
}
// etc
fn shape(&self) -> &[usize] {
self.0.shape()
}
}
```

Because `Fuse`

does not track shapes but does need to implement `RawTensor::shape`

, it can only return its shape by running the delayed computation. We don't want that - some derivative operations require access to the shape of tensors, and it would be bad if we had to run the tensor computation at that point.

`ShapeTracker`

solves this for us - it can answer `shape`

queries without executing tensor operations. The order is important here. `ShapeTracker`

needs to wrap `Fuse`

which needs to wrap the concrete `CpuRawTensor`

or `WgpuRawTensor`

.

## I love it when a plan comes together

Thanks to the power of interpreters aka final tagless encoding, Tensorken gained a capable yet small and extensible AD implementation. So far, I'm really happy with how Tensorken turned out. I started seriously researching deep learning from an implementation perspective at the beginning of 2023 with only some prior exposure to automatic differentiation. I randomly ran into the typed tagless final interpreters paper while I was studying TinyGrad, and figured that TinyGrad's style would lend itself well to the tagless final style. I could not have hoped for a better outcome.

After that, I saw a post on Reddit praising JAX and immediately preferred the functional style over PyTorch's imperative AD interface. It was much more challenging to implement in Rust though! Those signatures look straightforward now, but it took a lot of struggling with closures, lifetimes and lifetimes and closures and then lifetimes some more before everything came together. All this to say - I got lucky trying an implementation style I hadn't ever used and struggled for a long time. When AD finally worked, it felt almost magical. Persistence is worth some IQ points.

Now that Tensorken has all the pieces of a full-fledged deep learning library, it's time to put it to the test. I intend to follow along with Andrej Karpathy's Zero to Hero neural networks course, translating it from PyTorch to Tensorken. At the end of that, we should have a home-grown, walking and talking nanoGPT. Without a doubt, there'll be many interesting problems in Tensorken itself to solve along the way.

Many thanks for reading!

## References

JAX: The Autodiff Cookbook. I adapted some parts for this post.

JAX: Autodidax - JAX core from scratch. Great insight into how JAX works under the hood, using a small implementation of JAX.

[Video: Automatic Differentiation by Matthew Johnson]. Matthew Johnson is one of the authors of JAX. Here he talks about the principles underlying Autograd, another Python-based AD library.

TinyGrad. A small but powerful tensor library with reverse-mode AD in Python.

Swift's Differentiable Programming Manifesto. Swift has a powerful differentiable programming component, integrated with the compiler.

Swift For Tensorflow (Google Drive). A great overview of Swift's approach to AD.

DiffSharp. A tensor library with support for differentiable programming for .NET.

## Top comments (0)