In the last few months, I have been building an in-browser differentiable photo editor. The goal is to make each image processing function easily parameterisable by a machine learning model or be a learned transformation. I want to be able to use each in a loss function for downstream tasks. Today I wanted to share an automatic white balance correction approach (and any other global image transformations).
I’ll make no hard claims about the state of the art results, as I am primarily looking for a parameter efficient (4k parameters!) and fast approach that performs better than grey world and YUV coordinates based transforms. However, in my preliminary tests, it works pretty well compared to U-NET based approaches. Importantly, it can be trained in a couple of hours on a GTX 1070 without any hyperparameter tuning. It is also performant in the browser on a lightweight device, which is my primary use case.
So, let’s jump into the code.
class WhiteBalance(nn.Module): def __init__(self, hidden_dim=10): super(WhiteBalance, self).__init__() self.parameter_network = nn.Sequential( nn.Conv2d(3, 7, kernel_size=(3, 3), padding=1), nn.MaxPool2d(kernel_size=2), nn.LeakyReLU(negative_slope=0.2), nn.Conv2d(7, 14, kernel_size=(3, 3), padding=1), nn.MaxPool2d(kernel_size=2), nn.LeakyReLU(negative_slope=0.2), nn.Conv2d(14, 3, kernel_size=(3, 3), padding=1), ) self.feature_dim = 192 self.output_dim = 96 # Enough Parameters for a 16 dim inner neural network self.weight_size = self.output_dim//2 self.parameter_network.apply(weights_init) self.linear = nn.Sequential( nn.Linear(self.feature_dim, hidden_dim), nn.LeakyReLU(negative_slope=0.2), nn.Linear(hidden_dim, self.output_dim), ) def inner_network(self, image, params): batch, c, h, w = image.size() # split parameters for each layer of the inner network w_1, w_2 = params[:, :self.weight_size], params[:, self.weight_size:] hidden_dim_inner = self.weight_size//c # Reshape matrix into hidden_dim x channels w_1, w_2 = w_1.view(-1, hidden_dim_inner, c), w_2.view(-1, c, hidden_dim_inner) # apply batch of neural networks to image pixelwise pixels = image.permute(0, 2, 3, 1).contiguous().view(batch, -1, c) output = torch.selu(pixels.bmm(w_1.permute(0, 2, 1))) output = output.bmm(w_2.permute(0, 2, 1)).view(-1, h, w, c).permute(0, 3, 1, 2) return output def forward(self, input_image, full_resolution_image): batch = input_image.size(0) # Output Parameters to a small inner neural network using downsampled input image or patch. params = self.linear(self.parameter_network(input_image).view(batch, self.feature_dim)) # Apply pixelwise neural network to full scale image output = self.inner_network(full_resolution_image, params) # Scale in [0, 1] return torch.clamp(output, 0, 1)
So, a little about the architecture, the model takes a 32x32 sRGB image as input and outputs the parameters for a small two-layer neural network with a hidden dimension of 16 and an input and output size of 3. What this means, we learn to output a pixel-wise transformation that can apply to an image of any resolution.
There are a lot of benefits to this.
- We can easily pass the parameters to a WebGL shader.
- We can create a 3D lookup table (LUT) by feeding each pixel color progressively, allowing us to apply the same white balancing over a video in realtime.
- It is pretty flexible in terms of what type of nonlinear pixel-wise transforms we can learn.
As far as training goes, the dataset would be before and after pairs of white-balanced photos with an MAE loss or similar. I noticed performance gains from training with a higher resolution than the input to the parameter network. For example, you might have a 32x32 image as input. However, you compute the loss on a 224x224 version. It seems to have a regularisation effect.
Think something like this...
input_image, larger_image, larger_target = batch output = model(input_image, larger_image) loss = F.l1_loss(output, target)
There are ways to speed this up and reduce parameters further if you are looking for an option for a low energy device. For example, you could do a convolutional channel-wise average pooling over the output of the parameter network (perhaps with a learned weighting?). That would allow you to remove the linear networks, as well as the inner neural network.
But I prefer the flexibility of a global nonlinear pixel-wise transform for my use case.
Thanks for Reading, if you're interested in this sorta stuff I will be posting quite often to my Twitter.