DEV Community

Wu Haotian
Wu Haotian

Posted on • Originally published at blog.whtsky.me on

Decorator type gymnastics in Python

You need to type hints your decorator

Say you have a simple decorator for adding logging before calling a function.

import logging

def add_log(f):
    def wrapper(*args, **kwargs):
        logging.info("Called f!")
        return f(*args, **kwargs)
    return wrapper


@add_log
def two_sum(a, b):
    return a + b
Enter fullscreen mode Exit fullscreen mode

One day you decide to add type hints for this module -- it's easy to add type hints for two_sum

def two_sum(a: int, b: int) -> int:
    return a + b
Enter fullscreen mode Exit fullscreen mode

But you need to add type hints for your decorator (add_log in this case) too, or you'll get Anyed wrapped function. mypy's reveal_type can be used for verifying this.

import logging

def add_log(f):
    def wrapper(*args, **kwargs): # type: ignore
        logging.info("Called f!")
        return f(*args, **kwargs)
    return wrapper


def two_sum(a: int, b: int) -> int:
    return a + b

reveal_type(two_sum)  # Revealed type is 'def (a: builtins.int, b: builtins.int) -> builtins.int'
reveal_type(add_log(two_sum))  # Revealed type is 'Any'
Enter fullscreen mode Exit fullscreen mode

Simple type hints for simple decorators

Let's adding type hints for this simple decorator. For a simple decorator which doesn't modify the functions' arguments and return (like add_log above ), TypeVar should do the job pretty well.

from typing import TypeVar, Callable, cast
import logging

TCallable = TypeVar("TCallable", bound=Callable)

def add_log(f: TCallable) -> TCallable:
    def wrapper(*args, **kwargs): # type: ignore
        logging.info("Called f!")
        return f(*args, **kwargs)
    return cast(TCallable, wrapper)


@add_log
def two_sum(a: int, b: int) -> int:
    return a + b


reveal_type(two_sum)  # Revealed type is 'def (a: builtins.int, b: builtins.int) -> builtins.int'
reveal_type(add_log(two_sum))  # Revealed type is 'def (a: builtins.int, b: builtins.int) -> builtins.int'
Enter fullscreen mode Exit fullscreen mode

Hard type hints for hard decorators

But, what if you want to modify the arguments and/or return value? Well, there's no easy way to type arguments, but at least you can type return value correctly

from typing import TypeVar, Callable, Awaitable, cast
R = TypeVar('R')
def sync_to_async(f: Callable[..., R]) -> Callable[..., Awaitable[R]]:
    async def wrapper(*args, **kwargs): # type: ignore
        return f(*args, **kwargs)
    return cast(Callable[..., Awaitable[R]], wrapper)

@sync_to_async
def two_sum(a: int, b: int) -> int:
    return a + b

reveal_type(two_sum)  # Revealed type is 'def (*Any, **Any) -> typing.Awaitable[builtins.int*]'
reveal_type(two_sum(2, "3"))  # Revealed type is 'typing.Awaitable[builtins.int*]
Enter fullscreen mode Exit fullscreen mode

The dark way for typing arguments & return values

You can do some type gymnastics.. by generating numbers of TypeVars and using overload.

from typing import TypeVar, Callable, Awaitable, overload

A = TypeVar('A')
B = TypeVar('B')
C = TypeVar('C')
D = TypeVar('D')
E = TypeVar('E')
RV = TypeVar('RV')
@overload
def sync_to_async(f: Callable[[A], RV]) -> Callable[[A], Awaitable[RV]]:
    ...
@overload
def sync_to_async(f: Callable[[A, B], RV]) -> Callable[[A, B], Awaitable[RV]]:
    ...
@overload
def sync_to_async(f: Callable[[A, B, C], RV]) -> Callable[[A, B, C], Awaitable[RV]]:
    ...
@overload
def sync_to_async(f: Callable[[A, B, C, D], RV]) -> Callable[[A, B, C, D], Awaitable[RV]]:
    ...
@overload
def sync_to_async(f: Callable[[A, B, C, D, E], RV]) -> Callable[[A, B, C, D, E], Awaitable[RV]]:
    ...
def sync_to_async(f):
    async def wrapper(*args, **kwargs):
        return f(*args, **kwargs)
    return wrapper

@sync_to_async
def two_sum(a: int, b: int) -> int:
    return a + b

@sync_to_async
def do_log(content: str) -> None:
    print(content)

reveal_type(two_sum)  # Revealed type is 'def (builtins.int*, builtins.int*) -> typing.Awaitable[builtins.int*]'
reveal_type(do_log)  # Revealed type is 'def (builtins.str*) -> typing.Awaitable[None]'
Enter fullscreen mode Exit fullscreen mode

If you insist to go this way, here's the code snippet I used for generating code above:

gymnastics = 5

for i in range(gymnastics):
    char = chr(i + 65)
    print(f"{char} = TypeVar('{char}')")

print("RV = TypeVar('RV')")

for i in range(gymnastics):
    chars = [chr(n + 65) for n in range(i + 1)]
    args = ", ".join(chars)
    print(f"""@overload
def sync_to_async(f: Callable[[{args}], RV]) -> Callable[[{args}], Awaitable[RV]]:
    ...""")
Enter fullscreen mode Exit fullscreen mode

The Future: PEP-612

PEP-612 defines ParamSpec and Concatenate. They can make type hinting decorators pretty easy:

from typing import Concatenate, ParamSpec

P = ParamSpec('P')
R = TypeVar('R')

def with_context(f: Callable[Concatenate[Context, P], R]) -> Callable[P, R]:
    def inner(*args: P.args, **kwargs: P.kwargs) -> R:
        return f(context, *args, **kwargs)
    return inner

@with_context
def request(context: Context) -> int:
    return 42
Enter fullscreen mode Exit fullscreen mode

The sad thing is PEP-612 is not widely supported, as of now mypy does not fully support it.

Fin

May the type be with you.

Top comments (0)