From f13e3f11d88fb0a130fa8adc00633411c57abaa1 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee <34739235+Pipboyguy@users.noreply.github.com> Date: Mon, 25 Nov 2024 23:14:47 +0200 Subject: [PATCH] Support custom Ollama Host (#2044) * Add ollama config test Signed-off-by: Marcel Coetzee * Merge host and port in single var Signed-off-by: Marcel Coetzee * Remove redundant mypy ingore Signed-off-by: Marcel Coetzee * [test](lancedb): add embedding model env var in test_model_providers Signed-off-by: Marcel Coetzee * [fix](test): remove redundant LanceDB Ollama test case Signed-off-by: Marcel Coetzee * Format Signed-off-by: Marcel Coetzee * [docs](lancedb): update embedding model provider and add custom endpoint support Signed-off-by: Marcel Coetzee --------- Signed-off-by: Marcel Coetzee --- .../impl/lancedb/configuration.py | 2 + .../impl/lancedb/lancedb_client.py | 4 +- .../dlt-ecosystem/destinations/lancedb.md | 8 +++- tests/load/lancedb/test_model_providers.py | 44 +++++++++++++++++++ 4 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 tests/load/lancedb/test_model_providers.py diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index 8f6a192bb0..33642268c1 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -82,6 +82,8 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): """Embedding provider used for generating embeddings. Default is "cohere". You can find the full list of providers at https://github.com/lancedb/lancedb/tree/main/python/python/lancedb/embeddings as well as https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/.""" + embedding_model_provider_host: Optional[str] = None + """Full host URL with protocol and port (e.g. 'http://localhost:11434'). Uses LanceDB's default if not specified, assuming the provider accepts this parameter.""" embedding_model: str = "embed-english-v3.0" """The model used by the embedding provider for generating embeddings. Check with the embedding provider which options are available. diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 1a3e1a7d34..bb0e12f8ec 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -251,6 +251,7 @@ def __init__( self.dataset_name = self.config.normalize_dataset_name(self.schema) embedding_model_provider = self.config.embedding_model_provider + embedding_model_host = self.config.embedding_model_provider_host # LanceDB doesn't provide a standardized way to set API keys across providers. # Some use ENV variables and others allow passing api key as an argument. @@ -259,12 +260,13 @@ def __init__( embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) + self.model_func = self.registry.get(embedding_model_provider).create( name=self.config.embedding_model, max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, + **({"host": embedding_model_host} if embedding_model_host else {}), ) - self.vector_field_name = self.config.vector_field_name @property diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index b2aec665ab..035f27fe32 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -33,8 +33,10 @@ Configure the destination in the dlt secrets file located at `~/.dlt/secrets.tom ```toml [destination.lancedb] -embedding_model_provider = "cohere" -embedding_model = "embed-english-v3.0" +embedding_model_provider = "ollama" +embedding_model = "mxbai-embed-large" +embedding_model_provider_host = "http://localhost:11434" # Optional: custom endpoint for providers that support it + [destination.lancedb.credentials] uri = ".lancedb" api_key = "api_key" # API key to connect to LanceDB Cloud. Leave out if you are using LanceDB OSS. @@ -47,6 +49,7 @@ embedding_model_provider_api_key = "embedding_model_provider_api_key" # Not need - The `embedding_model` specifies the model used by the embedding provider for generating embeddings. Check with the embedding provider which options are available. Reference https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/. +- The `embedding_model_provider_host` specifies the full host URL with protocol and port for providers that support custom endpoints (like Ollama). If not specified, the provider's default endpoint will be used. - The `embedding_model_provider_api_key` is the API key for the embedding model provider used to generate embeddings. If you're using a provider that doesn't need authentication, such as Ollama, you don't need to supply this key. :::info Available model providers @@ -61,6 +64,7 @@ embedding_model_provider_api_key = "embedding_model_provider_api_key" # Not need - "sentence-transformers" - "huggingface" - "colbert" +- "ollama" ::: ### Define your data source diff --git a/tests/load/lancedb/test_model_providers.py b/tests/load/lancedb/test_model_providers.py new file mode 100644 index 0000000000..7ad5464fe5 --- /dev/null +++ b/tests/load/lancedb/test_model_providers.py @@ -0,0 +1,44 @@ +""" +Test intricacies and configuration related to each provider. +""" + +import os +from typing import Iterator, Any, Generator + +import pytest +from lancedb import DBConnection # type: ignore +from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore +from lancedb.table import Table # type: ignore + +import dlt +from dlt.common.configuration import resolve_configuration +from dlt.common.typing import DictStrStr +from dlt.common.utils import uniq_id +from dlt.destinations.impl.lancedb import lancedb_adapter +from dlt.destinations.impl.lancedb.configuration import LanceDBClientConfiguration +from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient +from tests.load.utils import drop_active_pipeline_data, sequence_generator +from tests.pipeline.utils import assert_load_info + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[Any]: + yield + drop_active_pipeline_data() + + +def test_lancedb_ollama_endpoint_configuration() -> None: + os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL_PROVIDER"] = "ollama" + os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL"] = "nomic-embed-text" + os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL_PROVIDER_HOST"] = "http://198.163.194.3:24233" + + config = resolve_configuration( + LanceDBClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "lancedb"), + ) + assert config.embedding_model_provider == "ollama" + assert config.embedding_model == "nomic-embed-text" + assert config.embedding_model_provider_host == "http://198.163.194.3:24233"