From 4ec07731617f66c5522f2dbddf0b44f0cb6629e3 Mon Sep 17 00:00:00 2001 From: jtyoung84 <104453205+jtyoung84@users.noreply.github.com> Date: Thu, 4 Jan 2024 12:55:02 -0800 Subject: [PATCH 1/4] feat: upgrades to pydantic version 2 --- pyproject.toml | 3 +- src/aind_codeocean_api/credentials.py | 351 +++++++++++++++----------- tests/test_credentials.py | 31 ++- 3 files changed, 224 insertions(+), 161 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8896b28..df2e2e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,8 @@ dynamic = ["version"] dependencies = [ 'requests', - 'pydantic<2.0', + 'pydantic>2.0', + 'pydantic-settings>2.0', 'boto3' ] diff --git a/src/aind_codeocean_api/credentials.py b/src/aind_codeocean_api/credentials.py index 51f2c8e..ecf11cd 100644 --- a/src/aind_codeocean_api/credentials.py +++ b/src/aind_codeocean_api/credentials.py @@ -1,32 +1,119 @@ """Basic CodeOcean Credentials Handling.""" +import functools import json import os +from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple, Type, Union import boto3 -from pydantic import BaseSettings, Field, SecretStr, validator -from pydantic.env_settings import ( +from pydantic import Field, SecretStr, field_validator +from pydantic.fields import FieldInfo +from pydantic_settings import ( + BaseSettings, EnvSettingsSource, InitSettingsSource, - SecretsSettingsSource, + PydanticBaseSettingsSource, + SettingsConfigDict, ) -# Small helper function to get an aws secret -def get_secret(secret_name: str) -> dict: - """ - Retrieves a secret from AWS Secrets Manager. +class JsonConfigSettingsSource(PydanticBaseSettingsSource, ABC): + """Abstract base class for settings that parse json""" - param secret_name: The name of the secret to retrieve. - """ - # Create a Secrets Manager client - client = boto3.client("secretsmanager") - try: - response = client.get_secret_value(SecretId=secret_name) - finally: - client.close() - return json.loads(response["SecretString"]) + def __init__(self, settings_cls, config_file_location): + """ + Class constructor for generic settings source that parses json + Parameters + ---------- + settings_cls + Required for parent init + config_file_location + Location of json contents to parse + """ + self.config_file_location = config_file_location + super().__init__(settings_cls) + + @abstractmethod + def _retrieve_contents(self) -> Dict[str, Any]: + """Retrieve contents from config_file_location""" + + @functools.cached_property + def _json_contents(self): + """Cache contents to a property to avoid re-downloading.""" + contents = self._retrieve_contents() + return contents + + def get_field_value( + self, field: FieldInfo, field_name: str + ) -> Tuple[Any, str, bool]: + """This function needs to be implemented for every + PydanticBaseSettingsSource""" + file_content_json = self._json_contents + field_value = file_content_json.get(field_name) + return field_value, field_name, False + + def prepare_field_value( + self, + field_name: str, + field: FieldInfo, + value: Any, + value_is_complex: bool, + ) -> Any: + """This function needs to be implemented for every + PydanticBaseSettingsSource""" + return value + + def __call__(self) -> Dict[str, Any]: + """This function needs to be implemented for every + PydanticBaseSettingsSource""" + d: Dict[str, Any] = {} + + for field_name, field in self.settings_cls.model_fields.items(): + field_value, field_key, value_is_complex = self.get_field_value( + field, field_name + ) + field_value = self.prepare_field_value( + field_name, field, field_value, value_is_complex + ) + if field_value is not None: + d[field_key] = field_value + + return d + + +class ConfigFileSettingsSource(JsonConfigSettingsSource): + """Class that parses from a local json file.""" + + def _retrieve_contents(self) -> Dict[str, Any]: + """Retrieve contents from config_file_location""" + with open(self.config_file_location, "r") as f: + contents = json.load(f) + return contents + + +class AWSConfigSettingsSource(JsonConfigSettingsSource): + """Class that parses from aws secrets manager.""" + + @staticmethod + def _get_secret(secret_name: str) -> dict: + """ + Retrieves a secret from AWS Secrets Manager. + + param secret_name: The name of the secret to retrieve. + """ + # Create a Secrets Manager client + client = boto3.client("secretsmanager") + try: + response = client.get_secret_value(SecretId=secret_name) + finally: + client.close() + return json.loads(response["SecretString"]) + + def _retrieve_contents(self) -> Dict[str, Any]: + """Retrieve contents from config_file_location""" + credentials_from_aws = self._get_secret(self.config_file_location) + return credentials_from_aws class CodeOceanCredentials(BaseSettings): @@ -34,13 +121,17 @@ class CodeOceanCredentials(BaseSettings): credentials through a class constructor, environment variables, a config file, or to pull them from aws secrets manager.""" + model_config = SettingsConfigDict(env_prefix="CODEOCEAN_") + aws_secrets_name: Optional[str] = Field( default=None, repr=False, description="Optionally pull credentials from aws secrets manager.", ) config_file: Optional[Path] = Field( - default=None, + default_factory=lambda: Path(os.path.expanduser("~")) + / ".codeocean" + / "credentials.json", repr=False, description="Optionally pull credentials from local config file.", ) @@ -60,163 +151,131 @@ class CodeOceanCredentials(BaseSettings): env_prefix="codeocean_", ) - @classmethod - def default_config_file_path(cls): - """Return the default config file path.""" - return cls.Config.secrets_dir - - @validator("domain", pre=True) - def _strip_trailing_slash(cls, input_domain): - """Strips trailing slash from domain.""" - return input_domain.strip("/") + @field_validator("domain") + def _strip_trailing_slash(cls, v): + """Strips the trailing slash from the domain. For example, if the user + inputs: https://acmecorp.codeocean.com/, then it will be changed to: + https://acmecorp.codeocean.com""" + return v.strip("/") - class Config: - """This class will add custom sourcing from aws.""" + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: InitSettingsSource, + env_settings: EnvSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + """ + Method to pull configs from a variety sources, such as a file or aws. + Arguments are required and set by pydantic. + Parameters + ---------- + settings_cls : Type[BaseSettings] + Top level class. Model fields can be pulled from this. + init_settings : InitSettingsSource + The settings in the init arguments. + env_settings : EnvSettingsSource + The settings pulled from environment variables. + dotenv_settings : PydanticBaseSettingsSource + Settings from .env files. Currently not supported. + file_secret_settings : PydanticBaseSettingsSource + Settings from secret files such as used in Docker. Currently not + supported. - # Prefix to append to env vars - env_prefix = "CODEOCEAN_" + Returns + ------- + Tuple[PydanticBaseSettingsSource, ...] - # Default location of the config file - secrets_dir = ( - Path(os.path.expanduser("~")) / ".codeocean" / "credentials.json" - ) + """ + init_file_path = init_settings.init_kwargs.get("config_file") + default_file_path = settings_cls.model_fields[ + "config_file" + ].default_factory() + aws_secrets_path = init_settings.init_kwargs.get("aws_secrets_name") + default_file_exists = os.path.isfile(default_file_path) - @staticmethod - def settings_from_config_file(config_file: Optional[Path]): - """ - Curried function that returns a function to retrieve creds from - a file. - Parameters - ---------- - config_file : Optional[Path] - Location of json file to retrieve the creds from. - - Returns - ------- - A function that retrieves the credentials. - - """ - - def set_settings(_: BaseSettings) -> Dict[str, Any]: - """ - A simple settings source that loads from a file - """ - if config_file is None or not os.path.isfile(config_file): - return {} - else: - with open(config_file, "r") as f: - contents = json.load(f) - return contents - - return set_settings - - @staticmethod - def settings_from_aws(secrets_name: Optional[str]): - """ - Curried function that returns a function to retrieve creds from aws - Parameters - ---------- - secrets_name : Optional[str] - Name of the credentials we wish to retrieve - Returns - ------- - A function that retrieves the credentials. - """ - - def set_settings(_: BaseSettings) -> Dict[str, Any]: - """ - A simple settings source that loads from aws secrets manager - """ - credentials_from_aws = get_secret(secrets_name) - return credentials_from_aws - - return set_settings - - @classmethod - def customise_sources( - cls, - init_settings: InitSettingsSource, - env_settings: EnvSettingsSource, - file_secret_settings: SecretsSettingsSource, - ): - """Class method to return custom sources.""" - - # Check if aws_secrets_name is defined during instantiation - aws_secrets_name = init_settings.init_kwargs.get( - "aws_secrets_name" + # If user defines aws secrets, create creds from there + if aws_secrets_path is not None: + return ( + init_settings, + AWSConfigSettingsSource(settings_cls, aws_secrets_path), ) - # Check if config_file is defined during instantiation - config_file = init_settings.init_kwargs.get("config_file") - domain = init_settings.init_kwargs.get("domain") - token = init_settings.init_kwargs.get("token") - env_domain = os.getenv((cls.env_prefix + "domain").upper()) - env_token = os.getenv((cls.env_prefix + "token").upper()) - - # If user inputs aws_secrets_name, ignore all other settings - if aws_secrets_name: - return ( - init_settings, - cls.settings_from_aws(secrets_name=aws_secrets_name), - ) - # If a user defines a config_file, ignore non-init settings - elif config_file is not None: - return ( - init_settings, - cls.settings_from_config_file(config_file), - ) - # If a user attempts to construct object using init args and env - # vars, ignore looking for a default config file - elif (domain or env_domain) and (token or env_token): - return ( - init_settings, - env_settings, - ) - # Otherwise, attempt to pull creds from default creds file location - else: - return ( - cls.settings_from_config_file( - file_secret_settings.secrets_dir - ), + # else, if user defines config file path, create creds from there + elif init_file_path is not None: + # raise an exception if config file does not exist + if os.path.isfile(init_file_path) is False: + raise FileNotFoundError( + f"{init_file_path} is defined, but file not found." ) + return ( + init_settings, + ConfigFileSettingsSource(settings_cls, init_file_path), + ) + # If default file exists, create creds from init, env, and then there + elif default_file_exists: + return ( + init_settings, + env_settings, + ConfigFileSettingsSource(settings_cls, default_file_path), + ) + # Otherwise, create creds from init and env + else: + return ( + init_settings, + env_settings, + ) - def save_credentials_to_file(self, output_path: Optional[Path] = None): - """Saves domain and token to the file defined in config_file field or - the default_config_file_path if config_file is None.""" + def save_credentials_to_file(self, output_path: Union[Path, str]) -> None: + """ + Saves domain and token to a file. + Parameters + ---------- + output_path : Union[Path, str] + Location to write to - # Use default path if config_file is None - if output_path is not None: - out_path = output_path - elif self.config_file is not None: - out_path = self.config_file + Returns + ------- + None + Save contents to a file + + """ + + if isinstance(output_path, str): + out_path = Path(output_path) else: - out_path = self.default_config_file_path() + out_path = output_path out_path.parent.mkdir(parents=True, exist_ok=True) - with open(out_path, "w+") as output: + with open(out_path, "w+") as output_file: json.dump( { "domain": self.domain, "token": self.token.get_secret_value(), }, - output, + output_file, indent=4, ) def create_config_file(): """Main method to create a config file from user inputs""" + default_file_path = CodeOceanCredentials.model_fields[ + "config_file" + ].default_factory() # Prompt user user_input_file_path = ( input( f"Save to (Leave blank to save to default location" - f" {CodeOceanCredentials.default_config_file_path()}): " + f" {default_file_path}): " ) - or None + or default_file_path ) domain = input("Domain (e.g. https://acmecorp.codeocean.com): ") token = input("API Token: ") - CodeOceanCredentials( - domain=domain, token=token, config_file=user_input_file_path - ).save_credentials_to_file() + CodeOceanCredentials(domain=domain, token=token).save_credentials_to_file( + user_input_file_path + ) if __name__ == "__main__": diff --git a/tests/test_credentials.py b/tests/test_credentials.py index d3e8ae4..d5cd51e 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -8,9 +8,9 @@ from botocore.exceptions import ClientError from aind_codeocean_api.credentials import ( + AWSConfigSettingsSource, CodeOceanCredentials, create_config_file, - get_secret, ) TEST_DIR = Path(os.path.dirname(os.path.realpath(__file__))) @@ -67,7 +67,9 @@ def test_env_vars(self): self.assertEqual("123-abc", creds_combo_3.token.get_secret_value()) @patch.dict(os.environ, EXAMPLE_ENV_VARS, clear=True) - @patch("aind_codeocean_api.credentials.get_secret") + @patch( + "aind_codeocean_api.credentials.AWSConfigSettingsSource._get_secret" + ) def test_from_aws(self, mock_get_secret: MagicMock): """Tests pulling credentials from aws secrets manager.""" @@ -82,7 +84,9 @@ def test_from_aws(self, mock_get_secret: MagicMock): self.assertEqual("tkn", creds_from_aws2.token.get_secret_value()) @patch.dict(os.environ, EXAMPLE_ENV_VARS, clear=True) - @patch("aind_codeocean_api.credentials.get_secret") + @patch( + "aind_codeocean_api.credentials.AWSConfigSettingsSource._get_secret" + ) def test_from_aws_error(self, mock_get_secret: MagicMock): """Tests situation where an error is raised when attempting to set credentials through aws_secrets_name.""" @@ -135,11 +139,9 @@ def test_from_file_error( # just raise an Exception instead of attempting to fallback with self.assertRaises(Exception) as e: CodeOceanCredentials(config_file="mocked_file", domain="some_url") + err_message = ( - "ValidationError(model='CodeOceanCredentials', " - "errors=[" - "{'loc': ('token',), 'msg': 'field required', 'type':" - " 'value_error.missing'}])" + "FileNotFoundError('mocked_file is defined, but file not found.')" ) self.assertEqual(err_message, repr(e.exception)) @@ -158,7 +160,7 @@ def test_fallback_to_default_file( # Assert the domain set in the file overrides the domain set in init self.assertEqual("http://acmecorp-cfg.com", creds_from_file.domain) self.assertEqual("123-abc-c", creds_from_file.token.get_secret_value()) - default_path = creds_from_file.default_config_file_path() + default_path = creds_from_file.config_file mock_file.assert_called_once_with(default_path, "r") @patch("builtins.open", new_callable=mock_open) @@ -168,10 +170,10 @@ def test_save_to_file(self, mock_mkdir: MagicMock, mock_file: MagicMock): creds = CodeOceanCredentials(domain="domain", token="token") creds2 = CodeOceanCredentials(domain="domain", token="token") creds2.config_file = TEST_DIR / "creds1.json" - default_path = creds.default_config_file_path() - creds.save_credentials_to_file() + default_path = creds.config_file + creds.save_credentials_to_file(creds.config_file) creds.save_credentials_to_file(output_path=(TEST_DIR / "creds2.json")) - creds2.save_credentials_to_file() + creds2.save_credentials_to_file(creds2.config_file) mock_mkdir.assert_has_calls( [ call(parents=True, exist_ok=True), @@ -206,7 +208,7 @@ def test_get_secret_success(self, mock_boto3_client): # Call the get_secret method with a mock secret name secret_name = "my_secret" - secret_value = get_secret(secret_name) + secret_value = AWSConfigSettingsSource._get_secret(secret_name) # Assert that the client was called with the correct arguments mock_boto3_client.assert_called_with("secretsmanager") @@ -237,7 +239,7 @@ def test_get_secret_permission_denied(self, mock_boto3_client): ) # Assert that ClientError is raised with self.assertRaises(ClientError): - get_secret("my_secret") + AWSConfigSettingsSource._get_secret("my_secret") class TestConfigFileCreation(unittest.TestCase): @@ -284,7 +286,8 @@ def test_default_config_file( mock_input.assert_called() mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) mock_file.assert_called_once_with( - CodeOceanCredentials.default_config_file_path(), "w+" + CodeOceanCredentials.model_fields["config_file"].default_factory(), + "w+", ) From 8069422498214d75661babe2892f72233af0ffc6 Mon Sep 17 00:00:00 2001 From: jtyoung84 <104453205+jtyoung84@users.noreply.github.com> Date: Thu, 4 Jan 2024 12:59:11 -0800 Subject: [PATCH 2/4] feat: pydantic ge 2.0 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df2e2e2..0561150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,8 @@ dynamic = ["version"] dependencies = [ 'requests', - 'pydantic>2.0', - 'pydantic-settings>2.0', + 'pydantic>=2.0', + 'pydantic-settings>=2.0', 'boto3' ] From 5cb45f78dcb49712875db91375492181aa4fe77b Mon Sep 17 00:00:00 2001 From: jtyoung84 <104453205+jtyoung84@users.noreply.github.com> Date: Thu, 4 Jan 2024 13:24:17 -0800 Subject: [PATCH 3/4] feat: updates doc strings --- src/aind_codeocean_api/credentials.py | 53 ++++++++++++++++++++++++--- tests/test_credentials.py | 4 +- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/src/aind_codeocean_api/credentials.py b/src/aind_codeocean_api/credentials.py index ecf11cd..c5cf9c7 100644 --- a/src/aind_codeocean_api/credentials.py +++ b/src/aind_codeocean_api/credentials.py @@ -47,8 +47,23 @@ def _json_contents(self): def get_field_value( self, field: FieldInfo, field_name: str ) -> Tuple[Any, str, bool]: - """This function needs to be implemented for every - PydanticBaseSettingsSource""" + """ + Gets the value, the key for model creation, and a flag to determine + whether value is complex. + Parameters + ---------- + field : FieldInfo + The field + field_name : str + The field name + + Returns + ------- + Tuple[Any, str, bool] + A tuple contains the key, value and a flag to determine whether + value is complex. + + """ file_content_json = self._json_contents field_value = file_content_json.get(field_name) return field_value, field_name, False @@ -60,8 +75,25 @@ def prepare_field_value( value: Any, value_is_complex: bool, ) -> Any: - """This function needs to be implemented for every - PydanticBaseSettingsSource""" + """ + Prepares the value of a field. + Parameters + ---------- + field_name : str + The field name + field : FieldInfo + The field + value : Any + The value of the field that has to be prepared + value_is_complex : bool + A flag to determine whether value is complex + + Returns + ------- + Any + The prepared value + + """ return value def __call__(self) -> Dict[str, Any]: @@ -96,11 +128,20 @@ class AWSConfigSettingsSource(JsonConfigSettingsSource): """Class that parses from aws secrets manager.""" @staticmethod - def _get_secret(secret_name: str) -> dict: + def _get_secret(secret_name: str) -> Dict[str, Any]: """ Retrieves a secret from AWS Secrets Manager. - param secret_name: The name of the secret to retrieve. + Parameters + ---------- + secret_name : str + Secret name as stored in Secrets Manager + + Returns + ------- + Dict[str, Any] + Contents of the secret + """ # Create a Secrets Manager client client = boto3.client("secretsmanager") diff --git a/tests/test_credentials.py b/tests/test_credentials.py index d5cd51e..741a203 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -192,7 +192,7 @@ def test_save_to_file(self, mock_mkdir: MagicMock, mock_file: MagicMock): ) @patch("boto3.client") - def test_get_secret_success(self, mock_boto3_client): + def test_get_secret_success(self, mock_boto3_client: MagicMock): """Tests that secret is retrieved as expected""" # Mock the Secrets Manager client and response mock_client = Mock() @@ -224,7 +224,7 @@ def test_get_secret_success(self, mock_boto3_client): self.assertEqual(secret_value, expected_value) @patch("boto3.client") - def test_get_secret_permission_denied(self, mock_boto3_client): + def test_get_secret_permission_denied(self, mock_boto3_client: MagicMock): """Tests secret retrieval fails with incorrect aws permissions""" mock_boto3_client.return_value.get_secret_value.side_effect = ( ClientError( From 77f97128c0616d303b37ed0493151aa30fe6bc45 Mon Sep 17 00:00:00 2001 From: jtyoung84 <104453205+jtyoung84@users.noreply.github.com> Date: Thu, 4 Jan 2024 13:29:32 -0800 Subject: [PATCH 4/4] feat: updates doc strings --- src/aind_codeocean_api/credentials.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/aind_codeocean_api/credentials.py b/src/aind_codeocean_api/credentials.py index c5cf9c7..f0a90c5 100644 --- a/src/aind_codeocean_api/credentials.py +++ b/src/aind_codeocean_api/credentials.py @@ -97,8 +97,15 @@ def prepare_field_value( return value def __call__(self) -> Dict[str, Any]: - """This function needs to be implemented for every - PydanticBaseSettingsSource""" + """ + Run this when this class is called. Required to be implemented. + + Returns + ------- + Dict[str, Any] + The fields for the settings defined as a dict object. + + """ d: Dict[str, Any] = {} for field_name, field in self.settings_cls.model_fields.items():