Jan van Brügge

Posted on

Let the compiler do the work for you!

Recently I came across a little programming puzzle, the task was to take a binary search tree and return a new tree, where every node is replaced by the sum of the nodes to the right. So given this tree:

``````    5
/ \
2    7
/    / \
1    6   8
\
6
``````

We start on the rightmost number and set it to zero (because the sum of nothing is still nothing). The seven above it comes next and the current sum is 8 because there is only one node to the right. After this comes the "second" 6, so the child of the other 6, again, because it is further to the right. In general, the order is always right subtree, self, left subtree. In the end we end up with this tree, check if you understand where every number comes from:

``````    27
/  \
32   8
/    / \
34   21  0
\
15
``````

Implementing this algorithm in Java

For the following implementations, we will only use stuff from the standard library.

First, we define ourselves a tree:

``````public class Node {
Node left;
int value
Node right;

public Node(Node l, int v, Node r) {
this.left = l;
this.value = v;
this.right = r;
}
}
``````

Now, as always with recursive data structures like lists and trees, we will define the algorithm as a recursive function. We will need to keep track of the current sum and just build up a new tree with the new data.

``````public class Pair {
Node n;
int v;

public Pair(Node n, int v) {
this.n = n;
this.v = v;
}
}

public class Node {
Node left;
int value;
Node right;

public Node(Node l, int v, Node r) { /* ... */ }

public Pair solve(int currentSum) {
// Store the current sum in here, as fallback if right is null
Pair rightResult = new Pair(null, currentSum);
// First, go to the right subtree, if it exists
if(this.right != null) {
rightResult = this.right.solve(currentSum);
}
int sum = rightResult.v + this.value;

// Again, save the sum as fallback
Pair leftResult = new Pair(null, sum);
if(this.left != null) {
leftResult = this.left.solve(sum);
}

// Finally create a new node (to replace self)
Node newSelf = new Node(leftResult.n, rightResult.v, rightResult.n);
// And return it together with the sum
return new Pair(newSelf, leftResult.v);
}
}
``````

Now, we just need a main function to run this:

``````public class Main {
static Node testTree = new Node(
new Node(
new Node(null, 1, null),
2,
null
),
5,
new Node(
new Node(
null,
6,
new Node(null, 6, null)
),
7,
new Node(null, 8, null)
)
);

public static void main(String[] args) {
Pair result = testTree.solve(0);
System.out.println(result.n);
}
}
``````

To see our result, we also need a `toString` method on the Node:

``````public class Node {
Node left;
int value;
Node right;

public Node(Node l, int v, Node r) { /* ... */ }

public Pair solve(int currentSum) { /* ... */ }

public String toString() {
String leftTree = left == null ? " " : left.toString();
String rightTree = right == null ? " " : right.toString();
return "Node(" + leftTree + ", " + value + ", " + rightTree + ")";
}
}
``````

If you run the code now, you will see

``````Node(Node(Node( , 34,  ), 32,  ), 27, Node(Node( , 21, Node( , 15,  )), 8, Node( , 0,  )))
``````

which is exactly the tree we are looking for!

Implementing this in Haskell

Now you may ask, what does this have to do with the title? The compiler does not write the code for us here. This is because the Java compiler is merely a "checking" compiler (as I call it). It complains if you mismatch types or make syntax errors, but it does not do work for you, it is basically like the teacher that checks your answer afterwards.

The Haskell compiler is different. And is so for several reasons. First, Haskell infer types, so you do not have to annotate them everywhere and errors also say which type is expected where. The other reason is that Haskell focuses a lot more on fundamental abstractions than other commonly used languages.

One of this fundamental abstractions is a `Functor`. If your datatype is a functor, you can map a function over it. Nothing more, nothing less. For example, an array or list is a Functor where you apply a function for each element (you may know this from Java streams or JavaScript Array.map). In Haskell, this works over any datatype that "contains" other data. So let's define such a datatype - note that we leave the type of data in the tree abstract, in the Java example it was `int`:

``````module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
``````

This declaration is fundamentally the same as the `Node` class in the Java example. Just instead of `null` we use `Leaf` to signal empty children and we did not name the data `left`, `value`, `right`, but just put them in that order.

So, now back to functor: The compiler is able to automatically generate the code you would need to implement this mapping. All you need to do is to enable that feature and use the `deriving` clause.

``````{-# LANGUAGE DeriveFunctor #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
deriving Functor
``````

This would already allow us to increment all nodes in the tree by one for example:

``````incrementNodes :: Tree Int -> Tree Int
incrementNodes tree = fmap (+1) tree
``````

But now we would like to map and collect at the same time. We solved the first half, let's do the second half. For collapsing a structure down to a singular value, there is the type class `Foldable` that requires your data to be a `Functor` already. Again, array and lists are foldable, and it works just like collect in Java 8 and Array.reduce in JavaScript. And again, this can be automatically generated by the compiler for you:

``````{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
deriving (Functor, Foldable)
``````

Now we could calculate the sum of all nodes for example:

``````sumNodes :: Tree Int -> Int
sumNodes tree = foldr (+) 0 tree
``````

Now to the last step: mapping and folding in one. For this we need another abstraction - `Traversable`. This class allows to map and also simultaneously keep track of some "side effects", but needs your data to be `Foldable` already. For example we can keep track of some local state - the current sum. And like the classes before, the compiler can generate it automatically for us:

``````{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
deriving (Functor, Foldable, Traversable)
``````

With all this in place, our final solution is pretty simple:

``````solve :: Tree Int -> (Int, Tree Int)
solve tree = mapAccumR (\a b -> (a + b, a)) 0 tree
``````

`mapAccumR` is a function from the Haskell standard library that is defined like this:

``````mapAccumR :: Traversable t => (a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumR fun init x = runStateR (traverse (StateR . flip fun) x) init
``````

This exact definition is not that important, just see that in the inner parenthesis it uses `StateR` to make your function behave like a side effect. Then it uses `traverse` to combine the stateful effects in the right order and `runStateR` then executes them one by one.

Java required us to write a `toString` method by hand. In Haskell we can again use the compiler for this by deriving `Show`. So with a main method to make everything runnable, our complete code for the puzzle is just:

``````{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
deriving (Show, Functor, Foldable, Traversable)

solve :: Tree Int -> (Int, Tree Int)
solve tree = mapAccumR (\a b -> (a + b, a)) 0 tree

testTree :: Tree Int
testTree =
Node
(Node
(Node Leaf 1 Leaf)
2
Leaf
)
5
(Node
(Node
Leaf
6
(Node Leaf 6 Leaf)
)
7
(Node Leaf 8 Leaf)
)

main :: IO ()
main = do
let (sum, tree) = solve testTree
print tree
``````

Conclusion

As you can see if the compiler has your back, you can severely cut down the number of lines of code. More code always means more bugs, so everything we do not have to write ourselves is good.

Jan van Brügge • Edited on

The problem is that Java's type system is not able to express these powerful abstractions. I've added the code for `mapAccumR` to the article. You can see that it requires any traversable `t` and later as third argument is `t b` meaning this traversable `t` contains a `b`, whatever that might be. In Java terms: `t` is Generic, kinda like `<T extends Traversable>`, the problem is the `B`. You cannot have generic generics in Java (also called higher-order types).

``````mapAccumR<T extends Traversable, B>(/* ... */, T<B> x)
``````

is not valid Java

Andrei Shikov

Arrow-kt implement support of higher kinded types for Kotlin (which has a similar limitation). They create a type using generic `Kind<T, B>` and some tools to convert it to usable Kotlin type.

Suggest you to check it out!

Nested Software

minor correction: JavaScript has `Array.reduce`, not `fold`.

Also, it seems to me that it would be helpful to explain `mapAccumR`, since that is doing the real heavy lifting here. Presumably you could implement that in Java (or maybe there is a library that does it already) and the solutions would become more or less the same.

Jan van Brügge

Yes, fixed the `Array.reduce`

The nice thing is, `mapAccumR` is not doing much either. It relies completely on the power of `Traversable` the class that the compiler generated for us. If you would change the type signature of `solve` to allow for other `Traversable`s than `Tree`, the code would work without any changes also for lists or arrays for example. All thanks to the power of fundamental abstractions.
I've added a short paragraph that shows that `mapAccumR` is not a lot of code.

DrBearhands

Maybe it would help clear the confusion by mentioning that `deriving` creates functions with certain rules (map, fold, etc.) that operate on the types?

DrBearhands

I've been having doubts about the use of `deriving` when it comes to code style. It sure as hell is convenient but it hides information from the programmer. It would cause problems for instance if for some reason someone were to switch the order of left and right branches (ignoring why the hell someone would do that). Admittedly this is also an argument for records.

For Aeson in particular I'd just rather resort to declaring my own instances.

What are your thoughts on this issue?

Jan van Brügge

I really advocate the use of deriving as much is possible. `DervingVia` basically allows you to create a standard set of instances you can derive for other data types. Handwritten instances should IMO be kept as small as possible.
As said, more code always means more bugs and it makes the code harder to read. The derived instances always behave the same, so it is independent of local conventions. Also, due to the laws attached to the classes, most of the time, there is only one lawful instance for a given datatype

Joseph Stevens

How that is way cleaner! Let the computer do it I say

We're hiring for a Senior Platform Engineer and would love for you to apply.