DEV Community

Marcos Brizeno
Marcos Brizeno

Posted on

Tail recursion in Elixir

A pendant and a wall with wood pieces on the background

You've probably heard (and written) some recursive functions/methods before. Chances are you've also heard that recursive functions might be more expensive when compared to a non-recursive alternative.

Let's see what is probably the simples example of a recursive function, the factorial implementation. Here we define two clauses of Factorial.recursive_factorial/1, on for 0 (which returns 1) and another for any other value:

defmodule Factorial do
  def recursive_factorial(0), do: 1
  def recursive_factorial(n), do: n * recursive_factorial(n-1)
end

On each call of recursive_factorial(n) we multiply n by the result of recursive_factorial(n-1) (which is a recursive call until n reaches 0). On every step we need to keep the current value of n so we can multiply it once recursive_factorial(n-1) is returned.

This is why recursive calls are usually considered to be more expensive (specially in terms of space), each call requires the stack to be saved until the result comes back. Saving a stack, even if it is just that n might start causing trouble when the number gets too high (say 500_000 or more, depending on where you're running it).

But there is something we can do to make it better and as performant as the non recursive alternatives

Tail Recursion

Take a look at this other implementation and compare it to the previous one:

defmodule Factorial do
  def tail_recursive_factorial(n), do: tail_recursive_factorial(n, 1)
  defp tail_recursive_factorial(0, acc), do: acc
  defp tail_recursive_factorial(n, acc), do: tail_recursive_factorial(n-1, acc*n)
end

The first call to Factorial.tail_recursive_factorial/1 starts the process by calling the private Factorial.tail_recursive_factorial/2 which does the recursive process.

Notice that the multiplication is now passed to the recursive call using the second argument. This means that there is no need to store the function stack for each call!

Each call now has everything it needs and the stack can be discarded as soon as the recursive call starts. The end clause here is when n is 0 and it just returns the value on acc.

When a recursive function doesn't need to store any temporary value to be used after the recursive call, it is considered a Tail Recursive function. Let's check a more interesting problem.

The 'Strain' exercise on Exercism

On exercism.io there is a nice little exercise where you're asked to implement two functions, keep and discard. Given a list of items and a function fun, keep will return the list of items where fun returns true and discard will return the list of items where fun returns false.

Me and some friends got together to solve this and our initial implementation of keep was to execute fun on the head of the list and if it is true return a list with the the head and then recurse on tail, if not then just call again ignoring the head of the list:

def keep([], _fun) do
  []
end

def keep([head | tail] = list, fun) do
  keep_if(list, fun, fun.(head))
end

defp keep_if([head | tail] = list, fun, true) do
  [head | keep(tail, fun)]
end

defp keep_if([head | tail] = list, fun, false) do
  keep(tail, fun)
end

The main problem here is that for each call where fun is true we need to save the value stored on head and after the recursive call finishes, the return has to be appended to a list containing head.

For the test scenarios we had, this wasn't a big issue, but we wanted to go a bit further and think about how to make this a tail recursive function.

Our optimization was similar to what we did for Factorial.tail_recursive_factorial/1, we create an accumulator argument and send it on each call instead of waiting to get the result back. Here is how it ended up looking like:

defmodule Strain do
 def keep(list, fun) do
   keep(list, [], fun)
 end

 def keep([], filtered, _fun) do
   filtered
 end

 def keep([head | _tail] = list, filtered, fun) do
   keep(list, filtered, fun, fun.(head))
 end

 defp keep([head | tail], filtered, fun, true) do
   keep(tail, filtered ++ [head], fun)
 end

 defp keep([_head | tail], filtered, fun, false) do
   keep(tail, filtered, fun)
 end
end

Another nice exercise to try and apply tail recursive optimization is the 'List Operations' one where you're asked to implement basic functions that we usually take for granted, like each, map, filter etc.

Recursive functions are quite common in functional languages, most of them don't even have loops, so learning about tail recursion and practicing how to implement it is a good investment :)

Top comments (0)