diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index 924ce16b2f..11ad2240d3 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -6,6 +6,7 @@ import os from urllib.parse import urlparse +import importlib_metadata import pytest import testbook from testbook.client import TestbookNotebookClient @@ -21,6 +22,16 @@ for nb in glob.glob(os.path.join(nb_root, '*.ipynb')) \ ] +try: + importlib_metadata.files('mosaicml') + package_name = 'mosaicml' +except importlib_metadata.PackageNotFoundError: + try: + importlib_metadata.files('composer') + package_name = 'composer' + except importlib_metadata.PackageNotFoundError: + raise RuntimeError('Could not find the package under mosaicml or composer.') + def patch_notebooks(): import itertools @@ -86,7 +97,6 @@ def modify_cell_source(tb: TestbookNotebookClient, notebook_name: str, cell_sour cell_source = cell_source.replace('batch_size=256', 'batch_size=64') cell_source = cell_source.replace('download=True', 'download=False') - package_name = os.environ.get('COMPOSER_PACKAGE_NAME', 'mosaicml') cell_source = cell_source.replace("pip install 'mosaicml", f"pip install '{package_name}") cell_source = cell_source.replace('pip install mosaicml', f'pip install {package_name}')