loading...

That `overloaded` Trick: Overloading Lambdas in C++17

tmr232 profile image Tamir Bahar Updated on ・4 min read

C++17 has granted us with std::variant. Simply put, it is a type-safe union. To access the value it stores, you can either request a specific type (using std::get or something similar) or "visit" the variant, automatically handling only the data-type that is actually there.
Visiting is done using std::visit, and is fairly straight forward.

Compilation, Execution

#include <variant>
#include <cstdio>
#include <vector>

using var_t = std::variant<int, const char*>; // (1)

struct Print { // (2)
    void operator() (int i) {
        printf("%d\n", i);
    }

    void operator () (const char* str) {
        puts(str);
    }
};

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"}; // (3)

    for (auto& v : vars) {
        std::visit(Print{}, v); // (4)
    }

    return 0;
}

In (1) we define our variant type. In (2) we define a class with an overloaded operator(). This is needed for the call to std::visit. In (3) we define a vector of variants. In (4) we visit each variant. We pass in an instance of Print, and overload resolution ensures that the correct overload will be called for every type.
But this example forces us to write and name an object for the overloaded operator(). We can do better. In fact, the example for std::visit on cppreference already does. Here is an example derived from it:

Compilation, Execution

#include <variant>
#include <cstdio>
#include <vector>

template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; }; // (1)
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;  // (2)

using var_t = std::variant<int, const char*>;

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};

    for (auto& v : vars) {
        std::visit(overloaded {  // (3)
            [](int i) { printf("%d\n", i); },
            [](const char* str) { puts(str); }
        }, v);
    }

    return 0;
}

This is certainly more compact, and we removed the Print struct. But how does it work? You can see a class-template (1), lambdas passed in as arguments for the construction (3), and something with an arrow and some more template magic (2). Let's build it step by step.

First, we want to break the print functions out of Print and compose them later.

Compilation, Execution

struct PrintInt { //(1)
    void operator() (int i) {
        printf("%d\n", i);
    }
};

struct PrintCString { // (2)
    void operator () (const char* str) {
        puts(str);
    }
};

struct Print : PrintInt, PrintCString { // (3)
    using PrintInt::operator();
    using PrintCString::operator();
};

In (1) and (2), we define the same operators as before, but in separate structs. In (3), we are inherit from both of those structs, then explicitly use their operator(). This results in exactly the same results as before. Next, we convert Print into a class template. I'll jump ahead and convert it directly to a variadic template.

Compilation, Execution

template <class... Ts> // (1)
struct Print : Ts... {
    using Ts::operator()...;
};

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};

    for (auto& v : vars) {
        std::visit(Print<PrintCString, PrintInt>{}, v); // (2)
    }

    return 0;
}

In (1) we define the template. We take an arbitrary number of classes, inherit from them, and use their operator(). In (2) we instantiate the Print class-template with PrintCString and PrintInt to get their functionality.
Next, we want to use lambdas to do the same. This is possible because lambdas are not functions; they are objects implementing operator().

Compilation, Execution

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};
    auto PrintInt = [](int i) { printf("%d\n", i); }; // (1)
    auto PrintCString = [](const char* str) { puts(str); };

    for (auto& v : vars) {
        std::visit(
            Print<decltype(PrintCString), decltype(PrintInt)>{PrintCString, PrintInt}, // (2)
            v);
    }

    return 0;
}

In (1) we define the lambdas we need. In (2) we instantiate the template with our lambdas. This is ugly. Since lambdas have unique types, we need to define them before using them as template parameters (deducing their types using decltype). Then, we need to pass the lambdas as arguments for aggregate initialization as lambdas have a delete default constructor. We are close, but not quite there yet.
The <decltype(PrintCString), decltype(PrintInt)> part is really ugly, and causes repetition. But it is needed as ctors cannot do type-deduction. So in proper C++ style, we will create a function to circumvent that.

Compilation, Execution

template <class... Ts> // (1)
auto MakePrint(Ts... ts) {
    return Print<Ts...>{ts...};
}

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};

    for (auto& v : vars) {
        std::visit(
            MakePrint( // (2)
                [](const char* str) { puts(str); },
                [](int i) { printf("%d\n", i); }
                ),
            v);
    }

    return 0;
}

In (1) we define our helper function, to perform type deduction and forward it to the ctor. In (2) we take advantage of our newly found type-deduction to define the lambdas inline. But this is C++17, and we can do better.

C++17 added user-defined deduction guides. Those allow us to instruct the compiler to perform the same actions as our helper function, but without adding another function. Using a suitable deduction guide, the code is as follows.

Compilation, Execution

#include <variant>
#include <cstdio>
#include <vector>

using var_t = std::variant<int, const char*>;

template <class... Ts>
struct Print : Ts... {
    using Ts::operator()...;
};

template <class...Ts> Print(Ts...) -> Print<Ts...>; // (1)

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};

    for (auto& v : vars) {
        std::visit(
            Print{ // (2)
                [](const char* str) { puts(str); },
                [](int i) { printf("%d\n", i); }
            },
            v);
    }

    return 0;
}

In (1) we define a deduction guide which acts as our previous helper function, and in (2) we use the constructor instead of a helper function. Done.

Now we have fully recreated the original example. As Print is no longer indicative of the template-class' behavior, overloaded is probably a better name.

template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

Posted on by:

Discussion

pic
Editor guide
 

It would be nice to mention why we changed () to {} when we switched to deduction guide. E.g. if you add constructor

overloaded(Ts... ts) : Ts(ts)... {}

then you can use () again. At least I was curious about this.

 

I did mention that we need to use aggregate initialization. Though I now think my reasoning for it is wrong. Also - I never tried making it work with ().

 

This also works without deduction guide:

template<class... Ts> struct overload : Ts... {
  overload(Ts...) = delete;
  using Ts::operator()...;
};

I'm confused with this part:

        Print{ // (2)
            [](const char* str) { puts(str); },
            [](int i) { printf("%d\n", i); }
        }

The braces indicate a constructor call, but there is none. There is a non-member helper template though, but it has no body. So, what's happening here?

It's not a non-member helper template, it's a user defined deduction guide. See this article:

arne-mertz.de/2017/06/class-templa...

It works also without deduction guide and the 'using' fix is for g++. 'operator()' cannot be ambiguous in visitors but gcc is very careful with overloading derived functions. In clang we can easily have the same results with minimal code (checked output assembler with Compiler Explorer):

template<class... Ts> struct overload : Ts... {
  overload(Ts...) = delete;
};

There is needed aggregate Initialisation + variadic templates.

 

Very cool code transformations. :) Thank you.

 

Excellent brake-down, ta!

 

That's just... weird...
PS: (...) this in the comment is not fold expression

 

What's weird?

 

Where the standard is going to :) nevertheless the implementation is cool!

 

(2) and (3) in second code snippet are mixed up.

 

Fixed. Thank you!