Skip to content

Commit

Permalink
Merge pull request #245 from ericpre/fix_xdist_download_test_data
Browse files Browse the repository at this point in the history
Fix download test data when using `pytest --pyargs rsciio -n`
  • Loading branch information
jlaehne authored Mar 31, 2024
2 parents 4cc6d4f + 8de9d5a commit a0dbed3
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/package_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ jobs:
# "github.event.pull_request.head.repo.full_name" is for "pull request" event while github.repository is for "push" event
# "github.event.pull_request.head.ref" is for "pull request" event while "github.ref_name" is for "push" event
POOCH_BASE_URL: https://github.com/${{ github.event.pull_request.head.repo.full_name || github.repository }}/raw/${{ github.event.pull_request.head.ref || github.ref_name }}/rsciio/tests/data/
# "-s" is used to show of output when downloading the test files
PYTEST_ARGS: "-n 2"
1 change: 1 addition & 0 deletions conda_environment_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ dependencies:
- pytest-rerunfailures
- hyperspy-base
- setuptools-scm
- filelock
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ tiff = ["tifffile>=2020.2.16", "imagecodecs>=2020.1.31"]
usid = ["pyUSID", "sidpy<=0.12.0"]
zspy = ["zarr", "msgpack"]
tests = [
"filelock",
"pooch",
"pytest>=3.6",
"pytest-xdist",
Expand Down
45 changes: 34 additions & 11 deletions rsciio/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# You should have received a copy of the GNU General Public License
# along with RosettaSciIO. If not, see <https://www.gnu.org/licenses/#GPL>.

import os
import json
from packaging.version import Version

from rsciio.tests.registry_utils import download_all
from filelock import FileLock
import pytest


try:
import hyperspy
Expand All @@ -33,13 +35,34 @@
pass


def pytest_configure(config):
# Run in pytest_configure hook to avoid capturing stdout by pytest and
# inform user that the test data are being downloaded
# From https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
@pytest.fixture(scope="session", autouse=True)
def session_data(request, tmp_path_factory, worker_id):
capmanager = request.config.pluginmanager.getplugin("capturemanager")

def _download_test_data():
from rsciio.tests.registry_utils import download_all

with capmanager.global_and_fixture_disabled():
print("Checking if test data need downloading...")
download_all()
print("All test data available.")

return "Test data available"

if worker_id == "master":
# not executing in with multiple workers, just produce the data and let
# pytest's fixture caching do its job
return _download_test_data()

# get the temp directory shared by all workers
root_tmp_dir = tmp_path_factory.getbasetemp().parent

# Workaround to avoid running it for each worker
worker_id = os.environ.get("PYTEST_XDIST_WORKER")
if worker_id is None:
print("Checking if test data need downloading...")
download_all()
print("All test data available.")
fn = root_tmp_dir / "data.json"
with FileLock(str(fn) + ".lock"):
if fn.is_file():
data = json.loads(fn.read_text())
else:
data = _download_test_data()
fn.write_text(json.dumps(data))
return data
25 changes: 13 additions & 12 deletions rsciio/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
dt = [("x", np.uint8), ("y", np.uint16), ("text", (bytes, 6))]


MY_PATH = Path(__file__).parent
TEST_XML_PATH = MY_PATH / ".." / "data" / "ToastedBreakFastSDD.xml"
@pytest.fixture
def XML_TEST_NODE():
MY_PATH = Path(__file__).parent
TEST_XML_PATH = MY_PATH / ".." / "data" / "ToastedBreakFastSDD.xml"

with open(TEST_XML_PATH, "r") as fn:
weird_but_valid_xml_str = fn.read()

with open(TEST_XML_PATH, "r") as fn:
weird_but_valid_xml_str = fn.read()


XML_TEST_NODE = ET.fromstring(weird_but_valid_xml_str)
yield ET.fromstring(weird_but_valid_xml_str)


# fmt: off
Expand All @@ -42,7 +42,7 @@ def test_msxml_sanitization():
assert et[3].text == "0,2,3" # is not float


def test_default_x2d():
def test_default_x2d(XML_TEST_NODE):
"""test of default XmlToDict translation with attributes prefixed with @,
interchild_text_parsing set to 'first',
no flattening tags set, and dub_text_str set to '#value'
Expand All @@ -59,7 +59,7 @@ def test_default_x2d():
assert pynode["TestXML"]["Main"]["ClassInstance"]["Sample"]["#value"] == t


def test_skip_interchild_text_flatten():
def test_skip_interchild_text_flatten(XML_TEST_NODE):
"""test of XmlToDict translation with interchild_text_parsing set to 'skip',
three string containing list set to flattening tags. Other kwrds - default.
"""
Expand All @@ -72,7 +72,7 @@ def test_skip_interchild_text_flatten():
assert pynode["Main"]["Sample"].get("#value") is None


def test_concat_interchild_text_val_flatten():
def test_concat_interchild_text_val_flatten(XML_TEST_NODE):
"""test of XmlToDict translator with interchild_text_parsing set to
'cat' (concatenation), four flattening tags set, and dub_text_str set
to '#text'
Expand All @@ -91,7 +91,7 @@ def test_concat_interchild_text_val_flatten():
assert pynode["Sample"]["#interchild_text"] == t


def test_list_interchild_text_val_flatten():
def test_list_interchild_text_val_flatten(XML_TEST_NODE):
"""test of XmlToDict translator interchild_text_parsing set to 'list'
"""
x2d = XmlToDict(
Expand All @@ -107,7 +107,7 @@ def test_list_interchild_text_val_flatten():
]


def x2d_subclass_for_custom_bool():
def x2d_subclass_for_custom_bool(XML_TEST_NODE):
"""test subclass of XmlToDict with updated eval function"""

class CustomXmlToDict(XmlToDict):
Expand Down Expand Up @@ -390,6 +390,7 @@ def test_get_chunk_slice(shape):
assert chunk_arr.shape == (1,)*len(shape)+(len(shape), 2)
assert chunk == tuple([(i,)for i in shape])


@pytest.mark.parametrize("shape", ((10, 20, 30, 512, 512),(20, 30, 512, 512), (10, 512, 512), (512, 512)))
def test_get_chunk_slice(shape):
chunks =(1,)*(len(shape)-2) +(-1,-1)
Expand Down
1 change: 1 addition & 0 deletions upcoming_changes/245.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix download test data when using ``pytest --pyargs rsciio -n``.

0 comments on commit a0dbed3

Please sign in to comment.