Skip to content

Latest commit

 

History

History
369 lines (234 loc) · 6.3 KB

Itertools for efficient looping.md

File metadata and controls

369 lines (234 loc) · 6.3 KB

Python Itertools - 高效的循环

作者: tushushu
项目地址: https://github.com/tushushu/flying-python

Python官方文档用"高效的循环"来形容itertools模块,有些tools会带来性能提升,而另外一些tools并不快,只是会节省一些开发时间而已,如果滥用还会导致代码可读性变差。我们不妨把itertools的兄弟们拉出来溜溜。

1. 数列累加

给定一个列表An,返回数列累加和Sn。 举例说明:

  • 输入: [1, 2, 3, 4, 5]
  • 返回: [1, 3, 6, 10, 15]

使用accumulate,性能提升了2.5倍

from itertools import accumulate
def _accumulate_list(arr):
    tot = 0
    for x in arr:
        tot += x
        yield tot

def accumulate_list(arr):
    return list(_accumulate_list(arr))
def fast_accumulate_list(arr):
    return list(accumulate(arr))
arr = list(range(1000))
%timeit accumulate_list(arr)
61 µs ± 2.91 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit fast_accumulate_list(arr)
21.3 µs ± 811 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

2. 选择数据

给定一个列表data,一个用0/1表示的列表selectors,返回被选择的数据。 举例说明:

  • 输入: [1, 2, 3, 4, 5], [0, 1, 0, 1, 0]
  • 返回: [2, 4]

使用compress,性能提升了2.8倍

from itertools import compress
from random import randint
def select_data(data, selectors):
    return [x for x, y in zip(data, selectors) if y]
def fast_select_data(data, selectors):
    return list(compress(data, selectors))
data = list(range(10000))
selectors = [randint(0, 1) for _ in range(10000)]
%timeit select_data(data, selectors)
341 µs ± 17.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fast_select_data(data, selectors)
130 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

3. 组合

给定一个列表arr和一个数字k,返回从arr中选择k个元素的所有情况。 举例说明:

  • 输入: [1, 2, 3], 2
  • 返回: [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]

使用permutations,性能提升了10倍

from itertools import permutations
def _get_permutations(arr, k, i):
    if i == k:
        return [arr[:k]]
    res = []
    for j in range(i, len(arr)):
        arr_cpy = arr.copy()
        arr_cpy[i], arr_cpy[j] = arr_cpy[j], arr_cpy[i]
        res += _get_permutations(arr_cpy, k, i + 1)
    return res
    
def get_permutations(arr, k):
    return _get_permutations(arr, k, 0)
def fast_get_permutations(arr, k):
    return list(permutations(arr, k))
arr = list(range(10))
k = 5
%timeit -n 1 get_permutations(arr, k)
15.5 ms ± 1.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit -n 1 fast_get_permutations(arr, k)
1.56 ms ± 284 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

4. 筛选数据

给定一个列表arr,筛选出所有的偶数。 举例说明:

  • 输入: [3, 1, 4, 5, 9, 2]
  • 返回: [(4, 2]

使用filterfalse,性能反而会变慢,所以不要迷信itertools。

from itertools import filterfalse
def get_even_nums(arr):
    return [x for x in arr if x % 2 == 0]
def fast_get_even_nums(arr):
    return list(filterfalse(lambda x: x % 2, arr))
arr = list(range(10000))
%timeit get_even_nums(arr)
417 µs ± 18.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fast_get_even_nums(arr)
823 µs ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

5. 条件终止

给定一个列表arr,依次对列表的所有数字进行求和,若遇到某个元素大于target之后则终止求和,返回这个和。 举例说明:

  • 输入: [1, 2, 3, 4, 5], 3
  • 返回: 6 (4 > 3,终止)

使用takewhile,性能反而会变慢,所以不要迷信itertools。

from itertools import takewhile
def cond_sum(arr, target):
    res = 0
    for x in arr:
        if x > target:
            break
        res += x
    return res
def fast_cond_sum(arr, target):
    return sum(takewhile(lambda x: x <= target, arr))
arr = list(range(10000))
target = 5000
%timeit cond_sum(arr, target)
245 µs ± 11.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fast_cond_sum(arr, target)
404 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

6. 循环嵌套

给定列表arr1,arr2,返回两个列表的所有元素两两相加的和。 举例说明:

  • 输入: [1, 2], [4, 5]
  • 返回: [1 + 4, 1 + 5, 2 + 4, 2 + 5]

使用product,性能提升了1.25倍。

from itertools import product
def _cross_sum(arr1, arr2):
    for x in arr1:
        for y in arr2:
            yield x + y

def cross_sum(arr1, arr2):
    return list(_cross_sum(arr1, arr2))
def fast_cross_sum(arr1, arr2):
    return [x + y for x, y in product(arr1, arr2)]
arr1 = list(range(100))
arr2 = list(range(100))
%timeit cross_sum(arr1, arr2)
484 µs ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fast_cross_sum(arr1, arr2)
373 µs ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

7. 二维列表转一维列表

给定二维列表arr,转为一维列表 举例说明:

  • 输入: [[1, 2], [3, 4]]
  • 返回: [1, 2, 3, 4]

使用chain,性能提升了6倍。

from itertools import chain
def _flatten(arr2d):
    for arr in arr2d:
        for x in arr:
            yield x

def flatten(arr2d):
    return list(_flatten(arr2d))
def fast_flatten(arr2d):
    return list(chain(*arr2d))
arr2d = [[x + y * 100 for x in range(100)] for y in range(100)]
%timeit flatten(arr2d)
379 µs ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fast_flatten(arr2d)
66.9 µs ± 3.43 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)