Skip to content

Commit

Permalink
Fixes treatment of inf and nan value in experiment params
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrJander committed Sep 30, 2020
1 parent d8731e9 commit 769a10f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
25 changes: 21 additions & 4 deletions neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pylint: disable=too-many-lines

import io
import json
import logging
import math
import os
import platform
import socket
Expand Down Expand Up @@ -52,7 +57,7 @@
from neptune.notebook import Notebook
from neptune.oauth import NeptuneAuthenticator
from neptune.projects import Project
from neptune.utils import is_float, with_api_exceptions_handler, update_session_proxies
from neptune.utils import with_api_exceptions_handler, update_session_proxies
from neptune.constants import ANONYMOUS, ANONYMOUS_API_TOKEN

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -770,19 +775,31 @@ def _get_all_items(get_portion, step):

return items

def _get_parameter_with_type(self, parameter):
string_type = 'string'
double_type = 'double'
if isinstance(parameter, bool):
return (string_type, str(parameter))
elif isinstance(parameter, float) or isinstance(parameter, int):
if math.isinf(parameter) or math.isnan(parameter):
return (string_type, json.dumps(parameter))
else:
return (double_type, str(parameter))
else:
return (string_type, str(parameter))

def _convert_to_api_parameters(self, raw_params):
Parameter = self.backend_swagger_client.get_model('Parameter')

params = []
for name, value in raw_params.items():
parameter_type = 'double' if is_float(str(value)) and not isinstance(value, six.string_types) else 'string'

(parameter_type, string_value) = self._get_parameter_with_type(value)
params.append(
Parameter(
id=str(uuid.uuid4()),
name=name,
parameterType=parameter_type,
value=str(value)
value=string_value
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def test_convert_to_api_parameters(self, uuid4, swagger_client_factory):
'bool': False,
'float': 1.23,
'int': int(12),
'inf': float('inf'),
'-inf': float('-inf'),
'nan': float('nan'),
'list': [123, 'abc', ['def']],
'obj': some_object
})
Expand All @@ -108,6 +111,9 @@ def test_convert_to_api_parameters(self, uuid4, swagger_client_factory):
ApiParameter(id=some_uuid, name='bool', parameterType='string', value='False'),
ApiParameter(id=some_uuid, name='float', parameterType='double', value='1.23'),
ApiParameter(id=some_uuid, name='int', parameterType='double', value='12'),
ApiParameter(id=some_uuid, name='inf', parameterType='string', value='Infinity'),
ApiParameter(id=some_uuid, name='-inf', parameterType='string', value='-Infinity'),
ApiParameter(id=some_uuid, name='nan', parameterType='string', value='NaN'),
ApiParameter(id=some_uuid, name='list', parameterType='string', value="[123, 'abc', ['def']]"),
ApiParameter(id=some_uuid, name='obj', parameterType='string', value=str(some_object))
}
Expand Down

0 comments on commit 769a10f

Please sign in to comment.