DEV Community

Cover image for Segment Tree 101 - From Prefix Sum to Segment Tree
Shahed Bin Satter
Shahed Bin Satter

Posted on

Segment Tree 101 - From Prefix Sum to Segment Tree

If the cover photo does not make much sense to you, then you're at the right place. Segment Tree in itself is a broad topic, but in this article we will stick to the basic segment tree (which we can later upgrade for complex problems, more on that later).

In this tutorial, we will look at basic formations of segment trees, what they are, and how they apply to problems. In addition, we take on some problems from CSES.
...

Prequel

Let's start off with the problem: Range Sum Queries I.

It is a simple problem to start off with, and I am pretty sure you can solve it all by yourself. You are given an array as input, and multiple [a,b] as query ranges. Your task is to calculate the sum of the values for indices from a through to b. I would like you to take a moment, solve this problem and come back here.

Waiting for you to solve the problem

Once you have that sweet "AC", think about the runtime complexity of your solution! Is it O(n) like the code segment below?

int query(int arr[], int start, int end) { // O(n) 
  int sum = 0;
  for (int i = start; i <= end; i++) {
    sum += arr[i];
  }
  return sum;
}
Enter fullscreen mode Exit fullscreen mode

If so, do you think the solution above can be further optimized? Yes! Let's try to optimize our queries using something called "Prefix Sum":

// make sure you pass a & b as 0-based indices
// use cumulative sum of array, instead of the array itself
// cumulative sum of [3,2,4,5,1,1,5,3] is [3,5,9,14,15,16,21,24]
int query(int sum[], int a, int b) { // O(1)
  if (a == 0) return sum[b];
  return sum[b] - sum[a - 1];
}
Enter fullscreen mode Exit fullscreen mode

You may be wondering: What does this have to do with a segment tree? And you're right, to solve this problem, we don't need segment tree, a prefix sum is sufficient. 

Now, you're given a requirement: I want to have the ability to update a value in the original array.

This isn't a complicated requirement, but our prefix sum method will only provide outdated values rendered invalid by each 'update' action. Voila, we're all set to solve Range Sum Queries II.

What is a Segment Tree?

From Wikipedia …

In computer science, a segment tree, also known as a statistic tree, is a tree data structure used for storing information about intervals, or segments. It allows querying which of the stored segments contain a given point. It is, in principle, a static structure; that is, it's a structure that cannot be modified once it's built. A similar data structure is the interval tree.

If the definition above did not make much sense to you, you're in luck. A segment tree is simply a:

  1. Tree-like structure.
  2. Each node stores collective (summary) information of its children, such as, the sum.

Let's assume we have an array of [3, 2, 4, 5, 1, 1, 5, 3]. If we want to create a segment tree with the sum of the values in the array, we have the following binary-tree like structure.

Segment Tree (sum) for an input array [3, 2, 4,5, 1, 1, 5, 3]

If you start to look from the top, it will not make sense. But let's start by noticing that all the leaf nodes are basically the inputs to this array. The nodes in intermediate layers represent nothing but the sum of its children. Simple, right? Let's notice how this simplicity is also an elegance.

By dividing up the input array in the middle in to two parts, we use divide and conquer to calculate cumulative sums. Hopefully, the following diagram (in red are the indices covered) makes it clearer:

In red, are marked the values in the array which are responsible for a particular sum in a specific node

Performing Range Queries on a Segment Tree

Now, we investigate why this is helpful. We will be using 1-based indexing (just as in our sample problem). With the above structure, imagine I ask you to find the sum of the following index ranges: 

a). [1,8]
b). [1,4]
c). [5,8]

Would you go through the whole array calculating sums for the required ranges? The Big O becomes a Big No (pun intended)! Then, let's imagine 'walking' the tree.

a). [1,8] - From the tree above, let's start at the root. The root of the tree already gives us the sum from the range [start of array, end of array] (which is [1,8]). Hence, the answer is 24. See the figure below for a visualization.

b). [1,4] - Once again, let's start at the root, which covers the range [1,8]. Notice, how our required range forms only a partial overlap. This means, it's time to divide and conquer. We visit both, left and right child of the root. 
The left child of the root already gives us the sum (of 14) for the range [1,4], which forms a full overlap 'inside' our required range. The right child spans the range [5,8], but it does not overlap at all with our requirement. Hence, the right child sum can be discarded.
Since there are no more partial overlaps, we have found our result and it is 14.

c). [5,8] - We know the root spans the range [1,8]. The left and right child of the root spans [1,4] and [5,8] respectively. Since only [5,8] overlaps completely with out range, we return the value stored in this node (i.e. 10).

Animated steps of finding range sums

Interestingly, each node spans the sum of elements of a sub-array of the original array. If you notice from the diagram, for any specified range (e.g. [4,4]), we must go (at max) the whole depth of the tree. The run time for traversing through the depth of a binary tree as we know it is: O (log n), meaning we just reduced our queries from O (n) to O (log n). That's a win!

Okay, what about a different kind of range, say [1,5]?

To answer that, let's take query some other ranges too, which lies as the overlap of multiple nodes. Following through with the diagram below will surely help to calculate the sum of the following more intuitive:

a). [1,5] - We start of at root (covering the range [1,8]). For partial overlap, we break down [1,8] into two ranges of left and right child, [1,4] and [5,8] respectively. 
Since left child [1,4] overlaps with our required range, we accumulate 14 into our result.
Our right child [5,8] needs to be further divided into its children, [5,6] and [7,8]. [7,8] is fully outside our requirement and can be discarded. We subdivide [5,6] to [5,5] and [6,6]. At last, we only use the sum represented by [5,5], which is 1.
Finally, we find our sum is 14 + 1 = 15.

b). [3,6] = 9 + 2 = 11.

c). [4,5] = 5 + 1 = 6

d). [2,7] = 2 + 9 + 2 + 5 = 18.

Animated diagram of finding range sums visualized step by step

Finding the sums of any range becomes much simpler now. Consider our query range is [a,b]. With the help of recursion, we traverse down the tree until we find all the nodes which fully lie between the query range.

But how do we store such data, by creating another tree-like structure?

Well, we can, however we can also make use of the principles we learned in a similar data structure, the very famous: Heap. We assign every value in the tree an id (in a sequential manner from the root). From heap, we know that for every node i:

  1. left child of node i can be found at index: (2 * i)
  2. right child of node i can be found at index: (2 * i + 1)

Segment Tree with ids

Now that you understand the theory behind accumulating the sum of segments, let's look at the code to understand better. Our query requires a range of indices [a,b]. Since we are dealing with a separate tree, we need to pass in three other parameters (more on them later).

// [left, right] => range represented by current node, left <= right
// [a, b] => range of interest, query range start and end, a <= b
int query(int id, int left, int right, int a, int b) {
   if (b < left || a > right) { 
       // either query range ends before current node range starts
       // current node range ends before query node range starts
       // query range: [6,8] current node range: [1,4] <- no overlap
       return 0;
   }
   if (a <= left && b >= right) {
       // current node range complete falls within query range
       // query range: [3,6], current node range: [3,4]
       return tree[id];
   }
   // partial overlap
   // query range: [2,7], current node range: [1,4]
   int mid = (left + right) / 2;
   int sum = 0;
   // subquery by querying left and right child
   sum += query(2 * id, left, mid, a, b);
   sum += query(2 * id + 1, mid + 1, right, a, b);

   return sum;
} 
Enter fullscreen mode Exit fullscreen mode

As promised, the three parameters:

  1. id is the id of the node we are currently at
  2. left is the start of the range the current node covers
  3. right is the end of the range the current node covers.

You may be starting to worry about their source, but it is very intuitive. The root has an id of 1, and covers the range [1,8]. Therefore, nodes with id and their [a,b] ranges are:

  1. [1,8] - this is the root, has id 1.
  2. [1,4] - left child of Node 1, id: 2 * 1 = 2
  3. [5,8] - right child of Node 1, id: 2 * 1 + 1 = 3

… and so on. You get the idea, right? Let's see how the values of a and b are determined. We take the current node's range [left, right], and divide it in the middle. Next, our left child takes the range [left, mid]. Similarly, right child takes the range [mid + 1, right].

Now, why go through all this trouble?

Very well, you may realize that you haven't found what makes segment trees so useful yet. After a few queries, you are now asked to update the value in the 5th index. Now, with prefix sum, this would mean you update all the values in the prefix sum array leading to an O(n) operation, and we want to steer clear of that.

Efficient Updates on a Segment Tree

With a segment tree, we can simplify this to a O (log n) operation. How, you ask? For instance, in the array [3, 2, 4, 5, 1, 1, 5, 3], we want to update the 5th index to 2. So, let's again 'walk' the tree, keeping in mind our index of interest is 5.

  1. We start at root which spans the range [1,8]. Our target falls in the right subtree (since mid is 4), so we only visit right child. 
  2. Our current node covers [5,8]. Mid of [5,8]: 6. Now, we visit the left subtree as (target index ≤ mid).
  3. Now, our node contains the range [5,6]. Mid is: (5 + 6) / 2 with integer division) ~ 5. Hence, we proceed with the left child since (target index ≤ mid).
  4. We have reached the leaf of the tree (where left = right). We update the value stored in this index to the target value. 

Simple, right? But we're not done yet. Note that the sum property of this tree is not maintained anymore after the update. So, as we go up the recursion stack once again, we update the parents with the new sum, which is, the sum of its two children.

// [left, right] => current node range
// idx => index of the array to update
// val => value to be placed in arr[idx]
void update(int id, int left, int right, int idx, int val) { // O (log N)
   if (left == right) {
       tree[id] = val;
       return;
   }
   int mid = (left + right) / 2;
   if (idx <= mid) {
       update(id * 2, left, mid, idx, val);
   } else {
       update(id * 2 + 1, mid + 1, right, idx, val);
   }
   tree[id] = tree[2 * id] + tree[2 * id + 1];
}
Enter fullscreen mode Exit fullscreen mode

Voila, we are updating our segment tree as well when necessary, and its faaast!

Birth of a Segment Tree

Building the Tree in the First Place

So far, we have mastered the theory of segment tree, the art of querying ranges and updating values. But, we have not initialized our array into a tree yet. If you have understood what happens above, the rest is simple. There are two ways to do that:

  1. The slower way - Slow, but simple. For every (i-1)-th item in the array, we just call our update function on the i-th index in the tree.
for (int i = 1; i ≤ array.length; i++) 
    update(1, 1, n, i, array[i -1]); 
Enter fullscreen mode Exit fullscreen mode

We perform O(log N) operations for N times, meaning our build complexity is O (N log N). 

  1. The faster way - The O(N) way would be to simply let our tree traverse until it reaches the leaf. We place the value from the corresponding array index into the respective tree index. On our way back to the root, we simply sum the value of the two children in each of the tree's indices.
// [left, right] => range spanned by current node
void build(int id, int left, int right) { // O (n)
   if (left == right) {
       // when we divide and conquer on the middle element, left and right becomes equal when we reach the leaf nodes
       tree[id]  = arr[left];
       return;
   }
   int mid = (left + right)/2;
   // build left child first
   build(2 * id, left, mid);
   // build right child next
   build(2 * id + 1, mid + 1, right);
   // current node sum is available from sum of left and right child
   tree[id] = tree[id * 2] + tree[id * 2 + 1];
}
Enter fullscreen mode Exit fullscreen mode

Wrapping up

Segment tree is a simple yet powerful technique that can help to make your query and updates much faster. We use the following way to create our array and tree. The size of our tree will never exceed 4 times the size of our array, so initializing our tree to 4 x size of array is a safe choice.

int n; // size of array
int arr[n]; // global array of elements that are provided as input
int tree[4 * n]; // global structure for containing segment tree

// initialize with global array arr
build(1, 1, n);
// query
query(1, 1, n, a, b); // query against range [a, b]
// update
update(1, 1, n, idx, val);
Enter fullscreen mode Exit fullscreen mode

You can find the code to my solution of Range Queries II by clicking here. Hope you enjoyed this article, and I hope to write more on this topic. Here are some more problems to give you some practice on the topic:

  1. Sherlock and Subarray Queries - Solution
  2. Supercomputer - Solution
  3. Maximum Sum - Solution
  4. Multiset - Solution

You can find the code to my solution of Range Queries II by clicking here and also use this code as a template.

A sample query flow chart is shown below (full-res image here), which shows the entire flow that will be taken for an array of range: [1,4]. If you liked this post, please subscribe. Happy Learning!

Query Flow

Top comments (1)

Collapse
 
programcpp profile image
ProgramCpp

how to handle inserts into the input array?