diff --git a/copulas/__init__.py b/copulas/__init__.py index 00ae675c..8a2f9c94 100644 --- a/copulas/__init__.py +++ b/copulas/__init__.py @@ -8,14 +8,14 @@ import contextlib import importlib +import sys +import warnings from copy import deepcopy +from operator import attrgetter import numpy as np import pandas as pd - -from copulas._addons import _find_addons - -_find_addons(group='copulas_modules', parent_globals=globals()) +from pkg_resources import iter_entry_points EPSILON = np.finfo(np.float32).eps @@ -262,3 +262,71 @@ def decorated(self, X, *args, **kwargs): return function(self, X, *args, **kwargs) return decorated + + +def _get_addon_target(addon_path_name): + """Find the target object for the add-on. + + Args: + addon_path_name (str): + The add-on's name. The add-on's name should be the full path of valid Python + identifiers (i.e. importable.module:object.attr). + + Returns: + tuple: + * object: + The base module or object the add-on should be added to. + * str: + The name the add-on should be added to under the module or object. + """ + module_path, _, object_path = addon_path_name.partition(':') + module_path = module_path.split('.') + + if module_path[0] != __name__: + msg = f"expected base module to be '{__name__}', found '{module_path[0]}'" + raise AttributeError(msg) + + target_base = sys.modules[__name__] + for submodule in module_path[1:-1]: + target_base = getattr(target_base, submodule) + + addon_name = module_path[-1] + if object_path: + if len(module_path) > 1 and not hasattr(target_base, module_path[-1]): + msg = f"cannot add '{object_path}' to unknown submodule '{'.'.join(module_path)}'" + raise AttributeError(msg) + + if len(module_path) > 1: + target_base = getattr(target_base, module_path[-1]) + + split_object = object_path.split('.') + addon_name = split_object[-1] + + if len(split_object) > 1: + target_base = attrgetter('.'.join(split_object[:-1]))(target_base) + + return target_base, addon_name + + +def _find_addons(): + """Find and load all copulas add-ons.""" + group = 'copulas_modules' + for entry_point in iter_entry_points(group=group): + try: + addon = entry_point.load() + except Exception: # pylint: disable=broad-exception-caught + msg = f'Failed to load "{entry_point.name}" from "{entry_point.module_name}".' + warnings.warn(msg) + continue + + try: + addon_target, addon_name = _get_addon_target(entry_point.name) + except AttributeError as error: + msg = f"Failed to set '{entry_point.name}': {error}." + warnings.warn(msg) + continue + + setattr(addon_target, addon_name, addon) + + +_find_addons() diff --git a/copulas/_addons.py b/copulas/_addons.py deleted file mode 100644 index 3f2f1761..00000000 --- a/copulas/_addons.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Copulas add-ons functionality.""" -import warnings - -from pkg_resources import iter_entry_points - - -def _find_addons(group, parent_globals): - """Find and load add-ons based on the given group. - - Args: - group (str): - The name of the entry points group to load. - parent_globals (dict): - The caller's global scope. Modules will be added - to the parent's global scope through their name. - """ - for entry_point in iter_entry_points(group=group): - try: - module = entry_point.load() - except Exception: - msg = f'Failed to load "{entry_point.name}" from "{entry_point.module}".' - warnings.warn(msg) - continue - - if entry_point.name not in parent_globals: - parent_globals[entry_point.name] = module diff --git a/tests/unit/test___init__.py b/tests/unit/test___init__.py index 761e35cf..4d88f5f0 100644 --- a/tests/unit/test___init__.py +++ b/tests/unit/test___init__.py @@ -1,3 +1,4 @@ +import sys from unittest import TestCase from unittest.mock import MagicMock, call, patch @@ -6,8 +7,10 @@ import pytest from numpy.testing import assert_array_equal +import copulas from copulas import ( - check_valid_values, get_instance, random_state, scalarize, validate_random_state, vectorize) + _find_addons, check_valid_values, get_instance, random_state, scalarize, validate_random_state, + vectorize) from copulas.multivariate import GaussianMultivariate @@ -421,3 +424,152 @@ def test_get_instance_with_kwargs(self): assert not instance.fitted assert isinstance(instance, GaussianMultivariate) assert instance.distribution == 'copulas.univariate.truncnorm.TruncNorm' + + +@pytest.fixture() +def mock_copulas(): + copulas_module = sys.modules['copulas'] + copulas_mock = MagicMock() + sys.modules['copulas'] = copulas_mock + yield copulas_mock + sys.modules['copulas'] = copulas_module + + +@patch.object(copulas, 'iter_entry_points') +def test__find_addons_module(entry_points_mock, mock_copulas): + """Test loading an add-on.""" + # Setup + entry_point = MagicMock() + entry_point.name = 'copulas.submodule.entry_name' + entry_point.load.return_value = 'entry_point' + entry_points_mock.return_value = [entry_point] + + # Run + _find_addons() + + # Assert + entry_points_mock.assert_called_once_with(group='copulas_modules') + assert mock_copulas.submodule.entry_name == 'entry_point' + + +@patch.object(copulas, 'iter_entry_points') +def test__find_addons_object(entry_points_mock, mock_copulas): + """Test loading an add-on.""" + # Setup + entry_point = MagicMock() + entry_point.name = 'copulas.submodule:entry_object.entry_method' + entry_point.load.return_value = 'new_method' + entry_points_mock.return_value = [entry_point] + + # Run + _find_addons() + + # Assert + entry_points_mock.assert_called_once_with(group='copulas_modules') + assert mock_copulas.submodule.entry_object.entry_method == 'new_method' + + +@patch('warnings.warn') +@patch('copulas.iter_entry_points') +def test__find_addons_bad_addon(entry_points_mock, warning_mock): + """Test failing to load an add-on generates a warning.""" + # Setup + def entry_point_error(): + raise ValueError() + + bad_entry_point = MagicMock() + bad_entry_point.name = 'bad_entry_point' + bad_entry_point.module_name = 'bad_module' + bad_entry_point.load.side_effect = entry_point_error + entry_points_mock.return_value = [bad_entry_point] + msg = 'Failed to load "bad_entry_point" from "bad_module".' + + # Run + _find_addons() + + # Assert + entry_points_mock.assert_called_once_with(group='copulas_modules') + warning_mock.assert_called_once_with(msg) + + +@patch('warnings.warn') +@patch('copulas.iter_entry_points') +def test__find_addons_wrong_base(entry_points_mock, warning_mock): + """Test incorrect add-on name generates a warning.""" + # Setup + bad_entry_point = MagicMock() + bad_entry_point.name = 'bad_base.bad_entry_point' + entry_points_mock.return_value = [bad_entry_point] + msg = ( + "Failed to set 'bad_base.bad_entry_point': expected base module to be 'copulas', found " + "'bad_base'." + ) + + # Run + _find_addons() + + # Assert + entry_points_mock.assert_called_once_with(group='copulas_modules') + warning_mock.assert_called_once_with(msg) + + +@patch('warnings.warn') +@patch('copulas.iter_entry_points') +def test__find_addons_missing_submodule(entry_points_mock, warning_mock): + """Test incorrect add-on name generates a warning.""" + # Setup + bad_entry_point = MagicMock() + bad_entry_point.name = 'copulas.missing_submodule.new_submodule' + entry_points_mock.return_value = [bad_entry_point] + msg = ( + "Failed to set 'copulas.missing_submodule.new_submodule': module 'copulas' has no " + "attribute 'missing_submodule'." + ) + + # Run + _find_addons() + + # Assert + entry_points_mock.assert_called_once_with(group='copulas_modules') + warning_mock.assert_called_once_with(msg) + + +@patch('warnings.warn') +@patch('copulas.iter_entry_points') +def test__find_addons_module_and_object(entry_points_mock, warning_mock): + """Test incorrect add-on name generates a warning.""" + # Setup + bad_entry_point = MagicMock() + bad_entry_point.name = 'copulas.missing_submodule:new_object' + entry_points_mock.return_value = [bad_entry_point] + msg = ( + "Failed to set 'copulas.missing_submodule:new_object': cannot add 'new_object' to unknown " + "submodule 'copulas.missing_submodule'." + ) + + # Run + _find_addons() + + # Assert + entry_points_mock.assert_called_once_with(group='copulas_modules') + warning_mock.assert_called_once_with(msg) + + +@patch('warnings.warn') +@patch.object(copulas, 'iter_entry_points') +def test__find_addons_missing_object(entry_points_mock, warning_mock, mock_copulas): + """Test incorrect add-on name generates a warning.""" + # Setup + bad_entry_point = MagicMock() + bad_entry_point.name = 'copulas.submodule:missing_object.new_method' + entry_points_mock.return_value = [bad_entry_point] + msg = ("Failed to set 'copulas.submodule:missing_object.new_method': missing_object.") + + del mock_copulas.submodule.missing_object + + # Run + _find_addons() + + # Assert + entry_points_mock.assert_called_once_with(group='copulas_modules') + warning_mock.assert_called_once_with(msg) diff --git a/tests/unit/test__addons.py b/tests/unit/test__addons.py deleted file mode 100644 index b02a5104..00000000 --- a/tests/unit/test__addons.py +++ /dev/null @@ -1,45 +0,0 @@ -from unittest.mock import Mock, patch - -from copulas._addons import _find_addons - - -@patch('copulas._addons.iter_entry_points') -def test__find_versions(entry_points_mock): - """Test loading an add-on.""" - # Setup - entry_point = Mock() - entry_point.name = 'entry_name' - entry_point.load.return_value = 'entry_point' - entry_points_mock.return_value = [entry_point] - test_dict = {} - - # Run - _find_addons(group='group', parent_globals=test_dict) - - # Assert - entry_points_mock.assert_called_once_with(group='group') - assert test_dict['entry_name'] == 'entry_point' - - -@patch('copulas._addons.warnings.warn') -@patch('copulas._addons.iter_entry_points') -def test__find_versions_bad_addon(entry_points_mock, warning_mock): - """Test failing to load an add-on generates a warning.""" - # Setup - def entry_point_error(): - raise ValueError() - - bad_entry_point = Mock() - bad_entry_point.name = 'bad_entry_point' - bad_entry_point.module = 'bad_module' - bad_entry_point.load.side_effect = entry_point_error - entry_points_mock.return_value = [bad_entry_point] - test_dict = {} - msg = 'Failed to load "bad_entry_point" from "bad_module".' - - # Run - _find_addons(group='group', parent_globals=test_dict) - - # Assert - entry_points_mock.assert_called_once_with(group='group') - warning_mock.assert_called_once_with(msg)