DEV Community

Kavin Bharathi
Kavin Bharathi

Posted on • Updated on

The fascinating Wave Function Collapse algorithm.

What is a wave function and why does it collapse?

Wave function collapse is a algorithm that can procedurally generate images, text, audio and almost anything that follows a pattern of a simple set of initial parameters and a set of rules. But the neat part is, wfc(short for wave function collapse) does all these without any artificial intelligence shenanigans. The output of wfc is a random showcase of the emergent behavior/rules built by you. This emergent behavior is what fascinates me! Now let's dive deeper into the nuances of wfc!

The "wave" in wfc!

In this article we'll talk about image generation only, unless specified otherwise. The "wave function" is the initial state of the image where no pixel or "cell" has been "collapsed" yet. There are many quotes involved in the last sentence, so let's break it down to simpler parts.

The wave function:

The wave function is the grid representation of the to-be generated image. This grid will hold the final image after the implementation of the algorithm is complete. When we say the function is being collapsed, it means that a cell is being assigned a specific sub image. For example,

Tile 0Tile 1Tile 2Tile 3Tile 4

These tiles are the basic constituents of the final image. Note that by randomly assigning tiles to each cell in the image, we'll not get the desired "pattern" since will cases where tiles connect into nothingness.

Image generated with randomly assigned tiles

To improve this haphazard placement of tiles while also making large scale images efficiently, we use the wave function collapse algorithm.

Image generated with wave function collapse

The "collapse" in wave function:

Finally, it's time for the fun part. I am going to use Python(3.8) in this article, but you can use any language of your choice after you understand the underlying algorithm. Let's go.

Step 1: Building the display

The visualization part is going to be handled by Pygame, a python library for game development. But we'll use it to just display some images onto the screen(and also other little functionalities that'll come later).

To install Pygame,
in windows

pip install pygame
Enter fullscreen mode Exit fullscreen mode

and in unix systems,

pip3 install pygame
Enter fullscreen mode Exit fullscreen mode

To build a basic display,

import pygame
pygame.init()

# global variables
width = 600
height = 600
display = pygame.display.set_mode((width, height))

def main():
    # game loop
    loop = True
    while loop:

        display.fill((0, 0, 0))
        # event handler
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                loop = False

            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_d:
                    hover_toggle = not hover_toggle

                if event.key == pygame.K_q:
                    loop = False
                    exit()

        # update the display
        pygame.display.flip()

# calling the main function
if __name__ == "__main__":
    main()
Enter fullscreen mode Exit fullscreen mode
Step 2: Creating the Cell class

Each individual "pixel" in the final image is going to be represented as a cell object. Therefore we are going to write a cell class that looks like this,

# Cell class
class Cell:
    def __init__(self, x, y, rez, options):
        self.x = x
        self.y = y
        self.rez = rez
        self.options = options
        self.collapsed = False

    # method for drawing the cell
    def draw(self, win):        
        if len(self.options) == 1:
            self.options[0].draw(win, self.y * self.rez, self.x * self.rez)

    # return the entropy of the cell
    def entropy(self):
        pass

    # update the collapsed variable
    def update(self):
        pass

    # observe the cell/collapse the cell
    def observe(self):
        pass
Enter fullscreen mode Exit fullscreen mode

You'll notice that the methods 'entropy()', 'update()', and 'observe()' lack any actual code in them. Before writing the code let's understand what they actually do.

  • Entropy:
    The entropy of a cell is a way of quantizing the number of choices for a cell. The choices are also called the states of the cell.

  • Observe:
    By observing a cell, we are going to select a random state from all the possible states. This is done after we calculate the possible states by inferring from the nearby cells. This is the crust of the wfc algorithm.

  • Update:
    In essence, update() is used to update the "collapsed" data variable of the cell object. The collapsed variable is True if the number of possible states of the cell is 1(i.e., the length of the options variable is 1) and False if otherwise.

Now, the content of the methods become,

    # return the entropy/the length of options
    def entropy(self):
        return len(self.options)

    # update the collapsed variable
    def update(self):
        self.collapsed = bool(self.entropy() == 1)

    # observe the cell/collapse the cell
    def observe(self):
        try:
            self.options = [random.choice(self.options)]
            self.collapsed = True
        except:
            return
Enter fullscreen mode Exit fullscreen mode

With this the cell class is finally complete.

Step 3: Creating the Tile class

The tile class is used for managing the individual tiles and setting the rules for assembling them in the image. The rules for each tile set may vary, so make sure to update the rulesets if you wanna change the tile set. There is also algorithms to generate rulesets based on the given images. But we'll not cover that in this article.

The tile class is simple. Each tile instance is assigned an index variable and is represented visually by the tile image. The possible tile that can placed above, to the right, below and to the left are stored in self.up, self.right, self.down and self.left data variables respectively. These variables are used to update each cell and make the "collapsing" of the wave function possible.

class Tile:
    def __init__(self, img):
        self.img = img
        self.index = -1
        self.edges = []
        self.up = []
        self.right = []
        self.down = []
        self.left = []

    # draw a single tile
    def draw(self, win, x, y):
        win.blit(self.img, (x, y))

    # set the rules for each tile
    def set_rules(self, tiles):
        for tile in tiles:
            # upper edge
            if self.edges[0] == tile.edges[2]:
                self.up.append(tile)

            # right edge
            if self.edges[1] == tile.edges[3]:
                self.right.append(tile)

            # lower edge
            if self.edges[2] == tile.edges[0]:
                self.down.append(tile)

            # left edge
            if self.edges[3] == tile.edges[1]:
                self.left.append(tile)

Enter fullscreen mode Exit fullscreen mode
Step 4: Creating the Grid class

The grid id going to hold the "wave", i.e., the image grid of the final output. This is the main core of the algorithm. So, let's get started.

The grid is going to hold a width, a height, a resolution variable(rez), a 2D array to hold all the cells in the output and a list of all the options(tiles) in the image.

Hence, the Grid class is initiated by,

class Grid:
    def __init__(self, width, height, rez, options):
        self.width = width
        self.height = height
        self.rez = rez
        self.w = self.width // self.rez
        self.h = self.height // self.rez
        self.grid = []
        self.options = options
Enter fullscreen mode Exit fullscreen mode

And a simple method for displaying the 2D array in image format to the pygame display.

    # draw each cell in the grid
    def draw(self, win):
        for row in self.grid:
            for cell in row:
                cell.draw(win)
Enter fullscreen mode Exit fullscreen mode

Then we need to initiate all the cells in the grid 2D array. We are going to make use of a method called "initiate()" to do that.

    # initiate each spot in the grid with a cell object
    def initiate(self):
        for i in range(self.w):
            self.grid.append([])
            for j in range(self.h):
                cell = Cell(i, j, self.rez, self.options)
                self.grid[i].append(cell)
Enter fullscreen mode Exit fullscreen mode

Now, during each frame/cycle of the program, we are going to loop through all the cells in the grid array and calculate the entropy of each cell. Then we "collapse" the cell with minimum entropy to a single state. "Collapse" refers to the action of arbitrarily selecting a state from the possible options of the cell, based on its neighbors. The entropy of each cell in updated using the update method for every frame.

The heuristic_pick (picking cell with minimum entropy) method is,

    # arbitrarily pick a cell using [entropy heuristic]
    def heuristic_pick(self):

        # shallow copy of a grid
        grid_copy = [i for row in self.grid for i in row]
        grid_copy.sort(key = lambda x:x.entropy())

        filtered_grid = list(filter(lambda x:x.entropy() > 1, grid_copy))
        if filtered_grid == []:
            return None

        least_entropy_cell = filtered_grid[0]
        filtered_grid = list(filter(lambda x:x.entropy()==least_entropy_cell.entropy(), filtered_grid))     

        # return a pick if filtered copy i s not empty
        pick = random.choice(filtered_grid)
        return pick
Enter fullscreen mode Exit fullscreen mode

Finally, we are going to perform the "collapsing" phase of the algorithm. Here we are going to loop through all the cells and update their entropy based on their neighbors. In the initial state, the algorithm will collapse a random cell to a random state. Then on, the cell entropy updates will guide the way. We'll check the top, right, bottom and left neighbors of each cell and change its self.options data variable accordingly. This options variable is used to quantify the entropy of each cell.

The code is given by,

    # [WAVE FUNCTION COLLAPSE] algorithm
    def collapse(self):

        # pick a random cell using entropy heuristic
        pick = self.heuristic_pick()
        if pick:
            self.grid[pick.x][pick.y].options
            self.grid[pick.x][pick.y].observe()
        else:
            return

        # shallow copy of the gris
        next_grid = copy.copy(self.grid)

        # update the entropy values and superpositions of each cell in the grid
        for i in range(len(self.grid)):
            for j in range(len(self.grid[0])):
                if self.grid[i][j].collapsed:
                    next_grid[i][j] = self.grid[i][j]

                else:
                    # cumulative_valid_options will hold the options that will satisfy the "down", "right", "up", "left" 
                    # conditions of each cell in the grid. The cumulative_valid_options is computed by,

                    cumulative_valid_options = self.options
                    # check above cell
                    cell_above = self.grid[(i - 1) % self.w][j]
                    valid_options = []                          # holds the valid options for the current cell to fit with the above cell
                    for option in cell_above.options:
                        valid_options.extend(option.down)
                    cumulative_valid_options = [option for option in cumulative_valid_options if option in valid_options]

                    # check right cell
                    cell_right = self.grid[i][(j + 1) % self.h]
                    valid_options = []                          # holds the valid options for the current cell to fit with the right cell
                    for option in cell_right.options:
                        valid_options.extend(option.left)
                    cumulative_valid_options = [option for option in cumulative_valid_options if option in valid_options]

                    # check down cell
                    cell_down = self.grid[(i + 1) % self.w][j]
                    valid_options = []                          # holds the valid options for the current cell to fit with the down cell
                    for option in cell_down.options:
                        valid_options.extend(option.up)
                    cumulative_valid_options = [option for option in cumulative_valid_options if option in valid_options]

                    # check left cell
                    cell_left = self.grid[i][(j - 1) % self.h]
                    valid_options = []                          # holds the valid options for the current cell to fit with the left cell
                    for option in cell_left.options:
                        valid_options.extend(option.right)
                    cumulative_valid_options = [option for option in cumulative_valid_options if option in valid_options]

                    # finally assign the cumulative_valid_options options to be the current cells valid options
                    next_grid[i][j].options = cumulative_valid_options
                    next_grid[i][j].update()

        # re-assign the grid value after cell evaluation
        self.grid = copy.copy(next_grid)

Enter fullscreen mode Exit fullscreen mode

I've added comments to make it a little more understandable.

But before testing this out, we have make our front-end layer compatible with out algorithm in the back-end.

For that we have to add a little more to our code. First for loading images with padding,

# function for loading images with given resolution/size
def load_image(path, rez_, padding = 0):
    img = pygame.image.load(path).convert_alpha()
    img = pygame.transform.scale(img, (rez_ - padding, rez_ - padding))
    return img

Enter fullscreen mode Exit fullscreen mode

Then, we'll update the main function to initiate and update the grid("wave").


# main game function
def main():
    # loading tile images
    options = []
    for i in range(5):
        # load tetris tile
        img = load_image(f"./assets/{i}.png", rez)

    # edge conditions for tetris tiles
    options[0].edges = [0, 0, 0, 0]
    options[1].edges = [1, 1, 0, 1]
    options[2].edges = [1, 1, 1, 0]
    options[3].edges = [0, 1, 1, 1]
    options[4].edges = [1, 0, 1, 1]

    # update tile rules for each tile
    for i, tile in enumerate(options):
        tile.index = i 
        tile.set_rules(options)

    # wave grid
    wave = Grid(width, height, rez, options)
    wave.initiate()

    # game loop
    loop = True
    while loop:

        display.fill((0, 0, 0))
        # event handler
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                loop = False

            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_d:
                    hover_toggle = not hover_toggle

                if event.key == pygame.K_q:
                    loop = False
                    exit()

        # grid draw function
        wave.draw(display)

        # grid collapse method to run the alogorithm
        wave.collapse()

        # update the display
        pygame.display.flip()

Enter fullscreen mode Exit fullscreen mode

Finally, call the main function at last,

# calling the main function
if __name__ == "__main__":
    main()

Enter fullscreen mode Exit fullscreen mode

Note that the algorithm is not guaranteed to find the solution. It will fill the image as far as possible. To guarantee a solution, we can use backtracking algorithm and undo previously collapsed cell states and continue until a solution is reached.

Thank you for reading. Feel free to comment your thoughts and opinions about this article.

Github repo for wfc

Top comments (1)

Collapse
 
paul_merrell42 profile image
Paul Merrell

Just to give a little more background on this. WFC is based on the Model Synthesis Algorithm published in 2007. For more info see: paulmerrell.org/model-synthesis/