DEV Community

Cover image for Sum of Three, Sum of Four, and Beyond? In Python!
Whiteboarding in Python
Whiteboarding in Python

Posted on • Updated on

Sum of Three, Sum of Four, and Beyond? In Python!

<< #21 Tree Practice | View Solution on GitHub

cat asleep on computer
Actual pic of me after trying to solve this problem (Image: Bend Pet Express)

Today's problem has stumped many a programmer. A colleague of mine and I sat down with "Sum of Three", talked it out, and after more than an hour, finally finished a solution that was by no means elegant or optimal. I was joking with him that next we should try "Sum of Four", which I had totally made up. Turns out, it's real. After another long whiteboarding session, we came up with...a bunch of ideas and no solution.

We were so focused on finding an elegant solution that we missed being able to solve the problem at all. But here's an obvious solution: brute force. Brute force may be simple, and it certainly isn't optimal, but it gets the job done. And when you have no other solution, it helps to have at least one way to solve it.

It ain't much, but it's honest work.
(Image: KnowYourMeme)

Today, we're going to look at "Sum of Three" and "Sum of Four", and solve them using brute force. Then, you'll see how we can expand the brute force solution using recursion to allow for "Sum of K", which will work for 3, 4, or any amount of numbers we want to sum up. You'll see that having a brute force solution might not be scalable, but you can still build some powerful things.

Let's get started with the simpler problem, "Sum of Three". Here is the prompt:

# Given an array nums of n integers and a target,
# are there elements a, b, c in nums 
# such that a + b + c = target?

# Example:

# Input: 
# nums = [-1,0,1,2,-1,-4]
# target = 0

# Output: [[-1,-1,2],[-1,0,1]]

# Notice that the solution set must not contain duplicate triplets.
Enter fullscreen mode Exit fullscreen mode

Were given a list of numbers, and we want to find every three numbers in the list that sum to zero, or whichever number is specified at the target. How would you go about solving this?

1. Approach

The brute force solution will involve iteration. Imagine we look at each number, and check it against every other number, and then check those two numbers against every other other number. What does this sound like? If you said a triple nested for loop, you're right (I'm sure there must also be a recursive way to do this part...let me know in the comments if you found it!).

But theres one other caveat, as listed in the prompt: Notice that the solution set must not contain duplicate triplets. How will we account for this? If you've been following this blog, you're probably thinking of using a set (more on this here). And yep, that's what we will be using. We'll codify each working triplet into a string and put it in the set to make sure we haven't included it already.

2. Setup

We're ready to start writing this bad boy. Our method sum_of_three will take in two arguments: the list nums and the target.

def sum_of_three(nums, target):
Enter fullscreen mode Exit fullscreen mode

Next, we need to check for an edge case. What if nums has less than 3 numbers in it? Then we can't find a triplet. In this case, we'll return an empty list.

def sum_of_three(nums, target):
  if len(nums) < 3:
    return []
Enter fullscreen mode Exit fullscreen mode

Finally, we need to initialize some variables that we'll be using. Firstly, the output list which will contain lists of triplets. We'll also start one of those triplet lists and call it current. Once we populate current with a working triplet, we can append it to output.

  output = []
  current = []
Enter fullscreen mode Exit fullscreen mode

And, as discussed earlier, we'll need to make a set. Additionally, we'll want to sort the list to make sure we don't add the same triplet to the set, just in a different order (for instance, only one of [1, -2, 1] and [1, 1, -2] should be added).

  checked = set()
  nums.sort()
Enter fullscreen mode Exit fullscreen mode

3. Iteration

We're ready to start our for loops. We're making three of them, and the most important part to consider is our indices' upper and lower bounds.

For the first loop, we'll start with the first number in the list and go until the 3rd from the last spot. Why? We need two additional numbers to check this with, and once we get to the 2nd or last index, there aren't enough numbers behind it to make a triplet.

  for i in range(len(nums) - 2):
Enter fullscreen mode Exit fullscreen mode

As a review, range(start, end) accepts two arguments: a start (inclusive) and an end (exclusive). If the start argument is missing, it's assumed to be 0.

Okay, what will we do inside this first for loop? Here's the agenda:

  1. Add the first number to current.
  2. Subtract that number from the sum.
  3. Check it with the other numbers.
  4. Undo steps 1 and 2 so we can move on to the next number.

Let's write these in code. The first one is easy, simply append the number at index i to current.

    current.append(nums[i])
Enter fullscreen mode Exit fullscreen mode

The next step is also pretty simple. You can subtract from the target, but I prefer making a new variable that makes it clear that we're changing it, called working_target.

    working_target = target - nums[i]
Enter fullscreen mode Exit fullscreen mode

Step 3 is more complicated. "Check it with the other numbers" is my shorthand for doing the nested for loops. I'll comment it out for now.

Step 4 is simple. Remove the last element from current with current.pop(), and to move the target back, well, we don't need to do anything because it's going to be re-instantiated each iteration. Altogether, our loop looks like this so far:

  for i in range(len(nums) - 2):
    current.append(nums[i])
    working_target = target - nums[i]
    # do some for loops
    current.pop()
Enter fullscreen mode Exit fullscreen mode

Great, we're ready to move onto loop #2.

What are our indices? We start at the index after i and go until the second-to-last spot in nums.

    for j in range(i + 1, len(nums) - 1):
Enter fullscreen mode Exit fullscreen mode

Next, we follow our steps 1-4 again. That is, we append the number at j to current, subtract it from working_target, leave space for the next loop, and then undo steps 1 and 2.

    for j in range(i + 1, len(nums) - 1):
      working_target -= nums[j]
      current.append(nums[j])
      # do another for loop
      current.pop()
      working_target += nums[j]
Enter fullscreen mode Exit fullscreen mode

Finally, we can write our innermost loop. The indices should be familiar by now, we start at j + 1 and go until the last index.

      for k in range(j + 1, len(nums)):
Enter fullscreen mode Exit fullscreen mode

We can append the number at k to current, but when we subtract from the working_total, now we want to check if it equals zero. If the numbers all sum to the target, then target - a - b - c = 0 (if you remember back to algebra class).

      for k in range(j + 1, len(nums)):
        current.append(nums[k])
        if working_target - nums[k] == 0:
Enter fullscreen mode Exit fullscreen mode

What goes in the if statement? Now it's time to use our set to determine if we've checked this triplet already.

We can't put lists into a set, so we have to codify our triplet into a string. The following code will make a string with &s in between the numbers, for example, [1, 1, -2] will get turned into 1&1&-2. You can use any character you want to deliniate the numbers, but you need something so 1, 11 isn't confused with 11, 1.

          # make code to put in duplicate set
          code = str(current[0])
          for n in range(1, len(current)):
            code += "&" + str(current[n])
Enter fullscreen mode Exit fullscreen mode

Now we can check to see if the code is in the set. If not, we can add it to the set and add current to our final output. Remember to add a copy of current, and not just the reference. Otherwise, our triplet could change after we add it to output!

if code not in checked:
output.append(current.copy())
checked.add(code)

Lastly, we pop the third number off of current. Altogether, the final for loop:

      for k in range(j + 1, len(nums)):
        current.append(nums[k])
        if working_target - nums[k] == 0:
          # make code to put in duplicate set
          code = str(current[0])
          for n in range(1, len(current)):
            code += "&" + str(current[n])
          if code not in checked:
            output.append(current.copy())
            checked.add(code)
        current.pop()
Enter fullscreen mode Exit fullscreen mode

4. Wrapping up Sum of Three

All that's left is to return our output list. Then, we're done! Here is the full method; be sure that your indentation matches exactly after all those for loops:

def sum_of_three(nums, target):
  if len(nums) < 3:
    return []
  output = []
  current = []
  checked = set()
  nums.sort()
  for i in range(len(nums) - 2):
    current.append(nums[i])
    working_target = target - nums[i]
    for j in range(i + 1, len(nums) - 1):
      working_target -= nums[j]
      current.append(nums[j])
      for k in range(j + 1, len(nums)):
        current.append(nums[k])
        if working_target - nums[k] == 0:
          # make code to put in duplicate set
          code = str(current[0])
          for n in range(1, len(current)):
            code += "&" + str(current[n])
          if code not in checked:
            output.append(current.copy())
            checked.add(code)
        current.pop()
      current.pop()
      working_target += nums[j]
    current.pop()
    working_target += nums[i]
  return output
Enter fullscreen mode Exit fullscreen mode

Here is some sample driver code to test the method:

nums = [-1,0,1,2,-1,-4]
target = 0
print(sum_of_three(nums, target))
Enter fullscreen mode Exit fullscreen mode

This will give us the final output of [[-1,-1,2],[-1,0,1]]. Notice how there are no duplicate triplets, even though there are two -1s to make the second triplet.

5. Step Up Complexity: Sum of Four

You were feeling pretty good, just finished Sum of Three. But can you next solve...sum of four?

Although it sounds more complex, conceptually, we can use the same brute force method. I know that our runtime will go from the terrible O(N3) to the even worse O(N4), but like I said earlier. It will still *run*.

How will we achieve this? We can simply adjust our code from sum_of_three by adding another loop. Each loop starts at the previous index + 1 and ends progressively closer to the end of the list. Here is what they will look like:

 for i in range(len(nums) - 3):
    for j in range(i + 1, len(nums) - 2):
      for k in range(j + 1, len(nums) - 1):
        for m in range(k + 1, len(nums)):
Enter fullscreen mode Exit fullscreen mode

In each loop, we'll follow the same steps as before: appending to current, subtracting from the target, checking some stuff, and then undoing the previous actions. I'll save you some scrolling and post the method in full:

def sum_of_four(nums, target):
  if len(nums) < 4:
    return []
  output = []
  current = []
  nums.sort()
  checked = set()
  for i in range(len(nums) - 3):
    working_target = target - nums[i]
    current.append(nums[i])
    for j in range(i + 1, len(nums) - 2):
      working_target -= nums[j]
      current.append(nums[j])
      for k in range(j + 1, len(nums) - 1):
        working_target -= nums[k]
        current.append(nums[k])
        for m in range(k + 1, len(nums)):
          current.append(nums[m])
          if working_target - nums[m] == 0:
            # make code to put in duplicate set
            code = str(current[0])
            for n in range(1, len(current)):
              code += "&" + str(current[n])
            if code not in checked:
              output.append(current.copy())
              checked.add(code)
          current.pop()
        current.pop()
        working_target += nums[k]
      current.pop()
      working_target += nums[j]
    current.pop()
  return output
Enter fullscreen mode Exit fullscreen mode

This exactly matches our method for "Sum of Three", just with an extra loop. We can run the following driver code:

nums = [1,0,-1,0,-2,2, 0]
target = 0
print(sum_of_four(nums, target))
Enter fullscreen mode Exit fullscreen mode

And our result is [[-2, -1, 1, 2], [-2, 0, 0, 2], [-1, 0, 0, 1]]. It works!

6. Sum of Five...and Beyond?

You wrote "Sum of Three" and "Sum of Four"; you're feeling pretty confident. What if I told you next you need to solve...Sum of Five?

No problem, you say. Just add another for loop. But you might have noticed something. This looks a little repetitive:

for j in range(i + 1, len(nums) - 2):
  working_target -= nums[j]
  current.append(nums[j])
  for k in range(j + 1, len(nums) - 1):
    working_target -= nums[k]
    current.append(nums[k])
    for m in range(k + 1, len(nums)):
      current.append(nums[m])
Enter fullscreen mode Exit fullscreen mode

And you're right. When we have repetitive code, a good practice is to write a method for it. So what if we wrote a recursive method to run each for loop? Then, we could tell it how many numbers we want to sum together, and it will know how many nested loops to construct.

We'll set up the method exactly we have as before, although this time, we take in an argument k of how many numbers (triplet, quadruplet, etc.) to sum.

def sum_of_k(nums, target, k):
  if len(nums) < k:
    return []
  output = []
  current = []
  nums.sort()
  checked = set()
Enter fullscreen mode Exit fullscreen mode

Next, we'll make a recursive method called sum_helper() that accepts...pretty much every variable we have so far. Additionally, we need to know the previous index (for when we start the loop at i + 1), which starts at 0. This method will alter the arrays, so we just need to call it and then afterwards, return output.

  sum_helper(nums, target, k, output, checked, current, 0)
  return output
Enter fullscreen mode Exit fullscreen mode

7. Iteration in the Recursive Method

Let's think about our approach for the recursive helper method. Each time, we're going to be subtracting from k. After all, "Sum of Four" is just "Sum of Three" with an extra number attached.

Our recursive method will need a base case. If k is zero, we don't have to do anything. In this case, we'll simply return out of the function.

  if k == 0:
    return
Enter fullscreen mode Exit fullscreen mode

If k is greater than zero, we want to loop through nums. We'll start at the index prev and go until the end of the list - k + 1 (we get this by looking at the pattern in our previous solutions).

  for i in range(prev, len(nums) - k + 1):
Enter fullscreen mode Exit fullscreen mode

What's next? Look at our list: we append the current number to current.

    current.append(nums[i])
Enter fullscreen mode Exit fullscreen mode

You might think the next step is to subtract that number from the target, but we need to check something first. This for loop needs to be universal for every iteration of k, and we need to account for if this is the innermost loop.

If this is the case, we need to check if the target minus the number is zero. Then we do exactly what we did before (you could probably copy paste with minimal edits) by adding to the set and if unique, add to the output.

if k == 1 and target - nums[i] == 0:
  # make code to put in duplicate set
  code = str(current[0])
  for n in range(1, len(current)):
    code += "&" + str(current[n])
  if code not in checked:
    output.append(current.copy())
    checked.add(code)
Enter fullscreen mode Exit fullscreen mode

Now we can subtract from the working target and write our recursive case. But wait, we can do this all in one step! Here is the recursive call:

sum_helper(nums, target - nums[i], k - 1, output, checked, current, i + 1)
Enter fullscreen mode Exit fullscreen mode

So target moves, k goes down 1, and the current index + 1 becomes the previous index prev.

Finally, we have to undo our addition to current:

current.pop()
Enter fullscreen mode Exit fullscreen mode

And that's it! Together, here are our methods:

def sum_of_k(nums, target, k):
  if len(nums) < k:
    return []
  output = []
  current = []
  nums.sort()
  checked = set()
  sum_helper(nums, target, k, output, checked, current, 0)
  return output

def sum_helper(nums, target, k, output, checked, current, prev):
  if k == 0:
    return
  for i in range(prev, len(nums) - k + 1):
    current.append(nums[i])
    if k == 1 and target - nums[i] == 0:
      # make code to put in duplicate set
      code = str(current[0])
      for n in range(1, len(current)):
        code += "&" + str(current[n])
      if code not in checked:
        output.append(current.copy())
        checked.add(code)
    sum_helper(nums, target - nums[i], k - 1, output, checked, current, i + 1)
    current.pop()
Enter fullscreen mode Exit fullscreen mode

Our method is still pretty unwieldy, but we don't have to adjust it for every changing k. The following driver code will work whether k is 3, 4, or beyond!

nums = [-1,0,1,2,-1,-4, -2]
target = 0
k = 4
print(sum_of_k(nums, target, k))
# -> [[-2, -1, 1, 2], [-1, -1, 0, 2]]
Enter fullscreen mode Exit fullscreen mode

Again, this is a brute force solution. The run time will be O(Nk), which is pretty bad if k is large. This problem could probably be solved more elegantly with linear algebra, so if some of you are particularly knowledgeable on that front and come up with something, let me know in the comments!

<< #21 Tree Practice | View Solution on GitHub

Sheamus Heikkila is formerly a Teaching Assistant at General Assembly. This blog is not associated with GA.

Top comments (0)