loading...

Let the compiler do the work for you!

jvanbruegge profile image Jan van Brügge ・5 min read

Recently I came across a little programming puzzle, the task was to take a binary search tree and return a new tree, where every node is replaced by the sum of the nodes to the right. So given this tree:

    5
   / \
  2    7
 /    / \
1    6   8
      \
       6

We start on the rightmost number and set it to zero (because the sum of nothing is still nothing). The seven above it comes next and the current sum is 8 because there is only one node to the right. After this comes the "second" 6, so the child of the other 6, again, because it is further to the right. In general, the order is always right subtree, self, left subtree. In the end we end up with this tree, check if you understand where every number comes from:

    27
   /  \
  32   8
 /    / \
34   21  0
      \
       15

Implementing this algorithm in Java

For the following implementations, we will only use stuff from the standard library.

First, we define ourselves a tree:

public class Node {
    Node left;
    int value
    Node right;

    public Node(Node l, int v, Node r) {
        this.left = l;
        this.value = v;
        this.right = r;
    }
}

Now, as always with recursive data structures like lists and trees, we will define the algorithm as a recursive function. We will need to keep track of the current sum and just build up a new tree with the new data.

public class Pair {
    Node n;
    int v;

    public Pair(Node n, int v) {
        this.n = n;
        this.v = v;
    }
}

public class Node {
    Node left;
    int value;
    Node right;

    public Node(Node l, int v, Node r) { /* ... */ }

    public Pair solve(int currentSum) {
        // Store the current sum in here, as fallback if right is null
        Pair rightResult = new Pair(null, currentSum);
        // First, go to the right subtree, if it exists
        if(this.right != null) {
            rightResult = this.right.solve(currentSum);
        }
        int sum = rightResult.v + this.value;

        // Again, save the sum as fallback
        Pair leftResult = new Pair(null, sum);
        if(this.left != null) {
            leftResult = this.left.solve(sum);
        }

        // Finally create a new node (to replace self)
        Node newSelf = new Node(leftResult.n, rightResult.v, rightResult.n);
       // And return it together with the sum
       return new Pair(newSelf, leftResult.v);
    }
}    

Now, we just need a main function to run this:

public class Main {
    static Node testTree = new Node(
        new Node(
            new Node(null, 1, null),
            2,
            null
        ),
        5,
        new Node(
            new Node(
                null,
                6,
                new Node(null, 6, null)
            ),
            7,
            new Node(null, 8, null)
        )
    );

    public static void main(String[] args) {
        Pair result = testTree.solve(0);
        System.out.println(result.n);
    }
}

To see our result, we also need a toString method on the Node:

public class Node {
    Node left;
    int value;
    Node right;

    public Node(Node l, int v, Node r) { /* ... */ }

    public Pair solve(int currentSum) { /* ... */ }

    public String toString() {
        String leftTree = left == null ? " " : left.toString();
        String rightTree = right == null ? " " : right.toString();
        return "Node(" + leftTree + ", " + value + ", " + rightTree + ")";
    }
}

If you run the code now, you will see

Node(Node(Node( , 34,  ), 32,  ), 27, Node(Node( , 21, Node( , 15,  )), 8, Node( , 0,  )))

which is exactly the tree we are looking for!

Implementing this in Haskell

Now you may ask, what does this have to do with the title? The compiler does not write the code for us here. This is because the Java compiler is merely a "checking" compiler (as I call it). It complains if you mismatch types or make syntax errors, but it does not do work for you, it is basically like the teacher that checks your answer afterwards.

The Haskell compiler is different. And is so for several reasons. First, Haskell infer types, so you do not have to annotate them everywhere and errors also say which type is expected where. The other reason is that Haskell focuses a lot more on fundamental abstractions than other commonly used languages.

One of this fundamental abstractions is a Functor. If your datatype is a functor, you can map a function over it. Nothing more, nothing less. For example, an array or list is a Functor where you apply a function for each element (you may know this from Java streams or JavaScript Array.map). In Haskell, this works over any datatype that "contains" other data. So let's define such a datatype - note that we leave the type of data in the tree abstract, in the Java example it was int:

module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)

This declaration is fundamentally the same as the Node class in the Java example. Just instead of null we use Leaf to signal empty children and we did not name the data left, value, right, but just put them in that order.

So, now back to functor: The compiler is able to automatically generate the code you would need to implement this mapping. All you need to do is to enable that feature and use the deriving clause.

{-# LANGUAGE DeriveFunctor #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
            deriving Functor

This would already allow us to increment all nodes in the tree by one for example:

incrementNodes :: Tree Int -> Tree Int
incrementNodes tree = fmap (+1) tree

But now we would like to map and collect at the same time. We solved the first half, let's do the second half. For collapsing a structure down to a singular value, there is the type class Foldable that requires your data to be a Functor already. Again, array and lists are foldable, and it works just like collect in Java 8 and Array.reduce in JavaScript. And again, this can be automatically generated by the compiler for you:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
            deriving (Functor, Foldable)

Now we could calculate the sum of all nodes for example:

sumNodes :: Tree Int -> Int
sumNodes tree = foldr (+) 0 tree

Now to the last step: mapping and folding in one. For this we need another abstraction - Traversable. This class allows to map and also simultaneously keep track of some "side effects", but needs your data to be Foldable already. For example we can keep track of some local state - the current sum. And like the classes before, the compiler can generate it automatically for us:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
            deriving (Functor, Foldable, Traversable)

With all this in place, our final solution is pretty simple:

solve :: Tree Int -> (Int, Tree Int)
solve tree = mapAccumR (\a b -> (a + b, a)) 0 tree

mapAccumR is a function from the Haskell standard library that is defined like this:

mapAccumR :: Traversable t => (a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumR fun init x = runStateR (traverse (StateR . flip fun) x) init

This exact definition is not that important, just see that in the inner parenthesis it uses StateR to make your function behave like a side effect. Then it uses traverse to combine the stateful effects in the right order and runStateR then executes them one by one.

Java required us to write a toString method by hand. In Haskell we can again use the compiler for this by deriving Show. So with a main method to make everything runnable, our complete code for the puzzle is just:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
            deriving (Show, Functor, Foldable, Traversable)

solve :: Tree Int -> (Int, Tree Int)
solve tree = mapAccumR (\a b -> (a + b, a)) 0 tree

testTree :: Tree Int
testTree =
    Node
        (Node
            (Node Leaf 1 Leaf)
            2
            Leaf
        )
        5
        (Node
            (Node
                Leaf
                6
                (Node Leaf 6 Leaf)
            )
            7
            (Node Leaf 8 Leaf)
        )

main :: IO ()
main = do
    let (sum, tree) = solve testTree
    print tree

Conclusion

As you can see if the compiler has your back, you can severely cut down the number of lines of code. More code always means more bugs, so everything we do not have to write ourselves is good.

Posted on by:

jvanbruegge profile

Jan van Brügge

@jvanbruegge

Student of the technical university munich, chief software engineer for the MOVE-II CubeSat, likes functional programming and sports

Discussion

markdown guide
 

minor correction: JavaScript has Array.reduce, not fold.

Also, it seems to me that it would be helpful to explain mapAccumR, since that is doing the real heavy lifting here. Presumably you could implement that in Java (or maybe there is a library that does it already) and the solutions would become more or less the same.

 

Yes, fixed the Array.reduce

The nice thing is, mapAccumR is not doing much either. It relies completely on the power of Traversable the class that the compiler generated for us. If you would change the type signature of solve to allow for other Traversables than Tree, the code would work without any changes also for lists or arrays for example. All thanks to the power of fundamental abstractions.
I've added a short paragraph that shows that mapAccumR is not a lot of code.

 

Maybe it would help clear the confusion by mentioning that deriving creates functions with certain rules (map, fold, etc.) that operate on the types?

 

I've been having doubts about the use of deriving when it comes to code style. It sure as hell is convenient but it hides information from the programmer. It would cause problems for instance if for some reason someone were to switch the order of left and right branches (ignoring why the hell someone would do that). Admittedly this is also an argument for records.

For Aeson in particular I'd just rather resort to declaring my own instances.

What are your thoughts on this issue?

 

I really advocate the use of deriving as much is possible. DervingVia basically allows you to create a standard set of instances you can derive for other data types. Handwritten instances should IMO be kept as small as possible.
As said, more code always means more bugs and it makes the code harder to read. The derived instances always behave the same, so it is independent of local conventions. Also, due to the laws attached to the classes, most of the time, there is only one lawful instance for a given datatype

 

Wow, that's cool!
You should consider making this a library for java!

 

The problem is that Java's type system is not able to express these powerful abstractions. I've added the code for mapAccumR to the article. You can see that it requires any traversable t and later as third argument is t b meaning this traversable t contains a b, whatever that might be. In Java terms: t is Generic, kinda like <T extends Traversable>, the problem is the B. You cannot have generic generics in Java (also called higher-order types).

mapAccumR<T extends Traversable, B>(/* ... */, T<B> x)

is not valid Java

 

Well, challenge accepted. I shall find a work around!
Wish me luck!

Arrow-kt implement support of higher kinded types for Kotlin (which has a similar limitation). They create a type using generic Kind<T, B> and some tools to convert it to usable Kotlin type.

Suggest you to check it out!

 

How that is way cleaner! Let the computer do it I say