DEV Community

Cover image for Quicksort, a Ruby implementation using Test-driven development
Leandro Proença
Leandro Proença

Posted on

Quicksort, a Ruby implementation using Test-driven development

Quicksort is a sorting algorithm designed to have a reasonable time complexity.

In average, most implementations of this algorithm are capable of coming to O(n log(n)). It is implemented by many programming languages and platforms, being one of the most important algorithms for sorting.

Image source: Wikipedia

Implementation

Its implementation relies on a divide-and-conquer algorithm, where it works by selecting a "pivot" element from the list, then partitioning the list into two sub-arrays: the elements that are smaller than pivot (smaller ones); and the elements that are larger than pivot (larger ones), which are then sorted recursively.

Understanding the algorithm

Although it's especially important to understand this algorithm, its implementation is not that hard to realize.

Observe the following pseudo-code:

list = [3, 1, 2]

# it's arbitrary to choose the pivot, however the choice may 
#  impact on the final time complexity. In this guide, we'll 
#  choose the first element for didactic purposes
pivot = 3
remaining = [1, 2]

smaller, larger = partition(pivot, remaining)

-> repeat the process for the smaller list
-> repeat the process for the larger list
Enter fullscreen mode Exit fullscreen mode

The above code only cares about partitioning the list into (un)sorted smaller and larger elements than pivot. If we repeat the process for each sub-list and place the pivot in the middle, we may end up having the following solution:

sort(smaller) + pivot + sort(larger)
Enter fullscreen mode Exit fullscreen mode

Consider the above explanation just a Quickstart (pun intended) to understand the overall algorithm.

For many, it's not straightforward to understand recursion, which is why this guide will try to dissecate the internals of Quicksort using Test-driven development, a great technique and good practice that I use on my daily-basis, mainly when I'm trying to experiment and understand the unknown.

Baby steps to the rescue

When doing TDD, we are encouraged to work in small cycles, or baby-steps, in order to fullfil the red-green-refactor iteration.

Such technique can help us to understand the Quicksort algorithm. Due to its nature of divide-and-conquer, it's a perfect fit for breaking down the steps through the TDD cycle.

First of all, we write a file named quicksort_test.rb, which will have our bootstrap code:

require 'test/unit'                   

class QuicksortTest < Test::Unit::Case
  def test_sorting_empty_list
  end
end    
Enter fullscreen mode Exit fullscreen mode

This dummy code will start by sorting an empty array. Run ruby quicksort_test.rb. The output should be as follows:

Finished in 0.0004595 seconds.
-------------------------------------------------------------------------------
1 tests, 0 assertions, 0 failures, 0 errors, 0 pendings, 0 omissions, 0 notifications
100% passed
-------------------------------------------------------------------------------
2176.28 tests/s, 0.00 assertions/s
Enter fullscreen mode Exit fullscreen mode

For now, we will write the sort method in the same file for the sake of simplicity.

require 'test/unit'                                   

def sort(list)
  []
end                      

class QuicksortTest < Test::Unit::TestCase
  def test_sorting_empty_list
    assert_equal [], sort([])
  end                                    
end                                       
Enter fullscreen mode Exit fullscreen mode

Run the test again and...YAY!
Alt Text

Test dummy (but important!) scenarios

Yes, dummy scenarios are important. Don't underestimate them.

def test_sorting_one_number_only
  assert_equal [42], sort([42])
end                                        
Enter fullscreen mode Exit fullscreen mode
def sort(list)            
  [42]                     
end                            
Enter fullscreen mode Exit fullscreen mode

Pretty cool, uh? And dummy, I know...

Let's sort two numbers:

def test_sorting_two_numbers
  assert_equal [8, 42], sort([42, 8])
end                                        
Enter fullscreen mode Exit fullscreen mode
def sort(list)            
  return [] if list.size == 0 # when the list is empty
  return list if list.size == 1 # when the list has one number

  [8, 42] # OMG!                      
end                            
Enter fullscreen mode Exit fullscreen mode

Please, STAHP the naiveness!

Ok, let's start to write a more sophisticated test:

def test_sorting_multiple numbers
  list = [5, 1, 4, 3, 2]                  

  assert_equal [1, 2, 3, 4, 5], sort(list)
end                                       
Enter fullscreen mode Exit fullscreen mode

The initial implementation

Suppose our input is an unsorted list:

list = [5, 1, 4, 3, 2]  
Enter fullscreen mode Exit fullscreen mode

We can start the implementation by choosing the pivot. Already mentioned that, in this guide, we will choose the first element.

But there are reasons where choosing the first element as the pivot is not the best performant scenario. This is for learning purposes only, as you can later adapt the algorithm to use another approach for the pivot.

pivot = list[0] # 5
Enter fullscreen mode Exit fullscreen mode

Then, we should get the remaining list (all elements but the pivot):

remaining = list[1..-1] # [1, 4, 3, 2]
Enter fullscreen mode Exit fullscreen mode

Partition recursively

Next step, think about the partition, which takes

  • the pivot
  • the remaining list
  • an array of smaller elements than pivot (starts empty)
  • an array of larger elements than pivot (starts empty)
partition(pivot, remaining, [], [])
Enter fullscreen mode Exit fullscreen mode

Why are we starting the smaller/larger arrays as empty? Because we will run this method recursively.

A bit of recursion

Ok, but what's recursion? In short:

def greet(name)
  puts "Hello #{name}"

  greet(name)
end
Enter fullscreen mode Exit fullscreen mode

The above code will run forever, taking all of the memory's process, resulting in a stack overflow.

We MUST place a stop condition before the recursive call, otherwise we would never stop it:

def greet(name, times = 10)
  puts "Hello #{name}"
  return if times == 0

  # decreasing the variable times on the next iteration
  #  next, will be 9
  #  then 8, 7, 6...successively until 0 meets our stop condition
  greet(name, times -= 1) 
end
Enter fullscreen mode Exit fullscreen mode

Now, the following example is how we WANT to call the partition. Hold on, I know the method doesn't yet exist, but we are doing baby-steps. A bit of pseudo-code:

unsorted_list = [3, 1] 
pivot = 3
remaining = [1]

smaller, larger = partition(pivot, remaining, [], [])

smaller #=> [1]
larger #=> []
Enter fullscreen mode Exit fullscreen mode

Let's write the first test for the partition in which there's no remaining numbers. A scenario where all the elements were already partitioned (not yet sorted).

Why we are writing this scenario first? Because we know that, when there's no more elements to partition, we can stop the recursion and return the accumulated smaller and larger sub-lists.

def test_partition_no_remaining
  pivot     = 5
  remaining = []
  smaller   = [3, 1]
  larger    = [8, 6, 7]

  assert_equal [[3, 1], [8, 6, 7]], partition(pivot, remaining, smaller, larger)
end  
Enter fullscreen mode Exit fullscreen mode
def partition(pivot, remaining, smaller, larger)
  return [smaller, larger] if remaining == 0         
end
Enter fullscreen mode Exit fullscreen mode

This first implementation means that, when the remaining list is empty, we can safely return the sub-arrays and stop the recursion.

Partition of two numbers:

def test_partition_two_numbers                       
  pivot = 1                                    
  remaining = [5]  

  assert_equal [[], [5]], partition(pivot, remaining, [], [])
end                                                  
Enter fullscreen mode Exit fullscreen mode

Here, the remaining list is not empty, so we should think on "what is a partition".

We can think of a partition being a method/function that takes a pivot, a list and accumulated smaller/larger arrays related to pivot.

How can we do this recursively?

  • reducing the list
  • accumulating the smaller/larger arrays

It's not an easy task to write a test that detach the logic of a recursive piece of code. But we can isolate the logic that is responsible to reduce and accumulate, and place it to another method.

So in the partition method, we'd have a separate call to the next_partition before the recursive call:

def partition(pivot, list, smaller = [], larger = [])
  return [smaller, larger] if list.size == 0

  # Here, the method `next_partition` will return the reduced
  #  list (tail|remaining), the "next" smaller (accumulated) and the "next" larger (accumulated)
  tail, next_smaller, next_larger = next_partition(pivot, list, smaller, larger) 

  # Now, as we already have the stop condition in the first
  #  line, we can safely call this method recursively
  partition(pivot, tail, next_smaller, next_larger)        
end                                                  
Enter fullscreen mode Exit fullscreen mode

Time to dissecate the next_partition before running the test.

Let's then write a test scenario for the next_partition method:

#=> next_partition(pivot, tail, smaller, larger)

list = [3, 5, 1, 4, 2]
pivot = 3
remaining = [5, 1, 4, 2]
smaller = []
larger = []

tail, next_smaller, next_larger = next_partition(pivot, remaining, smaller, larger)

assert_equal [1, 4, 2], tail
assert_equal [], next_smaller
assert_equal [5], next_larger
Enter fullscreen mode Exit fullscreen mode

It may be confused at first but we can interpret this test as being:

  • the unsorted list is [3, 5, 1, 4, 2]
  • we chose the pivot as being 3, the first element
  • excluding the pivot, the remaining numbers are [5, 1, 4, 2]
  • smaller starts empty
  • larger starts empty

So, for the next iteration, we have the following data:

  • reduced list (tail of remaining) is [1, 4, 2]
  • accumulated (next) smaller ones than pivot (3) is []
  • accumulated (next) larger ones than pivot (3) is [5]

Golden tip

Let's think: the next_iteration should compare the pivot with the first element of remaining list, and decide if the element should be added to the accumulated smaller or larger arrays!

Let's keep on our iterations, one-by-one (manual recursion):

next_tail = [1, 4, 2]
pivot = 3 # still our first pivot
next_smaller = []
next_larger = [5]

next_tail, next_smaller, next_larger = next_partition(pivot, next_tail, next_smaller, next_larger)

assert_equal [4, 2], next_tail
assert_equal [1], next_smaller
assert_equal [5], next_larger

## Next (final) iteration
next_tail = [4]
pivot = 3 # still
next_smaller = [1, 2]
next_larger = [5]

next_tail, next_smaller, next_larger = next_partition(pivot, next_tail, next_smaller, next_larger)

assert_equal [], next_tail
assert_equal [1, 2], next_smaller
assert_equal [4, 5], next_larger
Enter fullscreen mode Exit fullscreen mode

Our recursion would stop here because there's no "next tail". All of the initial list was reduced.

In case we had unsorted smaller/larger arrays, we would repeat the process all over again for each sub-array.

Now, we could concatenate the results according to our initial sort logic:

# Partition result
* smaller: [1, 2]
* pivot: 3
* larger: [4, 5]

# Sorted = sort(smaller) + [pivot] + sort(larger)
[1, 2] + [3] + [4, 5]
Enter fullscreen mode Exit fullscreen mode

YAY!

[1, 2, 3, 4, 5]
Enter fullscreen mode Exit fullscreen mode

But we don't have the code yet. Looking at this pseudo-code examples, now it's easy to think on the next_partition logic:

def next_partition(pivot, list, smaller, larger)
  head = list[0]                                
  tail = list[1..-1]                            

  if head <= pivot                              
    [tail, [head] + smaller, larger]            
  else                                          
    [tail, smaller, [head] + larger]            
  end                                           
end                                             
Enter fullscreen mode Exit fullscreen mode

We can test more partition scenarios:

def test_partition_three_numbers                 
  assert_equal [[], [6, 8]], partition(3, [8, 6])
end                                              
Enter fullscreen mode Exit fullscreen mode

And, finally, sorting multiple numbers

 def test_sorting_multiple_numbers                    
   assert_equal [1, 2, 3, 4, 5], sort([5, 3, 1, 2, 4])
   assert_equal [1, 2, 3, 4, 5], sort([3, 5, 1, 4, 2])
   assert_equal [1, 1, 3, 5, 8], sort([3, 1, 8, 1, 5])
 end                                                  
Enter fullscreen mode Exit fullscreen mode

Important note about performance

The examples above do not apply the “in-place” rule of Quicksort.

Performance may be degraded in time complexity, but the main point here is to demonstrate the overall logic of Quicksort.

Conclusion

The final code can be seen in this Gist.

In this guide we covered the fundamentals of the Quicksort algorithm by using test-driven development. We also did learn a bit about recursion and how to test recursion (manually or not) by isolating the methof/function responsible for reducing/accumulating for the next iteration.

I hope it could be helpful somehow. Feel free to drop a message or comment.

Follow me @leandronsp in dev.to, twitter and Linkedin.

Top comments (0)