DEV Community


A simple and light-weight approach to automatic white balancing.

algomancer_3 profile image Adam Hibble ・3 min read

In the last few months, I have been building an in-browser differentiable photo editor. The goal is to make each image processing function easily parameterisable by a machine learning model or be a learned transformation. I want to be able to use each in a loss function for downstream tasks. Today I wanted to share an automatic white balance correction approach (and any other global image transformations).

I’ll make no hard claims about the state of the art results, as I am primarily looking for a parameter efficient (4k parameters!) and fast approach that performs better than grey world and YUV coordinates based transforms. However, in my preliminary tests, it works pretty well compared to U-NET based approaches. Importantly, it can be trained in a couple of hours on a GTX 1070 without any hyperparameter tuning. It is also performant in the browser on a lightweight device, which is my primary use case.

So, let’s jump into the code.

class WhiteBalance(nn.Module):
    def __init__(self, hidden_dim=10):
        super(WhiteBalance, self).__init__()
        self.parameter_network = nn.Sequential(
            nn.Conv2d(3, 7, kernel_size=(3, 3), padding=1),
            nn.Conv2d(7, 14, kernel_size=(3, 3), padding=1),
            nn.Conv2d(14, 3, kernel_size=(3, 3), padding=1),
        self.feature_dim = 192
        self.output_dim = 96 # Enough Parameters for a 16 dim inner neural network
        self.weight_size = self.output_dim//2

        self.linear = nn.Sequential(
            nn.Linear(self.feature_dim, hidden_dim),
            nn.Linear(hidden_dim, self.output_dim),

    def inner_network(self, image, params):

        batch, c, h, w = image.size()
        # split parameters for each layer of the inner network
        w_1, w_2 = params[:, :self.weight_size], params[:, self.weight_size:]
        hidden_dim_inner = self.weight_size//c
        # Reshape matrix into hidden_dim x channels
        w_1, w_2 = w_1.view(-1, hidden_dim_inner, c), w_2.view(-1, c, hidden_dim_inner)
        # apply batch of neural networks to image pixelwise
        pixels = image.permute(0, 2, 3, 1).contiguous().view(batch, -1, c)
        output = torch.selu(pixels.bmm(w_1.permute(0, 2, 1)))
        output = output.bmm(w_2.permute(0, 2, 1)).view(-1, h, w, c).permute(0, 3, 1, 2)
        return output

    def forward(self, input_image, full_resolution_image):
        batch = input_image.size(0)
        # Output Parameters to a small inner neural network using downsampled input image or patch.
        params = self.linear(self.parameter_network(input_image).view(batch,  self.feature_dim))
        # Apply pixelwise neural network to full scale image
        output = self.inner_network(full_resolution_image, params)
        # Scale in [0, 1]
        return torch.clamp(output, 0, 1)

Enter fullscreen mode Exit fullscreen mode

So, a little about the architecture, the model takes a 32x32 sRGB image as input and outputs the parameters for a small two-layer neural network with a hidden dimension of 16 and an input and output size of 3. What this means, we learn to output a pixel-wise transformation that can apply to an image of any resolution.

There are a lot of benefits to this.

  • We can easily pass the parameters to a WebGL shader.
  • We can create a 3D lookup table (LUT) by feeding each pixel color progressively, allowing us to apply the same white balancing over a video in realtime.
  • It is pretty flexible in terms of what type of nonlinear pixel-wise transforms we can learn.

As far as training goes, the dataset would be before and after pairs of white-balanced photos with an MAE loss or similar. I noticed performance gains from training with a higher resolution than the input to the parameter network. For example, you might have a 32x32 image as input. However, you compute the loss on a 224x224 version. It seems to have a regularisation effect.

Think something like this...

input_image, larger_image, larger_target = batch
output = model(input_image, larger_image)
loss = F.l1_loss(output, target)

Enter fullscreen mode Exit fullscreen mode

Here are some samples. (Ignore duplicates, testing slight input variations)

There are ways to speed this up and reduce parameters further if you are looking for an option for a low energy device. For example, you could do a convolutional channel-wise average pooling over the output of the parameter network (perhaps with a learned weighting?). That would allow you to remove the linear networks, as well as the inner neural network.

But I prefer the flexibility of a global nonlinear pixel-wise transform for my use case.

Thanks for Reading, if you're interested in this sorta stuff I will be posting quite often to my Twitter.

Discussion (0)

Editor guide