DEV Community

Artur Curinga
Artur Curinga

Posted on

Segment Tree

Segment Tree

The Hackerearth can give us a good definition of a segtree:
Segment Tree is basically a binary tree used for storing the intervals or segments. Each node in the Segment Tree represents an interval. Consider an array A of size N and a corresponding Segment Tree T:

1- The root of T will represent the whole array A[0:N-1].
2- Each leaf in the Segment Tree T will represent a single element A[i] such that 0 <= i < N.
3- The internal nodes in the Segment Tree T represents the union of elementary intervals A[i:j] where 0 <= i < j < N.

Usage

SegTree is a data structure that can help us solve the following problem statement extracted from GeeksforGeeks:
Given an array arr[0 . . . n-1]. We should be able to:
1- Find the sum of elements from index l to r where 0 <= l <= r <= n-1
2- Change value of a specified element of the array to a new value x. We need to do arr[i] = x where 0 <= i <= n-1.

Or the formal definition from cp-algorithms:
We have an array a[0…n−1], and the Segment Tree must be able to find the sum of elements between the indices l and r (i.e. computing the sum ∑ri=la[i]), and also handle changing values of the elements in the array (i.e. perform assignments of the form a[i]=x).

The first solution that can come in our minds is to loop from the l to the r index summing the values of the array and to update the value just update the value on the array directly. The first operation takes O(n) and the second O(1).

A second approach can be to store on another array the sum from index i to the ith, storing all the sums of the segments, now after the array creation, the sum operation takes O(1), since they are stored, but the update functions takes O(n).

Now, the Segment Tree allows us to get the sum of any interval in O(log n) and update any value in O(log n).

Implementation

There are two main implementations of a segment tree, a recursive one and an iterative one, we are going to focus this text on the recursive one and provide an iterative implementation at the end with a link to a discussion about the memory usage of each type.

Warning: you should try to implement the Segment Tree first before moving to the solution

Motivated by this medium question on Leetcode:

The LeetCode Question 307 (Range Sum Query - Mutable) is a great initial motivation to build a SegmentTree. The question asks you to build a structure and methods to get the sum of an interval and to be able to update any value in it. As we discussed before, the Segment Tree allows us to do that in great performance.

First of all: Whats the easiest way to store a segmention tree?
Like a Heap [], we can store it in an array. We start by storing our root element in index 0 and for each index i, it's left child is at index 2*i + 1, it's right child is at index 2*i + 2 and it's parent is at index (i - 1) / 2.

And important difference from the Heap, is that a segment tree is full binary tree (every node has 0 or 2 children) and that means that if the input array is not a power of 2, when storing at an array, we may have gaps between the nodes. If n is a power of 2, the size of the segtree array will be 2n-1. If not, the size will be 2x - 1 where x is the smallest power of 2 greater than n and we place dummy nodes to fill the gaps between the nodes.

Example: For input array [1, 3, 5, 7, 9, 11], the segment tree will be: [36, 9, 27, 4, 5, 16, 11, 1, 3,DUMMY, DUMMY, 7, 9, DUMMY, DUMMY].

  • Construct Let's go back to the actual code. We start by creating a recursive method build that will construct our SegTree with a bottom-up approach. This method will receive two indexes tl and tr that represents a segment of our input array. If the segment is of lenght 1, then we simply store the leaf node. If not, it divides the segment into two halves and call itself on both halves. When the recursion returns, the sum of the segment is stored in the current node v of the segtree array.

The arguments are:

  • v : the current node of the segtree array.
  • tl : left index of the segment in the input array.
  • tr: right index of the segment in the input array.

  • Updating a value:
    The update operation is also done recursively. We receive the index pos to be updated, and, similar with the build method, we split a segment into two halves. The difference though, is that we only call the recursive method on the half where the pos is. Doing that, we eventually reach the segment of lenght 1 that is exactly the node we are updating with the new value val. When the recursion returns, we update the current node value with the sum of it's both children.

The arguments are:

  • v : the current node of the segtree array.
  • tl : left index of the segment in the input array.
  • tr: right index of the segment in the input array.
  • pos : position to be updated
  • val : new value

  • Querying the sum:
    From the root node and going recursively, we start by checking if the desired sum indexes l and r are equal with the segment of the current node v. If its exactly the same range, the value of the node is exactly the value of the sum from l to r. If not, call the method on both children of v (also updating the segment representing the called node) and sum the result.

The arguments are:

  • v : the current node of the segtree array.
  • tl : left index of the segment in the input array.
  • tr: right index of the segment in the input array.
  • l : left index of the desired sum
  • r : right index of the desired sum

Important: On every recursive call, we manually adjust the segment tl to tr that represents the called node.

class NumArray:
    def __init__(self, nums: List[int]):
        if len(nums) > 0:
            self.n = len(nums)
            self.nums = nums
            self.arr = [0] * (4 * self.n)
            self.build(1, 0, self.n - 1)
        #print(self.arr)

    def build(self, v, tl, tr ):
        if (tl == tr):
            self.arr[v] = self.nums[tl]
        else:
            tm = (tl + tr) // 2;
            self.build(v*2, tl, tm)
            self.build(v*2+1, tm+1, tr)
            self.arr[v] = self.arr[v*2] + self.arr[v*2+1]

    def _update(self, v, tl, tr, pos, val):
        if tl == tr:
            self.arr[v] = val
        else:
            mid = (tl + tr) // 2
            if pos <= mid:
                self._update(v * 2, tl, mid, pos, val)
            else:
                self._update(v * 2 + 1, mid + 1, tr, pos, val)

            self.arr[v] = self.arr[v * 2] + self.arr[v * 2 + 1]

    def update(self, i: int, val: int) -> None:
        self._update(1, 0, self.n -1, i, val)

    def _sum(self, v, tl, tr, l, r): 
        if (l > r): 
            return 0
        if l == tl and r == tr:
            return self.arr[v]

        tm = (tl + tr) // 2
        return self._sum(v*2, tl, tm, l, min(r, tm)) + self._sum(v*2+1, tm+1, tr, max(l, tm+1), r)

    def sumRange(self, i: int, j: int) -> int:
        return self._sum(1, 0, self.n - 1,   i, j)

Leetcode - Question 732

Introduction

This is a resolution of the question 732 from Leetcode, with the intent of increasing familiarity with the data structure Segmentation Trees.

The Problem

We need to implement a MyCalendarThree class that stores events in a way that we can always add more events. Our class will only have one method book(int start, int end). The main goal of this problem lies within the bookmethod , where we need to detect a K-Booking, which happens when we have K events that have non-empty intersections. Our book method will return the largest value of K possible such that there exists a K-Booking in the calendar.

The Approach

Since we're talking about time intervals and the intersection between those time intervals, a data structure that naturally stores intervals of numbers seems to be the adequate choice for this kind of problem. More than that, we are implementing a SegTree that holds within each node its value of K.

Let's get to the code.

    class Node:
    def __init__(self, l, r):
        self.start = l
        self.end = r
        self.val = 0
        self.lazy = 0
        self.left = None
        self.right = None
class MaxSegTree:

    def __init__(self, start, end):
        self.tree = Node(start, end)


    def add(self, start, end, node) -> None:
        if start > node.end or node.start >= end:
            return node.val

        if start <= node.start and end > node.end:
            node.val += 1
            node.lazy += 1
            return node.val

        mid = (node.start + node.end ) //2

        if node.left is None:
            node.left = Node(node.start, mid)

        if node.right is None:
            node.right = Node(mid + 1, node.end)
        a = self.add(start, end, node.left)
        b = self.add(start, end, node.right)
        node.val = max(a, b) + node.lazy

        return node.val

    def getMax(self):
        return self.tree.val

class MyCalendarThree:

    def __init__(self):
        self.segTree = MaxSegTree(0, int(10e9))

    def book(self, start: int, end: int) -> int:
        self.segTree.add(start, end, self.segTree.tree)
        return self.segTree.getMax()

Initially, we have an implementation of the Node class, which is the basis for any kind of tree data structure, so it makes sense that we need it. In it, we have the beginning and end of its respective interval (start and end), its K value (represented by val), its amount of event bookings that happened to fall inside that interval (represented by lazy) and its standart left and right children.

So far so good, pretty basic stuff. Now let's get into our actual SegTree.

Our MyCalendarThree class is only a wrapper to our MaxSegTree, and its book method is only a call to our add method, so by understanding the underlying data structure, you'll get the big picture.

We instantiate our SegTree by simply inputting a starting interval, which makes sense since it'll be our first node. In the MyCalendarThree, we use the values of 0 as our starting point and 10e9 as our end to signal that we are dealing with an empty SegTree, so there will always be intersections, since we can assume that no event lasts 10e9 anything.

The overall strategy of our add method is to increase our initially huge SegTree by dividing its intervals until we can find an absolute intersection, which is our second if.

Let's use an example. Let's add an event that goes from 10 to 20. So start = 10 and end = 20. We skip the first if, since 10 > 10e9 and 0 >= 20 are both false. This would be true if we were dealing with disjoint intervals. We also skip the second if, because 10 <= 0 and 20 > 10e9 are both false. This happens when we have the whole node contained within our input interval. Therefore, in this if, we update its val field, signaling that one more event happened within that time period.

Going back to our example, we skip these two ifs and divide our initial interval in two and then we recursively call the add method to both halves that we just created. We instantiate two variables, a and b to hold the value that each add call returns. Remember, the add method returns the K value.

We will continue to do so until we stumble upon an interval that is contained within [10,20], on the left side, and another one that is disjoint from [10,20], on the right side. Therefore, our a value will be bigger than b and it will be the K value for all of these calls. This is why call max(a,b). Then, we add this maximum value to the already stablished amount booked events within that interval, and create its maximum K. At last, we return it.

By doing this bottom-up, when we get to the root of the tree, we already have the largest K value possible for the whole tree. Then, to get the maximum, we only call the val field of the root node.

This post was written by:
Artur Curinga
Mateus Melo
Larissa Moura
Giovanni Rosário
Daniel Oliveira Guerra

Top comments (0)