Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Fix evaluate().explore() in Python 3
Browse files Browse the repository at this point in the history
Fixes #1866

* Use `list()` to get the same behavior from filter on 2 & 3
* Use `six.iteritems()` to get `iteritems()` in 2 and `items()` in 3
* Removes roc_curve and confusion_matrix from the serialized data (it
  was inadvertently added in #1891, and causes serialization errors)
* Add unit test coverage to prevent this type of issue
  • Loading branch information
Zach Nation committed May 21, 2019
1 parent 97822e2 commit 5d8eb5e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 9 deletions.
5 changes: 5 additions & 0 deletions src/unity/python/turicreate/data_structures/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def show(self):
"""
from ..visualization._plot import _target

# Suppress visualization output if 'none' target is set
if _target == 'none':
return

try:
img = self._to_pil_image()
try:
Expand Down
6 changes: 6 additions & 0 deletions src/unity/python/turicreate/data_structures/sframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4479,6 +4479,12 @@ def explore(self, title=None):
if sys.platform != 'darwin' and sys.platform != 'linux2' and sys.platform != 'linux':
raise NotImplementedError('Visualization is currently supported only on macOS and Linux.')


# Suppress visualization output if 'none' target is set
from ..visualization._plot import _target
if _target == 'none':
return

path_to_client = _get_client_app_path()

if title is None:
Expand Down
6 changes: 6 additions & 0 deletions src/unity/python/turicreate/test/test_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ def test_save_and_load(self):
self.test_list_fields()
print("List fields passed")

def test_evaluate_explore(self):
# Run the explore method and make sure we don't throw an exception.
# This will test the JSON serialization logic.
tc.visualization.set_target('none')
evaluation = self.model.evaluate(data)
evaluation.explore()

class ImageClassifierSqueezeNetTest(ImageClassifierTest):
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...visualization import _get_client_app_path, _focus_client_app

import subprocess as __subprocess
import six as _six
from six.moves import _thread

import json as _json
Expand All @@ -28,10 +29,10 @@ def __init__(self, obj = {}):
def _get_eval_json(self):
evaluation_dictionary = dict()

for key, value in self.data.iteritems():
for key, value in _six.iteritems(self.data):
if (isinstance(value, float) or isinstance(value, int)) and _math.isnan(value):
continue
if (key is "test_data"):
if (key is "test_data" or key is "confusion_matrix" or key is "roc_curve"):
continue
evaluation_dictionary[key] = value

Expand All @@ -48,7 +49,12 @@ def _get_eval_json(self):
return str(_json.dumps({ "evaluation_spec": evaluation_dictionary }, allow_nan = False))

def explore(self):
_thread.start_new_thread(_start_process, (self._get_eval_json()+"\n", self.data["test_data"], self, ))
params = (self._get_eval_json()+"\n", self.data["test_data"], self, )
# Suppress visualization output if 'none' target is set
from ...visualization._plot import _target
if _target == 'none':
return
_thread.start_new_thread(_start_process, params)


def _get_data_spec(filters, start, length, row_type, mat_type, sframe, evaluation):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,8 @@ def hclusterSort(vectors, dist_fn):

excluding_names = [min_dist['from']['name'], min_dist['to']['name']]

vecs = filter(lambda v: v['name'] not in excluding_names, vecs)
distances = filter(lambda dist: (dist['from']['name'] not in excluding_names) and (dist['to']['name'] not in excluding_names), distances)
vecs = list(filter(lambda v: v['name'] not in excluding_names, vecs))
distances = list(filter(lambda dist: (dist['from']['name'] not in excluding_names) and (dist['to']['name'] not in excluding_names), distances))

for v in vecs:
total = 0
Expand Down
14 changes: 10 additions & 4 deletions src/unity/python/turicreate/visualization/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def _run_cmdline(command):

def set_target(target='auto'):
"""
Sets the target for visualizations launched with the `show` method. If
unset, or if target is not provided, defaults to 'auto'.
Sets the target for visualizations launched with the `show` or `explore`
methods. If unset, or if target is not provided, defaults to 'auto'.
Notes
-----
Expand All @@ -63,10 +63,11 @@ def set_target(target='auto'):
* 'auto': display plot output inline when in Jupyter Notebook, and
otherwise launch a native GUI window.
* 'gui': always launch a native GUI window.
* 'none': prevent all visualizations from being displayed.
"""
global _target
if target not in ['auto', 'gui']:
raise ValueError("Expected target to be one of: 'auto', 'gui'.")
if target not in ['auto', 'gui', 'none']:
raise ValueError("Expected target to be one of: 'auto', 'gui', 'none'.")
_target = target


Expand Down Expand Up @@ -121,6 +122,11 @@ def show(self):
"""
global _target

# Suppress visualization output if 'none' target is set
if _target == 'none':
return

display = False
try:
if _target == 'auto' and \
Expand Down

0 comments on commit 5d8eb5e

Please sign in to comment.