Skip to content

Commit

Permalink
Hiding the model internal data when retrieving index settings if the …
Browse files Browse the repository at this point in the history
…model is a Marqtune model. (#1018)

Co-authored-by: Jesse Clark <[email protected]>
Co-authored-by: TomHamer <[email protected]>
Co-authored-by: Farshid Zavareh <[email protected]>
  • Loading branch information
4 people authored Nov 5, 2024
1 parent d78a563 commit a31660c
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 6 deletions.
18 changes: 16 additions & 2 deletions src/marqo/tensor_search/models/index_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def from_marqo_index(cls, marqo_index: core.MarqoIndex) -> "IndexSettings":
treatUrlsAndPointersAsMedia=marqo_index.treat_urls_and_pointers_as_media,
filterStringMaxLength=marqo_index.filter_string_max_length,
model=marqo_index.model.name,
modelProperties=marqo_index.model.properties,
modelProperties=IndexSettings.get_model_properties(marqo_index),
normalizeEmbeddings=marqo_index.normalize_embeddings,
textPreprocessing=marqo_index.text_preprocessing,
imagePreprocessing=marqo_index.image_preprocessing,
Expand All @@ -235,7 +235,7 @@ def from_marqo_index(cls, marqo_index: core.MarqoIndex) -> "IndexSettings":
],
tensorFields=[field.name for field in marqo_index.tensor_fields],
model=marqo_index.model.name,
modelProperties=marqo_index.model.properties,
modelProperties=IndexSettings.get_model_properties(marqo_index),
normalizeEmbeddings=marqo_index.normalize_embeddings,
textPreprocessing=marqo_index.text_preprocessing,
imagePreprocessing=marqo_index.image_preprocessing,
Expand All @@ -250,6 +250,20 @@ def from_marqo_index(cls, marqo_index: core.MarqoIndex) -> "IndexSettings":
else:
raise api_exceptions.InternalError(f"Unknown index type: {type(marqo_index)}")

@classmethod
def get_model_properties(cls, marqo_index):
if marqo_index.model.properties is None:
return None

if marqo_index.model.properties.get('isMarqtuneModel', False):
# Hide all properties except for isMarqtuneModel
marqo_index.model.properties.pop('name', None)
marqo_index.model.properties.pop('dimensions')
marqo_index.model.properties.pop('model_location')
marqo_index.model.properties.pop('type')
marqo_index.model.properties.pop('trustRemoteCode', None)
return marqo_index.model.properties


class IndexSettingsWithName(IndexSettings):
indexName: str
184 changes: 180 additions & 4 deletions tests/core/index_management/test_get_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,43 @@ def setUpClass(cls) -> None:
textPreprocessing=TextPreProcessing(splitLength=3, splitMethod=TextSplitMethod.Word, splitOverlap=1),
).to_marqo_index_request('a' + str(uuid.uuid4()).replace('-', ''))

unstructured_marqtune_index = IndexSettings(
model='marqtune/model-id/release-checkpoint',
modelProperties={
"isMarqtuneModel": True,
"name": "ViT-B-32",
"dimensions": 512,
"model_location": {
"s3": {
"Bucket": "marqtune-public-bucket",
"Key": "marqo-test-open-clip-model/epoch_1.pt",
},
"auth_required": False
},
"type": "open_clip",
},
normalizeEmbeddings=False,
textPreprocessing=TextPreProcessing(splitLength=3, splitMethod=TextSplitMethod.Word, splitOverlap=1),
).to_marqo_index_request('a' + str(uuid.uuid4()).replace('-', ''))

unstructured_non_marqtune_index = IndexSettings(
type=IndexType.Unstructured,
model='marqtune/model-id/release-checkpoint',
modelProperties={
"dimensions": 384,
"model_location": {
"s3": {
"Bucket": "marqtune-public-bucket",
"Key": "marqo-test-hf-model/epoch_1.zip",
},
"auth_required": False
},
"type": "hf",
},
normalizeEmbeddings=False,
textPreprocessing=TextPreProcessing(splitLength=3, splitMethod=TextSplitMethod.Word, splitOverlap=1),
).to_marqo_index_request('a' + str(uuid.uuid4()).replace('-', ''))

structured_custom_index = IndexSettings(
type=IndexType.Structured,
allFields=[
Expand All @@ -45,17 +82,49 @@ def setUpClass(cls) -> None:
textPreprocessing=TextPreProcessing(splitLength=3, splitMethod=TextSplitMethod.Word, splitOverlap=1),
).to_marqo_index_request('a' + str(uuid.uuid4()).replace('-', ''))


structured_marqtune_index = IndexSettings(
type=IndexType.Structured,
allFields=[
FieldRequest(name='field1', type=FieldType.Text),
FieldRequest(name='field2', type=FieldType.Text),
],
tensorFields=[],
model='marqtune/model-id/release-checkpoint',
modelProperties={
"isMarqtuneModel": True,
"dimensions": 384,
"model_location": {
"s3": {
"Bucket": "marqtune-public-bucket",
"Key": "marqo-test-hf-model/epoch_1.zip",
},
"auth_required": False
},
"trustRemoteCode": True,
"type": "hf",
},
normalizeEmbeddings=False,
textPreprocessing=TextPreProcessing(splitLength=3, splitMethod=TextSplitMethod.Word, splitOverlap=1),
).to_marqo_index_request('a' + str(uuid.uuid4()).replace('-', ''))

cls.indexes = cls.create_indexes([
unstructured_default_index,
structured_default_index,
unstructured_custom_index,
structured_custom_index
structured_custom_index,
unstructured_marqtune_index,
structured_marqtune_index,
unstructured_non_marqtune_index
])

cls.unstructured_default_index = cls.indexes[0]
cls.structured_default_index = cls.indexes[1]
cls.unstructured_custom_index = cls.indexes[2]
cls.structured_custom_index = cls.indexes[3]
cls.unstructured_marqtune_index = cls.indexes[4]
cls.structured_marqtune_index = cls.indexes[5]
cls.unstructured_non_marqtune_index = cls.indexes[6]

def setUp(self) -> None:
self.clear_indexes(self.indexes)
Expand All @@ -67,7 +136,6 @@ def setUp(self) -> None:
def tearDown(self) -> None:
self.device_patcher.stop()


def test_no_index(self):
self.assertRaises(IndexNotFoundError, self.config.index_management.get_index, "non-existent-index")

Expand Down Expand Up @@ -140,7 +208,6 @@ def test_default_settings(self):
retrieved_settings = IndexSettings.from_marqo_index(retrieved_index).dict(exclude_none=True, by_alias=True)
self.assertEqual(retrieved_settings, expected_structured_default_settings)


def test_custom_settings(self):
"""adding custom settings to the index should be reflected in the returned output
"""
Expand Down Expand Up @@ -208,4 +275,113 @@ def test_custom_settings(self):
retrieved_index = self.config.index_management.get_index(self.structured_custom_index.name)
retrieved_settings = IndexSettings.from_marqo_index(retrieved_index).dict(exclude_none=True, by_alias=True)
self.assertEqual(retrieved_settings, expected_structured_custom_settings)


def test_index_settings_with_marqtune_model(self):
"""Model name, dimensions, and model location should be hidden if model is marqtune
"""
with self.subTest("Unstructured index with marqtune model"):
expected_unstructured_custom_settings = \
{
'annParameters': {
'parameters': {'efConstruction': 512, 'm': 16},
'spaceType': DistanceMetric.PrenormalizedAngular
},
'filterStringMaxLength': 50,
'imagePreprocessing': {},
'model': 'marqtune/model-id/release-checkpoint',
'modelProperties': {
"isMarqtuneModel": True,
},
'normalizeEmbeddings': False,
'textPreprocessing': {'splitLength': 3,
'splitMethod': TextSplitMethod.Word,
'splitOverlap': 1},
'audioPreprocessing': {'splitLength': 10, 'splitOverlap': 3},
'videoPreprocessing': {'splitLength': 20, 'splitOverlap': 3},
'treatUrlsAndPointersAsImages': False,
'treatUrlsAndPointersAsMedia': False,
'type': IndexType.Unstructured,
'vectorNumericType': VectorNumericType.Float
}

retrieved_index = self.config.index_management.get_index(self.unstructured_marqtune_index.name)
retrieved_settings = IndexSettings.from_marqo_index(retrieved_index).dict(exclude_none=True, by_alias=True)
self.assertEqual(retrieved_settings, expected_unstructured_custom_settings)

with self.subTest("Structured index with marqtune model"):
expected_structured_custom_settings = \
{
'allFields': [
{
'features': [],
'name': 'field1',
'type': FieldType.Text
},
{
'features': [],
'name': 'field2',
'type': FieldType.Text
}
],
'annParameters': {
'parameters': {'efConstruction': 512, 'm': 16},
'spaceType': DistanceMetric.PrenormalizedAngular
},
'imagePreprocessing': {},
'model': 'marqtune/model-id/release-checkpoint',
'modelProperties': {
"isMarqtuneModel": True,
},
'normalizeEmbeddings': False,
'tensorFields': [],
'textPreprocessing': {
'splitLength': 3,
'splitMethod': TextSplitMethod.Word,
'splitOverlap': 1
},
'audioPreprocessing': {'splitLength': 10, 'splitOverlap': 3},
'videoPreprocessing': {'splitLength': 20, 'splitOverlap': 3},
'type': IndexType.Structured,
'vectorNumericType': VectorNumericType.Float
}

retrieved_index = self.config.index_management.get_index(self.structured_marqtune_index.name)
retrieved_settings = IndexSettings.from_marqo_index(retrieved_index).dict(exclude_none=True, by_alias=True)
self.assertEqual(retrieved_settings, expected_structured_custom_settings)

with self.subTest("Unstructured index with non-marqtune model"):
expected_unstructured_custom_settings = \
{
'annParameters': {
'parameters': {'efConstruction': 512, 'm': 16},
'spaceType': DistanceMetric.PrenormalizedAngular
},
'filterStringMaxLength': 50,
'imagePreprocessing': {},
'model': 'marqtune/model-id/release-checkpoint',
'modelProperties': {
'dimensions': 384,
'model_location': {
'auth_required': False,
's3': {
'Bucket': 'marqtune-public-bucket',
'Key': 'marqo-test-hf-model/epoch_1.zip'
}
},
'type': 'hf'
},
'normalizeEmbeddings': False,
'textPreprocessing': {'splitLength': 3,
'splitMethod': TextSplitMethod.Word,
'splitOverlap': 1},
'audioPreprocessing': {'splitLength': 10, 'splitOverlap': 3},
'videoPreprocessing': {'splitLength': 20, 'splitOverlap': 3},
'treatUrlsAndPointersAsImages': False,
'treatUrlsAndPointersAsMedia': False,
'type': IndexType.Unstructured,
'vectorNumericType': VectorNumericType.Float
}

retrieved_index = self.config.index_management.get_index(self.unstructured_non_marqtune_index.name)
retrieved_settings = IndexSettings.from_marqo_index(retrieved_index).dict(exclude_none=True, by_alias=True)
self.assertEqual(retrieved_settings, expected_unstructured_custom_settings)

0 comments on commit a31660c

Please sign in to comment.