From aaeb8582f494c9e2648302c48f864dee0d710457 Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Fri, 11 Oct 2024 21:09:34 -0700 Subject: [PATCH] clean up parallel async --- fastcore/parallel.py | 5 +---- nbs/03a_parallel.ipynb | 34 ++++++++++++++++++++++------------ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/fastcore/parallel.py b/fastcore/parallel.py index 683aa071..fbaa5e92 100644 --- a/fastcore/parallel.py +++ b/fastcore/parallel.py @@ -141,13 +141,11 @@ def _add_one(x, a=1): return x+a # %% ../nbs/03a_parallel.ipynb -async def parallel_async(f, items, *args, n_workers=16, total=None, +async def parallel_async(f, items, *args, n_workers=16, timeout=None, chunksize=1, on_exc=print, **kwargs): "Applies `f` to `items` in parallel using asyncio and a semaphore to limit concurrency." import asyncio - if n_workers is None: n_workers = defaults.cpus semaphore = asyncio.Semaphore(n_workers) - results = [] async def limited_task(item): coro = f(item, *args, **kwargs) if asyncio.iscoroutinefunction(f) else asyncio.to_thread(f, item, *args, **kwargs) @@ -155,7 +153,6 @@ async def limited_task(item): return await asyncio.wait_for(coro, timeout) if timeout else await coro tasks = [limited_task(item) for item in items] - if total is None: total = len(items) return asyncio.gather(*tasks) # %% ../nbs/03a_parallel.ipynb diff --git a/nbs/03a_parallel.ipynb b/nbs/03a_parallel.ipynb index b3c9e53f..d18a7a51 100644 --- a/nbs/03a_parallel.ipynb +++ b/nbs/03a_parallel.ipynb @@ -363,7 +363,7 @@ "\n", "> ProcessPoolExecutor (max_workers=8, on_exc=,\n", "> pause=0, mp_context=None, initializer=None,\n", - "> initargs=(), max_tasks_per_child=None)\n", + "> initargs=())\n", "\n", "*Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution*" ], @@ -376,7 +376,7 @@ "\n", "> ProcessPoolExecutor (max_workers=8, on_exc=,\n", "> pause=0, mp_context=None, initializer=None,\n", - "> initargs=(), max_tasks_per_child=None)\n", + "> initargs=())\n", "\n", "*Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution*" ] @@ -479,11 +479,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "0 2024-10-12 13:30:21.217649\n", - "1 2024-10-12 13:30:21.469191\n", - "2 2024-10-12 13:30:21.721034\n", - "3 2024-10-12 13:30:21.972793\n", - "4 2024-10-12 13:30:22.223159\n" + "0 2024-10-11 21:08:04.678835\n", + "1 2024-10-11 21:08:04.930711\n", + "2 2024-10-11 21:08:05.181549\n", + "3 2024-10-11 21:08:05.435812\n", + "4 2024-10-11 21:08:05.687301\n" ] } ], @@ -527,13 +527,11 @@ "outputs": [], "source": [ "#|export\n", - "async def parallel_async(f, items, *args, n_workers=16, total=None,\n", + "async def parallel_async(f, items, *args, n_workers=16,\n", " timeout=None, chunksize=1, on_exc=print, **kwargs):\n", " \"Applies `f` to `items` in parallel using asyncio and a semaphore to limit concurrency.\"\n", " import asyncio\n", - " if n_workers is None: n_workers = defaults.cpus\n", " semaphore = asyncio.Semaphore(n_workers)\n", - " results = []\n", "\n", " async def limited_task(item):\n", " coro = f(item, *args, **kwargs) if asyncio.iscoroutinefunction(f) else asyncio.to_thread(f, item, *args, **kwargs)\n", @@ -541,7 +539,6 @@ " return await asyncio.wait_for(coro, timeout) if timeout else await coro\n", "\n", " tasks = [limited_task(item) for item in items]\n", - " if total is None: total = len(items)\n", " return asyncio.gather(*tasks)" ] }, @@ -558,7 +555,20 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2 2024-10-11 21:08:23.196473 0.28277557322761593\n", + "1 2024-10-11 21:08:23.258949 0.3449957623467014\n", + "0 2024-10-11 21:08:23.341502 0.4269314183522479\n", + "5 2024-10-11 21:08:23.665180 0.3214516296250627\n", + "3 2024-10-11 21:08:23.814849 0.6164311736352199\n", + "4 2024-10-11 21:08:23.864115 0.6032399559625771\n" + ] + } + ], "source": [ "async def print_time_async(i): \n", " wait = random.random()\n",