Skip to content
This repository has been archived by the owner on Jun 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #51 from IdentityPython/develop
Browse files Browse the repository at this point in the history
Configuration refactoring
  • Loading branch information
peppelinux authored Nov 10, 2021
2 parents beaba7c + fc0e9ef commit 1c422a3
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 101 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run_tests(self):
"Programming Language :: Python :: 3.8",
"Topic :: Software Development :: Libraries :: Python Modules"],
install_requires=[
"cryptojwt>=1.5.0",
"cryptojwt>=1.6.0",
"pyOpenSSL",
"filelock>=3.0.12",
'pyyaml>=5.1.2'
Expand Down
70 changes: 35 additions & 35 deletions src/oidcmsg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__author__ = "Roland Hedberg"
__version__ = "1.5.0"
__version__ = "1.5.1"

import os
from typing import Dict
Expand Down Expand Up @@ -34,37 +34,37 @@ def proper_path(path):
return path


def add_base_path(conf: Dict[str, str], item_paths: dict, base_path: str):
"""
This is for adding a base path to path specified in a configuration
:param conf: Configuration
:param item_paths: The relative item path
:param base_path: An absolute path to add to the relative
"""
for section, items in item_paths.items():
if section == "":
part = conf
else:
part = conf.get(section)

if part:
if isinstance(items, list):
for attr in items:
_path = part.get(attr)
if _path:
if _path.startswith("/"):
continue
elif _path == "":
part[attr] = "./" + _path
else:
part[attr] = os.path.join(base_path, _path)
elif items is None:
if part.startswith("/"):
continue
elif part == "":
conf[section] = "./"
else:
conf[section] = os.path.join(base_path, part)
else: # Assume items is dictionary like
add_base_path(part, items, base_path)
# def add_base_path(conf: Dict[str, str], item_paths: dict, base_path: str):
# """
# This is for adding a base path to path specified in a configuration
#
# :param conf: Configuration
# :param item_paths: The relative item path
# :param base_path: An absolute path to add to the relative
# """
# for section, items in item_paths.items():
# if section == "":
# part = conf
# else:
# part = conf.get(section)
#
# if part:
# if isinstance(items, list):
# for attr in items:
# _path = part.get(attr)
# if _path:
# if _path.startswith("/"):
# continue
# elif _path == "":
# part[attr] = "./" + _path
# else:
# part[attr] = os.path.join(base_path, _path)
# elif items is None:
# if part.startswith("/"):
# continue
# elif part == "":
# conf[section] = "./"
# else:
# conf[section] = os.path.join(base_path, part)
# else: # Assume items is dictionary like
# add_base_path(part, items, base_path)
167 changes: 124 additions & 43 deletions src/oidcmsg/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from oidcmsg.logging import configure_logging
from oidcmsg.util import load_yaml_config

DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir',
'private_path', 'public_path', 'db_file']
DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename',
'private_path', 'public_path', 'db_file', 'jwks_file']

URIS = ["redirect_uris", 'issuer', 'base_url']
DEFAULT_DIR_ATTRIBUTE_NAMES = ['template_dir']


def lower_or_upper(config, param, default=None):
Expand All @@ -22,17 +22,31 @@ def lower_or_upper(config, param, default=None):
return res


def add_base_path(conf: dict, base_path: str, file_attributes: List[str]):
def add_path_to_filename(filename, base_path):
if filename == "" or filename.startswith("/"):
return filename
else:
return os.path.join(base_path, filename)


def add_path_to_directory_name(directory_name, base_path):
if directory_name.startswith("/"):
return directory_name
elif directory_name == "":
return "./" + directory_name
else:
return os.path.join(base_path, directory_name)


def add_base_path(conf: dict, base_path: str, attributes: List[str], attribute_type: str = "file"):
for key, val in conf.items():
if key in file_attributes:
if val.startswith("/"):
continue
elif val == "":
conf[key] = "./" + val
if key in attributes:
if attribute_type == "file":
conf[key] = add_path_to_filename(val, base_path)
else:
conf[key] = os.path.join(base_path, val)
conf[key] = add_path_to_directory_name(val, base_path)
if isinstance(val, dict):
conf[key] = add_base_path(val, base_path, file_attributes)
conf[key] = add_base_path(val, base_path, attributes, attribute_type)

return conf

Expand All @@ -53,41 +67,71 @@ def set_domain_and_port(conf: dict, uris: List[str], domain: str, port: int):
return conf


class Base:
class Base(dict):
""" Configuration base class """

parameter = {}
uris = ["issuer", "base_url"]

def __init__(self,
conf: Dict,
base_path: str = '',
file_attributes: Optional[List[str]] = None,
dir_attributes: Optional[List[str]] = None,
domain: Optional[str] = "",
port: Optional[int] = 0,
):
dict.__init__(self)
self._file_attributes = file_attributes or DEFAULT_FILE_ATTRIBUTE_NAMES
self._dir_attributes = dir_attributes or DEFAULT_DIR_ATTRIBUTE_NAMES

if file_attributes is None:
file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES

if base_path and file_attributes:
if base_path:
# this adds a base path to all paths in the configuration
add_base_path(conf, base_path, file_attributes)
if self._file_attributes:
add_base_path(conf, base_path, self._file_attributes, "file")
if self._dir_attributes:
add_base_path(conf, base_path, self._dir_attributes, "dir")

def __getitem__(self, item):
if item in self.__dict__:
return self.__dict__[item]
# entity info
self.domain = domain or conf.get("domain", "127.0.0.1")
self.port = port or conf.get("port", 80)

self.conf = set_domain_and_port(conf, self.uris, self.domain, self.port)

def __getattr__(self, item, default=None):
if item in self:
return self[item]
else:
raise KeyError
return default

def get(self, item, default=None):
return getattr(self, item, default)
def __setattr__(self, key, value):
if key in self:
raise KeyError('{} has already been set'.format(key))
super(Base, self).__setitem__(key, value)

def __setitem__(self, key, value):
if key in self:
raise KeyError('{} has already been set'.format(key))
super(Base, self).__setitem__(key, value)

def __contains__(self, item):
return item in self.__dict__
def get(self, item, default=None):
return self.__getattr__(item, default)

def items(self):
for key in self.__dict__:
for key in self.keys():
if key.startswith('__') and key.endswith('__'):
continue
yield key, getattr(self, key)

def extend(self, entity_conf, conf, base_path, file_attributes, domain, port):
def extend(self,
conf: Dict,
base_path: str,
domain: str,
port: int,
entity_conf: Optional[List[dict]] = None,
file_attributes: Optional[List[str]] = None,
dir_attributes: Optional[List[str]] = None,
):
for econf in entity_conf:
_path = econf.get("path")
_cnf = conf
Expand All @@ -98,11 +142,49 @@ def extend(self, entity_conf, conf, base_path, file_attributes, domain, port):
_cls = econf["class"]
setattr(self, _attr,
_cls(_cnf, base_path=base_path, file_attributes=file_attributes,
domain=domain, port=port))
domain=domain, port=port, dir_attributes=dir_attributes))

def complete_paths(self, conf: Dict, keys: List[str], default_config: Dict, base_path: str):
for key in keys:
_val = conf.get(key)
if _val is None and key in default_config:
_val = default_config[key]
if key in self._file_attributes:
_val = add_path_to_filename(_val, base_path)
elif key in self._dir_attributes:
_val = add_path_to_directory_name(_val, base_path)
if not _val:
continue

setattr(self, key, _val)

def format(self, conf, base_path: str, domain: str, port: int,
file_attributes: Optional[List[str]] = None,
dir_attributes: Optional[List[str]] = None) -> None:
"""
Formats parts of the configuration. That includes replacing the strings {domain} and {port}
with the used domain and port and making references to files and directories absolute
rather then relative. The formatting is done in place.
:param dir_attributes:
:param conf: The configuration part
:param base_path: The base path used to make file/directory refrences absolute
:param file_attributes: Attribute names that refer to files or directories.
:param domain: The domain name
:param port: The port used
"""
if isinstance(conf, dict):
if file_attributes:
add_base_path(conf, base_path, file_attributes, attribute_type="file")
if dir_attributes:
add_base_path(conf, base_path, dir_attributes, attribute_type="dir")
if isinstance(conf, dict):
set_domain_and_port(conf, self.uris, domain=domain, port=port)


class Configuration(Base):
"""Server Configuration"""
"""Entity Configuration Base"""
uris = ["redirect_uris", 'issuer', 'base_url', 'server_name']

def __init__(self,
conf: Dict,
Expand All @@ -111,27 +193,24 @@ def __init__(self,
file_attributes: Optional[List[str]] = None,
domain: Optional[str] = "",
port: Optional[int] = 0,
dir_attributes: Optional[List[str]] = None,
):
Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes)
Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes,
dir_attributes=dir_attributes, domain=domain, port=port)

log_conf = conf.get('logging')
log_conf = self.conf.get('logging')
if log_conf:
self.logger = configure_logging(config=log_conf).getChild(__name__)
else:
self.logger = logging.getLogger('oidcrp')

self.web_conf = lower_or_upper(conf, "webserver")

# entity info
if not domain:
domain = conf.get("domain", "127.0.0.1")

if not port:
port = conf.get("port", 80)
self.web_conf = lower_or_upper(self.conf, "webserver")

if entity_conf:
self.extend(entity_conf=entity_conf, conf=conf, base_path=base_path,
file_attributes=file_attributes, domain=domain, port=port)
self.extend(conf=self.conf, base_path=base_path,
domain=self.domain, port=self.port, entity_conf=entity_conf,
file_attributes=self._file_attributes,
dir_attributes=self._dir_attributes)


def create_from_config_file(cls,
Expand All @@ -140,7 +219,9 @@ def create_from_config_file(cls,
entity_conf: Optional[List[dict]] = None,
file_attributes: Optional[List[str]] = None,
domain: Optional[str] = "",
port: Optional[int] = 0):
port: Optional[int] = 0,
dir_attributes: Optional[List[str]] = None
):
if filename.endswith(".yaml"):
"""Load configuration as YAML"""
_cnf = load_yaml_config(filename)
Expand All @@ -158,4 +239,4 @@ def create_from_config_file(cls,
return cls(_cnf,
entity_conf=entity_conf,
base_path=base_path, file_attributes=file_attributes,
domain=domain, port=port)
domain=domain, port=port, dir_attributes=dir_attributes)
1 change: 1 addition & 0 deletions tests/server_conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"httpc_params": {
"verify": false
},
"hash_seed": "MustangSally",
"keys": {
"private_path": "private/jwks.json",
"key_defs": [
Expand Down
6 changes: 0 additions & 6 deletions tests/test_03_time_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,3 @@ def test_later_than_str():
b = in_a_while(seconds=20)
assert later_than(b, a)
assert later_than(a, b) is False


def test_utc_time():
utc_now = utc_time_sans_frac()
expected_utc_now = int((datetime.utcnow() - datetime(1970, 1, 1)).total_seconds())
assert utc_now == expected_utc_now
23 changes: 7 additions & 16 deletions tests/test_20_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from oidcmsg.configure import Configuration
from oidcmsg.configure import create_from_config_file
from oidcmsg.configure import lower_or_upper
from oidcmsg.configure import set_domain_and_port
from oidcmsg.util import rndstr

_dirname = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -26,23 +25,14 @@ def __init__(self,
domain: Optional[str] = "",
port: Optional[int] = 0,
file_attributes: Optional[List[str]] = None,
uris: Optional[List[str]] = None
uris: Optional[List[str]] = None,
dir_attributes: Optional[List[str]] = None
):

Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes)
Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes,
dir_attributes=dir_attributes)

self.keys = lower_or_upper(conf, 'keys')

if not domain:
domain = conf.get("domain", "127.0.0.1")

if not port:
port = conf.get("port", 80)

if uris is None:
uris = URIS
conf = set_domain_and_port(conf, uris, domain, port)

self.hash_seed = lower_or_upper(conf, 'hash_seed', rndstr(32))
self.base_url = conf.get("base_url")
self.httpc_params = conf.get("httpc_params", {"verify": False})
Expand Down Expand Up @@ -74,5 +64,6 @@ def test_entity_config(filename):
assert configuration.httpc_params == {"verify": False}
assert configuration['keys']
ni = dict(configuration.items())
assert len(ni) == 4
assert set(ni.keys()) == {'keys', 'base_url', 'httpc_params', 'hash_seed'}
assert len(ni) == 9
assert set(ni.keys()) == {'base_url', '_dir_attributes', '_file_attributes', 'hash_seed',
'httpc_params', 'keys', 'conf', 'port', 'domain'}

0 comments on commit 1c422a3

Please sign in to comment.