Skip to content

Commit

Permalink
Fix tensorflow import in dataset wrapper to handle invalid tensorflow…
Browse files Browse the repository at this point in the history
… installation due to unsupported protobuf
  • Loading branch information
imatiach-msft committed Aug 21, 2024
1 parent eb571a5 commit 91d9e38
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
7 changes: 7 additions & 0 deletions python/ml_wrappers/dataset/dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
import tensorflow as tf
except ImportError:
module_logger.debug('Could not import tensorflow, required if using a Tensorflow model')
except TypeError as te:
# handle issues with protobuf version mismatch on some environments
module_logger.debug(
"Could not import tensorflow in ml-wrappers due to"
"TypeError, required if using a Tensorflow model. "
"Inner exception: {0}".format(te))


SAMPLED_STRING_ROWS = 10

Expand Down
10 changes: 7 additions & 3 deletions python/ml_wrappers/model/tensorflow_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ def is_sequential(model):
:rtype: bool
"""
model_type = str(type(model))
# the sequential namespace changed in tensorflow 2.13
old_sequential_ns = "keras.engine.sequential.Sequential'>"
new_sequential_ns = "keras.src.engine.sequential.Sequential'>"
return (model_type.endswith(old_sequential_ns) or model_type.endswith(new_sequential_ns))
# the sequential namespace changed in tensorflow 2.13
new_sequential_ns1 = "keras.src.engine.sequential.Sequential'>"
# it changed again in tensorflow 2.17
new_sequential_ns2 = "keras.src.models.sequential.Sequential'>"
return any([model_type.endswith(old_sequential_ns),
model_type.endswith(new_sequential_ns1),
model_type.endswith(new_sequential_ns2)])


class WrappedTensorflowModel(object):
Expand Down
2 changes: 1 addition & 1 deletion python/ml_wrappers/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name = 'ml_wrappers'
_major = '0'
_minor = '5'
_patch = '5'
_patch = '6'
version = '{}.{}.{}'.format(_major, _minor, _patch)
2 changes: 0 additions & 2 deletions tests/main/test_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,10 @@ def test_wrap_lightgbm_regression_model(self, housing):
train_regression_model_numpy(create_lightgbm_regressor, housing)
train_regression_model_pandas(create_lightgbm_regressor, housing)

@pytest.mark.skip("Keras API failing in tests with latest tensorflow")
def test_wrap_keras_regression_model(self, housing):
train_regression_model_numpy(create_keras_regressor, housing)
train_regression_model_pandas(create_keras_regressor, housing)

@pytest.mark.skip("Keras API failing in tests with latest tensorflow")
def test_wrap_scikit_keras_regression_model(self, housing):
train_regression_model_numpy(create_scikit_keras_regressor, housing)
train_regression_model_pandas(create_scikit_keras_regressor, housing)
Expand Down
1 change: 0 additions & 1 deletion tests/main/test_tf_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def test_wrap_scikit_keras_regression_model(self, housing):
train_regression_model_numpy(wrapped_init, housing)
train_regression_model_pandas(wrapped_init, housing)

@pytest.mark.skip("Keras API failing in tests with latest tensorflow")
def test_validate_is_sequential(self):
sequential_layer = tf.keras.Sequential(layers=None, name=None)
assert is_sequential(sequential_layer)
Expand Down

0 comments on commit 91d9e38

Please sign in to comment.