Some time ago one of my java colleagues from work asked:
How do you do a flat map in python?
Well, there's no built-in flat_map
function. But it can be implemented in a few different ways. But instead of just discussing how one can do a pythonic flat map I want to focus on a more interesting question:
What is the most efficient way to do a flat map in python?
Definition
Let's start with a definition. A flat map is an operation that takes a list which elements have type A
and a function f
of type A -> [B]
. The function f
is then applied to each element of the initial list and then all the results are concatenated. So type of flat_map
is:
flat_map :: (t -> [a]) -> [t] -> [a]
I think showing an example is much simpler than describing it:
id = lambda x: x
flat_map(id, [[1,2], [3,4]]) == [1, 2, 3, 4]
s = lambda x: split(x, ";")
flat_map(s, ["a,b", "c,d"]) == ["a", "b", "c", "d"]
It's a really simple operation although, the general idea behind it is a marvelous one and is called a monad. As far as I know, there are four simple ways to implement flat_map
in Python:
- for loop and extend
- double list comprehension
- map and reduce
- map and sum
Measure
Which approach is the most efficient? To answer this question I am going to measure the time of execution of each implementation in 3 cases:
a) 100 lists each with 10 integers
b) 10 000 lists each with 10 integers
c) 10 000 lists each with 10 objects (class instances)
To do this I will use this simple check_time
function and identity as a function applied by the flat map:
import timeit
def id(x):
return x
def check_time(f, arg):
n = 50
return 1000 * (timeit.timeit(lambda: f(id, arg), number=n) / n) # ms
All code I used for this article is available as gist.
Map reduce
If there's a "most functional" answer then it's the combination of map
and reduce
(also known as fold). It requires an import from functools
(if you don't know this package or mentioned function check it!) so it's not a real one-liner.
from functools import reduce
flat_map = lambda f, xs: reduce(lambda a, b: a + b, map(f, xs))
Result:
a) 0.1230 ms
b) 1202.4405 ms
c) 3249.0119 ms
The map operation can be substituted with a list comprehension but that will result in no spectacular improvement.
Map sum
Another functional-like solution is to use sum
. I really like this approach, however, its performance degrades with the length of the input list. And when applied to nongeneric types it's awfully slow. The results are no better than previous ones.
flat_map = lambda f, xs: sum(map(f, xs), [])
Result:
a) 0.1080 ms
b) 1150.4400 ms
c) 3296.3167 ms
For and extend
So far our flat_map
implementations were quite good when applied to relatively small inputs and lousy on 10 000 length lists. Luckily this 'classic' implementation is up to 1000x faster! The only drawback of this solution is the number of lines it requires. But definitely it's a good price for truly improved performance.
def flat_map(f, xs):
ys = []
for x in xs:
ys.extend(f(x))
return ys
Result:
a) 0.0211 ms
b) 2.4535 ms
c) 2.7347 ms
The interesting thing is that difference between applying this flat_map
to int
and class
lists is no longer so distinctively big.
List comprehension
This is truly a pythonic approach (and also functional). The minus of this implementation is its clarity. I always have to check what is the proper order of for
in double comprehension. Of course, it's just a nit.
flat_map = lambda f, xs: [y for ys in xs for y in f(ys)]
Result:
a) 0.0372 ms
b) 4.1477 ms
c) 4.5945 ms
As we can see it performs really well! Slightly slower than previous implementation but list comprehension is one-liner! Someone may even suggest switching the list comprehension to generator expression like that:
flat_map = lambda f, xs: (y for ys in xs for y in f(ys))
Result:
a) 0.0008 ms
b) 0.0008 ms
c) 0.0007 ms
Even better, eh? Of course, using generator expression is only applicable when we are going to iterate over the result. So, to verify its real performance we should measure code similar to this one:
ys = flat_map(id, xs)
for y in ys:
pass
Then we get the following result
comprehension:
a) 0.0537 ms
b) 5.3700 ms
c) 5.7781 ms
generator:
a) 0.1512 ms
b) 13.0335 ms
c) 13.3203 ms
So the initial improvement was not truly one as nothing was executed in fact. This shows that when measuring performance we should always remember about the context in which the code is used.
Summary
Method | 100 x 10 int | 10000 x 10 int | 10000 x 10 class |
---|---|---|---|
map reduce | 0.1230 ms | 1202.4405 | 3249.0119 ms |
map sum | 0.1080 ms | 1150.4400 ms | 3296.3167 ms |
for + extend | 0.0211 ms | 2.4535 ms | 2.7347 ms ms |
double list comprehension | 0.0372 ms | 4.1477 ms | 4.5945 ms |
Simple problem, many solutions but only one winner... or two of them. I think it reasonable to grant the first place to both loop+extend and list comprehension. The first one is great when performance is your main concern. The second one is a really good choice when you need a handy one-liner.
So next time you will need to apply quick flat_map
in Python just use list comprehension. And most importantly, do not be tempted to use reduce
or sum
for lists that are relatively long!
Top comments (3)
Hi. I believe that the definition of
flat_map
via generators is wrong in the gist. It only accepts theid
function as the first argument. I also suggest writingfor x in xs
instead offor ys in xs
since elements ofxs
may be not lists.Nice post. I think the table is mixed up. It shows map sum as the fastest.
Thanks David! I've updated the table 🚀