diff --git a/neptune/internal/backends/hosted_neptune_backend.py b/neptune/internal/backends/hosted_neptune_backend.py index 6b0992dd0..8ecc39436 100644 --- a/neptune/internal/backends/hosted_neptune_backend.py +++ b/neptune/internal/backends/hosted_neptune_backend.py @@ -13,8 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +# pylint: disable=too-many-lines + import io +import json import logging +import math import os import platform import socket @@ -52,7 +57,7 @@ from neptune.notebook import Notebook from neptune.oauth import NeptuneAuthenticator from neptune.projects import Project -from neptune.utils import is_float, with_api_exceptions_handler, update_session_proxies +from neptune.utils import with_api_exceptions_handler, update_session_proxies from neptune.constants import ANONYMOUS, ANONYMOUS_API_TOKEN _logger = logging.getLogger(__name__) @@ -770,19 +775,31 @@ def _get_all_items(get_portion, step): return items + def _get_parameter_with_type(self, parameter): + string_type = 'string' + double_type = 'double' + if isinstance(parameter, bool): + return (string_type, str(parameter)) + elif isinstance(parameter, float) or isinstance(parameter, int): + if math.isinf(parameter) or math.isnan(parameter): + return (string_type, json.dumps(parameter)) + else: + return (double_type, str(parameter)) + else: + return (string_type, str(parameter)) + def _convert_to_api_parameters(self, raw_params): Parameter = self.backend_swagger_client.get_model('Parameter') params = [] for name, value in raw_params.items(): - parameter_type = 'double' if is_float(str(value)) and not isinstance(value, six.string_types) else 'string' - + (parameter_type, string_value) = self._get_parameter_with_type(value) params.append( Parameter( id=str(uuid.uuid4()), name=name, parameterType=parameter_type, - value=str(value) + value=string_value ) ) diff --git a/tests/neptune/internal/backends/test_hosted_neptune_backend.py b/tests/neptune/internal/backends/test_hosted_neptune_backend.py index bb104584c..43fd19bd2 100644 --- a/tests/neptune/internal/backends/test_hosted_neptune_backend.py +++ b/tests/neptune/internal/backends/test_hosted_neptune_backend.py @@ -98,6 +98,9 @@ def test_convert_to_api_parameters(self, uuid4, swagger_client_factory): 'bool': False, 'float': 1.23, 'int': int(12), + 'inf': float('inf'), + '-inf': float('-inf'), + 'nan': float('nan'), 'list': [123, 'abc', ['def']], 'obj': some_object }) @@ -108,6 +111,9 @@ def test_convert_to_api_parameters(self, uuid4, swagger_client_factory): ApiParameter(id=some_uuid, name='bool', parameterType='string', value='False'), ApiParameter(id=some_uuid, name='float', parameterType='double', value='1.23'), ApiParameter(id=some_uuid, name='int', parameterType='double', value='12'), + ApiParameter(id=some_uuid, name='inf', parameterType='string', value='Infinity'), + ApiParameter(id=some_uuid, name='-inf', parameterType='string', value='-Infinity'), + ApiParameter(id=some_uuid, name='nan', parameterType='string', value='NaN'), ApiParameter(id=some_uuid, name='list', parameterType='string', value="[123, 'abc', ['def']]"), ApiParameter(id=some_uuid, name='obj', parameterType='string', value=str(some_object)) }