Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Deployments refactor; Add deployment service and fix deployment config setting #831

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
080fde9
Deployments refactor; Add deployment service and fix deployment confi…
malexw Nov 6, 2024
54da111
Changes for code review
malexw Nov 14, 2024
6b8025e
Fix a number of integration and unit tests
malexw Nov 19, 2024
87d4367
Merge latest main and fix a few tests
malexw Nov 22, 2024
3775f56
Fix failing chat tests
malexw Nov 26, 2024
9a3436d
Move some tests from unit/routers to integration/routers
malexw Nov 28, 2024
0575161
Fix a few more tests
malexw Nov 29, 2024
ba6a829
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Nov 29, 2024
e154d30
Fix a few more tests
malexw Nov 29, 2024
a2cb2cc
Fix remainder of broken integration tests
malexw Dec 2, 2024
00867f4
Fix lint issues
malexw Dec 2, 2024
062dd41
Run prettier on Coral
malexw Dec 2, 2024
14bc51d
Remove old, unused model crud helper
malexw Dec 2, 2024
3617b64
Fix failing deployments unit tests
malexw Dec 2, 2024
883064d
Coral fix to account for agent.tools possibly being null
malexw Dec 2, 2024
04787a1
Fix TS styling
malexw Dec 2, 2024
721f447
Provide a dummy Cohere API key during testing
malexw Dec 2, 2024
b6ff9d7
Update Coral to align with latest version of the backend API
malexw Dec 5, 2024
6bdcad3
Fix lint issues in Coral
malexw Dec 5, 2024
dad938b
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 5, 2024
af5fca0
Last few changes for code review
malexw Dec 5, 2024
a14623e
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 6, 2024
fb7e0eb
Update generated API in assistants_web
malexw Dec 10, 2024
51c641e
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 10, 2024
5dbf9a4
Fix assistants_web build
malexw Dec 10, 2024
dc8ab67
Fix backend lint issues
malexw Dec 10, 2024
6b62703
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 10, 2024
609584d
Simplify validate_deployment_header
malexw Dec 10, 2024
e7e9d48
Don't seed the DB with deployment data, and fix a DeploymentDefinitio…
malexw Dec 14, 2024
b1dcce9
Fix backend lint issues
malexw Dec 14, 2024
81a5e88
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 14, 2024
99114b1
Fix broken unit tests
malexw Dec 15, 2024
025a455
Skip cohere deployments tests since they're breaking other tests
malexw Dec 15, 2024
28982bb
Fix deployment integration tests
malexw Dec 17, 2024
c67e213
More fixes to deployments integration tests
malexw Dec 17, 2024
faed53e
Fix deployment integration tests
malexw Dec 18, 2024
b8cc3c3
What API key are we using to call Cohere in the tests?
malexw Dec 18, 2024
18d1088
Mock list_models of CoherePlatform model to avoid Cohere API calls
malexw Dec 18, 2024
ccb6a6f
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ run-community-tests:

.PHONY: run-integration-tests
run-integration-tests:
docker compose run --rm --build backend poetry run pytest -c src/backend/pytest_integration.ini src/backend/tests/integration/$(file)
docker compose run --rm --build backend poetry run pytest -c src/backend/pytest_integration.ini src/backend/tests/integration/$(file) -rx

run-tests: run-unit-tests

Expand Down
29 changes: 9 additions & 20 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any

from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS,
get_default_deployment,
)
from backend.database_models.database import get_session
from backend.exceptions import DeploymentNotFoundError
malexw marked this conversation as resolved.
Show resolved Hide resolved
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services import deployment as deployment_service


def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
Expand All @@ -16,22 +15,12 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:

Returns:
BaseDeployment: Deployment implementation instance based on the deployment name.

Raises:
ValueError: If the deployment is not supported.
"""
kwargs["ctx"] = ctx
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)

# Check provided deployment against config const
if deployment is not None:
return deployment.deployment_class(**kwargs, **deployment.kwargs)

# Fallback to first available deployment
default = get_default_deployment(**kwargs)
if default is not None:
return default
try:
session = next(get_session())
deployment = deployment_service.get_deployment_by_name(session, name, **kwargs)
except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment(**kwargs)

raise ValueError(
f"Deployment {name} is not supported, and no available deployments were found."
)
return deployment
6 changes: 0 additions & 6 deletions src/backend/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as MANAGED_DEPLOYMENTS_SETUP,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)


def start():
Expand Down Expand Up @@ -50,9 +47,6 @@ def start():

# SET UP ENVIRONMENT FOR DEPLOYMENTS
all_deployments = MANAGED_DEPLOYMENTS_SETUP.copy()
if use_community_features:
all_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)

selected_deployments = select_deployments_prompt(all_deployments, secrets)

for deployment in selected_deployments:
Expand Down
4 changes: 2 additions & 2 deletions src/backend/config/default_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime

from backend.config.deployments import ModelDeploymentName
from backend.config.tools import Tool
from backend.model_deployments.cohere_platform import CohereDeployment
from backend.schemas.agent import AgentPublic

DEFAULT_AGENT_ID = "default"
DEFAULT_DEPLOYMENT = ModelDeploymentName.CoherePlatform
DEFAULT_DEPLOYMENT = CohereDeployment.name()
DEFAULT_MODEL = "command-r-plus"

def get_default_agent() -> AgentPublic:
Expand Down
135 changes: 15 additions & 120 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,35 @@
from enum import StrEnum

from backend.config.settings import Settings
from backend.model_deployments import (
AzureDeployment,
BedrockDeployment,
CohereDeployment,
SageMakerDeployment,
SingleContainerDeployment,
)
from backend.model_deployments.azure import AZURE_ENV_VARS
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS
from backend.model_deployments.single_container import SC_ENV_VARS
from backend.schemas.deployment import Deployment
from backend.services.logger.utils import LoggerFactory

logger = LoggerFactory().get_logger()


class ModelDeploymentName(StrEnum):
CoherePlatform = "Cohere Platform"
SageMaker = "SageMaker"
Azure = "Azure"
Bedrock = "Bedrock"
SingleContainer = "Single Container"


use_community_features = Settings().get('feature_flags.use_community_features')
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }

# TODO names in the map below should not be the display names but ids
ALL_MODEL_DEPLOYMENTS = {
ModelDeploymentName.CoherePlatform: Deployment(
id="cohere_platform",
name=ModelDeploymentName.CoherePlatform,
deployment_class=CohereDeployment,
models=CohereDeployment.list_models(),
is_available=CohereDeployment.is_available(),
env_vars=COHERE_ENV_VARS,
),
ModelDeploymentName.SingleContainer: Deployment(
id="single_container",
name=ModelDeploymentName.SingleContainer,
deployment_class=SingleContainerDeployment,
models=SingleContainerDeployment.list_models(),
is_available=SingleContainerDeployment.is_available(),
env_vars=SC_ENV_VARS,
),
ModelDeploymentName.SageMaker: Deployment(
id="sagemaker",
name=ModelDeploymentName.SageMaker,
deployment_class=SageMakerDeployment,
models=SageMakerDeployment.list_models(),
is_available=SageMakerDeployment.is_available(),
env_vars=SAGE_MAKER_ENV_VARS,
),
ModelDeploymentName.Azure: Deployment(
id="azure",
name=ModelDeploymentName.Azure,
deployment_class=AzureDeployment,
models=AzureDeployment.list_models(),
is_available=AzureDeployment.is_available(),
env_vars=AZURE_ENV_VARS,
),
ModelDeploymentName.Bedrock: Deployment(
id="bedrock",
name=ModelDeploymentName.Bedrock,
deployment_class=BedrockDeployment,
models=BedrockDeployment.list_models(),
is_available=BedrockDeployment.is_available(),
env_vars=BEDROCK_ENV_VARS,
),
}

def get_available_deployments() -> list[type[BaseDeployment]]:
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())

def get_available_deployments() -> dict[ModelDeploymentName, Deployment]:
if use_community_features:
if Settings().get("feature_flags.use_community_features"):
try:
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)

model_deployments = ALL_MODEL_DEPLOYMENTS.copy()
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
return model_deployments
except ImportError:
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())
except ImportError as e:
logger.warning(
event="[Deployments] No available community deployments have been configured"
event="[Deployments] No available community deployments have been configured", ex=e
)

deployments = Settings().get('deployments.enabled_deployments')
if deployments is not None and len(deployments) > 0:
return {
key: value
for key, value in ALL_MODEL_DEPLOYMENTS.items()
if value.id in Settings().get('deployments.enabled_deployments')
}

return ALL_MODEL_DEPLOYMENTS


def get_default_deployment(**kwargs) -> BaseDeployment:
# Fallback to the first available deployment
fallback = None
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
fallback = deployment.deployment_class(**kwargs)
break

default = Settings().get('deployments.default_deployment')
if default:
return next(
(
v.deployment_class(**kwargs)
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items()
if v.id == default
),
fallback,
)
else:
return fallback


def find_config_by_deployment_id(deployment_id: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.id == deployment_id:
return deployment
return None


def find_config_by_deployment_name(deployment_name: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.name == deployment_name:
return deployment
return None
enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
if enabled_deployment_ids:
return [
deployment
for deployment in installed_deployments
if deployment.id() in enabled_deployment_ids
]

return installed_deployments

AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()
26 changes: 11 additions & 15 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os

from sqlalchemy.orm import Session

from backend.database_models import Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
from backend.services.transaction import validate_transaction
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
from backend.schemas.deployment import (
DeploymentCreate,
DeploymentDefinition,
DeploymentUpdate,
)
from backend.services.transaction import validate_transaction


@validate_transaction
Expand All @@ -19,7 +18,7 @@ def create_deployment(db: Session, deployment: DeploymentCreate) -> Deployment:

Args:
db (Session): Database session.
deployment (DeploymentSchema): Deployment data to be created.
deployment (DeploymentDefinition): Deployment data to be created.

Returns:
Deployment: Created deployment.
Expand Down Expand Up @@ -132,27 +131,24 @@ def delete_deployment(db: Session, deployment_id: str) -> None:


@validate_transaction
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
def create_deployment_by_config(db: Session, deployment_config: DeploymentDefinition) -> Deployment:
"""
malexw marked this conversation as resolved.
Show resolved Hide resolved
Create a new deployment by config.

Args:
db (Session): Database session.
deployment (str): Deployment data to be created.
deployment_config (DeploymentSchema): Deployment config.
deployment_config (DeploymentDefinition): Deployment config.

Returns:
Deployment: Created deployment.
"""
deployment = Deployment(
name=deployment_config.name,
description="",
default_deployment_config= {
env_var: os.environ.get(env_var, "")
for env_var in deployment_config.env_vars
},
deployment_class_name=deployment_config.deployment_class.__name__,
is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
default_deployment_config=deployment_config.config,
deployment_class_name=deployment_config.class_name,
is_community=deployment_config.is_community,
)
db.add(deployment)
db.commit()
Expand Down
22 changes: 12 additions & 10 deletions src/backend/crud/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from sqlalchemy.orm import Session

from backend.database_models import Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentDefinition
from backend.schemas.model import ModelCreate, ModelUpdate
from backend.services.logger.utils import LoggerFactory
from backend.services.transaction import validate_transaction

logger = LoggerFactory().get_logger()


@validate_transaction
def create_model(db: Session, model: ModelCreate) -> Model:
Expand Down Expand Up @@ -127,29 +129,29 @@ def delete_model(db: Session, model_id: str) -> None:
db.commit()


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
def create_model_by_config(db: Session, deployment_config: DeploymentDefinition, deployment_id: str, model: str | None) -> Model:
"""
Create a new model by config if present

Args:
db (Session): Database session.
deployment (Deployment): Deployment data.
deployment_config (DeploymentSchema): Deployment config data.
model (str): Model data.
deployment_config (DeploymentDefinition): A deployment definition for any kind of deployment.
deployment_id (DeploymentDefinition): Deployment ID for a deployment from the DB.
model (str): Optional model name that should have its data returned from this call.

Returns:
Model: Created model.
"""
deployment_config_models = deployment_config.models
deployment_db_models = get_models_by_deployment_id(db, deployment.id)
logger.debug(event="create_model_by_config", deployment_models=deployment_config.models, deployment_id=deployment_id, model=model)
deployment_db_models = get_models_by_deployment_id(db, deployment_id)
model_to_return = None
for deployment_config_model in deployment_config_models:
for deployment_config_model in deployment_config.models:
model_in_db = any(record.name == deployment_config_model for record in deployment_db_models)
if not model_in_db:
new_model = Model(
name=deployment_config_model,
cohere_name=deployment_config_model,
deployment_id=deployment.id,
deployment_id=deployment_id,
)
db.add(new_model)
db.commit()
Expand Down
Loading
Loading