Skip to content

Commit

Permalink
Merge pull request #34 from clear-street/awheelock/sc-59962/add-secre…
Browse files Browse the repository at this point in the history
…t-rotation-support-in-gestalt-attribute

Add ttl and caching support in vault provider
  • Loading branch information
adisunw authored Jun 21, 2023
2 parents e5c0fa1 + ec0aaa2 commit e10bfcb
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 214 deletions.
148 changes: 78 additions & 70 deletions gestalt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from gestalt.vault import Vault
from gestalt.vault import Vault # noqa: E999
from gestalt.provider import Provider
import os
import glob
import collections.abc as collections

from typing import Dict, List, Type, Union, Optional, MutableMapping, Text, Any
import yaml
import re
import json

from .utils import flatten


def merge_into(
a: Dict[Text, Union[List[Any], Text, int, bool, float]],
Expand Down Expand Up @@ -49,21 +51,6 @@ def __init__(self) -> None:
self.regex_pattern = re.compile(
r"^ref\+([^\+]*)://([^(\+)]+)\#([^\+]+)?$")

def __flatten(
self,
d: MutableMapping[Text, Any],
parent_key: str = '',
sep: str = '.'
) -> Dict[Text, Union[List[Any], Text, int, bool, float]]:
items: List[Any] = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(self.__flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)

def add_config_path(self, path: str) -> None:
"""Adds a path to read configs from.
Expand Down Expand Up @@ -162,13 +149,10 @@ def build_config(self) -> None:
f'File {f} is marked as ".yaml" but cannot be read as such: {e}'
)

self.__conf_data = self.__flatten(self.__conf_data,
sep=self.__delim_char)
self.__conf_data = flatten(self.__conf_data, sep=self.__delim_char)

self.__parse_dictionary_keys(self.__conf_data)
self.__conf_data = self.__interpolate_keys(self.__conf_data)
self.__parse_dictionary_keys(self.__conf_sets)
self.__conf_sets = self.__interpolate_keys(self.__conf_sets)

def __parse_dictionary_keys(
self, dictionary: Dict[str, Union[List[Any], str, int, bool, float]]
Expand Down Expand Up @@ -208,24 +192,6 @@ def configure_provider(self, provider_name: str,
else:
raise TypeError("Provider provider is not supported")

def __interpolate_keys(
self, dictionary: Dict[str, Union[List[Any], str, int, bool, float]]
) -> Dict[str, Union[List[Any], str, int, bool, float]]:
"""Interpolates the keys in the configuration data.
"""
for path, v in self.__secret_map.items():
m = self.regex_pattern.search(path)
if m is not None:
provider = self.providers[m.group(1)]
for config_key in v:
secret = provider.get(key=config_key,
path=m.group(2),
filter=m.group(3))
dictionary.update({config_key: secret})

dictionary = self.__flatten(dictionary, sep=self.__delim_char)
return dictionary

def auto_env(self) -> None:
"""Auto env provides sane defaults for using environment variables
Expand Down Expand Up @@ -427,37 +393,18 @@ def __get(
raise TypeError(
f'Provided default is of incorrect type {type(default)}, it should be of type {t}'
)
if key in self.__conf_sets:
val = self.__conf_sets[key]
if not isinstance(val, t):
raise TypeError(
f'Given set key is not of type {t}, but of type {type(val)}'
)
return val
if self.__use_env:
e_key = key.upper().replace(self.__delim_char, '_')
if e_key in os.environ:
try:
return t(os.environ[e_key])
except ValueError as e:
raise TypeError(
f'The environment variable {e_key} could not be converted to type {t}: {e}'
)
if key in self.__conf_data:
if not isinstance(self.__conf_data[key], t):
raise TypeError(
f'The requested key of {key} is not of type {t} (it is {type(self.__conf_data[key])})'
)
return self.__conf_data[key]
if default:
return default
if key in self.__conf_defaults:
val = self.__conf_defaults[key]
if not isinstance(val, t):
raise TypeError(
f'Given default set key is not of type {t}, but of type {type(val)}'
)
return val
split_keys = key.split(self.__delim_char)
consider_keys = list()
for split_key in split_keys:
consider_keys.append(split_key)
joined_key = self.__delim_char.join(consider_keys)
config_val = self._get_config_for_key(key=key,
key_to_search=joined_key,
default=default,
object_type=t)
if config_val is not None:
return config_val

raise ValueError(
f'Given key {key} is not in any configuration and no default is provided'
)
Expand Down Expand Up @@ -597,3 +544,64 @@ def dump(self) -> Text:
ret.update(self.__conf_data)
ret.update(self.__conf_sets)
return str(json.dumps(ret, indent=4))

def _get_config_for_key(
self, key: str, key_to_search: str,
default: Optional[Union[str, int, float, bool, List[Any]]],
object_type: Type[Union[str, int, float, bool, List[Any]]]
) -> Optional[Union[str, int, float, bool, List[Any]]]:
if key_to_search in self.__conf_sets:
val = self.__conf_sets[key_to_search]
if not isinstance(val, object_type):
raise TypeError(
f'Given set key is not of type {object_type}, but of type {type(val)}'
)
return val
if self.__use_env:
e_key = key_to_search.upper().replace(self.__delim_char, '_')
if e_key in os.environ:
try:
return object_type(os.environ[e_key])
except ValueError as e:
raise TypeError(
f'The environment variable {e_key} could not be converted to type {object_type}: {e}'
)

if key_to_search in self.__conf_data:
val = self.__conf_data[key_to_search]
for provider in self.providers.values():
if isinstance(val, str) and val.startswith(provider.scheme):
regex_search = self.regex_pattern.search(val)
if regex_search is not None:
path = regex_search.group(2)
filter_ = regex_search.group(3)
remainder_filter = key[len(key_to_search):]
if len(remainder_filter) > 1:
if filter_ is not None:
filter_ = f".{filter_}{remainder_filter}"

else:
filter_ = remainder_filter

interpolated_val = provider.get(key=val,
path=path,
filter=filter_,
sep=self.__delim_char)
break
else:
interpolated_val = val
if not isinstance(interpolated_val, object_type):
raise TypeError(
f'Given set key is not of type {object_type}, but of type {type(interpolated_val)}'
)
return interpolated_val
if default:
return default
if key_to_search in self.__conf_defaults:
val = self.__conf_defaults[key_to_search]
if not isinstance(val, object_type):
raise TypeError(
f'Given default set key is not of type {object_type}, but of type {type(val)}'
)
return val
return None
12 changes: 10 additions & 2 deletions gestalt/provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Tuple, Dict, Any
from typing import Tuple, Dict, Any, Optional, Union, List


class Provider(metaclass=ABCMeta):
Expand All @@ -15,7 +15,15 @@ def __init__(self, *args: Tuple[Any], **kwargs: Dict[Any, Any]):
pass

@abstractmethod
def get(self, key: str, path: str, filter: str) -> Any:
def get(self, key: str, path: str, filter: str,
sep: Optional[str]) -> Union[str, int, float, bool, List[Any]]:
"""Abstract method to get a value from the provider
"""
pass

@property
@abstractmethod
def scheme(self) -> str:
"""Returns scheme of provider
"""
pass
17 changes: 17 additions & 0 deletions gestalt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import MutableMapping, Text, Any, Union, Dict, List
import collections.abc as collections


def flatten(
d: MutableMapping[Text, Any],
parent_key: str = '',
sep: str = '.'
) -> Dict[Text, Union[List[Any], Text, int, bool, float]]:
items: List[Any] = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
55 changes: 51 additions & 4 deletions gestalt/vault.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime, timedelta
from time import sleep
from gestalt.provider import Provider
import requests
from jsonpath_ng import parse # type: ignore
from typing import Optional, Tuple, Any
from typing import Optional, Tuple, Any, Dict, Union, List
import hvac # type: ignore
import asyncio
import os
Expand All @@ -18,7 +19,8 @@ def __init__(self,
jwt: Optional[str] = None,
url: Optional[str] = os.environ.get("VAULT_ADDR"),
token: Optional[str] = os.environ.get("VAULT_TOKEN"),
verify: Optional[bool] = True) -> None:
verify: Optional[bool] = True,
scheme: str = "ref+vault://") -> None:
"""Initialized vault client and authenticates vault
Args:
Expand All @@ -28,13 +30,17 @@ def __init__(self,
auth_config (HVAC_ClientAuthentication): authenticates the initialized vault client
with role and jwt string from kubernetes
"""
self._scheme: str = scheme
self.dynamic_token_queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=0)
self.kubes_token_queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=0)

self.vault_client = hvac.Client(url=url,
token=token,
cert=cert,
verify=verify)
self._secret_expiry_times: Dict[str, datetime] = dict()
self._secret_values: Dict[str, Union[str, int, float, bool,
List[Any]]] = dict()

try:
self.vault_client.is_authenticated()
Expand Down Expand Up @@ -73,15 +79,31 @@ def __init__(self,
kubernetes_ttl_renew.start()

@retry(RuntimeError, delay=3, tries=3) # type: ignore
def get(self, key: str, path: str, filter: str) -> Any:
def get(
self,
key: str,
path: str,
filter: str,
sep: Optional[str] = "."
) -> Union[str, int, float, bool, List[Any]]:
"""Gets secret from vault
Args:
key (str): key to get secret from
path (str): path to secret
filter (str): filter to apply to secret
sep (str): delimiter used for flattening
Returns:
secret (str): secret
"""
# if the key has been read before and is not a TTL secret
if key in self._secret_values and key not in self._secret_expiry_times:
return self._secret_values[key]

# if the secret can expire but hasn't expired yet
if key in self._secret_expiry_times and not self._is_secret_expired(
key):
return self._secret_values[key]

try:
response = self.vault_client.read(path)
if response is None:
Expand Down Expand Up @@ -109,7 +131,28 @@ def get(self, key: str, path: str, filter: str) -> Any:
returned_value_from_secret = match[0].value
if returned_value_from_secret == "":
raise RuntimeError("Gestalt Error: Empty secret!")
return returned_value_from_secret

self._secret_values[key] = returned_value_from_secret
if "ttl" in requested_data:
self._set_secrets_ttl(requested_data, key)

return returned_value_from_secret # type: ignore

def _is_secret_expired(self, key: str) -> bool:
now = datetime.now()
secret_expires_dt = self._secret_expiry_times[key]
is_expired = now >= secret_expires_dt
return is_expired

def _set_secrets_ttl(self, requested_data: Dict[str, Any],
key: str) -> None:
last_vault_rotation_str = requested_data["last_vault_rotation"].split(
".")[0] # to the nearest second
last_vault_rotation_dt = datetime.strptime(last_vault_rotation_str,
'%Y-%m-%dT%H:%M:%S')
ttl = requested_data["ttl"]
secret_expires_dt = last_vault_rotation_dt + timedelta(seconds=ttl)
self._secret_expiry_times[key] = secret_expires_dt

async def worker(self, token_queue: Any) -> None:
"""
Expand Down Expand Up @@ -138,3 +181,7 @@ async def worker(self, token_queue: Any) -> None:
"Gestalt Error: Gestalt couldn't connect to Vault")
except Exception as err:
raise RuntimeError(f"Gestalt Error: {err}")

@property
def scheme(self) -> str:
return self._scheme
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def readme():
reqs_list = list(map(lambda x: x.rstrip(), reqs))

setup(name='gestalt-cfg',
version='3.2.0',
version='3.3.0',
description='A sensible configuration library for Python',
long_description=readme(),
long_description_content_type="text/markdown",
Expand Down
Loading

0 comments on commit e10bfcb

Please sign in to comment.