diff --git a/memory_profiler.py b/memory_profiler.py index 25477f1..1da978d 100644 --- a/memory_profiler.py +++ b/memory_profiler.py @@ -19,6 +19,7 @@ import traceback import warnings import contextlib +import asyncio if sys.platform == "win32": # any value except signal.CTRL_C_EVENT and signal.CTRL_BREAK_EVENT @@ -683,14 +684,21 @@ def add_function(self, func): def wrap_function(self, func): """ Wrap a function to profile it. """ - + if asyncio.iscoroutinefunction(func): + async def f(*args, **kwds): + self.enable_by_count() + try: + return await func(*args, **kwds) + finally: + self.disable_by_count() + return f + def f(*args, **kwds): self.enable_by_count() try: return func(*args, **kwds) finally: self.disable_by_count() - return f def runctx(self, cmd, globals, locals): @@ -1098,14 +1106,24 @@ def profile(func=None, stream=None, precision=1, backend='psutil'): if not tracemalloc.is_tracing(): tracemalloc.start() if func is not None: - @wraps(func) - def wrapper(*args, **kwargs): - prof = LineProfiler(backend=backend) - val = prof(func)(*args, **kwargs) - show_results(prof, stream=stream, precision=precision) - return val - - return wrapper + if not asyncio.iscoroutinefunction(func): + @wraps(func) + def wrapper(*args, **kwargs): + prof = LineProfiler(backend=backend) + val = prof(func)(*args, **kwargs) + show_results(prof, stream=stream, precision=precision) + + return val + return wrapper + else: + @wraps(func) + async def wrapper(*args, **kwargs): + prof = LineProfiler(backend=backend) + val = await prof(func)(*args, **kwargs) + show_results(prof, stream=stream, precision=precision) + + return val + return wrapper else: def inner_wrapper(f): return profile(f, stream=stream, precision=precision,