forked from NVIDIA-AI-Blueprints/vulnerability-analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding_loader.py
86 lines (71 loc) · 3.3 KB
/
embedding_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import typing
if typing.TYPE_CHECKING:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings # nocover
from langchain.embeddings.openai import OpenAIEmbeddings # nocover
from langchain_core.embeddings import Embeddings # nocover
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # nocover
EmbeddingTypes = typing.Literal["huggingface", "openai", "nim"]
class EmbeddingLoader:
@typing.overload
@staticmethod
def create(type: typing.Literal["openai"], *embedding_args, **embedding_kwargs) -> "OpenAIEmbeddings":
pass
@typing.overload
@staticmethod
def create(type: typing.Literal["huggingface"], *embedding_args, **embedding_kwargs) -> "HuggingFaceEmbeddings":
pass
@typing.overload
@staticmethod
def create(type: typing.Literal["nim"], *embedding_args, **embedding_kwargs) -> "NVIDIAEmbeddings":
pass
@typing.overload
@staticmethod
def create(type: str, *embedding_args, **embedding_kwargs) -> "Embeddings":
pass
@staticmethod
def create(type: str | EmbeddingTypes, *embedding_args, **embedding_kwargs) -> "Embeddings":
"""
Returns a LangChaing embedding class from a properly defined config.
Parameters
----------
embedding_type : str
The embedding model to create
embedding_args: dict[str, typing.Any]
Additional keyword arguments to pass to the model
"""
match type.lower():
case "openai":
module_name = "langchain.embeddings.openai"
module_class = "OpenAIEmbeddings"
case "huggingface":
module_name = "langchain.embeddings.huggingface"
module_class = "HuggingFaceEmbeddings"
case "nim":
module_name = "langchain_nvidia_ai_endpoints"
module_class = "NVIDIAEmbeddings"
if "api_key" not in embedding_kwargs or embedding_kwargs["api_key"] is None:
embedding_kwargs["api_key"] = os.environ.get("NVIDIA_API_KEY", None)
if "base_url" not in embedding_kwargs or embedding_kwargs["base_url"] is None:
# Use environment variable if available, otherwise don't set `base_url` to use LangChain default
if "NIM_EMBED_BASE_URL" in os.environ:
embedding_kwargs["base_url"] = os.environ.get("NIM_EMBED_BASE_URL")
module = importlib.import_module(module_name)
class_ = getattr(module, module_class)
instance = class_(*embedding_args, **embedding_kwargs)
return instance