If you're totally unfamiliar with type annotations in Python, my previous article should get you started.
In this post, I'm going to show you how to use type variables, or TypeVar
s, for fun and profit.
The problem
This function accepts anything as the argument and returns it as is. How do you explain to the type checker that the return type is the same as the type of arg
?
def identity(arg):
return arg
Why not use Any
?
def identity(arg: Any) -> Any:
return arg
If you use Any
, the type checker will not understand how this function works: as far as it's concerned, the function can return anything at all. The return type doesn't depend on the type of arg
.
We'd really want number
to be an int
here, so that the type checker will catch an error on the next line:
Why not specialize the function for different types?
def identity_int(arg: int) -> int:
return arg
def identity_str(arg: str) -> str:
return arg
def identity_list_str(arg: list[str]) -> list[str]:
return arg
...
This doesn't scale well. Are you going to replicate the same function 10 times? Will you remember to keep them in sync?
What if this is a library function? You won't be able to predict all the ways people will use this function.
The solution: type variables
Type variables allow you to link several types together. This is how you can use a type variable to annotate the identity
function:
from typing import TypeVar
T = TypeVar("T")
def identity(arg: T) -> T:
return arg
Here the return type is "linked" to the parameter type: whatever you put into the function, the same thing comes out.
This is how it looks in action (in VSCode with Pylance):
Putting constraints on a type variable
Is this a well-typed function?
def triple(string: Union[str, bytes]) -> Union[str, bytes]:
return string * 3
Not really: if you pass in a string, you always get a string, same with bytes. This will cause you some pain, because you know when you get a str
and when you get a bytes
back.
"If you pass in str
, you get str
. If you pass in bytes
, you get bytes
" -- sounds like a job for a type variable.
That's fair enough -- not all types support multiplication. We can put a restriction that our type variable should only accept str
or bytes
(and their subclasses, of course).
AnyString = TypeVar("AnyString", str, bytes)
def triple(string: AnyString) -> AnyString:
return string * 3
unicode_scream = triple("A") + "!"
bytes_scream = triple(b"A") + b"!"
Using type variables as parameters
You can also use type variables as parameters to generic types, like list
or Iterable
.
def remove_falsey_from_list(items: list[T]) -> list[T]:
return [item for item in items if item]
def remove_falsey(items: Iterable[T]) -> Iterator[T]:
for item in items:
if item:
yield item
Howver, this gets tricky pretty fast. I'll cover it in depth in the next article.
Links
-
mypy
documentation on generic functions: https://mypy.readthedocs.io/en/stable/generics.html#generic-functions
Top comments (3)
Thank you! Very simple and clear.
BTW, maybe mistype there:
def identity_int(arg: str) -> str:
(you wanted to make function
identity_str
?)Thanks
I find it worth mentioning that
AnyString
from this article is already defined astyping.AnyStr
.