DEV Community

John Calabrese
John Calabrese

Posted on

From DFS to Topological Sort

Recently I revisited this leetcode question https://leetcode.com/problems/alien-dictionary/solution/. I first learned about it a year and a half ago when I was a data science fellow at Insight, in New York. It's a fun problem, though a terrible interview question. I think it would be much better if you could at least assume the alien dictionary has no inconsistencies (ie that an order can always be extracted).

DFS

Depth-first search is a standard method to explore a graph. For example, let's say we want to print the value of all the nodes in a graph. Here is a(n incorrect) recursive implementation.

def dfs(node):
    print(node.val)
    for child in node.children:
        dfs(child)
Enter fullscreen mode Exit fullscreen mode

There are two big issues with this code. The first is that it might get stuck in a loop. For the graph 1->2->3->1, calling dfs(node_1) will loop forever. Let's fix this by keeping track of which nodes have already been visited by the DFS.

def dfs(node):
    if node in visited:
        return None
    print(node.val)
    visited.add(node)
    for child in node.children:
        dfs(child)

visited = set()
Enter fullscreen mode Exit fullscreen mode

The second issue is that it has no awareness of the whole graph. For the graph 1 -> 2 -> 3, calling dfs(2) will miss node 1. So let's assume we have a dictionary edges, where edges[node] lists all the children of node.

def print_nodes(edges):
    def dfs(node):
        if node in visited:
            return None
        print(node.val)
        visited.add(node)
        for child in edges[node]:
            dfs(child)

    visited = set()
    for node in edges.keys():
        dfs(node)
Enter fullscreen mode Exit fullscreen mode

For the graph

    1
   /
  2
 / \
3   4
Enter fullscreen mode Exit fullscreen mode

(where I assume the edges are all pointing down) the function above will print the nodes in the order 1,2,3,4 (assuming 3 comes before 4 in edges[node_2]). This would be a "pre-order" way of printing the graph. The "post-order" way would be

def print_nodes_post(edges):
    def dfs(node):
        if node in visited:
            return None
        visited.add(node)
        for child in edges[node]:
            dfs(child)
        print(node)

    visited = set()
    for node in edges.keys():
        dfs(node)
Enter fullscreen mode Exit fullscreen mode

yielding 3,4,2,1.

Obviously printing out the nodes in a graph is a silly task (we could just directly use edges.keys()). The point is that DFS allows us to explore a whole graph, and we can perform tasks as we move along.

def dfs(node):
    visited.add(node)
    do_some_work_pre_children(node)
    for child in edges[node]:
        dfs(child)
    do_some_work_post_children(node)
Enter fullscreen mode Exit fullscreen mode

Detecting cycles

Given a graph, let's write a function to detect if it has any cycles. Here is an (incorrect) approach.

def is_acyclic(edges):
    def dfs(node):
        if node in visited:
            # cycle detected?
            return False

        visited.add(node)
        for child in edges[node]:
            if dfs(child) is False:
                return False
        return True

    visited = set()
    for node in edges.keys():
        if node not in visited:
            if dfs(node) is False:
                return False
    return True
Enter fullscreen mode Exit fullscreen mode

The main thought process here is that if upon calling dfs the node is already in visited then a cycle must be present. However this is deeply flawed. For example with the graph

  1
 / \
2   3
 \ /
  4
Enter fullscreen mode Exit fullscreen mode

(where again all the edges are pointing downwards) if we start DFS from 0 everything is fine. But if we instead started from node 2 (or 3, or 4) the function will return False, ie that the graph contains a cycle.

To fix this, we add an additional visited set, which is specific to the recursion.

def is_acyclic(edges):
    def dfs(node):
        if node in visited:
            return True

        visited.add(node)
        visited_dfs.add(node)
        for child in edges[node]:
            if child in visited_dfs:
                return False
            if dfs(child) is False:
                return False

        # after we explored node and all its children
        # we can remove the temporary marker
        # if we don't we'll run into the same issue
        # as in the naive approach
        visited_dfs.remove(node)
        return True

    visited = set()
    visited_dfs = set()
    for node in edges.keys():
        if dfs(node) is False:
            return False
    return True
Enter fullscreen mode Exit fullscreen mode

Topological Sort

Let's say you have a graph which represents certain tasks to be executed (eg ETL tasks for an orchestrator like airflow or prefect, or a computation graph for a neural network). Edges in this graph represent a dependency between two tasks. Because of these dependencies, we can't just execute these tasks in a random order. It would be nice if we could extract a total order of the tasks, compatible with the edges, ie if you can go from node a to node b, we should have a < b in the order.

We can't do this for all graphs. Whatever order we choose for the nodes of 1->2->3->1 it will be incompatible with the graph structure. However if we have a DAG (directed acyclic graph), ie a graph without cycles, then we can. One way to come up with such an order is via DFS. For the graph

  0
  |
  1
 / \
2   3
 \ /
  4

5
Enter fullscreen mode Exit fullscreen mode

[0,1,2,3,4,5], [0,1,3,2,4,5], [0,5,1,2,3,4] are all acceptable orders. The idea is to explore the graph using DFS and to add nodes to the list in post-order. For example let's start by exploring the graph above from the node 2. This will yield the list [4,2]. Continuing from node 1, we will have [4,2,3,1]. Continue with node 5 to get [4,2,3,1,5], and finally add 0 to get [4,2,3,1,5,0]. Reversing this list we get [0,5,1,3,2,4], which is a valid total order.

def topological_sort(edges):
    """Assumes graph is acyclic."""
    def dfs(node):
        if node in visited:
            return None
        visited.add(node)
        for child in edges[node]:
                dfs(child)
        order.append(node)

    order = []
    visited = set()
    for node in edges.keys():
        dfs(node)

    order.reverse()
    return order
Enter fullscreen mode Exit fullscreen mode

Combining the two

Of course, if your graph is not guaranteed to be acyclic, we can combine the two operations at once.

def topological_sort_if_acyclic(edges):
    """Return an order if acyclic, else return []"""
    def dfs(node):
        if node in visited:
            return True

        visited.add(node)
        visited_dfs.add(node)

        for child in edges[node]:
            if child in visited_dfs:
                return False
            if dfs(child) is False:
                return False
        visited_dfs.remove(node)
        order.append(node)
        return True

    order = []
    visited = set()
    visited_dfs = set()
    for node in edges.keys():
        if dfs(node) is False:
            return []

    order.reverse()
    return order
Enter fullscreen mode Exit fullscreen mode

Top comments (0)