From 8cdc7a24dba075360bba206ea52d0866a3de40ce Mon Sep 17 00:00:00 2001 From: Jeremy Howard Date: Sat, 11 May 2024 16:08:26 +1000 Subject: [PATCH] fixes #550 --- fastcore/all.py | 1 + fastcore/parallel.py | 22 ++++++++++++++-------- fastcore/xml.py | 1 + nbs/03a_parallel.ipynb | 22 ++++++++++++++-------- nbs/11_xml.ipynb | 1 + 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/fastcore/all.py b/fastcore/all.py index 7eac9eea..1a9735ce 100644 --- a/fastcore/all.py +++ b/fastcore/all.py @@ -9,3 +9,4 @@ from .meta import * from .imports import * from .script import * +from .xml import * diff --git a/fastcore/parallel.py b/fastcore/parallel.py index 9102d7c7..d944ef30 100644 --- a/fastcore/parallel.py +++ b/fastcore/parallel.py @@ -20,14 +20,20 @@ except: pass # %% ../nbs/03a_parallel.ipynb 4 -def threaded(f): - "Run `f` in a thread, and returns the thread" - @wraps(f) - def _f(*args, **kwargs): - res = Thread(target=f, args=args, kwargs=kwargs) - res.start() - return res - return _f +def threaded(process=False): + "Run `f` in a `Thread` (or `Process` if `process=True`), and returns it" + def _r(f): + @wraps(f) + def _f(*args, **kwargs): + res = (Thread,Process)[process](target=f, args=args, kwargs=kwargs) + res.start() + return res + return _f + if callable(process): + o = process + process = False + return _r(o) + return _r # %% ../nbs/03a_parallel.ipynb 6 def startthread(f): diff --git a/fastcore/xml.py b/fastcore/xml.py index 508e62db..577046c7 100644 --- a/fastcore/xml.py +++ b/fastcore/xml.py @@ -45,6 +45,7 @@ def xt(tag:str, *c, **kw): # %% ../nbs/11_xml.ipynb 9 def to_xml(elm, lvl=0): "Convert `xt` element tree into an XML string" + if hasattr(elm, '__xt__'): elm = elm.__xt__() sp = ' ' * lvl if not isinstance(elm, list): if isinstance(elm, str): elm = escape(elm) diff --git a/nbs/03a_parallel.ipynb b/nbs/03a_parallel.ipynb index 6cb51a47..0d90ec12 100644 --- a/nbs/03a_parallel.ipynb +++ b/nbs/03a_parallel.ipynb @@ -58,14 +58,20 @@ "outputs": [], "source": [ "#|export\n", - "def threaded(f):\n", - " \"Run `f` in a thread, and returns the thread\"\n", - " @wraps(f)\n", - " def _f(*args, **kwargs):\n", - " res = Thread(target=f, args=args, kwargs=kwargs)\n", - " res.start()\n", - " return res\n", - " return _f" + "def threaded(process=False):\n", + " \"Run `f` in a `Thread` (or `Process` if `process=True`), and returns it\"\n", + " def _r(f):\n", + " @wraps(f)\n", + " def _f(*args, **kwargs):\n", + " res = (Thread,Process)[process](target=f, args=args, kwargs=kwargs)\n", + " res.start()\n", + " return res\n", + " return _f\n", + " if callable(process):\n", + " o = process\n", + " process = False\n", + " return _r(o)\n", + " return _r" ] }, { diff --git a/nbs/11_xml.ipynb b/nbs/11_xml.ipynb index d289893c..8cc39b70 100644 --- a/nbs/11_xml.ipynb +++ b/nbs/11_xml.ipynb @@ -147,6 +147,7 @@ "#| export\n", "def to_xml(elm, lvl=0):\n", " \"Convert `xt` element tree into an XML string\"\n", + " if hasattr(elm, '__xt__'): elm = elm.__xt__()\n", " sp = ' ' * lvl\n", " if not isinstance(elm, list):\n", " if isinstance(elm, str): elm = escape(elm)\n",