diff --git a/pliers/extractors/__init__.py b/pliers/extractors/__init__.py index df7999bf..9520dd43 100644 --- a/pliers/extractors/__init__.py +++ b/pliers/extractors/__init__.py @@ -63,7 +63,7 @@ from .misc import MetricExtractor from .models import (TensorFlowKerasApplicationExtractor, TFHubImageExtractor, TFHubTextExtractor, - TFHubExtractor) + TFHubExtractor, TFHubAudioExtractor) from .text import (ComplexTextExtractor, DictionaryExtractor, PredefinedDictionaryExtractor, LengthExtractor, NumUniqueWordsExtractor, PartOfSpeechExtractor, @@ -128,6 +128,7 @@ 'TFHubExtractor', 'TFHubImageExtractor', 'TFHubTextExtractor', + 'TFHubAudioExtractor', 'ComplexTextExtractor', 'DictionaryExtractor', 'PredefinedDictionaryExtractor', diff --git a/pliers/extractors/models.py b/pliers/extractors/models.py index e6c1a4d5..a589557c 100644 --- a/pliers/extractors/models.py +++ b/pliers/extractors/models.py @@ -6,7 +6,7 @@ from pliers.extractors.image import ImageExtractor from pliers.extractors.base import Extractor, ExtractorResult from pliers.filters.image import ImageResizingFilter -from pliers.stimuli import ImageStim, TextStim +from pliers.stimuli import ImageStim, TextStim, AudioStim from pliers.stimuli.base import Stim from pliers.support.exceptions import MissingDependencyError from pliers.utils import (attempt_to_import, verify_dependencies, @@ -26,34 +26,34 @@ class TFHubExtractor(Extractor): Args: url_or_path (str): url or path to TFHub model. You can browse models at https://tfhub.dev/. - features (optional): list of labels (for classification) - or other feature names. The number of items must - match the number of features in the output. For example, - if a classification model with 1000 output classes is passed + features (optional): list of feature names matching output dimensions + + For example, for a classification model with 1000 output classes + this must be a list containing 1000 items. (e.g. EfficientNet B6, - see https://tfhub.dev/tensorflow/efficientnet/b6/classification/1), - this must be a list containing 1000 items. If a text encoder - outputting 768-dimensional encoding is passed (e.g. base BERT), - this must be a list containing 768 items. Each dimension in the - model output will be returned as a separate feature in the - ExtractorResult. + https://tfhub.dev/tensorflow/efficientnet/b6/classification/1), + Alternatively, the model output can be packed into a single feature (i.e. a vector) by passing a single-element list - (e.g. ['encoding']) or a string. Along the lines of - the previous examples, if a single feature name is - passed here (e.g. if features=['encoding']) for a TFHub model - that outputs a 768-dimensional encoding, the extractor will - return only one feature named 'encoding', which contains the - encoding vector as a 1-d array wrapped in a list. + or a string. For example, for a model that outputs a + 768-dimensional encoding, the value 'encoding' will result + in a 1-d array wrapped in a list named 'encoding'. + If no value is passed, the extractor will automatically compute the number of features in the model output and return an equal number of features in pliers, labeling each feature with a generic prefix + its positional index in the model output (feature_0, feature_1, ... ,feature_n). + + Note that for saved models, the feature names are inferred + from the output signature, but can be over-ridden. transform_out (optional): function to transform model output for compatibility with extractor result transform_inp (optional): function to transform Stim.data for compatibility with model input format + output_key (str): key to desired in output in + dictionary. Set to None if the output is not a dictionary, + or to output all keys in dictionary. keras_kwargs (dict): arguments to hub.KerasLayer call ''' @@ -62,12 +62,14 @@ class TFHubExtractor(Extractor): def __init__(self, url_or_path, features=None, transform_out=None, transform_inp=None, - keras_kwargs=None): + output_key=None, keras_kwargs=None): verify_dependencies(['tensorflow_hub']) if keras_kwargs is None: keras_kwargs = {} self.keras_kwargs = keras_kwargs + self.output_key = output_key self.model = hub.KerasLayer(url_or_path, **keras_kwargs) + self.url_or_path = url_or_path self.features = features self.transform_out = transform_out @@ -75,11 +77,21 @@ def __init__(self, url_or_path, features=None, super().__init__() def get_feature_names(self, out): + # Manual feature names always take precedence if self.features: return listify(self.features) + # Infer feature names from output else: - return ['feature_' + str(i) - for i in range(out.shape[-1])] + # If dict, use provided output key, or all keys + if isinstance(out, dict): + if self.output_key: + return [self.output_key] + else: + return list(out.keys()) + # Worst case, use generic feature names + else: + return ['feature_' + str(i) + for i in range(out.shape[-1])] def _preprocess(self, stim): if self.transform_inp: @@ -91,17 +103,56 @@ def _preprocess(self, stim): return stim.data def _postprocess(self, out): + # If key is provided, return only that key + if self.output_key: + try: + out = out[self.output_key] + except KeyError: + raise ValueError(f'{self.output_key} is not a valid key.' + 'Check which keys are available in the output ' + 'at the model URL ({self.url_or_path})') + except (IndexError, TypeError): + raise ValueError(f'Model output is not a dictionary. ' + 'Try initialize the extractor with output_key=None.') + + # If output is a dict and no output key, return all keys + if isinstance(out, dict): + out = np.vstack(list(out.values())).T + elif isinstance(out, tf.Tensor): + out = out.numpy() + + # Always squeeze last dimension if it is 1 + out = out.squeeze() + if self.transform_out: out = self.transform_out(out) - return out.numpy().squeeze() + return out + + def _get_timing(self, out, stim): + """ Returns the timing of the output. + Args: + out: output of the model + stim: input stimulus + + Returns: + onsets: onsets of the output + durations: durations of the output + orders: order of the output + """ + + return stim.onset, stim.duration, None def _extract(self, stim): inp = self._preprocess(stim) out = self.model(inp) - out = self._postprocess(out) features = self.get_feature_names(out) + out = self._postprocess(out) + + onsets, durations, orders = self._get_timing(out, stim) + return ExtractorResult(listify(out), stim, self, - features=features) + onsets=onsets, durations=durations, + features=features, orders=orders) def _to_df(self, result): if len(result.features) == 1: @@ -116,41 +167,84 @@ def _to_df(self, result): class TFHubImageExtractor(TFHubExtractor): - ''' TFHub Extractor class for image models + ''' TFHub Extractor class for image models. + + Note that some models may require specific input shapes.' + You can reshape inputs using filters, such as ImageResizingFilter. + ImageRescaleFilter. + Args: url_or_path (str): url or path to TFHub model - features (optional): list of labels (for classification) - or other feature names. If not specified, returns - numbered features (feature_0, feature_1, ... ,feature_n) - keras_kwargs (dict): arguments to hub.KerasLayer call + input_dtype (optional): dtype of input data. Defaults to tf.float32 ''' _input_type = ImageStim - _log_attributes = ('url_or_path', 'features', 'keras_kwargs') def __init__(self, url_or_path, - features=None, input_dtype=None, - keras_kwargs=None): + **kwargs): self.input_dtype = input_dtype if input_dtype else tf.float32 - if keras_kwargs is None: - keras_kwargs = {} - self.keras_kwargs = keras_kwargs - logging.warning('Some models may require specific input shapes.' - ' Incompatible shapes may raise errors' - ' at extraction. If needed, you can reshape' - ' your input image using ImageResizingFilter, ' - ' and rescale using ImageRescalingFilter') - super().__init__(url_or_path, features, keras_kwargs=keras_kwargs) + super().__init__(url_or_path, **kwargs) def _preprocess(self, stim): x = tf.convert_to_tensor(stim.data, dtype=self.input_dtype) x = tf.expand_dims(x, axis=0) return x +class TFHubAudioExtractor(TFHubExtractor): + + ''' TFHub Extractor class for audio models. + + Note that some models may require a specific sampling frequency.' + You can resample inputs using AudioResamplingFilter. + + Args: + url_or_path (str): url or path to TFHub model + input_dtype (optional): dtype of input data. Defaults to tf.float32 + ''' + + _input_type = AudioStim + + def __init__(self, + url_or_path, + input_dtype=None, + **kwargs): + + self.input_dtype = input_dtype if input_dtype else tf.float32 + + super().__init__(url_or_path, **kwargs) + + def _preprocess(self, stim): + x = tf.convert_to_tensor(stim.data, dtype=self.input_dtype) + return x + + def _get_timing(self, out, stim): + """ Returns the timing of the output. + + Assumes model returns a fixed sampling frequency, + and deduces durations and onsets from the sampling frequency. + + Args: + out: output of the model + stim: input stimulus + + Returns: + onsets: onsets of the output + durations: durations of the output + orders: order of the output + """ + + durations = [stim.duration / out.shape[0]] * out.shape[0] + onsets = np.arange(0, stim.duration, durations[0]) + if stim.onset is not None: + onsets += stim.onset + onsets = onsets.tolist() + orders = range(0, len(onsets)) + + return onsets, durations, orders class TFHubTextExtractor(TFHubExtractor): @@ -162,7 +256,11 @@ class TFHubTextExtractor(TFHubExtractor): The number of items must match the number of features in the model output. For example, if a text encoder outputting 768-dimensional encoding is passed - (e.g. base BERT), this must be a list containing 768 items. + output_key (str): key to desired embedding in output + dictionary (see documentation at + https://www.tensorflow.org/hub/common_saved_model_apis/text). + Set to None is the output is not a dictionary, or to + output all keys (e.g. base BERT), this must be a list containing 768 items. Each dimension in the model output will be returned as a separate feature in the ExtractorResult. Alternatively, the model output can be packed into a single @@ -176,7 +274,8 @@ class TFHubTextExtractor(TFHubExtractor): output_key (str): key to desired embedding in output dictionary (see documentation at https://www.tensorflow.org/hub/common_saved_model_apis/text). - Set to None is the output is not a dictionary. + Set to None is the output is not a dictionary, or to + output all keys preprocessor_url_or_path (str): if the model requires preprocessing through another TFHub model, specifies the url or path to the preprocessing module. Information on @@ -196,7 +295,7 @@ def __init__(self, preprocessor_kwargs=None, keras_kwargs=None, **kwargs): - super().__init__(url_or_path, features, + super().__init__(url_or_path, features, output_key=output_key, keras_kwargs=keras_kwargs, **kwargs) self.output_key = output_key @@ -217,22 +316,6 @@ def _preprocess(self, stim): self.preprocessor_url_or_path, **self.preprocessor_kwargs) x = preprocessor(x) return x - - def _postprocess(self, out): - if not self.output_key: - return out.numpy().squeeze() - else: - try: - return out[self.output_key].numpy().squeeze() - except KeyError: - raise ValueError(f'{self.output_key} is not a valid key.' - 'Check which keys are available in the output ' - 'embedding dictionary in TFHub docs ' - '(https://www.tensorflow.org/hub/common_saved_model_apis/text)' - f' or at the model URL ({self.url_or_path})') - except (IndexError, TypeError): - raise ValueError(f'Model output is not a dictionary. ' - 'Try initialize the extractor with output_key=None.') class TensorFlowKerasApplicationExtractor(ImageExtractor): diff --git a/pliers/tests/extractors/test_model_extractors.py b/pliers/tests/extractors/test_model_extractors.py index 111c6e7d..0bd2111b 100644 --- a/pliers/tests/extractors/test_model_extractors.py +++ b/pliers/tests/extractors/test_model_extractors.py @@ -9,6 +9,7 @@ from pliers.extractors import (TensorFlowKerasApplicationExtractor, TFHubExtractor, TFHubImageExtractor, + TFHubAudioExtractor, TFHubTextExtractor, BertExtractor, BertSequenceEncodingExtractor, @@ -40,7 +41,7 @@ TOKENIZER_URL = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2' ELECTRA_URL = 'https://tfhub.dev/google/electra_small/2' SPEECH_URL = 'https://tfhub.dev/google/speech_embedding/1' - +SPICE_URL = 'https://tfhub.dev/google/spice/2' pytestmark = pytest.mark.skipif( environ.get('skip_high_memory', False) == 'true', reason='high memory') @@ -447,3 +448,21 @@ def compute_expected_length(stim, ext): with pytest.raises(ValueError) as err: AudiosetLabelExtractor(top_n=10, labels=labels) assert 'Top_n and labels are mutually exclusive' in str(err.value) + +def test_spice_extractor(): + audio_stim = AudioStim(join(AUDIO_DIR, 'homer.wav')) + audio_filter = AudioResamplingFilter(target_sr=16000) + audio_resampled = audio_filter.transform(audio_stim) + + ext = TFHubAudioExtractor(SPICE_URL, keras_kwargs=dict( + signature='serving_default', signature_outputs_as_dict=True)) + r_orig = ext.transform(audio_stim).to_df() + assert r_orig.shape == (74, 6) + + r_orig.onset.min() == 0.0 + r_orig.duration.min() == r_orig.duration.max() == 0.04594594594594594 + r_orig.uncertainty[0] == 0.974131 + r_orig.pitch[0] == 0.171392 + + + \ No newline at end of file