diff --git a/charmhelpers/osplatform.py b/charmhelpers/osplatform.py index 1ace468f7..28e407bb1 100644 --- a/charmhelpers/osplatform.py +++ b/charmhelpers/osplatform.py @@ -9,19 +9,13 @@ def get_platform(): will be returned (which is the name of the module). This string is used to decide which platform module should be imported. """ - # linux_distribution is deprecated and will be removed in Python 3.7 - # Warnings *not* disabled, as we certainly need to fix this. - if hasattr(platform, 'linux_distribution'): - tuple_platform = platform.linux_distribution() - current_platform = tuple_platform[0] - else: - current_platform = _get_platform_from_fs() + current_platform = _get_current_platform() if "Ubuntu" in current_platform: return "ubuntu" elif "CentOS" in current_platform: return "centos" - elif "debian" in current_platform: + elif "debian" in current_platform or "Debian" in current_platform: # Stock Python does not detect Ubuntu and instead returns debian. # Or at least it does in some build environments like Travis CI return "ubuntu" @@ -36,6 +30,24 @@ def get_platform(): .format(current_platform)) +def _get_current_platform(): + """Return the current platform information for the OS. + + Attempts to lookup linux distribution information from the platform + module for releases of python < 3.7. For newer versions of python, + the platform is determined from the /etc/os-release file. + """ + # linux_distribution is deprecated and will be removed in Python 3.7 + # Warnings *not* disabled, as we certainly need to fix this. + if hasattr(platform, 'linux_distribution'): + tuple_platform = platform.linux_distribution() + current_platform = tuple_platform[0] + else: + current_platform = _get_platform_from_fs() + + return current_platform + + def _get_platform_from_fs(): """Get Platform from /etc/os-release.""" with open(os.path.join(os.sep, 'etc', 'os-release')) as fin: @@ -47,3 +59,4 @@ def _get_platform_from_fs(): for k, v in content.items(): content[k] = v.strip('"') return content["NAME"] + diff --git a/tests/test_osplatform.py b/tests/test_osplatform.py new file mode 100644 index 000000000..f76049bd5 --- /dev/null +++ b/tests/test_osplatform.py @@ -0,0 +1,75 @@ +# +# Copyright (C) 2024 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +try: + import unittest.mock as mock +except ImportError: + import mock + +from charmhelpers import osplatform + + +class TestPlatform(unittest.TestCase): + + @mock.patch.object(osplatform, "_get_current_platform") + def test_get_platform_ubuntu(self, _platform): + _platform.return_value = "Ubuntu" + self.assertEqual("ubuntu", osplatform.get_platform()) + + @mock.patch.object(osplatform, "_get_current_platform") + def test_get_platform_centos(self, _platform): + _platform.return_value = "CentOS" + self.assertEqual("centos", osplatform.get_platform()) + + @mock.patch.object(osplatform, "_get_current_platform") + def test_get_platform_debian(self, _platform): + _platform.return_value = "debian gnu/linux" + self.assertEqual("ubuntu", osplatform.get_platform()) + + _platform.return_value = "Debian GNU/Linux" + self.assertEqual("ubuntu", osplatform.get_platform()) + + @mock.patch.object(osplatform, "_get_current_platform") + def test_get_platform_elementary(self, _platform): + _platform.return_value = "elementary linux" + self.assertEqual("ubuntu", osplatform.get_platform()) + + @mock.patch.object(osplatform, "_get_current_platform") + def test_get_platform_pop_os(self, _platform): + _platform.return_value = "Pop!_OS" + self.assertEqual("ubuntu", osplatform.get_platform()) + + @mock.patch.object(osplatform, "_get_current_platform") + def test_get_platform_unknown(self, _platform): + _platform.return_value = "crazy custom flavor" + self.assertRaises(RuntimeError, osplatform.get_platform) + + @mock.patch.object(osplatform, "_get_platform_from_fs") + @mock.patch.object(osplatform, "platform") + def test_get_current_platform_module(self, _platform, _platform_from_fs): + _platform.linux_distribution.return_value = ("Ubuntu", "test") + self.assertEqual("Ubuntu", osplatform._get_current_platform()) + _platform_from_fs.assert_not_called() + + @mock.patch.object(osplatform, "_get_platform_from_fs") + @mock.patch.object(osplatform, "platform") + def test_get_current_platform_fs(self, _platform, _platform_from_fs): + # make sure hasattr says False + del _platform.linux_distribution + _platform_from_fs.return_value = "foobar" + self.assertEqual("foobar", osplatform._get_current_platform()) +