Skip to content

Commit

Permalink
Fixes treatment of inf and nan value in experiment params
Browse files Browse the repository at this point in the history
Fixes treatment of inf and nan value in experiment params
  • Loading branch information
PiotrJander authored Oct 1, 2020
2 parents 1a41a52 + 769a10f commit f783833
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
25 changes: 21 additions & 4 deletions neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pylint: disable=too-many-lines

import io
import json
import logging
import math
import os
import platform
import socket
Expand Down Expand Up @@ -52,7 +57,7 @@
from neptune.notebook import Notebook
from neptune.oauth import NeptuneAuthenticator
from neptune.projects import Project
from neptune.utils import is_float, with_api_exceptions_handler, update_session_proxies
from neptune.utils import with_api_exceptions_handler, update_session_proxies
from neptune.constants import ANONYMOUS, ANONYMOUS_API_TOKEN

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -770,19 +775,31 @@ def _get_all_items(get_portion, step):

return items

def _get_parameter_with_type(self, parameter):
string_type = 'string'
double_type = 'double'
if isinstance(parameter, bool):
return (string_type, str(parameter))
elif isinstance(parameter, float) or isinstance(parameter, int):
if math.isinf(parameter) or math.isnan(parameter):
return (string_type, json.dumps(parameter))
else:
return (double_type, str(parameter))
else:
return (string_type, str(parameter))

def _convert_to_api_parameters(self, raw_params):
Parameter = self.backend_swagger_client.get_model('Parameter')

params = []
for name, value in raw_params.items():
parameter_type = 'double' if is_float(str(value)) and not isinstance(value, six.string_types) else 'string'

(parameter_type, string_value) = self._get_parameter_with_type(value)
params.append(
Parameter(
id=str(uuid.uuid4()),
name=name,
parameterType=parameter_type,
value=str(value)
value=string_value
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def test_convert_to_api_parameters(self, uuid4, swagger_client_factory):
'bool': False,
'float': 1.23,
'int': int(12),
'inf': float('inf'),
'-inf': float('-inf'),
'nan': float('nan'),
'list': [123, 'abc', ['def']],
'obj': some_object
})
Expand All @@ -108,6 +111,9 @@ def test_convert_to_api_parameters(self, uuid4, swagger_client_factory):
ApiParameter(id=some_uuid, name='bool', parameterType='string', value='False'),
ApiParameter(id=some_uuid, name='float', parameterType='double', value='1.23'),
ApiParameter(id=some_uuid, name='int', parameterType='double', value='12'),
ApiParameter(id=some_uuid, name='inf', parameterType='string', value='Infinity'),
ApiParameter(id=some_uuid, name='-inf', parameterType='string', value='-Infinity'),
ApiParameter(id=some_uuid, name='nan', parameterType='string', value='NaN'),
ApiParameter(id=some_uuid, name='list', parameterType='string', value="[123, 'abc', ['def']]"),
ApiParameter(id=some_uuid, name='obj', parameterType='string', value=str(some_object))
}
Expand Down

0 comments on commit f783833

Please sign in to comment.