Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
fix for minimal pkg not throw (#3199)
Browse files Browse the repository at this point in the history
* fix import check not throw

* python2 import error __str__ uses args instead of message

* tuple forwarding
  • Loading branch information
Guihao Liang authored May 19, 2020
1 parent aa78196 commit a101a96
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 12 deletions.
22 changes: 18 additions & 4 deletions src/python/turicreate/_deps/minimal_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,26 @@ def _minimal_package_import_check(name):
else:
version = __version__

# append more information
e.msg = (
"{}. This is a minimal package for SFrame only, without {} pinned"
# append more information from Import error
emsg = str(e)
emsg = (
"{}.\nThis is a minimal package for SFrame only, without {} pinned"
" as a dependency. You can try install all required packages by installing"
" the full package. For example:\n"
"pip install --force-reinstall turicreate=={}\n"
).format(e.msg, name, version)
).format(emsg, name, version)

if six.PY2:
# __str__ and __repr__ uses `args`.
# only change the first element of args tuple
args = list(e.args)
if args:
args[0] = emsg
else:
args = (emsg,)
e.args = tuple(args)
e.message = emsg
else:
e.msg = emsg

raise e
26 changes: 26 additions & 0 deletions src/python/turicreate/test/test_minimal_pkg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Copyright © 2020 Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can
# be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
#
from __future__ import print_function as _ # noqa
from __future__ import division as _ # noqa
from __future__ import absolute_import as _ # noqa
import turicreate as _tc
import pytest

pytestmark = [pytest.mark.minimal]


@pytest.mark.skipif(not _tc._deps.is_minimal_pkg(), reason="skip when testing full pkg")
class TestMinimalPackage(object):
""" test minimal package for toolkits.
well, other toolkits are too hard to setup
"""

def test_audio_classifier(self):
with pytest.raises(
ImportError, match=r".*pip install --force-reinstall turicreate==.*"
):
_tc.load_audio("./dummy/audio")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import turicreate as _tc
from turicreate.toolkits._main import ToolkitError as _ToolkitError
from turicreate._deps.minimal_package import _minimal_package_import_check


def load_audio(
Expand Down Expand Up @@ -60,7 +61,7 @@ def load_audio(
>>> audio_path = "~/Documents/myAudioFiles/"
>>> audio_sframe = tc.audio_analysis.load_audio(audio_path, recursive=True)
"""
from scipy.io import wavfile as _wavfile
_scipy = _minimal_package_import_check("scipy")

path = _tc.util._make_internal_url(path)

Expand All @@ -84,7 +85,7 @@ def load_audio(
)
for cur_file_path in all_wav_files:
try:
sample_rate, data = _wavfile.read(cur_file_path)
sample_rate, data = _scipy.io.wavfile.read(cur_file_path)
except Exception as e:
error_string = "Could not read {}: {}".format(cur_file_path, e)
if not ignore_failure:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import time as _time

from coremltools.models import MLModel
import numpy as _np
from tensorflow import keras as _keras

# Suppresses verbosity to only errors
import turicreate.toolkits._tf_utils as _utils
Expand Down Expand Up @@ -90,11 +87,13 @@ def __init__(self):
self.gpu_policy.start()

model_path = vggish_model_file.get_model_path(format="tensorflow")
self.vggish_model = _keras.models.load_model(model_path)
_tf = _minimal_package_import_check("tensorflow")
self.vggish_model = _tf.keras.models.load_model(model_path)
else:
# Use Core ML
model_path = vggish_model_file.get_model_path(format="coreml")
self.vggish_model = MLModel(model_path)
coremltools = _minimal_package_import_check("coremltools")
self.vggish_model = coremltools.models.MLModel(model_path)

def __del__(self):
if self.mac_ver < (10, 14):
Expand Down Expand Up @@ -197,4 +196,5 @@ def get_spec(self):
else:
vggish_model_file = VGGish()
coreml_model_path = vggish_model_file.get_model_path(format="coreml")
return MLModel(coreml_model_path).get_spec()
coremltools = _minimal_package_import_check("coremltools")
return coremltools.models.MLModel(coreml_model_path).get_spec()

0 comments on commit a101a96

Please sign in to comment.