diff --git a/frontend/checkFrontend.py b/frontend/checkFrontend.py index d1f06d5d9..d15f7abba 100755 --- a/frontend/checkFrontend.py +++ b/frontend/checkFrontend.py @@ -3,20 +3,32 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""Check if a glideinFrontend is running +"""Check if a glideinFrontend is running. -Arguments: - $1 = work_dir - $2 = (optional) run mode (defaults to "run") +This script checks whether a glideinFrontend is running in the specified +working directory. It optionally allows specifying a run mode. -Exit code: - 0 - Running - 1 - Not running anything - 2 - Not running my types, but another type is indeed running +Usage: + python script_name.py [run_mode] + +Args: + work_dir (str): The working directory to check for a running glideinFrontend. + run_mode (str, optional): The desired run mode to check for. Defaults to "run". + +Exit Codes: + 0: A glideinFrontend of the specified type is running. + 1: No glideinFrontend is running. + 2: A glideinFrontend of a different type is running. + +Examples: + Check for a glideinFrontend running in "my_work_dir" with the default mode: + $ python check_glidein_frontend.py my_work_dir + + Check for a glideinFrontend running in "my_work_dir" with a specific mode: + $ python check_glidein_frontend.py my_work_dir run """ import sys - from glideinwms.frontend import glideinFrontendPidLib if __name__ == "__main__": @@ -28,7 +40,7 @@ sys.exit(1) if action_type is None: - # if not defined, assume it is the standard running type + # If not defined, assume it is the standard running type action_type = "run" if len(sys.argv) >= 3: diff --git a/frontend/glideinFrontend.py b/frontend/glideinFrontend.py index c7007ba75..34c2048a5 100755 --- a/frontend/glideinFrontend.py +++ b/frontend/glideinFrontend.py @@ -3,12 +3,17 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This is the main of the glideinFrontend +"""glideinFrontend Main Script. -Arguments: - $1 = work_dir -""" +This script serves as the entry point for managing the glideinFrontend processes, handling group operations, +failure monitoring, and performance aggregation. + +Usage: + python glidein_frontend_main.py +Args: + work_dir (str): The working directory for the frontend. +""" import fcntl import os @@ -33,35 +38,64 @@ ############################################################ # KEL remove this method and just call the monitor aggregator method directly below? we don't use the results def aggregate_stats(): + """Aggregate monitoring data using the monitor aggregator. + + Returns: + dict: Aggregated statistics for the frontend. + """ return glideinFrontendMonitorAggregator.aggregateStatus() ############################################################ class FailureCounter: + """Tracks and counts failures within a specific time window. + + Attributes: + my_name (str): Name or identifier for the failure counter. + max_lifetime (int): Time window in seconds for retaining failure records. + failure_times (list): List of timestamps for failures. + """ + def __init__(self, my_name, max_lifetime): + """ + Args: + my_name (str): Name or identifier for the failure counter. + max_lifetime (int): Time window in seconds for retaining failure records. + """ self.my_name = my_name self.max_lifetime = max_lifetime - self.failure_times = [] def add_failure(self, when=None): + """Record a failure event. + + Args: + when (float, optional): Timestamp of the failure. Defaults to the current time. + """ if when is None: when = time.time() - self.clean_old() self.failure_times.append(when) def get_failures(self): + """Retrieve a list of failures within the retention window. + + Returns: + list: A list of timestamps for failures. + """ self.clean_old() return self.failure_times def count_failures(self): - return len(self.get_failures()) + """Count the number of failures within the retention window. - # INTERNAL + Returns: + int: The number of recorded failures. + """ + return len(self.get_failures()) - # clean out any old records def clean_old(self): + """Remove outdated failure records that exceed the retention window.""" min_time = time.time() - self.max_lifetime while self.failure_times and (self.failure_times[0] < min_time): # Assuming they are ordered @@ -70,10 +104,20 @@ def clean_old(self): ############################################################ def spawn_group(work_dir, group_name, action): + """Spawn a subprocess for a specific group. + + Args: + work_dir (str): The working directory for the frontend. + group_name (str): The name of the group to process. + action (str): The action to perform for the group. + + Returns: + subprocess.Popen: The spawned child process. + """ command_list = [sys.executable, glideinFrontendElement.__file__, str(os.getpid()), work_dir, group_name, action] child = subprocess.Popen(command_list, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - # set it in non blocking mode + # Set stdout and stderr to non-blocking mode for fd in (child.stdout.fileno(), child.stderr.fileno()): fl = fcntl.fcntl(fd, fcntl.F_GETFL) fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) @@ -83,28 +127,49 @@ def spawn_group(work_dir, group_name, action): ############################################################ def poll_group_process(group_name, child): - # empty stdout and stderr + """Poll the status of a group's subprocess. + + Args: + group_name (str): The name of the group being processed. + child (subprocess.Popen): The child process to poll. + + Returns: + int or None: The exit code of the process if it has exited, or None if it is still running. + """ + # Empty stdout and stderr try: tempOut = child.stdout.read() if tempOut: logSupport.log.info(f"[{group_name}]: {tempOut}") except OSError: - pass # ignore + pass # Ignore errors + try: tempErr = child.stderr.read() - if tempOut: + if tempErr: logSupport.log.warning(f"[{group_name}]: {tempErr}") except OSError: - pass # ignore + pass # Ignore errors return child.poll() ############################################################ - - -# return the list of (group,walltime) pairs def spawn_iteration(work_dir, frontendDescript, groups, max_active, failure_dict, max_failures, action): + """Execute a full iteration for managing groups and monitoring failures. + + Args: + work_dir (str): The working directory for the frontend. + frontendDescript (FrontendDescript): The frontend configuration descriptor. + groups (list): A list of group names to process. + max_active (int): The maximum number of active groups allowed simultaneously. + failure_dict (dict): A dictionary mapping group names to their respective FailureCounter objects. + max_failures (int): The maximum number of failures allowed before aborting. + action (str): The action to perform for the iteration. + + Returns: + list: A list of tuples containing group names and their respective wall times. + """ childs = {} for group_name in groups: @@ -254,7 +319,21 @@ def spawn_iteration(work_dir, frontendDescript, groups, max_active, failure_dict ############################################################ def spawn_cleanup(work_dir, frontendDescript, groups, frontend_name, ha_mode): - # Invalidate glidefrontendmonitor classad + """Perform cleanup tasks for frontend processes. + + This function invalidates glidefrontendmonitor classads and performs deadvertising + for all groups. + + Args: + work_dir (str): The working directory. + frontendDescript (FrontendDescript): The frontend descriptor object. + groups (list): List of groups to clean up. + frontend_name (str): The name of the frontend. + ha_mode (str): High-availability mode. + + Returns: + None + """ try: set_frontend_htcondor_env(work_dir, frontendDescript) fm_advertiser = glideinFrontendInterface.FrontendMonitorClassadAdvertiser() @@ -301,6 +380,24 @@ def spawn( restart_interval, restart_attempts, ): + """Spawn and manage frontend groups in master/slave modes. + + This function manages the spawning and monitoring of frontend groups + in a high-availability (HA) environment, supporting master and slave roles. + + Args: + sleep_time (float): Time (in seconds) to sleep between iterations. + advertize_rate (int): Rate at which to advertise classads. + work_dir (str): The working directory for the frontend. + frontendDescript (FrontendDescript): The frontend descriptor object. + groups (list): List of groups to manage. + max_parallel_workers (int): Maximum number of parallel workers. + restart_interval (int): Interval (in seconds) before attempting a restart. + restart_attempts (int): Maximum number of restart attempts. + + Returns: + None + """ num_groups = len(groups) # TODO: Get the ha_check_interval from the config @@ -382,12 +479,17 @@ def spawn( ############################################################ def shouldHibernate(frontendDescript, work_dir, ha, mode, groups): - """ - Check if the frontend is running in HA mode. If run in master mode never - hibernate. If run in slave mode, hiberate if master is active. + """Determine if the frontend should enter hibernation. + + Args: + frontendDescript (FrontendDescript): The frontend descriptor object. + work_dir (str): The working directory for the frontend. + ha (dict): High-availability settings. + mode (str): Current operating mode ("master" or "slave"). + groups (list): List of groups being managed. - @rtype: bool - @return: True if we should hibernate else False + Returns: + bool: True if the frontend should hibernate, False otherwise. """ servicePerformance.startPerfMetricEvent("frontend", "ha_check") @@ -442,19 +544,41 @@ def shouldHibernate(frontendDescript, work_dir, ha, mode, groups): def clear_diskcache_dir(work_dir): - """Clear the cache by removing the directory used for the cachedir, and recreate it.""" + """Clear the disk cache directory and recreate it. + + This function removes the existing cache directory used by the frontend, + handles any errors if the directory does not exist, and recreates it. + + Args: + work_dir (str): The working directory for the frontend. + + Raises: + OSError: If an error occurs while attempting to remove the cache directory. + """ cache_dir = os.path.join(work_dir, glideinFrontendConfig.frontendConfig.cache_dir) try: shutil.rmtree(cache_dir) except OSError as ose: - if ose.errno != 2: # errno 2 is ok, dir is missing. Maybe it's the first execution? - logSupport.log.exception("Error removing cache directory %s" % cache_dir) + if ose.errno != 2: # errno 2 is okay (directory missing) + logSupport.log.exception(f"Error removing cache directory {cache_dir}") raise os.mkdir(cache_dir) def set_frontend_htcondor_env(work_dir, frontendDescript, element=None): - # Collector DN is only in the group's mapfile. Just get first one. + """Set the HTCondor environment for the frontend. + + Configures the environment variables required for HTCondor operations + based on the frontend description and element. + + Args: + work_dir (str): The working directory for the frontend. + frontendDescript (FrontendDescript): The frontend descriptor object. + element (Element, optional): The specific group element. Defaults to None. + + Returns: + None + """ groups = frontendDescript.data["Groups"].split(",") if groups: if element is None: @@ -468,43 +592,85 @@ def set_frontend_htcondor_env(work_dir, frontendDescript, element=None): def set_env(env): - for var in env: - os.environ[var] = env[var] + """Set the environment variables from a dictionary. + + Args: + env (dict): Dictionary of environment variables and their values. + + Returns: + None + """ + for var, value in env.items(): + os.environ[var] = value def clean_htcondor_env(): + """Remove HTCondor-related environment variables. + + This function clears specific environment variables used by HTCondor + to prevent conflicts with other processes. + + Returns: + None + """ for v in ("CONDOR_CONFIG", "_CONDOR_CERTIFICATE_MAPFILE", "X509_USER_PROXY"): if os.environ.get(v): del os.environ[v] -############################################################ +def spawn_removal(work_dir, frontendDescript, groups, max_parallel_workers, removal_action): + """Perform group removal operations. + This function handles removing groups based on the specified removal action. -def spawn_removal(work_dir, frontendDescript, groups, max_parallel_workers, removal_action): - failure_dict = {} - for group in groups: - failure_dict[group] = FailureCounter(group, 3600) + Args: + work_dir (str): The working directory for the frontend. + frontendDescript (FrontendDescript): The frontend descriptor object. + groups (list): List of group names to process. + max_parallel_workers (int): Maximum number of parallel workers. + removal_action (str): The specific removal action to perform. + Returns: + None + """ + failure_dict = {group: FailureCounter(group, 3600) for group in groups} spawn_iteration(work_dir, frontendDescript, groups, max_parallel_workers, failure_dict, 1, removal_action) -############################################################ def cleanup_environ(): + """Clean up environment variables. + + Removes environment variables related to CONDOR and X509 to ensure + a clean execution environment. + + Returns: + None + """ for val in list(os.environ.keys()): val_low = val.lower() - if val_low[:8] == "_condor_": - # remove any CONDOR environment variables - # don't want any surprises + if val_low.startswith("_condor_"): del os.environ[val] - elif val_low[:5] == "x509_": - # remove any X509 environment variables - # don't want any surprises + elif val_low.startswith("x509_"): del os.environ[val] -############################################################ def main(work_dir, action): + """Main entry point for the glideinFrontend. + + This function initializes logging, processes configuration, and starts + the frontend based on the specified action. + + Args: + work_dir (str): The working directory for the frontend. + action (str): The action to perform (e.g., "run", "removeIdle"). + + Raises: + ValueError: If an unknown action is specified. + Exception: For any errors during initialization or processing. + + Returns: + None + """ startup_time = time.time() glideinFrontendConfig.frontendConfig.frontend_descript_file = os.path.join( @@ -512,20 +678,17 @@ def main(work_dir, action): ) frontendDescript = glideinFrontendConfig.FrontendDescript(work_dir) - # the log dir is shared between the frontend main and the groups, so use a subdir + # Configure logging logSupport.log_dir = os.path.join(frontendDescript.data["LogDir"], "frontend") - - # Configure frontend process logging logSupport.log = logSupport.get_logger_with_handlers("frontend", logSupport.log_dir, frontendDescript.data) logSupport.log.info("Logging initialized") - logSupport.log.debug("Frontend startup time: %s" % str(startup_time)) + logSupport.log.debug(f"Frontend startup time: {startup_time}") clear_diskcache_dir(work_dir) try: cleanup_environ() - # we use a dedicated config... ignore the system-wide os.environ["CONDOR_CONFIG"] = frontendDescript.data["CondorConfig"] sleep_time = int(frontendDescript.data["LoopDelay"]) @@ -535,7 +698,6 @@ def main(work_dir, action): restart_interval = int(frontendDescript.data["RestartInterval"]) groups = sorted(frontendDescript.data["Groups"].split(",")) - glideinFrontendMonitorAggregator.monitorAggregatorConfig.config_frontend( os.path.join(work_dir, "monitor"), groups ) @@ -544,62 +706,61 @@ def main(work_dir, action): raise glideinFrontendMonitoring.write_frontend_descript_xml(frontendDescript, os.path.join(work_dir, "monitor/")) + logSupport.log.info(f"Enabled groups: {groups}") - logSupport.log.info("Enabled groups: %s" % groups) - - # create lock file + # Create lock file pid_obj = glideinFrontendPidLib.FrontendPidSupport(work_dir) - # start try: pid_obj.register(action) except glideinFrontendPidLib.pidSupport.AlreadyRunning as err: pid_obj.load_registered() logSupport.log.exception( - "Failed starting Frontend with action %s. Instance with pid %s is aready running for action %s. Exception during pid registration: %s" - % (action, pid_obj.mypid, str(pid_obj.action_type), err) + f"Failed starting Frontend with action {action}. " + f"Instance with pid {pid_obj.mypid} is already running for action {pid_obj.action_type}. " + f"Exception during pid registration: {err}" ) raise try: - try: - if action == "run": - spawn( - sleep_time, - advertize_rate, - work_dir, - frontendDescript, - groups, - max_parallel_workers, - restart_interval, - restart_attempts, - ) - elif action in ( - "removeWait", - "removeIdle", - "removeAll", - "removeWaitExcess", - "removeIdleExcess", - "removeAllExcess", - ): - spawn_removal(work_dir, frontendDescript, groups, max_parallel_workers, action) - else: - raise ValueError("Unknown action: %s" % action) - except KeyboardInterrupt: - logSupport.log.info("Received signal...exit") - except HUPException: - logSupport.log.info("Received SIGHUP, reload config") - pid_obj.relinquish() - os.execv( - os.path.join(glideinFrontendLib.__file__, "../creation/reconfig_frontend"), - ["reconfig_frontend", "-sighupreload", "-xml", "/etc/gwms-frontend/frontend.xml"], + if action == "run": + spawn( + sleep_time, + advertize_rate, + work_dir, + frontendDescript, + groups, + max_parallel_workers, + restart_interval, + restart_attempts, ) - except Exception: - logSupport.log.exception("Exception occurred trying to spawn: ") + elif action in ( + "removeWait", + "removeIdle", + "removeAll", + "removeWaitExcess", + "removeIdleExcess", + "removeAllExcess", + ): + spawn_removal(work_dir, frontendDescript, groups, max_parallel_workers, action) + else: + raise ValueError(f"Unknown action: {action}") + except KeyboardInterrupt: + logSupport.log.info("Received signal...exit") + except HUPException: + logSupport.log.info("Received SIGHUP, reload config") + pid_obj.relinquish() + os.execv( + os.path.join(glideinFrontendLib.__file__, "../creation/reconfig_frontend"), + ["reconfig_frontend", "-sighupreload", "-xml", "/etc/gwms-frontend/frontend.xml"], + ) + except Exception: + logSupport.log.exception("Exception occurred trying to spawn: ") finally: pid_obj.relinquish() + ############################################################ # # S T A R T U P diff --git a/frontend/glideinFrontendDowntimeLib.py b/frontend/glideinFrontendDowntimeLib.py index d58555209..941fb0da9 100644 --- a/frontend/glideinFrontendDowntimeLib.py +++ b/frontend/glideinFrontendDowntimeLib.py @@ -16,33 +16,96 @@ # if end_time is None, the downtime does not have a set expiration # (i.e. it runs forever) class DowntimeFile: + """Manages downtime periods stored in a file. + + Attributes: + fname (str): The filename of the downtime file. + """ + def __init__(self, fname): + """ + Args: + fname (str): Path to the downtime file. + """ self.fname = fname - # if check_time==None, use current time def checkDowntime(self, check_time=None): - rtn = checkDowntime(self.fname, check_time) - return rtn + """Check if the specified time falls within a downtime period. + + Args: + check_time (int, optional): The time to check, in UNIX timestamp format. + Defaults to the current time. + + Returns: + bool: True if the specified time is within a downtime period, False otherwise. + """ + return checkDowntime(self.fname, check_time) - # add a scheduled downtime def addPeriod(self, start_time, end_time, create_if_empty=True): + """Add a scheduled downtime period to the file. + + Args: + start_time (int): Start time of the downtime, in UNIX timestamp format. + end_time (int): End time of the downtime, in UNIX timestamp format. Use `None` + for an indefinite downtime. + create_if_empty (bool): Whether to create the file if it doesn't exist. + + Returns: + int: 0 if the downtime was added successfully. + """ return addPeriod(self.fname, start_time, end_time, create_if_empty) - # start a downtime that we don't know when it will end # if start_time==None, use current time def startDowntime(self, start_time=None, end_time=None, create_if_empty=True): + """Start an indefinite downtime. + + Args: + start_time (int, optional): Start time of the downtime, in UNIX timestamp format. + Defaults to the current time. + end_time (int, optional): End time of the downtime, in UNIX timestamp format. + Use `None` for indefinite downtime. + create_if_empty (bool): Whether to create the file if it doesn't exist. + + Returns: + int: 0 if the downtime was started successfully. + """ if start_time is None: start_time = int(time.time()) return self.addPeriod(start_time, end_time, create_if_empty) - # end a downtime (not a scheduled one) # if end_time==None, use current time def endDowntime(self, end_time=None): + """End the current downtime. + + Args: + end_time (int, optional): End time of the downtime, in UNIX timestamp format. + Defaults to the current time. + + Returns: + int: The number of downtimes ended. + """ return endDowntime(self.fname, end_time) def printDowntime(self, check_time=None): + """Print active downtime periods. + + Args: + check_time (int, optional): The time to check against, in UNIX timestamp format. + Defaults to the current time. + + Returns: + None + """ return printDowntime(self.fname, check_time) - # return a list of downtime periods (utimes) a value of None idicates "forever" for example: [(1215339200,1215439170),(1215439271,None)] def read(self, raise_on_error=False): + """Read and parse downtime periods from the file. + + Args: + raise_on_error (bool): Whether to raise an exception on errors. + + Returns: + list: A list of tuples representing downtime periods. Each tuple contains + (start_time, end_time), where `None` for end_time represents an indefinite period. + """ return read(self.fname, raise_on_error) @@ -55,6 +118,16 @@ def read(self, raise_on_error=False): # a value of None idicates "forever" # for example: [(1215339200,1215439170),(1215439271,None)] def read(fname, raise_on_error=False): + """Read downtime periods from a file. + + Args: + fname (str): Path to the downtime file. + raise_on_error (bool): Whether to raise an exception on errors. + + Returns: + list: A list of tuples representing downtime periods. Each tuple contains + (start_time, end_time), where `None` for end_time represents an indefinite period. + """ try: with open(fname) as fd: fcntl.flock(fd, fcntl.LOCK_SH | fcntl.LOCK_NB) @@ -111,6 +184,16 @@ def read(fname, raise_on_error=False): # if check_time==None, use current time def checkDowntime(fname, check_time=None): + """Check if a time falls within any downtime periods. + + Args: + fname (str): Path to the downtime file. + check_time (int, optional): The time to check, in UNIX timestamp format. + Defaults to the current time. + + Returns: + bool: True if the time is within a downtime period, False otherwise. + """ if check_time is None: check_time = int(time.time()) @@ -131,6 +214,17 @@ def checkDowntime(fname, check_time=None): # just insert a new line with start time and end time def addPeriod(fname, start_time, end_time, create_if_empty=True): + """Add a new downtime period to the file. + + Args: + fname (str): Path to the downtime file. + start_time (int): Start time of the downtime, in UNIX timestamp format. + end_time (int): End time of the downtime, in UNIX timestamp format. Use `None` for indefinite. + create_if_empty (bool): Whether to create the file if it doesn't exist. + + Returns: + int: 0 if the downtime period was added successfully. + """ exists = os.path.isfile(fname) if (not exists) and (not create_if_empty): raise OSError("[Errno 2] No such file or directory: '%s'" % fname) @@ -155,6 +249,16 @@ def addPeriod(fname, start_time, end_time, create_if_empty=True): # end a downtime (not a scheduled one) # if end_time==None, use current time def endDowntime(fname, end_time=None): + """End the current downtime by updating the file. + + Args: + fname (str): Path to the downtime file. + end_time (int, optional): End time of the downtime, in UNIX timestamp format. + Defaults to the current time. + + Returns: + int: Number of downtimes ended. + """ if end_time is None: end_time = int(time.time()) @@ -229,6 +333,16 @@ def endDowntime(fname, end_time=None): def printDowntime(fname, check_time=None): + """Print downtime periods currently in effect. + + Args: + fname (str): Path to the downtime file. + check_time (int, optional): The time to check against, in UNIX timestamp format. + Defaults to the current time. + + Returns: + None + """ if check_time is None: check_time = int(time.time()) diff --git a/lib/classadSupport.py b/lib/classadSupport.py index cf6351532..b6b0bc6c7 100644 --- a/lib/classadSupport.py +++ b/lib/classadSupport.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This module describes base classes for classads and advertisers -""" +"""This module describes base classes for classads and advertisers.""" import os import time @@ -18,12 +17,12 @@ class Classad: """Base class describing a classad.""" def __init__(self, adType, advertiseCmd, invalidateCmd): - """Constructor + """Constructor. Args: - adType (str): Type of the classad - advertiseCmd (str): Condor update-command to advertise this classad - invalidateCmd (str): Condor update-command to invalidate this classad + adType (str): Type of the classad. + advertiseCmd (str): Condor update-command to advertise this classad. + invalidateCmd (str): Condor update-command to invalidate this classad. """ self.adType = adType self.adAdvertiseCmd = advertiseCmd @@ -41,11 +40,11 @@ def __init__(self, adType, advertiseCmd, invalidateCmd): self.adParams["GlideinWMSVersion"] = "UNKNOWN" def update(self, params_dict, prefix=""): - """Update or Add ClassAd attributes + """Update or add ClassAd attributes. Args: - params_dict: new attributes - prefix: prefix used for the attribute names (Default: "") + params_dict (dict): New attributes. + prefix (str): Prefix used for the attribute names. Defaults to "". """ for k, v in list(params_dict.items()): if isinstance(v, int): @@ -56,14 +55,14 @@ def update(self, params_dict, prefix=""): self.adParams[f"{prefix}{k}"] = "%s" % escaped_v def writeToFile(self, filename, append=True): - """Write a ClassAd to file, adding a blank line if in append mode to separate the ClassAd + """Write a ClassAd to a file, adding a blank line if in append mode to separate the ClassAd. There can be no empty line at the beginning of the file: https://htcondor-wiki.cs.wisc.edu/index.cgi/tktview?tn=5147 Args: - filename: file to write to - append: write mode if False, append if True (Default) + filename (str): File to write to. + append (bool): Write mode if False, append if True. Defaults to True. """ o_flag = "a" if not append: @@ -71,8 +70,8 @@ def writeToFile(self, filename, append=True): try: f = open(filename, o_flag) - except Exception: - raise + except Exception as e: + raise e with f: if append and f.tell() > 0: @@ -83,12 +82,15 @@ def writeToFile(self, filename, append=True): f.write("%s" % self) def __str__(self): - """String representation of the classad.""" + """String representation of the classad. + Returns: + str: The string representation of the classad. + """ ad = "" for key, value in self.adParams.items(): - if isinstance(value, str) or isinstance(value, str): + if isinstance(value, str): # Format according to Condor String Literal definition # http://research.cs.wisc.edu/htcondor/manual/v7.8/4_1HTCondor_s_ClassAd.html#SECTION005121 classad_value = value.replace('"', r"\"") diff --git a/lib/cleanupSupport.py b/lib/cleanupSupport.py index c154f918b..a4413577a 100644 --- a/lib/cleanupSupport.py +++ b/lib/cleanupSupport.py @@ -11,14 +11,37 @@ class Cleanup: + """ + A class used to manage cleanup tasks for various objects and processes. + + Attributes: + cleanup_objects (list): A list of cleanup objects to be processed. + cleanup_pids (list): A list of process IDs for cleanup tasks. + """ + def __init__(self): + """ + Initializes a Cleanup instance with empty lists for cleanup objects and PIDs. + """ self.cleanup_objects = [] self.cleanup_pids = [] def add_cleaner(self, cleaner): + """ + Adds a cleanup object to the list of objects to be cleaned. + + Args: + cleaner: An object with a cleanup method. + """ self.cleanup_objects.append(cleaner) def start_background_cleanup(self): + """ + Starts background cleanup processes by forking the current process. + + This method forks the current process into multiple subprocesses to handle + the cleanup tasks in parallel. + """ if self.cleanup_pids: logSupport.log.warning("Earlier cleanup PIDs %s still exist; skipping this cycle" % self.cleanup_pids) else: @@ -38,6 +61,12 @@ def start_background_cleanup(self): del cleanup_lists def wait_for_cleanup(self): + """ + Waits for all cleanup subprocesses to finish. + + This method checks the status of the cleanup subprocesses and logs + when they have finished. + """ for pid in self.cleanup_pids: try: return_pid, _ = os.waitpid(pid, os.WNOHANG) @@ -49,7 +78,12 @@ def wait_for_cleanup(self): logSupport.log.warning(f"Received error {e.strerror} while waiting for PID {pid}") def cleanup(self): - # foreground cleanup + """ + Performs foreground cleanup tasks. + + This method iterates over all registered cleanup objects and calls + their cleanup methods. + """ for cleaner in self.cleanup_objects: cleaner.cleanup() @@ -59,10 +93,16 @@ def cleanup(self): class CredCleanup(Cleanup): """ - Cleans up old credential files. + A class used to clean up old credential files. """ def cleanup(self, in_use_proxies): + """ + Cleans up credential files that are no longer in use. + + Args: + in_use_proxies (list): A list of currently in-use proxy files. + """ for cleaner in self.cleanup_objects: cleaner.cleanup(in_use_proxies) @@ -70,8 +110,18 @@ def cleanup(self, in_use_proxies): cred_cleaners = CredCleanup() -# this class is used for cleanup class DirCleanup: + """ + A class used for cleaning up old files in a directory. + + Attributes: + dirname (str): The directory to clean. + fname_expression (str): A regular expression to match file names. + maxlife (int): The maximum lifetime of files in seconds. + should_log (bool): Whether to log information messages. + should_log_warnings (bool): Whether to log warning messages. + """ + def __init__( self, dirname, @@ -80,6 +130,16 @@ def __init__( should_log=True, should_log_warnings=True, ): + """ + Initializes a DirCleanup instance. + + Args: + dirname (str): The directory to clean. + fname_expression (str): A regular expression to match file names. + maxlife (int): The maximum lifetime of files in seconds. + should_log (bool, optional): Whether to log information messages. Defaults to True. + should_log_warnings (bool, optional): Whether to log warning messages. Defaults to True. + """ self.dirname = dirname self.fname_expression = fname_expression self.fname_expression_obj = re.compile(fname_expression) @@ -88,6 +148,11 @@ def __init__( self.should_log_warnings = should_log_warnings def cleanup(self): + """ + Cleans up files in the directory that match the filename expression and are older than maxlife. + + This method removes files that are older than the specified maximum lifetime. + """ count_removes = 0 treshold_time = time.time() - self.maxlife @@ -109,9 +174,16 @@ def cleanup(self): if self.should_log: logSupport.log.info("Removed %i files." % count_removes) - # INTERNAL - # return a dictionary of fpaths each having the os.lstat output def get_files_wstats(self): + """ + Retrieves a dictionary of file paths and their statistics. + + This method returns a dictionary where the keys are file paths and the + values are the output of os.lstat for each file. + + Returns: + dict: A dictionary of file paths and their statistics. + """ out_data = {} fnames = os.listdir(self.dirname) @@ -129,13 +201,30 @@ def get_files_wstats(self): return out_data - # this may reimplemented by the children def delete_file(self, fpath): + """ + Deletes a file from the filesystem. + + Args: + fpath (str): The path to the file to be deleted. + """ os.unlink(fpath) -# this class is used for cleanup class DirCleanupWSpace(DirCleanup): + """ + A class used for cleaning up files in a directory based on both age and total space used. + + Attributes: + dirname (str): The directory to clean. + fname_expression (str): A regular expression to match file names. + maxlife (int): The maximum lifetime of files in seconds. + minlife (int): The minimum lifetime of files in seconds. + maxspace (int): The maximum allowed space for the files in bytes. + should_log (bool): Whether to log information messages. + should_log_warnings (bool): Whether to log warning messages. + """ + def __init__( self, dirname, @@ -146,11 +235,29 @@ def __init__( should_log=True, should_log_warnings=True, ): + """ + Initializes a DirCleanupWSpace instance. + + Args: + dirname (str): The directory to clean. + fname_expression (str): A regular expression to match file names. + maxlife (int): The maximum lifetime of files in seconds. + minlife (int): The minimum lifetime of files in seconds. + maxspace (int): The maximum allowed space for the files in bytes. + should_log (bool, optional): Whether to log information messages. Defaults to True. + should_log_warnings (bool, optional): Whether to log warning messages. Defaults to True. + """ DirCleanup.__init__(self, dirname, fname_expression, maxlife, should_log, should_log_warnings) self.minlife = minlife self.maxspace = maxspace def cleanup(self): + """ + Cleans up files in the directory based on age and total space used. + + This method removes files that are older than the specified maximum lifetime or if + the total space used by the files exceeds the specified maximum space. + """ count_removes = 0 count_removes_bytes = 0 @@ -159,10 +266,8 @@ def cleanup(self): files_wstats = self.get_files_wstats() fpaths = list(files_wstats.keys()) - # order based on time (older first) fpaths.sort(key=lambda x: files_wstats[x][stat.ST_MTIME]) - # first calc the amount of space currently used used_space = 0 for fpath in fpaths: fstat = files_wstats[fpath] @@ -194,15 +299,34 @@ def cleanup(self): class DirCleanupCredentials(DirCleanup): """ - Used to cleanup old credential files saved to disk by the factory for glidein submission (based on ctime). + A class used to clean up old credential files saved to disk by the factory for glidein submission. + + Attributes: + dirname (str): The directory to clean. + fname_expression (str): A regular expression to match file names. + maxlife (int): The maximum lifetime of files in seconds. """ def __init__( self, dirname, fname_expression, maxlife # regular expression, used with re.match ): # max lifetime after which it is deleted + """ + Initializes a DirCleanupCredentials instance. + + Args: + dirname (str): The directory to clean. + fname_expression (str): A regular expression to match file names. + maxlife (int): The maximum lifetime of files in seconds. + """ DirCleanup.__init__(self, dirname, fname_expression, maxlife, should_log=True, should_log_warnings=True) def cleanup(self, in_use_creds): + """ + Cleans up credential files that are no longer in use. + + Args: + in_use_creds (list): A list of currently in-use credential files. + """ count_removes = 0 curr_time = time.time() diff --git a/lib/condorExe.py b/lib/condorExe.py index 03ca6cb70..7c52d7b7a 100644 --- a/lib/condorExe.py +++ b/lib/condorExe.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -# Description: -# This module implements the functions to execute condor commands - +""" +This module implements the functions to execute condor commands. +""" import os @@ -13,36 +13,59 @@ class CondorExeError(RuntimeError): - """Base class for condorExe module errors""" + """ + Base class for condorExe module errors. + """ def __init__(self, err_str): + """ + Initializes the CondorExeError with an error message. + + Args: + err_str (str): The error message. + """ RuntimeError.__init__(self, err_str) class UnconfigError(CondorExeError): + """ + Exception raised when condor is unconfigured. + """ + def __init__(self, err_str): + """ + Initializes the UnconfigError with an error message. + + Args: + err_str (str): The error message. + """ CondorExeError.__init__(self, err_str) class ExeError(CondorExeError): - def __init__(self, err_str): - CondorExeError.__init__(self, err_str) + """ + Exception raised when there is an error executing a condor command. + """ + def __init__(self, err_str): + """ + Initializes the ExeError with an error message. -# -# Configuration -# + Args: + err_str (str): The error message. + """ + CondorExeError.__init__(self, err_str) def set_path(new_condor_bin_path, new_condor_sbin_path=None): - """Set path to condor binaries, if needed + """ + Set path to condor binaries, if needed. - Works changing the global variables condor_bin_path and condor_sbin_path + Works by changing the global variables condor_bin_path and condor_sbin_path. Args: - new_condor_bin_path (str): directory where the HTCondor binaries are located - new_condor_sbin_path (str): directory where the HTCondor system binaries are located - + new_condor_bin_path (str): Directory where the HTCondor binaries are located. + new_condor_sbin_path (str, optional): Directory where the HTCondor system binaries are located. Defaults to None. """ global condor_bin_path, condor_sbin_path condor_bin_path = new_condor_bin_path @@ -51,21 +74,23 @@ def set_path(new_condor_bin_path, new_condor_sbin_path=None): def exe_cmd(condor_exe, args, stdin_data=None, env={}): - """Execute an arbitrary condor command and return its output as a list of lines - Fails if stderr is not empty + """ + Execute an arbitrary condor command and return its output as a list of lines. + + Fails if stderr is not empty. Args: - condor_exe (str): condor_exe uses a relative path to $CONDOR_BIN - args (str): arguments for the command - stdin_data (str): Data that will be fed to the command via stdin - env (dict): Environment to be set before execution + condor_exe (str): Condor executable, uses a relative path to $CONDOR_BIN. + args (str): Arguments for the command. + stdin_data (str, optional): Data that will be fed to the command via stdin. Defaults to None. + env (dict, optional): Environment to be set before execution. Defaults to {}. Returns: - Lines of stdout from the command + list: Lines of stdout from the command. Raises: - UnconfigError: - ExeError: + UnconfigError: If condor_bin_path is undefined. + ExeError: If there is an error executing the command. """ global condor_bin_path @@ -79,6 +104,24 @@ def exe_cmd(condor_exe, args, stdin_data=None, env={}): def exe_cmd_sbin(condor_exe, args, stdin_data=None, env={}): + """ + Execute an arbitrary condor system command and return its output as a list of lines. + + Fails if stderr is not empty. + + Args: + condor_exe (str): Condor executable, uses a relative path to $CONDOR_SBIN. + args (str): Arguments for the command. + stdin_data (str, optional): Data that will be fed to the command via stdin. Defaults to None. + env (dict, optional): Environment to be set before execution. Defaults to {}. + + Returns: + list: Lines of stdout from the command. + + Raises: + UnconfigError: If condor_sbin_path is undefined. + ExeError: If there is an error executing the command. + """ global condor_sbin_path if condor_sbin_path is None: @@ -90,28 +133,20 @@ def exe_cmd_sbin(condor_exe, args, stdin_data=None, env={}): return iexe_cmd(cmd, stdin_data, env) -############################################################ -# -# P R I V A T E, do not use -# -############################################################ def generate_bash_script(cmd, environment): - """Print to a string a shell script setting the environment in 'environment' and running 'cmd' - If 'cmd' last argument is a file it will be printed as well in the string + """ + Print to a string a shell script setting the environment in 'environment' and running 'cmd'. + + If 'cmd' last argument is a file it will be printed as well in the string. Args: - cmd (str): command string - environment (dict): environment as a dictionary + cmd (str): Command string. + environment (dict): Environment as a dictionary. Returns: - str: multi-line string with environment, command and eventually the input file + str: Multi-line string with environment, command, and eventually the input file. """ script = ["script to reproduce failure:", "-" * 20 + " begin script " + "-" * 20, "#!/bin/bash"] - # FROM:migration_3_1, 3 lines - # script = ['script to reproduce failure:'] - # script.append('-' * 20 + ' begin script ' + '-' * 20) - # script.append('#!/bin/bash') - script += [f"{k}={v}" for k, v in environment.items()] script.append(cmd) script.append("-" * 20 + " end script " + "-" * 20) @@ -130,20 +165,20 @@ def generate_bash_script(cmd, environment): def iexe_cmd(cmd, stdin_data=None, child_env=None, log=None): - """Fork a process and execute cmd - rewritten to use select to avoid filling - up stderr and stdout queues. + """ + Fork a process and execute cmd - rewritten to use select to avoid filling up stderr and stdout queues. Args: - cmd (str): Sting containing the entire command including all arguments - stdin_data (str): Data that will be fed to the command via stdin - child_env (dict): Environment to be set before execution + cmd (str): Command string containing the entire command including all arguments. + stdin_data (str, optional): Data that will be fed to the command via stdin. Defaults to None. + child_env (dict, optional): Environment to be set before execution. Defaults to None. + log (optional): Logger instance. Defaults to None. Returns: - list of str: Lines of stdout from the command + list: Lines of stdout from the command. Raises: - ExeError - + ExeError: If there is an error executing the command. """ stdout_data = "" if log is None: @@ -175,13 +210,10 @@ def iexe_cmd(cmd, stdin_data=None, child_env=None, log=None): return stdout_data.splitlines() -######################### -# Module initialization -# - - def init1(): - """Set condor_bin_path""" + """ + Set condor_bin_path using various methods to locate the HTCondor binaries. + """ global condor_bin_path # try using condor commands to find it out try: @@ -217,7 +249,9 @@ def init1(): def init2(): - """Set condor_sbin_path""" + """ + Set condor_sbin_path using various methods to locate the HTCondor system binaries. + """ global condor_sbin_path # try using condor commands to find it out try: @@ -253,7 +287,9 @@ def init2(): def init(): - """Set both Set condor_bin_path and condor_sbin_path""" + """ + Initialize both condor_bin_path and condor_sbin_path. + """ init1() init2() diff --git a/lib/condorLogParser.py b/lib/condorLogParser.py index 568f7a465..a4c0dc126 100644 --- a/lib/condorLogParser.py +++ b/lib/condorLogParser.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This module implements classes and functions to parse the condor log files. +""" +This module implements classes and functions to parse the condor log files. NOTE: -Inactive files are log files that have only completed or removed entries -Such files will not change in the future +Inactive files are log files that have only completed or removed entries. +Such files will not change in the future. """ import mmap @@ -21,16 +22,26 @@ class cachedLogClass: """ - This is the base class for most Log Parsers in lib/condorLogParser - and factory/glideFactoryLogParser. (I{virtual, do not use}) + Base class for most Log Parsers in lib/condorLogParser and factory/glideFactoryLogParser. + (Virtual, do not use directly) The Constructor for inherited classes needs to define logname and cachename - (possibly by using clInit) as well as the methods - loadFromLog, merge and isActive. - init method to be used by real constructors + (possibly by using clInit) as well as the methods loadFromLog, merge, and isActive. + + Attributes: + logname (str): The name of the log file. + cachename (str): The name of the cache file. """ def clInit(self, logname, cache_dir, cache_ext): + """ + Initializes the log and cache names. + + Args: + logname (str): The name of the log file. + cache_dir (str): The directory where the cache is stored. + cache_ext (str): The extension for the cache file. + """ self.logname = logname if cache_dir is None: self.cachename = logname + cache_ext @@ -39,8 +50,10 @@ def clInit(self, logname, cache_dir, cache_ext): def has_changed(self): """ - Compare to cache, and tell if the log file has changed - since last checked + Compare to cache, and tell if the log file has changed since last checked. + + Returns: + bool: True if the log file has changed, False otherwise. """ if os.path.isfile(self.logname): fstat = os.lstat(self.logname) @@ -59,11 +72,8 @@ def has_changed(self): def load(self): """ - Load data from most recent file. Update the cache if needed. - If the file has not changed, use the cache instead - (typically named something like filename.ftstpk) in a pickle format. - If file is newer, uses inherited class's loadFromLog method. - Then, save in pickle cache. + Load data from the most recent file. Update the cache if needed. + If the file has not changed, use the cache instead. """ if not self.has_changed(): # cache is newer, just load the cache @@ -89,51 +99,79 @@ def load(self): return # should never reach this point def loadCache(self): + """ + Load data from the cache file. + """ self.data = loadCache(self.cachename) return def loadFromLog(self): + """ + Load data from the log file. + + This method should be implemented by subclasses. + """ raise RuntimeError("loadFromLog not implemented!") - ####### PRIVATE ########### def saveCache(self): + """ + Save data to the cache file. + """ saveCache(self.cachename, self.data) return class logSummary(cachedLogClass): """ - This class will keep track of: - jobs in various of statuses (Wait, Idle, Running, Held, Completed, Removed) - This data is available in self.data dictionary - for example - self.data={'Idle':['123.003','123.004'],'Running':['123.001','123.002']} + Keeps track of jobs in various statuses (Wait, Idle, Running, Held, Completed, Removed). + + This data is available in self.data dictionary, for example: + self.data = {'Idle': ['123.003', '123.004'], 'Running': ['123.001', '123.002']} """ def __init__(self, logname, cache_dir): + """ + Initializes logSummary with log and cache names. + + Args: + logname (str): The name of the log file. + cache_dir (str): The directory where the cache is stored. + """ self.clInit(logname, cache_dir, ".cstpk") def loadFromLog(self): """ Parse the condor activity log and interpret the globus status code. - Stores in self.data + + Stores the result in self.data. """ jobs = parseSubmitLogFastRaw(self.logname) self.data = listAndInterpretRawStatuses(jobs, listStatuses) return def isActive(self): + """ + Determine if there are any active jobs. + + Returns: + bool: True if there are active jobs, False otherwise. + """ active = False for k in list(self.data.keys()): if k not in ["Completed", "Removed"]: if len(self.data[k]) > 0: - active = True # it is enought that at least one non Completed/removed job exist + active = True # it is enough that at least one non Completed/Removed job exists return active def merge(self, other): """ - Merge self data with other info - @return: merged data, may modify other + Merge self data with other data. + + Args: + other (dict): The data to merge with. + + Returns: + dict: The merged data. """ if other is None: return self.data @@ -149,13 +187,13 @@ def merge(self, other): def diff(self, other): """ - diff self data with other info. Used to compare - previous iteration with current iteration + Compare self data with other data. - Performs symmetric difference on the two sets and - creates a dictionary for each status. + Args: + other (dict): The data to compare with. - @return: data[status]['Entered'|'Exited'] - list of jobs + Returns: + dict: A dictionary with the differences, showing entered and exited jobs. """ if other is None: outdata = {} @@ -200,24 +238,28 @@ def diff(self, other): class logCompleted(cachedLogClass): """ - This class will keep track of: - - counts of statuses (Wait, Idle, Running, Held, Completed, Removed) - - list of completed jobs - This data is available in self.data dictionary - - For example self.data= - {'completed_jobs':['123.002','555.001'], - 'counts':{'Idle': 1145, 'Completed': 2}} + Keeps track of counts of statuses (Wait, Idle, Running, Held, Completed, Removed) + and a list of completed jobs. + + This data is available in self.data dictionary, for example: + self.data = {'completed_jobs': ['123.002', '555.001'], 'counts': {'Idle': 1145, 'Completed': 2}} """ def __init__(self, logname, cache_dir): + """ + Initializes logCompleted with log and cache names. + + Args: + logname (str): The name of the log file. + cache_dir (str): The directory where the cache is stored. + """ self.clInit(logname, cache_dir, ".clspk") def loadFromLog(self): """ - Load information from condor_activity logs - Then parse globus statuses. - Finally, parse and add counts. + Load information from condor_activity logs, parse globus statuses, and add counts. + + Stores the result in self.data. """ tmpdata = {} jobs = parseSubmitLogFastRaw(self.logname) @@ -234,19 +276,29 @@ def loadFromLog(self): return def isActive(self): + """ + Determine if there are any active jobs. + + Returns: + bool: True if there are active jobs, False otherwise. + """ active = False counts = self.data["counts"] for k in list(counts.keys()): if k not in ["Completed", "Removed"]: if counts[k] > 0: - # Enough that at least one non Completed/removed job exist - active = True + active = True # Enough that at least one non Completed/Removed job exists return active def merge(self, other): """ - Merge self data with other info - @return: merged data, may modify other + Merge self data with other data. + + Args: + other (dict): The data to merge with. + + Returns: + dict: The merged data. """ if other is None: return self.data @@ -263,10 +315,13 @@ def merge(self, other): def diff(self, other): """ - Diff self.data with other info. - For use in comparing previous iteration with current iteration + Compare self data with other data. + + Args: + other (dict): The data to compare with. - Uses symmetric difference of sets. + Returns: + dict: A dictionary with the differences, showing entered and exited jobs. """ if other is None: if self.data is not None: @@ -317,33 +372,55 @@ def diff(self, other): class logCounts(cachedLogClass): """ - This class will keep track of - counts of statuses (Wait, Idle, Running, Held, Completed, Removed) - This data is available in self.data dictionary - For example self.data={'Idle': 1145, 'Completed': 2} + Keeps track of counts of statuses (Wait, Idle, Running, Held, Completed, Removed). + + This data is available in self.data dictionary, for example: + self.data = {'Idle': 1145, 'Completed': 2} """ def __init__(self, logname, cache_dir): + """ + Initializes logCounts with log and cache names. + + Args: + logname (str): The name of the log file. + cache_dir (str): The directory where the cache is stored. + """ self.clInit(logname, cache_dir, ".clcpk") def loadFromLog(self): + """ + Load and parse jobs from the log file, then count and interpret their statuses. + + Stores the result in self.data. + """ jobs = parseSubmitLogFastRaw(self.logname) self.data = countAndInterpretRawStatuses(jobs) return def isActive(self): + """ + Determine if there are any active jobs. + + Returns: + bool: True if there are active jobs, False otherwise. + """ active = False for k in list(self.data.keys()): if k not in ["Completed", "Removed"]: if self.data[k] > 0: - # Enough that at least one non Completed/removed job exist - active = True + active = True # Enough that at least one non Completed/Removed job exists return active def merge(self, other): """ - Merge self data with other info - @return: merged data, may modify other + Merge self data with other data. + + Args: + other (dict): The data to merge with. + + Returns: + dict: The merged data. """ if other is None: return self.data @@ -359,8 +436,13 @@ def merge(self, other): def diff(self, other): """ - Diff self data with other info - @return: diff of counts + Compare self data with other data. + + Args: + other (dict): The data to compare with. + + Returns: + dict: A dictionary with the differences in counts. """ if other is None: if self.data is not None: @@ -397,34 +479,55 @@ def diff(self, other): class logSummaryTimings(cachedLogClass): """ - This class will keep track of: - jobs in various of statuses (Wait, Idle, Running, Held, Completed, Removed) - This data is available in self.data dictionary - for example - self.data={'Idle':['123.003','123.004'],'Running':['123.001','123.002']} + Keeps track of jobs in various statuses (Wait, Idle, Running, Held, Completed, Removed) with timings. + + This data is available in self.data dictionary, for example: + self.data = {'Idle': ['123.003', '123.004'], 'Running': ['123.001', '123.002']} """ def __init__(self, logname, cache_dir): + """ + Initializes logSummaryTimings with log and cache names. + + Args: + logname (str): The name of the log file. + cache_dir (str): The directory where the cache is stored. + """ self.clInit(logname, cache_dir, ".ctstpk") def loadFromLog(self): + """ + Load and parse jobs from the log file, including timings. + + Stores the result in self.data. + """ jobs, self.startTime, self.endTime = parseSubmitLogFastRawTimings(self.logname) self.data = listAndInterpretRawStatuses(jobs, listStatusesTimings) return def isActive(self): + """ + Determine if there are any active jobs. + + Returns: + bool: True if there are active jobs, False otherwise. + """ active = False for k in list(self.data.keys()): if k not in ["Completed", "Removed"]: if len(self.data[k]) > 0: - # Enough that at least one non Completed/removed job exist - active = True + active = True # Enough that at least one non Completed/Removed job exists return active def merge(self, other): """ - merge self data with other info - @return: merged data, may modify other + Merge self data with other data. + + Args: + other (dict): The data to merge with. + + Returns: + dict: The merged data. """ if other is None: return self.data @@ -440,8 +543,13 @@ def merge(self, other): def diff(self, other): """ - diff self data with other info - @return: data[status]['Entered'|'Exited'] - list of jobs + Compare self data with other data. + + Args: + other (dict): The data to compare with. + + Returns: + dict: A dictionary with the differences, showing entered and exited jobs. """ if other is None: outdata = {} @@ -473,6 +581,12 @@ def diff(self, other): for oel_e in other[s]: oel.append(oel_e[0]) + outdata_s = {"Entered": [], "Exited": []} + outdata[s] = outdata_s + + sset = set(sel) + oset = set(oel) + ################# # Need to finish @@ -507,8 +621,8 @@ def diff(self, other): class cacheDirClass: """ This is the base class for all the directory log Parser - classes. It parses some/all log files in a directory. - It should generally not be called directly. Rather, + classes. It parses some/all log files in a directory. + It should generally not be called directly. Rather, call one of the inherited classes. """ @@ -526,9 +640,19 @@ def __init__( username=None, ): """ - @param inactive_files: if None, will be reloaded from cache - @param inactive_timeout: how much time must elapse before a file can be declared inactive - @param cache_dir: If None, use dirname for the cache directory. + Initializes the cacheDirClass. + + Args: + logClass: The class used for parsing the logs. + dirname (str): The directory containing the log files. + log_prefix (str): The prefix for log files. + log_suffix (str): The suffix for log files. Defaults to ".log". + cache_ext (str): The extension for the cache file. Defaults to ".cifpk". + inactive_files (list): List of inactive files. If None, will be reloaded from cache. + inactive_timeout (int): Time in seconds before a file can be declared inactive. Defaults to 24 * 3600. + cache_dir (str): Directory for the cache files. If None, use dirname. + wrapperClass: The wrapper class, if any. + username (str): The username, if any. """ self.cdInit( logClass, @@ -557,13 +681,20 @@ def cdInit( username=None, ): """ - @param logClass: this is an actual class, not an object - @param inactive_files: if None, will be reloaded from cache - @param inactive_timeout: how much time must elapse before a file can be - declared inactive - @param cache_dir: If None, use dirname for the cache directory. + Initializes the cache directory. + + Args: + logClass: The class used for parsing the logs. + dirname (str): The directory containing the log files. + log_prefix (str): The prefix for log files. + log_suffix (str): The suffix for log files. Defaults to ".log". + cache_ext (str): The extension for the cache file. Defaults to ".cifpk". + inactive_files (list): List of inactive files. If None, will be reloaded from cache. + inactive_timeout (int): Time in seconds before a file can be declared inactive. Defaults to 24 * 3600. + cache_dir (str): Directory for the cache files. If None, use dirname. + wrapperClass: The wrapper class, if any. + username (str): The username, if any. """ - self.wrapperClass = wrapperClass self.username = username @@ -583,17 +714,19 @@ def cdInit( self.inactive_files = [] else: self.inactive_files = inactive_files - return def getFileList(self, active_only): """ Lists the directory and returns files that match the prefix/suffix extensions and are active (modified within - inactivity_timeout) + inactivity_timeout). - @return: a list of log files - """ + Args: + active_only (bool): If True, only return active files. + Returns: + list: A list of log files. + """ prefix_len = len(self.log_prefix) suffix_len = len(self.log_suffix) files = [] @@ -609,12 +742,11 @@ def getFileList(self, active_only): def has_changed(self): """ - Checks all the files in the list to see if any - have changed. + Checks all the files in the list to see if any have changed. - @return: True/False + Returns: + bool: True if any file has changed, False otherwise. """ - ch = False fnames = self.getFileList(active_only=True) for fname in fnames: @@ -630,15 +762,15 @@ def has_changed(self): def load(self, active_only=True): """ - For each file in the filelist, call the appropriate load() - function for that file. Merge all the data from all the files + For each file in the file list, call the appropriate load() + function for that file. Merge all the data from all the files into temporary array mydata then set it to self.data. It will save the list of inactive_files it finds in a cache for quick access. - This function should set self.data. + Args: + active_only (bool): If True, only load active files. """ - mydata = None new_inactives = [] @@ -663,7 +795,6 @@ def load(self, active_only=True): self.data = mydata # try to save inactive files in the cache - # if one was looking at inactive only if active_only and (len(new_inactives) > 0): self.inactive_files += new_inactives try: @@ -671,16 +802,19 @@ def load(self, active_only=True): except OSError: return # silently ignore, this was a load in the end - return - def diff(self, other): """ - Diff self data with other info + Compare self data with other data. This is a virtual function that just calls the class diff() function. - """ + Args: + other (dict): The data to compare with. + + Returns: + dict: The differences between self.data and other. + """ if (self.wrapperClass is not None) and (self.username is not None): dummyobj = self.wrapperClass.getObj(os.path.join(self.dirname, "dummy.txt"), self.cache_dir, self.username) else: @@ -692,11 +826,10 @@ def diff(self, other): class dirSummary(cacheDirClass): """ - This class will keep track of: - jobs in various of statuses (Wait, Idle, Running, Held, Completed, Removed) - This data is available in self.data dictionary - For example, - self.data={'Idle':['123.003','123.004'],'Running':['123.001','123.002']} + Keeps track of jobs in various statuses (Wait, Idle, Running, Held, Completed, Removed). + + This data is available in self.data dictionary, for example: + self.data = {'Idle': ['123.003', '123.004'], 'Running': ['123.001', '123.002']} """ def __init__( @@ -710,23 +843,27 @@ def __init__( cache_dir=None, ): """ - @param inactive_files: if ==None, will be reloaded from cache - @param inactive_timeout: how much time must elapse before - @param cache_dir: if None, use dirname + Initializes dirSummary with log and cache parameters. + + Args: + dirname (str): The directory containing the log files. + log_prefix (str): The prefix for log files. + log_suffix (str): The suffix for log files. Defaults to ".log". + cache_ext (str): The extension for the cache file. Defaults to ".cifpk". + inactive_files (list): List of inactive files. If None, will be reloaded from cache. + inactive_timeout (int): Time in seconds before a file can be declared inactive. Defaults to 24 * 3600. + cache_dir (str): Directory for the cache files. If None, use dirname. """ - self.cdInit(logSummary, dirname, log_prefix, log_suffix, cache_ext, inactive_files, inactive_timeout, cache_dir) class dirCompleted(cacheDirClass): """ - This class will keep track of: - - counts of statuses (Wait, Idle, Running, Held, Completed, Removed) - - list of completed jobs - This data is available in self.data dictionary - for example - self.data={'completed_jobs':['123.002','555.001'], - 'counts':{'Idle': 1145, 'Completed': 2}} + Keeps track of counts of statuses (Wait, Idle, Running, Held, Completed, Removed) + and a list of completed jobs. + + This data is available in self.data dictionary, for example: + self.data = {'completed_jobs': ['123.002', '555.001'], 'counts': {'Idle': 1145, 'Completed': 2}} """ def __init__( @@ -740,11 +877,17 @@ def __init__( cache_dir=None, ): """ - @param inactive_files: if ==None, will be reloaded from cache - @param inactive_timeout: how much time must elapse before - @param cache_dir: if None, use dirname + Initializes dirCompleted with log and cache parameters. + + Args: + dirname (str): The directory containing the log files. + log_prefix (str): The prefix for log files. + log_suffix (str): The suffix for log files. Defaults to ".log". + cache_ext (str): The extension for the cache file. Defaults to ".cifpk". + inactive_files (list): List of inactive files. If None, will be reloaded from cache. + inactive_timeout (int): Time in seconds before a file can be declared inactive. Defaults to 24 * 3600. + cache_dir (str): Directory for the cache files. If None, use dirname. """ - self.cdInit( logCompleted, dirname, log_prefix, log_suffix, cache_ext, inactive_files, inactive_timeout, cache_dir ) @@ -752,10 +895,10 @@ def __init__( class dirCounts(cacheDirClass): """ - This class will keep track of - counts of statuses (Wait, Idle, Running, Held, Completed, Removed) - These data is available in self.data dictionary - for example self.data={'Idle': 1145, 'Completed': 2} + Keeps track of counts of statuses (Wait, Idle, Running, Held, Completed, Removed). + + This data is available in self.data dictionary, for example: + self.data = {'Idle': 1145, 'Completed': 2} """ def __init__( @@ -769,25 +912,27 @@ def __init__( cache_dir=None, ): """ - @param inactive_files: if ==None, will be reloaded from cache - @param inactive_timeout: how much time must elapse before - @param cache_dir: if None, use dirname + Initializes dirCounts with log and cache parameters. + + Args: + dirname (str): The directory containing the log files. + log_prefix (str): The prefix for log files. + log_suffix (str): The suffix for log files. Defaults to ".log". + cache_ext (str): The extension for the cache file. Defaults to ".cifpk". + inactive_files (list): List of inactive files. If None, will be reloaded from cache. + inactive_timeout (int): Time in seconds before a file can be declared inactive. Defaults to 24 * 3600. + cache_dir (str): Directory for the cache files. If None, use dirname. """ - self.cdInit(logCounts, dirname, log_prefix, log_suffix, cache_ext, inactive_files, inactive_timeout, cache_dir) class dirSummaryTimings(cacheDirClass): """ - This class will keep track of: - jobs in various of statuses (Wait, Idle, Running, Held, Completed, Removed) - This data is available in self.data dictionary - For example self.data={'Idle':[('123.003','09/28 01:38:53', - '09/28 01:42:23', '09/28 08:06:33'),('123.004','09/28 02:38:53', - '09/28 02:42:23', '09/28 09:06:33')], - 'Running':[('123.001','09/28 01:32:53', '09/28 01:43:23', - '09/28 08:07:33'),('123.002','09/28 02:38:53', '09/28 03:42:23', - '09/28 06:06:33')]} + Keeps track of jobs in various statuses (Wait, Idle, Running, Held, Completed, Removed) with timings. + + This data is available in self.data dictionary, for example: + self.data = {'Idle': [('123.003', '09/28 01:38:53', '09/28 01:42:23', '09/28 08:06:33')], + 'Running': [('123.001', '09/28 01:32:53', '09/28 01:43:23', '09/28 08:07:33')]} """ def __init__( @@ -801,11 +946,17 @@ def __init__( cache_dir=None, ): """ - @param inactive_files: if ==None, will be reloaded from cache - @param inactive_timeout: how much time must elapse before - @param cache_dir: if None, use dirname + Initializes dirSummaryTimings with log and cache parameters. + + Args: + dirname (str): The directory containing the log files. + log_prefix (str): The prefix for log files. + log_suffix (str): The suffix for log files. Defaults to ".log". + cache_ext (str): The extension for the cache file. Defaults to ".cifpk". + inactive_files (list): List of inactive files. If None, will be reloaded from cache. + inactive_timeout (int): Time in seconds before a file can be declared inactive. Defaults to 24 * 3600. + cache_dir (str): Directory for the cache files. If None, use dirname. """ - self.cdInit( logSummaryTimings, dirname, log_prefix, log_suffix, cache_ext, inactive_files, inactive_timeout, cache_dir ) @@ -845,9 +996,9 @@ def __init__( # 019 - Globus Resource Back Up # 020 - Detected Down Globus Resource # 021 - Remote error -# 022 - Remote system diconnected +# 022 - Remote system disconnected # 023 - Remote system reconnected -# 024 - Remote system cannot recconect +# 024 - Remote system cannot reconnect # 025 - Grid Resource Back Up # 026 - Detected Down Grid Resource # 027 - Job submitted to grid resource @@ -865,11 +1016,13 @@ def get_new_status(old_status, new_status): Given a job with an old and new status, will return the appropriate status to register to the job. - @param old_status: Globus job status - @param new_status: Globus job status - @return: Appropriate status for the job - """ + Args: + old_status (str): Globus job status. + new_status (str): Globus job status. + Returns: + str: Appropriate status for the job. + """ # keep the old status unless you really want to change status = old_status @@ -897,10 +1050,13 @@ def parseSubmitLogFastRaw(fname): """ Read a condor submit log. - @return: a dictionary of jobStrings each having the last statusString - For example {'1583.004': '000', '3616.008': '009'} - """ + Args: + fname (str): Filename of the log to parse. + Returns: + dict: A dictionary of jobStrings each having the last statusString. + For example, {'1583.004': '000', '3616.008': '009'} + """ jobs = {} size = os.path.getsize(fname) @@ -944,15 +1100,19 @@ def parseSubmitLogFastRaw(fname): def parseSubmitLogFastRawTimings(fname): """ - Read a condor submit log. Returns a dictionary of jobStrings - each having (the last statusString,firstTime,runningTime,lastTime) - plus the first and last date in the file - - for example {'9568.001':('000', '09/28 01:38:53', '', '09/28 01:38:53'),'9868.003':('005', '09/28 01:48:52', '09/28 16:11:23', '09/28 20:31:53')},'09/28 01:38:53','09/28 20:31:53' - - @return: a dictionary of jobStrings + Read a condor submit log and return a dictionary of jobStrings + each having the last statusString, firstTime, runningTime, lastTime, + plus the first and last date in the file. + + Args: + fname (str): Filename of the log to parse. + + Returns: + tuple: A dictionary of jobStrings, the first time, and the last time. + For example, ({'9568.001': ('000', '09/28 01:38:53', '', '09/28 01:38:53'), + '9868.003': ('005', '09/28 01:48:52', '09/28 16:11:23', '09/28 20:31:53')}, + '09/28 01:38:53', '09/28 20:31:53') """ - jobs = {} first_time = None @@ -1014,13 +1174,13 @@ def parseSubmitLogFastRawTimings(fname): def parseSubmitLogFastRawCallback(fname, callback): """ - Read a condor submit log - for each new event, call a callback + Read a condor submit log and for each new event, call a callback. - @param fname: Condor submit file to parse - @param callname: def callback(new_status_str,timestamp_str,job_str) + Args: + fname (str): Filename of the log to parse. + callback (function): Function to call for each new event. + Should have the signature callback(new_status_str, timestamp_str, job_str) """ - jobs = {} size = os.path.getsize(fname) @@ -1056,14 +1216,14 @@ def parseSubmitLogFastRawCallback(fname, callback): if new_status != old_status: callback(new_status, line_time, jobid) if new_status in ("005", "009"): - del jobs[jobid] # end of live, don't need it anymore + del jobs[jobid] # end of life, don't need it anymore else: jobs[jobid] = new_status else: jobs[jobid] = status callback(status, line_time, jobid) - i1 = buf.find("...", idx) + i1 = buf.find(b"...", idx) if i1 < 0: break idx = i1 + 4 # the 3 dots plus newline @@ -1072,59 +1232,88 @@ def parseSubmitLogFastRawCallback(fname, callback): return -def rawJobId2Nr(str): +def rawJobId2Nr(job_str): """ - Convert the log representation into (ClusterId,ProcessId) + Convert the log representation into (ClusterId, ProcessId). + + Args: + job_str (str): Job string in the format 'ClusterId.ProcessId'. - Return (-1,-1) in case of error + Returns: + tuple: (ClusterId, ProcessId) or (-1, -1) in case of error. """ - arr = str.split(b".") + arr = job_str.split(b".") try: - return (int(arr[0]), int(arr[1])) + return int(arr[0]), int(arr[1]) except (IndexError, ValueError): - return (-1, -1) # invalid + return -1, -1 # invalid -def rawTime2cTime(instr, year): +def rawTime2cTime(time_str, year): """ - Convert the log representation into ctime + Convert the log representation into ctime. - @return: ctime or -1 in case of error + Args: + time_str (str): Time string in the format 'MM/DD HH:MM:SS'. + year (int): The year. + + Returns: + int: ctime or -1 in case of error. """ try: ctime = time.mktime( - (year, int(instr[0:2]), int(instr[3:5]), int(instr[6:8]), int(instr[9:11]), int(instr[12:14]), 0, 0, -1) + ( + year, + int(time_str[0:2]), + int(time_str[3:5]), + int(time_str[6:8]), + int(time_str[9:11]), + int(time_str[12:14]), + 0, + 0, + -1, + ) ) except ValueError: return -1 # invalid return ctime -def rawTime2cTimeLastYear(instr): +def rawTime2cTimeLastYear(time_str): """ - Convert the log representation into ctime, - works only for the past year + Convert the log representation into ctime, works only for the past year. + + Args: + time_str (str): Time string in the format 'MM/DD HH:MM:SS'. - @return: ctime or -1 in case of error + Returns: + int: ctime or -1 in case of error. """ now = time.time() current_year = time.localtime(now)[0] - ctime = rawTime2cTime(instr, current_year) + ctime = rawTime2cTime(time_str, current_year) if ctime <= now: return ctime else: # cannot be in the future... it must have been in the past year - ctime = rawTime2cTime(instr, current_year - 1) - return ctime + return rawTime2cTime(time_str, current_year - 1) def diffTimes(start_time, end_time, year): """ - Get two condor time strings and compute the difference - The start_time must be before the end_time + Get two condor time strings and compute the difference. + The start_time must be before the end_time. + + Args: + start_time (str): Start time in the format 'MM/DD HH:MM:SS'. + end_time (str): End time in the format 'MM/DD HH:MM:SS'. + year (int): The year. + + Returns: + int: Difference in seconds or -1 in case of error. """ start_ctime = rawTime2cTime(start_time, year) end_ctime = rawTime2cTime(end_time, year) - if (start_time < 0) or (end_time < 0): + if start_ctime < 0 or end_ctime < 0: return -1 # invalid return int(end_ctime) - int(start_ctime) @@ -1132,8 +1321,17 @@ def diffTimes(start_time, end_time, year): def diffTimeswWrap(start_time, end_time, year, wrap_time): """ - Get two condor time strings and compute the difference - The start_time must be before the end_time + Get two condor time strings and compute the difference with wrapping. + The start_time must be before the end_time. + + Args: + start_time (str): Start time in the format 'MM/DD HH:MM:SS'. + end_time (str): End time in the format 'MM/DD HH:MM:SS'. + year (int): The year. + wrap_time (str): Wrap time in the format 'MM/DD HH:MM:SS'. + + Returns: + int: Difference in seconds or -1 in case of error. """ if start_time > wrap_time: start_year = year @@ -1147,7 +1345,7 @@ def diffTimeswWrap(start_time, end_time, year, wrap_time): end_year = year + 1 end_ctime = rawTime2cTime(end_time, end_year) - if (start_time < 0) or (end_time < 0): + if start_ctime < 0 or end_ctime < 0: return -1 # invalid return int(end_ctime) - int(start_ctime) @@ -1155,8 +1353,14 @@ def diffTimeswWrap(start_time, end_time, year, wrap_time): def interpretStatus(status, default_status="Idle"): """ - Transform a integer globus status to - either Wait, Idle, Running, Held, Completed or Removed + Transform an integer globus status to either Wait, Idle, Running, Held, Completed or Removed. + + Args: + status (int): Globus status code. + default_status (str): Default status to return if status code is unknown. Defaults to "Idle". + + Returns: + str: Interpreted status. """ if status == 5: return "Completed" @@ -1176,15 +1380,17 @@ def interpretStatus(status, default_status="Idle"): def countStatuses(jobs): """ - Given a dictionary of job statuses - (like the one got from parseSubmitLogFastRaw) - will return a dictionary of sstatus counts + Given a dictionary of job statuses, will return a dictionary of status counts. - for example: {'009': 25170, '012': 418, '005': 1503} - """ + Args: + jobs (dict): Dictionary of job statuses. + Returns: + dict: Dictionary of status counts. + For example, {'009': 25170, '012': 418, '005': 1503} + """ counts = {} - for e in list(jobs.values()): + for e in jobs.values(): try: counts[e] += 1 except KeyError: @@ -1195,19 +1401,18 @@ def countStatuses(jobs): def countAndInterpretRawStatuses(jobs_raw): """ - Given a dictionary of job statuses - (like the one got from parseSubmitLogFastRaw) - will return a dictionary of status counts + Given a dictionary of job statuses, will return a dictionary of interpreted status counts. - for example: {'Completed': 30170, 'Removed': 148, 'Running': 5013} + Args: + jobs_raw (dict): Dictionary of job statuses. - @param jobs_raw: Dictionary of job statuses - @return: Dictionary of status counts + Returns: + dict: Dictionary of interpreted status counts. + For example, {'Completed': 30170, 'Removed': 148, 'Running': 5013} """ - outc = {} tmpc = countStatuses(jobs_raw) - for s in list(tmpc.keys()): + for s in tmpc.keys(): i_s = interpretStatus(int(s[1:])) # ignore flags try: outc[i_s] += tmpc[s] @@ -1219,18 +1424,17 @@ def countAndInterpretRawStatuses(jobs_raw): def listStatuses(jobs): """ - Given a dictionary of job statuses - (like the one got from parseSubmitLogFastRaw) - will return a dictionary of jobs in each status + Given a dictionary of job statuses, will return a dictionary of jobs in each status. - For example: {'009': ["1.003","2.001"], '012': ["418.001"], '005': ["1503.001","1555.002"]} + Args: + jobs (dict): Dictionary of job statuses. - @param jobs: Dictionary of job statuses - @return: Dictionary of jobs in each status category + Returns: + dict: Dictionary of jobs in each status category. + For example, {'009': ["1.003","2.001"], '012': ["418.001"], '005': ["1503.001","1555.002"]} """ - status = {} - for k, e in list(jobs.items()): + for k, e in jobs.items(): try: status[e].append(k) except KeyError: @@ -1241,18 +1445,18 @@ def listStatuses(jobs): def listStatusesTimings(jobs): """ - Given a dictionary of job statuses + timings - (like the one got from parseSubmitLogFastRawTimings) - will return a dictionary of jobs +timings in each status + Given a dictionary of job statuses and timings, will return a dictionary of jobs and timings in each status. - For example: {'009': [("1.003",'09/28 01:38:53', '', '09/28 01:38:53'),("2.001",'09/28 03:38:53', '', '09/28 04:38:53')], '005': [("1503.001", '09/28 01:48:52', '09/28 16:11:23', '09/28 20:31:53'),("1555.002", '09/28 02:48:52', '09/28 18:11:23', '09/28 23:31:53')]} + Args: + jobs (dict): Dictionary of job statuses and timings. - @param jobs: Dictionary of job statuses and timings - @return: Dictionary of jobs+timings in each status category + Returns: + dict: Dictionary of jobs and timings in each status category. + For example, {'009': [("1.003", '09/28 01:38:53', '', '09/28 01:38:53')], + '005': [("1503.001", '09/28 01:48:52', '09/28 16:11:23', '09/28 20:31:53')]} """ - status = {} - for k, e in list(jobs.items()): + for k, e in jobs.items(): try: status[e[0]].append((k,) + e[1:]) except KeyError: @@ -1263,27 +1467,23 @@ def listStatusesTimings(jobs): def listAndInterpretRawStatuses(jobs_raw, invert_function): """ - Given a dictionary of job statuses - (whatever the invert_function recognises) - will return a dictionary of jobs in each status - (syntax depends on the invert_function) - - for example with linvert_funtion==istStatuses: - {'Completed': ["2.003","5.001"], 'Removed': ["41.001"], - 'Running': ["408.003"]} - - @param jobs_raw: A dictionary of job statuses - @param invert_function: function to turn a job status into "Completed","Removed","Running", etc - @return: Dictionary of jobs in each category. - """ + Given a dictionary of job statuses, will return a dictionary of jobs in each status + according to the provided invert function. + + Args: + jobs_raw (dict): Dictionary of job statuses. + invert_function (function): Function to turn a job status into "Completed","Removed","Running", etc. + Returns: + dict: Dictionary of jobs in each category. + For example, {'Completed': ["2.003","5.001"], 'Removed': ["41.001"], 'Running': ["408.003"]} + """ outc = {} tmpc = invert_function(jobs_raw) - for s in list(tmpc.keys()): + for s in tmpc.keys(): try: i_s = interpretStatus(int(s[1:])) # ignore flags except Exception: # file corrupted, protect - # print "lairs: Unexpect line: %s"%s continue try: outc[i_s] += tmpc[s] @@ -1295,35 +1495,37 @@ def listAndInterpretRawStatuses(jobs_raw, invert_function): def parseSubmitLogFast(fname): """ - Reads a Condor submit log, return a dictionary of jobIds + Reads a Condor submit log and returns a dictionary of job IDs each having the last status. - For example {(1583,4)': 0, (3616,8): 9} + Args: + fname (str): Filename to parse. - @param fname: filename to parse - @return: Dictionary of jobIDs and last status + Returns: + dict: Dictionary of job IDs and last status. + For example, {(1583,4): 0, (3616,8): 9} """ - jobs_raw = parseSubmitLogFastRaw(fname) jobs = {} - for k in list(jobs_raw.keys()): + for k in jobs_raw.keys(): jobs[rawJobId2Nr(k)] = int(jobs_raw[k]) return jobs def parseSubmitLogFastTimings(fname, year=None): """ - Reads a Condor submit log, return a dictionary of jobIds - each having (the last status, seconds in queue, - if status==5, seconds running) + Reads a Condor submit log and returns a dictionary of job IDs + each having the last status, seconds in queue, + and if status == 5, seconds running. - For example {(1583,4)': (0,345,None), (3616,8): (5,7777,4532)} + Args: + fname (str): Filename to parse. + year (int): The year. If None, use the current year. - @param fname: filename to parse - @param year: if no year, then use the current one - @return: Dictionary of jobIDs + Returns: + dict: Dictionary of job IDs with timings. + For example, {(1583,4): (0,345,None), (3616,8): (5,7777,4532)} """ - jobs_raw, first_time, last_time = parseSubmitLogFastRawTimings(fname) if year is None: @@ -1335,7 +1537,7 @@ def parseSubmitLogFastTimings(fname, year=None): jobs = {} if year_wrap: year1 = year - 1 - for k in list(jobs_raw.keys()): + for k in jobs_raw.keys(): el = jobs_raw[k] status = int(el[0]) diff_time = diffTimeswWrap(el[1], el[3], year1, first_time) @@ -1345,7 +1547,7 @@ def parseSubmitLogFastTimings(fname, year=None): running_time = None jobs[rawJobId2Nr(k)] = (status, diff_time, running_time) else: - for k in list(jobs_raw.keys()): + for k in jobs_raw.keys(): el = jobs_raw[k] status = int(el[0]) diff_time = diffTimes(el[1], el[3], year) @@ -1367,23 +1569,30 @@ def loadCache(fname): """ Loads a pickle file from a filename and returns the resulting data. - @param fname: Filename to load - @return: data retrieved from file + Args: + fname (str): Filename to load. + + Returns: + Any: Data retrieved from file. + + Raises: + RuntimeError: If the file could not be read. """ try: data = util.file_pickle_load(fname) except Exception as e: - raise RuntimeError("Could not read %s" % fname) from e + raise RuntimeError(f"Could not read {fname}") from e return data def saveCache(fname, data): """ Creates a temporary file to store data in, then moves the file into - the correct place. Uses pickle to store data. + the correct place. Uses pickle to store data. - @param fname: Filename to write to. - @param data: data to store in pickle format + Args: + fname (str): Filename to write to. + data (Any): Data to store in pickle format. """ util.file_pickle_dump(fname, data) return diff --git a/lib/condorManager.py b/lib/condorManager.py index f29eedb26..f10cd5634 100644 --- a/lib/condorManager.py +++ b/lib/condorManager.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This module implements functions that will act on Condor -""" +"""This module implements functions that will act on Condor.""" import re @@ -12,6 +11,15 @@ ############################################## # Helper functions def pool2str(pool_name): + """ + Convert pool name to a string suitable for the Condor command line. + + Args: + pool_name (str): The name of the pool. + + Returns: + str: The pool name formatted for the command line. + """ if pool_name is not None: return "-pool %s " % pool_name else: @@ -19,6 +27,15 @@ def pool2str(pool_name): def schedd2str(schedd_name): + """ + Convert schedd name to a string suitable for the Condor command line. + + Args: + schedd_name (str): The name of the schedd. + + Returns: + str: The schedd name formatted for the command line. + """ if schedd_name is not None: return "-name %s " % schedd_name else: @@ -26,6 +43,19 @@ def schedd2str(schedd_name): def cached_exe_cmd(cmd, arg_str, schedd_name, pool_name, schedd_lookup_cache): + """ + Execute a cached Condor command. + + Args: + cmd (str): The Condor command to execute. + arg_str (str): The arguments for the command. + schedd_name (str): The name of the schedd. + pool_name (str): The name of the pool. + schedd_lookup_cache: The cache for schedd lookups. + + Returns: + str: The output of the Condor command. + """ if schedd_lookup_cache is None: schedd_lookup_cache = condorMonitor.NoneScheddCache() @@ -45,6 +75,18 @@ def cached_exe_cmd(cmd, arg_str, schedd_name, pool_name, schedd_lookup_cache): def condorSubmitOne( submit_file, schedd_name=None, pool_name=None, schedd_lookup_cache=condorMonitor.local_schedd_cache ): + """ + Submit a new job using a submit file. Works only when a single cluster is created. + + Args: + submit_file (str): The path to the submit file. + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + schedd_lookup_cache (optional): The cache for schedd lookups. Defaults to condorMonitor.local_schedd_cache. + + Returns: + int: The ClusterId of the submitted job. + """ outstr = cached_exe_cmd("condor_submit", submit_file, schedd_name, pool_name, schedd_lookup_cache) # extract 'submitted to cluster xxx.' part @@ -63,6 +105,19 @@ def condorSubmitOne( def condorRemove( constraint, schedd_name=None, pool_name=None, do_forcex=False, schedd_lookup_cache=condorMonitor.local_schedd_cache ): + """ + Remove a set of jobs from the queue. + + Args: + constraint (str): The constraint to match jobs for removal. + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + do_forcex (bool, optional): If True, force removal. Defaults to False. + schedd_lookup_cache (optional): The cache for schedd lookups. Defaults to condorMonitor.local_schedd_cache. + + Returns: + str: The output of the condor_rm command. + """ opts = "-constraint '%s' " % constraint if do_forcex: opts += "-forcex " @@ -80,6 +135,19 @@ def condorRemoveOne( do_forcex=False, schedd_lookup_cache=condorMonitor.local_schedd_cache, ): + """ + Remove a single job from the queue. + + Args: + cluster_or_uname (str): The ClusterId or username of the job to remove. + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + do_forcex (bool, optional): If True, force removal. Defaults to False. + schedd_lookup_cache (optional): The cache for schedd lookups. Defaults to condorMonitor.local_schedd_cache. + + Returns: + str: The output of the condor_rm command. + """ opts = "%s " % cluster_or_uname if do_forcex: opts += "-forcex " @@ -91,6 +159,18 @@ def condorRemoveOne( # Hold a set of jobs from the queue # def condorHold(constraint, schedd_name=None, pool_name=None, schedd_lookup_cache=condorMonitor.local_schedd_cache): + """ + Hold a set of jobs in the queue. + + Args: + constraint (str): The constraint to match jobs for holding. + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + schedd_lookup_cache (optional): The cache for schedd lookups. Defaults to condorMonitor.local_schedd_cache. + + Returns: + str: The output of the condor_hold command. + """ opts = "-constraint '%s' " % constraint return cached_exe_cmd("condor_hold", opts, schedd_name, pool_name, schedd_lookup_cache) @@ -102,6 +182,18 @@ def condorHold(constraint, schedd_name=None, pool_name=None, schedd_lookup_cache def condorHoldOne( cluster_or_uname, schedd_name=None, pool_name=None, schedd_lookup_cache=condorMonitor.local_schedd_cache ): + """ + Hold a single job in the queue. + + Args: + cluster_or_uname (str): The ClusterId or username of the job to hold. + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + schedd_lookup_cache (optional): The cache for schedd lookups. Defaults to condorMonitor.local_schedd_cache. + + Returns: + str: The output of the condor_hold command. + """ opts = "%s " % cluster_or_uname return cached_exe_cmd("condor_hold", opts, schedd_name, pool_name, schedd_lookup_cache) @@ -111,6 +203,18 @@ def condorHoldOne( # Release a set of jobs from the queue # def condorRelease(constraint, schedd_name=None, pool_name=None, schedd_lookup_cache=condorMonitor.local_schedd_cache): + """ + Release a set of jobs from hold in the queue. + + Args: + constraint (str): The constraint to match jobs for release. + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + schedd_lookup_cache (optional): The cache for schedd lookups. Defaults to condorMonitor.local_schedd_cache. + + Returns: + str: The output of the condor_release command. + """ opts = "-constraint '%s' " % constraint return cached_exe_cmd("condor_release", opts, schedd_name, pool_name, schedd_lookup_cache) @@ -122,6 +226,18 @@ def condorRelease(constraint, schedd_name=None, pool_name=None, schedd_lookup_ca def condorReleaseOne( cluster_or_uname, schedd_name=None, pool_name=None, schedd_lookup_cache=condorMonitor.local_schedd_cache ): + """ + Release a single job from hold in the queue. + + Args: + cluster_or_uname (str): The ClusterId or username of the job to release. + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + schedd_lookup_cache (optional): The cache for schedd lookups. Defaults to condorMonitor.local_schedd_cache. + + Returns: + str: The output of the condor_release command. + """ opts = "%s " % cluster_or_uname return cached_exe_cmd("condor_release", opts, schedd_name, pool_name, schedd_lookup_cache) @@ -131,6 +247,14 @@ def condorReleaseOne( # Issue a condor_reschedule # def condorReschedule(schedd_name=None, pool_name=None, schedd_lookup_cache=condorMonitor.local_schedd_cache): + """ + Issue a condor_reschedule command. + + Args: + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + schedd_lookup_cache (optional): The cache for schedd lookups. Defaults to condorMonitor.local_schedd_cache. + """ cached_exe_cmd("condor_reschedule", "", schedd_name, pool_name, schedd_lookup_cache) return @@ -138,6 +262,15 @@ def condorReschedule(schedd_name=None, pool_name=None, schedd_lookup_cache=condo ############################################## # Helper functions of condorAdvertise def usetcp2str(use_tcp): + """ + Convert use_tcp flag to a string suitable for the Condor command line. + + Args: + use_tcp (bool): If True, use TCP. + + Returns: + str: The use_tcp flag formatted for the command line. + """ if use_tcp: return "-tcp " else: @@ -145,6 +278,15 @@ def usetcp2str(use_tcp): def ismulti2str(is_multi): + """ + Convert is_multi flag to a string suitable for the Condor command line. + + Args: + is_multi (bool): If True, indicate multiple. + + Returns: + str: The is_multi flag formatted for the command line. + """ if is_multi: return "-multiple " else: @@ -153,8 +295,21 @@ def ismulti2str(is_multi): ############################################## # -# Remove a job from the queue +# Advertise a job to the queue # def condorAdvertise(classad_fname, command, use_tcp=False, is_multi=False, pool_name=None): + """ + Advertise a job to the Condor queue. + + Args: + classad_fname (str): The filename of the classad. + command (str): The Condor command to advertise. + use_tcp (bool, optional): If True, use TCP. Defaults to False. + is_multi (bool, optional): If True, indicate multiple. Defaults to False. + pool_name (str, optional): The name of the pool. Defaults to None. + + Returns: + str: The output of the condor_advertise command. + """ cmd_opts = f"{pool2str(pool_name)}{usetcp2str(use_tcp)}{ismulti2str(is_multi)}{command} {classad_fname}" return condorExe.exe_cmd_sbin("condor_advertise", cmd_opts) diff --git a/lib/condorMonitor.py b/lib/condorMonitor.py index 0ec7e16b5..0b99e7ae1 100644 --- a/lib/condorMonitor.py +++ b/lib/condorMonitor.py @@ -39,12 +39,25 @@ def htcondor_full_reload(): + """ + Reloads the HTCondor configuration from the environment and updates HTCondor parameters. + + If the HTCondor Python bindings are enabled, this function reloads the configuration by reading the + `CONDOR_CONFIG` environment variable and manually adds `_CONDOR_` prefixed environment variables to + the HTCondor parameters. + + Returns: + None + """ HTCONDOR_ENV_PREFIX = "_CONDOR_" - HTCONDOR_ENV_PREFIX_LEN = len(HTCONDOR_ENV_PREFIX) # len of _CONDOR_ = 8 + HTCONDOR_ENV_PREFIX_LEN = len(HTCONDOR_ENV_PREFIX) # Length of _CONDOR_ = 8 + if not USE_HTCONDOR_PYTHON_BINDINGS: return + # Reload configuration reading CONDOR_CONFIG from the environment htcondor.reload_config() + # _CONDOR_ variables need to be added manually to _Params for i in os.environ: if i.startswith(HTCONDOR_ENV_PREFIX): @@ -58,6 +71,15 @@ def htcondor_full_reload(): # Set path to condor binaries def set_path(new_condor_bin_path): + """ + Sets the path to the Condor binaries. + + Args: + new_condor_bin_path (str): The new path to the Condor binaries. + + Returns: + None + """ global condor_bin_path condor_bin_path = new_condor_bin_path @@ -111,6 +133,16 @@ def getScheddId(self, schedd_name, pool_name): return (self.iGetCmdScheddStr(schedd_name), {}) def iGetCmdScheddStr(self, schedd_name): + """ + Constructs a command string for the specified schedd name. + + Args: + schedd_name (str or None): The name of the schedd. If None, an empty string is returned. + + Returns: + str: The command string for the schedd. If `schedd_name` is None, returns an empty string. + Otherwise, returns the command string with the schedd name. + """ if schedd_name is None: schedd_str = "" else: @@ -126,21 +158,41 @@ class LocalScheddCache(NoneScheddCache): local disk lookup. Remember which one to use. """ - def __init__(self): - self.enabled = True - # dict of (schedd_name,pool_name)=>(cms arg schedd string,env) - self.cache = {} + class YourClassName: + def __init__(self): + """ + Initializes the instance with default settings. + + Attributes: + enabled (bool): Indicates if the instance is enabled. Default is True. + cache (dict): A dictionary to store schedd and pool name mappings to their respective CMS argument strings and environments. + my_ips (list): A list of IP addresses associated with the current host, including localhost if defined. + """ + self.enabled = True + # dict of (schedd_name,pool_name)=>(cms arg schedd string,env) + self.cache = {} + + self.my_ips = socket.gethostbyname_ex(socket.gethostname())[2] - self.my_ips = socket.gethostbyname_ex(socket.gethostname())[2] try: self.my_ips += socket.gethostbyname_ex("localhost")[2] except socket.gaierror: pass # localhost not defined, ignore def enable(self): + """ + Enables the instance. + + Sets the `enabled` attribute to True. + """ self.enabled = True def disable(self): + """ + Disables the instance. + + Sets the `enabled` attribute to False. + """ self.enabled = False def getScheddId(self, schedd_name, pool_name): @@ -178,42 +230,62 @@ def getScheddId(self, schedd_name, pool_name): # return None if not found # Can raise exceptions - def iGetEnv(self, schedd_name, pool_name): - global disk_cache - data = disk_cache.get(schedd_name + ".igetenv") # pylint: disable=assignment-from-none - if data is None: - cs = CondorStatus("schedd", pool_name) - data = cs.fetch( - constraint='Name=?="%s"' % schedd_name, - format_list=[("ScheddIpAddr", "s"), ("SPOOL_DIR_STRING", "s"), ("LOCAL_DIR_STRING", "s")], - ) - disk_cache.save(schedd_name + ".igetenv", data) - if schedd_name not in data: - raise RuntimeError("Schedd '%s' not found" % schedd_name) - el = data[schedd_name] - if "SPOOL_DIR_STRING" not in el and "LOCAL_DIR_STRING" not in el: - # not advertising, cannot use disk optimization - return None - if "ScheddIpAddr" not in el: - # This should never happen - raise RuntimeError("Schedd '%s' is not advertising ScheddIpAddr" % schedd_name) - - schedd_ip = el["ScheddIpAddr"][1:].split(":")[0] - if schedd_ip in self.my_ips: # seems local, go for the dir - l_dir = el.get("SPOOL_DIR_STRING", el.get("LOCAL_DIR_STRING")) - if os.path.isdir(l_dir): # making sure the directory exists - if "SPOOL_DIR_STRING" in el: - return {"_CONDOR_SPOOL": "%s" % l_dir} - else: # LOCAL_DIR_STRING, assuming spool is LOCAL_DIR_STRING/spool - if os.path.isdir("%s/spool" % l_dir): - return {"_CONDOR_SPOOL": "%s/spool" % l_dir} - else: - # dir does not exist, not relevant, revert to standard behaviour - return None + +def iGetEnv(self, schedd_name, pool_name): + """ + Retrieves the HTCondor environment settings for a specified schedd and pool. + + This method checks the disk cache for existing data. If not found, it fetches the data from the Condor status. + It also checks if the schedd is local and if it is advertising its SPOOL or LOCAL directory. + + Args: + schedd_name (str): The name of the schedd. + pool_name (str): The name of the pool. + + Returns: + dict or None: A dictionary with the environment settings for the schedd if applicable, + or None if the directory does not exist or if the schedd is not local. + + Raises: + RuntimeError: If the schedd is not found or if it is not advertising `ScheddIpAddr`. + Exception: Other exceptions may be raised during the fetching or processing of data. + """ + global disk_cache + data = disk_cache.get(schedd_name + ".igetenv") # pylint: disable=assignment-from-none + if data is None: + cs = CondorStatus("schedd", pool_name) + data = cs.fetch( + constraint='Name=?="%s"' % schedd_name, + format_list=[("ScheddIpAddr", "s"), ("SPOOL_DIR_STRING", "s"), ("LOCAL_DIR_STRING", "s")], + ) + disk_cache.save(schedd_name + ".igetenv", data) + if schedd_name not in data: + raise RuntimeError("Schedd '%s' not found" % schedd_name) + + el = data[schedd_name] + if "SPOOL_DIR_STRING" not in el and "LOCAL_DIR_STRING" not in el: + # Not advertising, cannot use disk optimization + return None + if "ScheddIpAddr" not in el: + # This should never happen + raise RuntimeError("Schedd '%s' is not advertising ScheddIpAddr" % schedd_name) + + schedd_ip = el["ScheddIpAddr"][1:].split(":")[0] + if schedd_ip in self.my_ips: # Seems local, go for the directory + l_dir = el.get("SPOOL_DIR_STRING", el.get("LOCAL_DIR_STRING")) + if os.path.isdir(l_dir): # Making sure the directory exists + if "SPOOL_DIR_STRING" in el: + return {"_CONDOR_SPOOL": "%s" % l_dir} + else: # LOCAL_DIR_STRING, assuming spool is LOCAL_DIR_STRING/spool + if os.path.isdir("%s/spool" % l_dir): + return {"_CONDOR_SPOOL": "%s/spool" % l_dir} else: - # not local + # Directory does not exist, not relevant, revert to standard behavior return None + else: + # Not local + return None # The class does not belong here, it should be in the disk_cache module. @@ -399,6 +471,17 @@ class CondorQuery(StoredQuery): """ def __init__(self, exe_name, resource_str, group_attribute, pool_name=None, security_obj=None, env={}): + """ + Initializes a new instance of the class. + + Args: + exe_name (str): The name of the executable. + resource_str (str): The resource string. + group_attribute (str): The group attribute. + pool_name (str, optional): The name of the pool. Defaults to None. + security_obj (object, optional): The security object. Defaults to None. + env (dict, optional): The environment variables. Defaults to an empty dictionary. + """ self.exe_name = exe_name self.env = env self.resource_str = resource_str @@ -417,11 +500,11 @@ def __init__(self, exe_name, resource_str, group_attribute, pool_name=None, secu self.security_obj = condorSecurity.ProtoRequest() def require_integrity(self, requested_integrity): - """Set client integrity settings to use for condor commands + """ + Set client integrity settings to use for condor commands. Args: - requested_integrity (str): HTCondor integrity level - + requested_integrity (str): HTCondor integrity level. """ if requested_integrity is None: condor_val = None @@ -433,10 +516,11 @@ def require_integrity(self, requested_integrity): self.security_obj.set("CLIENT", "INTEGRITY", condor_val) def get_requested_integrity(self): - """Get the current integrity settings - - Returns: None->None; REQUIRED->True; OPTIONAL->False + """ + Get the current integrity settings. + Returns: + bool or None: None->None; REQUIRED->True; OPTIONAL->False. """ condor_val = self.security_obj.get("CLIENT", "INTEGRITY") if condor_val is None: @@ -444,11 +528,11 @@ def get_requested_integrity(self): return condor_val == "REQUIRED" def require_encryption(self, requested_encryption): - """Set client encryption settings to use for condor commands + """ + Set client encryption settings to use for condor commands. Args: - requested_encryption (str): HTCondor encryption level - + requested_encryption (str): HTCondor encryption level. """ if requested_encryption is None: condor_val = None @@ -460,10 +544,11 @@ def require_encryption(self, requested_encryption): self.security_obj.set("CLIENT", "ENCRYPTION", condor_val) def get_requested_encryption(self): - """Get the current encryption settings - - Returns: None->None; REQUIRED->True; OPTIONAL->False + """ + Get the current encryption settings. + Returns: + bool or None: None->None; REQUIRED->True; OPTIONAL->False. """ condor_val = self.security_obj.get("CLIENT", "ENCRYPTION") if condor_val is None: @@ -471,15 +556,19 @@ def get_requested_encryption(self): return condor_val == "REQUIRED" def fetch(self, constraint=None, format_list=None): - """Return the results obtained using HTCondor commands or python bindings - + """ + Return the results obtained using HTCondor commands or python bindings. Args: - constraint (str): query constraint - format_list (list): Classad attr & type. [(attr1, 'i'), ('attr2', 's')] + constraint (str, optional): Query constraint. Defaults to None. + format_list (list, optional): Classad attr & type. Defaults to None. + Example: [(attr1, 'i'), ('attr2', 's')]. - Returns (dict): Dict containing the query results + Returns: + dict: Dict containing the query results. + Raises: + QueryError: If an error occurs during the query execution. """ try: if USE_HTCONDOR_PYTHON_BINDINGS: @@ -563,85 +652,118 @@ def load(self, constraint=None, format_list=None): """ self.stored_data = self.fetch(constraint, format_list) - def __repr__(self): - output = "%s:\n" % self.__class__.__name__ - output += "exe_name = %s\n" % str(self.exe_name) - output += "env = %s\n" % str(self.env) - output += "resource_str = %s\n" % str(self.resource_str) - output += "group_attribute = %s\n" % str(self.group_attribute) - output += "pool_name = %s\n" % str(self.pool_name) - output += "pool_str = %s\n" % str(self.pool_str) - output += "security_obj = %s\n" % str(self.security_obj) - output += "used_python_bindings = %s\n" % USE_HTCONDOR_PYTHON_BINDINGS - output += "stored_data = %s" % str(self.stored_data) - return output + +def __repr__(self): + """ + Returns a string representation of the object. + + Returns a string containing detailed information about the object's attributes. + + Returns: + str: A string representation of the object. + """ + output = "%s:\n" % self.__class__.__name__ + output += "exe_name = %s\n" % str(self.exe_name) + output += "env = %s\n" % str(self.env) + output += "resource_str = %s\n" % str(self.resource_str) + output += "group_attribute = %s\n" % str(self.group_attribute) + output += "pool_name = %s\n" % str(self.pool_name) + output += "pool_str = %s\n" % str(self.pool_str) + output += "security_obj = %s\n" % str(self.security_obj) + output += "used_python_bindings = %s\n" % USE_HTCONDOR_PYTHON_BINDINGS + output += "stored_data = %s" % str(self.stored_data) + return output class CondorQ(CondorQuery): """Class to implement condor_q. Uses htcondor-python bindings if possible.""" - def __init__(self, schedd_name=None, pool_name=None, security_obj=None, schedd_lookup_cache=local_schedd_cache): - self.schedd_name = schedd_name - if schedd_lookup_cache is None: - schedd_lookup_cache = NoneScheddCache() +def __init__(self, schedd_name=None, pool_name=None, security_obj=None, schedd_lookup_cache=local_schedd_cache): + """ + Initializes a new instance of the class. - schedd_str, env = schedd_lookup_cache.getScheddId(schedd_name, pool_name) - CondorQuery.__init__(self, "condor_q", schedd_str, ["ClusterId", "ProcId"], pool_name, security_obj, env) + Args: + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + security_obj (object, optional): The security object. Defaults to None. + schedd_lookup_cache (object, optional): The cache object used for schedd lookup. + Defaults to local_schedd_cache if not provided. + """ + self.schedd_name = schedd_name - def fetch(self, constraint=None, format_list=None): - if format_list is not None: - # If format_list, make sure ClusterId and ProcId are present - format_list = complete_format_list(format_list, [("ClusterId", "i"), ("ProcId", "i")]) - return CondorQuery.fetch(self, constraint=constraint, format_list=format_list) + if schedd_lookup_cache is None: + schedd_lookup_cache = NoneScheddCache() - def fetch_using_bindings(self, constraint=None, format_list=None): - """Fetch the condor_q results using htcondor-python bindings + schedd_str, env = schedd_lookup_cache.getScheddId(schedd_name, pool_name) + CondorQuery.__init__(self, "condor_q", schedd_str, ["ClusterId", "ProcId"], pool_name, security_obj, env) - Args: - constraint (str): Constraints to be applied to the query - format_list (list): Classad attr & type. [(attr1, 'i'), ('attr2', 's')] - Returns (dict): Dict containing the results +def fetch(self, constraint=None, format_list=None): + """ + Fetches data from the Condor query. - """ - global disk_cache - results_dict = {} # defined here in case of exception - constraint = bindings_friendly_constraint(constraint) - attrs = bindings_friendly_attrs(format_list) + Args: + constraint (str, optional): A constraint to filter the query results. Defaults to None. + format_list (list of tuple, optional): A list of attributes to include in the query results. + Defaults to None. - self.security_obj.save_state() - try: - self.security_obj.enforce_requests() - htcondor_full_reload() - if self.pool_name: - collector = htcondor.Collector(str(self.pool_name)) - else: - collector = htcondor.Collector() + Returns: + list: A list of query results. + """ + if format_list is not None: + # If format_list, make sure ClusterId and ProcId are present + format_list = complete_format_list(format_list, [("ClusterId", "i"), ("ProcId", "i")]) + return CondorQuery.fetch(self, constraint=constraint, format_list=format_list) - if self.schedd_name is None: - schedd = htcondor.Schedd() - else: - schedd_ad = disk_cache.get(self.schedd_name + ".locate") # pylint: disable=assignment-from-none - if schedd_ad is None: - schedd_ad = collector.locate(htcondor.DaemonTypes.Schedd, self.schedd_name) - disk_cache.save(self.schedd_name + ".locate", schedd_ad) - schedd = htcondor.Schedd(schedd_ad) - results = schedd.query(constraint, attrs) - results_dict = list2dict(results, self.group_attribute) - except Exception as ex: - s = "default" - if self.schedd_name is not None: - s = self.schedd_name - p = "default" - if self.pool_name is not None: - p = self.pool_name - err_str = f"Error querying schedd {s} in pool {p} using python bindings: {ex}" - raise PBError(err_str) from ex - finally: - self.security_obj.restore_state() - return results_dict +def fetch_using_bindings(self, constraint=None, format_list=None): + """Fetch the condor_q results using htcondor-python bindings + + Args: + constraint (str): Constraints to be applied to the query + format_list (list): Classad attr & type. [(attr1, 'i'), ('attr2', 's')] + + Returns (dict): Dict containing the results + + """ + global disk_cache + results_dict = {} # defined here in case of exception + constraint = bindings_friendly_constraint(constraint) + attrs = bindings_friendly_attrs(format_list) + + self.security_obj.save_state() + try: + self.security_obj.enforce_requests() + htcondor_full_reload() + if self.pool_name: + collector = htcondor.Collector(str(self.pool_name)) + else: + collector = htcondor.Collector() + + if self.schedd_name is None: + schedd = htcondor.Schedd() + else: + schedd_ad = disk_cache.get(self.schedd_name + ".locate") # pylint: disable=assignment-from-none + if schedd_ad is None: + schedd_ad = collector.locate(htcondor.DaemonTypes.Schedd, self.schedd_name) + disk_cache.save(self.schedd_name + ".locate", schedd_ad) + schedd = htcondor.Schedd(schedd_ad) + results = schedd.query(constraint, attrs) + results_dict = list2dict(results, self.group_attribute) + except Exception as ex: + s = "default" + if self.schedd_name is not None: + s = self.schedd_name + p = "default" + if self.pool_name is not None: + p = self.pool_name + err_str = f"Error querying schedd {s} in pool {p} using python bindings: {ex}" + raise PBError(err_str) from ex + finally: + self.security_obj.restore_state() + + return results_dict class CondorStatus(CondorQuery): @@ -705,10 +827,26 @@ class BaseSubQuery(StoredQuery): """ def __init__(self, query, subquery_func): + """ + Initializes a new instance of the class. + + Args: + query (object): The query object. + subquery_func (function): The function used for subquery processing. + """ self.query = query self.subquery_func = subquery_func def fetch(self, constraint=None): + """ + Fetches data based on the provided query and applies the subquery function. + + Args: + constraint (str, optional): A constraint to filter the query results. Defaults to None. + + Returns: + object: The result of applying the subquery function to the fetched data. + """ indata = self.query.fetch(constraint) return self.subquery_func(self, indata) @@ -717,6 +855,12 @@ def fetch(self, constraint=None): # and had query.load issued before # def load(self, constraint=None): + """ + Loads data using a stored query and applies the subquery function to it. + + Args: + constraint (str, optional): A constraint to filter the query results. Defaults to None. + """ indata = self.query.fetchStored(constraint) self.stored_data = self.subquery_func(indata) @@ -727,9 +871,22 @@ class SubQuery(BaseSubQuery): """ def __init__(self, query, constraint_func=None): + """ + Initializes a new instance of the class. + + Args: + query (object): The query object. + constraint_func (function, optional): A function to apply constraints. Defaults to None. + """ BaseSubQuery.__init__(self, query, lambda d: applyConstraint(d, constraint_func)) def __repr__(self): + """ + Returns a string representation of the object. + + Returns: + str: A string representation of the object. + """ output = "%s:\n" % self.__class__.__name__ # output += "client_name = %s\n" % str(self.client_name) # output += "entry_name = %s\n" % str(self.entry_name) @@ -783,53 +940,99 @@ class Summarize: Summarizing classes """ - # hash_func - Hashing function - # One argument: classad dictionary - # Returns: hash value - # if None, will not be counted - # if a list, all elements will be used - def __init__(self, query, hash_func=lambda x: 1): - self.query = query - self.hash_func = hash_func - # Parameters: - # constraint - string to be passed to query.fetch() - # hash_func - if !=None, use this instead of the main one - # Returns a dictionary of hash values - # Elements are counts (or more dictionaries if hash returns lists) - def count(self, constraint=None, hash_func=None, flat_hash=False): - data = self.query.fetch(constraint) - if flat_hash: - return fetch2count_flat(data, self.getHash(hash_func)) - return fetch2count(data, self.getHash(hash_func)) - - # Use data pre-stored in query - # Same output as count - def countStored(self, constraint_func=None, hash_func=None, flat_hash=False): - data = self.query.fetchStored(constraint_func) - if flat_hash: - return fetch2count_flat(data, self.getHash(hash_func)) - return fetch2count(data, self.getHash(hash_func)) - - # Parameters, same as count - # Returns a dictionary of hash values - # Elements are lists of keys (or more dictionaries if hash returns lists) - def list(self, constraint=None, hash_func=None): - data = self.query.fetch(constraint) - return fetch2list(data, self.getHash(hash_func)) - - # Use data pre-stored in query - # Same output as list - def listStored(self, constraint_func=None, hash_func=None): - data = self.query.fetchStored(constraint_func) - return fetch2list(data, self.getHash(hash_func)) - - ### Internal - def getHash(self, hash_func): - if hash_func is None: - return self.hash_func - else: - return hash_func +def __init__(self, query, hash_func=lambda x: 1): + """ + Initializes a new instance of the class. + + Args: + query (object): The query object. + hash_func (function, optional): The hashing function. Defaults to a function that always returns 1. + """ + self.query = query + self.hash_func = hash_func + + +def count(self, constraint=None, hash_func=None, flat_hash=False): + """ + Counts occurrences of items based on the query results. + + Args: + constraint (str, optional): A constraint to filter the query results. Defaults to None. + hash_func (function, optional): A custom hashing function to use instead of the main one. Defaults to None. + flat_hash (bool, optional): Whether to return a flat dictionary or nested dictionaries. Defaults to False. + + Returns: + dict: A dictionary of hash values with counts. + """ + data = self.query.fetch(constraint) + if flat_hash: + return fetch2count_flat(data, self.getHash(hash_func)) + return fetch2count(data, self.getHash(hash_func)) + + +def countStored(self, constraint_func=None, hash_func=None, flat_hash=False): + """ + Counts occurrences of items based on pre-stored query results. + + Args: + constraint_func (function, optional): A constraint function to filter the stored query results. Defaults to None. + hash_func (function, optional): A custom hashing function to use instead of the main one. Defaults to None. + flat_hash (bool, optional): Whether to return a flat dictionary or nested dictionaries. Defaults to False. + + Returns: + dict: A dictionary of hash values with counts. + """ + data = self.query.fetchStored(constraint_func) + if flat_hash: + return fetch2count_flat(data, self.getHash(hash_func)) + return fetch2count(data, self.getHash(hash_func)) + + +def list(self, constraint=None, hash_func=None): + """ + Lists items based on the query results. + + Args: + constraint (str, optional): A constraint to filter the query results. Defaults to None. + hash_func (function, optional): A custom hashing function to use instead of the main one. Defaults to None. + + Returns: + dict: A dictionary of hash values with lists of keys. + """ + data = self.query.fetch(constraint) + return fetch2list(data, self.getHash(hash_func)) + + +def listStored(self, constraint_func=None, hash_func=None): + """ + Lists items based on pre-stored query results. + + Args: + constraint_func (function, optional): A constraint function to filter the stored query results. Defaults to None. + hash_func (function, optional): A custom hashing function to use instead of the main one. Defaults to None. + + Returns: + dict: A dictionary of hash values with lists of keys. + """ + data = self.query.fetchStored(constraint_func) + return fetch2list(data, self.getHash(hash_func)) + + +def getHash(self, hash_func): + """ + Get the hash function to use. + + Args: + hash_func (function): The custom hash function, if provided. + + Returns: + function: The hash function to use. + """ + if hash_func is None: + return self.hash_func + else: + return hash_func ############################################################ @@ -839,9 +1042,17 @@ def getHash(self, hash_func): ############################################################ -# check that req_format_els are present in in_format_list, and if not add them -# return a new format_list def complete_format_list(in_format_list, req_format_els): + """ + Checks if required format elements are present in the input format list, and if not, adds them. + + Args: + in_format_list (list): The input format list. + req_format_els (list): The list of required format elements. + + Returns: + list: The new format list with required elements added if missing. + """ out_format_list = in_format_list[0:] for req_format_el in req_format_els: found = False @@ -854,35 +1065,17 @@ def complete_format_list(in_format_list, req_format_els): return out_format_list -# -# Convert Condor XML to list -# -# For Example: -# -# -# -# -# -# Job -# Machine -# 0 -# -# -# ON_EXIT -# -# -# Job -# Machine -# 0 -# -# /DC=gov/DC=fnal/O=Fermilab/OU=People/CN=Igor Sfiligoi/UID=sfiligoi -# -# -# +def xml2list_start_element(name, attrs): + """ + XML handler function called when starting an XML element. + Args: + name (str): The name of the XML element. + attrs (dict): The attributes of the XML element. -# 3 xml2list XML handler functions -def xml2list_start_element(name, attrs): + Raises: + TypeError: If the XML element type is not supported. + """ global xml2list_data, xml2list_inclassad, xml2list_inattr, xml2list_intype if name == "c": xml2list_inclassad = {} @@ -912,29 +1105,70 @@ def xml2list_start_element(name, attrs): def xml2list_end_element(name): + """ + XML handler function called when encountering the end of an XML element. + + Args: + name (str): The name of the XML element. + + Global Variables: + xml2list_data (list): A list containing parsed classad dictionaries. + xml2list_inclassad (dict): The current classad dictionary being parsed. + xml2list_inattr (dict): The current attribute dictionary within the classad being parsed. + xml2list_intype (str): The data type of the current attribute value ('i' for integer, 'r' for float, + 'b' for boolean, 'un' for unknown type). + + Raises: + TypeError: If an unexpected XML element type is encountered. + + Notes: + - If the XML element is 'c' (classad), appends the current classad dictionary to xml2list_data. + - If the XML element is 'a' (attribute), adds the attribute to the current classad dictionary. + - Resets xml2list_intype to 's' (string) if the XML element is one of 'i', 'b', 'un', or 'r'. + - Handles cases where the XML element name is 's' (string), 'e' (end), or 'classads' by passing silently. + - Raises a TypeError if an unexpected XML element type is encountered. + + """ global xml2list_data, xml2list_inclassad, xml2list_inattr, xml2list_intype - # The following would be resetting global variables and failing ./test_frontend.py - # xml2list_data, xml2list_inclassad, xml2list_inattr, xml2list_intype = {} + if name == "c": xml2list_data.append(xml2list_inclassad) xml2list_inclassad = None elif name == "a": - xml2list_inclassad[xml2list_inattr["name"]] = xml2list_inattr["val"] # pylint: disable=unsubscriptable-object + xml2list_inclassad[xml2list_inattr["name"]] = xml2list_inattr["val"] xml2list_inattr = None elif name in ("i", "b", "un", "r"): xml2list_intype = "s" - elif name in ("s", "e"): - pass # nothing to do - elif name == "classads": - pass # top element, nothing to do + elif name in ("s", "e", "classads"): + pass # Nothing to do for these elements else: raise TypeError("Unexpected type: %s" % name) def xml2list_char_data(data): + """ + XML handler function called when receiving character data within an XML element. + + Args: + data (str): The character data received. + + Global Variables: + xml2list_data (list): A list containing parsed classad dictionaries. + xml2list_inclassad (dict): The current classad dictionary being parsed. + xml2list_inattr (dict): The current attribute dictionary within the classad being parsed. + xml2list_intype (str): The data type of the current attribute value ('i' for integer, 'r' for float, + 'b' for boolean, 'un' for unknown type). + + Notes: + - This function updates the xml2list_inattr["val"] with parsed data based on xml2list_intype. + - If xml2list_intype is 'b' and xml2list_inattr["val"] is None, it interprets the data as boolean. + - Handles unescaped double quotes in the data by replacing '\\"' with '"'. + + """ global xml2list_data, xml2list_inclassad, xml2list_inattr, xml2list_intype + if xml2list_inattr is None: - # only process when in attribute + # Only process when inside an attribute return if xml2list_intype == "i": @@ -943,38 +1177,67 @@ def xml2list_char_data(data): xml2list_inattr["val"] = float(data) elif xml2list_intype == "b": if xml2list_inattr["val"] is not None: - # nothing to do, value was in attribute + # Value was already in attribute, nothing to do pass else: + # Interpret the first character of data as boolean value xml2list_inattr["val"] = data[0] in ("T", "t", "1") elif xml2list_intype == "un": - # nothing to do, value was in attribute + # Value was already in attribute, nothing to do pass else: + # Append unescaped data to the current attribute value unescaped_data = data.replace('\\"', '"') xml2list_inattr["val"] += unescaped_data def xml2list(xml_data): + """ + Parse XML data representing Condor classads and convert it into a list of dictionaries. + + This function parses the XML data using the Expat parser and extracts classads and their attributes. + + Args: + xml_data (list of str): The XML data representing Condor classads. + + Returns: + list of dict: A list containing dictionaries, where each dictionary represents a classad. + + Global Variables: + xml2list_data (list): A list containing parsed classad dictionaries. + xml2list_inclassad (dict): The current classad dictionary being parsed. + xml2list_inattr (dict): The current attribute dictionary within the classad being parsed. + xml2list_intype (str): The data type of the current attribute value ('i' for integer, 'r' for float, + 'b' for boolean, 'un' for unknown type). + + Raises: + RuntimeError: If there's an error parsing the XML data. + + """ global xml2list_data, xml2list_inclassad, xml2list_inattr, xml2list_intype + # Initialize global variables xml2list_data = [] xml2list_inclassad = None xml2list_inattr = None xml2list_intype = None + # Create an Expat parser p = xml.parsers.expat.ParserCreate() + + # Set XML handler functions p.StartElementHandler = xml2list_start_element p.EndElementHandler = xml2list_end_element p.CharacterDataHandler = xml2list_char_data + # Find the position of the XML header found_xml = -1 for line in range(len(xml_data)): - # look for the xml header if xml_data[line][:5] == "= 0: try: p.Parse(" ".join(xml_data[found_xml:]), 1) @@ -982,7 +1245,6 @@ def xml2list(xml_data): raise RuntimeError("Failed to parse XML data, TypeError: %s" % e) from e except Exception as e: raise RuntimeError("Failed to parse XML data, generic error") from e - # else no xml, so return an empty list return xml2list_data @@ -1232,31 +1494,57 @@ def fetch2count_flat(data, hash_func): # Elements are lists of keys (or more dictionaries if hash returns lists) # def fetch2list(data, hash_func): + """ + Convert data into a nested dictionary structure based on hash values. + + This function takes a dictionary of data and a hash function, and creates a nested dictionary structure + based on the hash values returned by the hash function. It uses the hash values to organize the data into + lists or dictionaries within the nested structure. + + Args: + data (dict): The input data to be converted into a nested dictionary. + hash_func (function): The hash function used to generate hash values from data elements. + + Returns: + dict: A nested dictionary structure containing the converted data. + + """ return_list = {} + + # Iterate through each key in the data dictionary for k in list(data.keys()): - el = data[k] + el = data[k] # Get the data element associated with the current key + # Calculate the hash value using the provided hash function hid = hash_func(el) + + # Skip this element if the hash function returns None if hid is None: - # hash tells us it does not want to list this continue - # lel will point to the real list + # Initialize a pointer to the current level of the return dictionary lel = return_list - # check if it is a list + # Check if the hash value is a list if isinstance(hid, list): - # have to create structure inside list + # Traverse the nested dictionary structure based on the hash values for h in hid[:-1]: if h not in lel: lel[h] = {} lel = lel[h] + + # Use the last hash value to access the final level of the nested dictionary hid = hid[-1] + # Check if the hash value already exists in the current level of the return dictionary if hid in lel: + # If the hash value already exists, append the current key to the corresponding list list_el = lel[hid].append[k] else: + # If the hash value does not exist, create a new list with the current key list_el = [k] + + # Update the nested dictionary with the new list lel[hid] = list_el return return_list @@ -1334,14 +1622,52 @@ def bindings_friendly_attrs(format_list): class SummarizeMulti: + """ + Class to summarize multiple queries. + + This class aggregates the results of multiple queries into a single summary. It provides methods to count + occurrences based on specified constraints and hash functions. + + Args: + queries (list): A list of query objects to be summarized. + hash_func (function, optional): The hash function used for summarization. Defaults to a function that + always returns 1. + + Attributes: + counts (list): A list containing the results of individual queries. + + """ + def __init__(self, queries, hash_func=lambda x: 1): + """ + Initializes the SummarizeMulti object. + + Args: + queries (list): A list of query objects to be summarized. + hash_func (function, optional): The hash function used for summarization. Defaults to a function + that always returns 1. + + """ self.counts = [] for query in queries: self.counts.append(self.count(query, hash_func)) self.hash_func = hash_func - # see Count for description def count(self, constraint=None, hash_func=None): + """ + Count occurrences based on specified constraints and hash functions. + + This method counts occurrences based on the specified constraints and hash functions for all queries + stored in the SummarizeMulti object. + + Args: + constraint (str, optional): A string representing the query constraint. Defaults to None. + hash_func (function, optional): The hash function used for counting. Defaults to None. + + Returns: + dict: A dictionary containing the count of occurrences. + + """ out = {} for c in self.counts: @@ -1350,8 +1676,21 @@ def count(self, constraint=None, hash_func=None): return out - # see Count for description def countStored(self, constraint_func=None, hash_func=None): + """ + Count occurrences using stored data and specified constraints and hash functions. + + This method counts occurrences using the stored data and specified constraints and hash functions + for all queries stored in the SummarizeMulti object. + + Args: + constraint_func (function, optional): A function representing the constraint. Defaults to None. + hash_func (function, optional): The hash function used for counting. Defaults to None. + + Returns: + dict: A dictionary containing the count of occurrences. + + """ out = {} for c in self.counts: @@ -1363,7 +1702,39 @@ def countStored(self, constraint_func=None, hash_func=None): # condor_q, where we have only one ProcId x ClusterId class CondorQLite(CondorQuery): + """ + Class for querying a Condor pool with simplified functionality. + + This class extends the functionality of the CondorQuery class to provide simplified querying of a Condor pool. + It is designed to handle basic query operations with a focus on ease of use and reduced complexity. + + Args: + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + security_obj (object, optional): The security object for authentication. Defaults to None. + schedd_lookup_cache (object, optional): The cache object for storing schedd lookup data. Defaults to + local_schedd_cache. + + Attributes: + schedd_name (str): The name of the schedd. + pool_name (str): The name of the pool. + security_obj (object): The security object for authentication. + schedd_lookup_cache (object): The cache object for storing schedd lookup data. + + """ + def __init__(self, schedd_name=None, pool_name=None, security_obj=None, schedd_lookup_cache=local_schedd_cache): + """ + Initializes the CondorQLite object. + + Args: + schedd_name (str, optional): The name of the schedd. Defaults to None. + pool_name (str, optional): The name of the pool. Defaults to None. + security_obj (object, optional): The security object for authentication. Defaults to None. + schedd_lookup_cache (object, optional): The cache object for storing schedd lookup data. Defaults to + local_schedd_cache. + + """ self.schedd_name = schedd_name if schedd_lookup_cache is None: @@ -1374,6 +1745,20 @@ def __init__(self, schedd_name=None, pool_name=None, security_obj=None, schedd_l CondorQuery.__init__(self, "condor_q", schedd_str, "ClusterId", pool_name, security_obj, env) def fetch(self, constraint=None, format_list=None): + """ + Fetches data from the Condor pool. + + This method fetches data from the Condor pool based on the specified constraint and format list. + + Args: + constraint (str, optional): The constraint for filtering the query results. Defaults to None. + format_list (list, optional): The list of attributes and their types for formatting the query results. + Defaults to None. + + Returns: + dict: A dictionary containing the query results. + + """ if format_list is not None: # check that ClusterId is present, and if not add it format_list = complete_format_list(format_list, [("ClusterId", "i")]) diff --git a/lib/condorSecurity.py b/lib/condorSecurity.py index 2d799979f..87be1f73d 100644 --- a/lib/condorSecurity.py +++ b/lib/condorSecurity.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This module implements classes that will setup the Condor security as needed -""" +"""This module implements classes that will set up the Condor security as needed.""" import copy import os @@ -20,15 +19,28 @@ # # All info is in the state attribute class EnvState: + """ + This class manages the state of the Condor environment. + + Attributes: + filter (list): List of Condor variables to save. + state (dict): The saved state of the environment variables. + """ + def __init__(self, filter): - # filter is a list of Condor variables to save + """ + Initializes the EnvState instance. + + Args: + filter (list): List of Condor variables to save. + """ self.filter = filter self.load() - ################################# - # Restore back to what you found - # when creating this object def restore(self): + """ + Restores the environment variables to their original state. + """ for condor_key in list(self.state.keys()): env_key = "_CONDOR_%s" % condor_key old_val = self.state[condor_key] @@ -38,11 +50,12 @@ def restore(self): if os.environ.get(env_key): del os.environ[env_key] - ########################################## - # Load the environment state into - # Almost never called by the user - # It gets called automatically by __init__ def load(self): + """ + Loads the current environment state into the instance. + + This method is automatically called by the __init__ method. + """ filter = self.filter saved_state = {} for condor_key in filter: @@ -60,6 +73,15 @@ def load(self): def convert_sec_filter(sec_filter): + """ + Converts a security filter dictionary to a list of Condor keys. + + Args: + sec_filter (dict): Dictionary of security contexts and features. + + Returns: + list: List of Condor keys. + """ filter = [] for context in list(sec_filter.keys()): for feature in sec_filter[context]: @@ -69,8 +91,20 @@ def convert_sec_filter(sec_filter): class SecEnvState(EnvState): + """ + This class manages the state of the Condor security environment. + + Attributes: + sec_filter (dict): Dictionary of security contexts and features. + """ + def __init__(self, sec_filter): - # sec_filter is a dictionary of [contex]=[feature list] + """ + Initializes the SecEnvState instance. + + Args: + sec_filter (dict): Dictionary of security contexts and features. + """ EnvState.__init__(self, convert_sec_filter(sec_filter)) self.sec_filter = sec_filter @@ -85,9 +119,21 @@ def __init__(self, sec_filter): # This class handle requests for ensuring # the security state is in a particular state class SecEnvRequest: + """ + This class handles requests for setting the Condor security environment state. + + Attributes: + requests (dict): Dictionary of security requests. + saved_state (SecEnvState): The saved state of the environment variables. + """ + def __init__(self, requests=None): - # requests is a dictionary of requests [context][feature]=VAL - # TODO: requests can be a self initializinf dictionary of dictionaries in PY3 + """ + Initializes the SecEnvRequest instance. + + Args: + requests (dict, optional): Dictionary of security requests. Defaults to None. + """ self.requests = {} if requests is not None: for context in list(requests.keys()): @@ -96,9 +142,15 @@ def __init__(self, requests=None): self.saved_state = None - ############################################## - # Methods for accessing the requests - def set(self, context, feature, value): # if value is None, remove the request + def set(self, context, feature, value): + """ + Sets a security request. + + Args: + context (str): The security context. + feature (str): The security feature. + value (str): The value to set. If None, the request is removed. + """ if value is not None: if context not in self.requests: self.requests[context] = {} @@ -110,6 +162,16 @@ def set(self, context, feature, value): # if value is None, remove the request del self.requests[context] def get(self, context, feature): + """ + Gets a security request. + + Args: + context (str): The security context. + feature (str): The security feature. + + Returns: + str: The value of the request, or None if not found. + """ if context in self.requests: if feature in self.requests[context]: return self.requests[context][feature] @@ -118,13 +180,22 @@ def get(self, context, feature): else: return None - ############################################## - # Methods for preserving the old state - def has_saved_state(self): + """ + Checks if there is a saved state. + + Returns: + bool: True if there is a saved state, False otherwise. + """ return self.saved_state is not None def save_state(self): + """ + Saves the current state of the environment variables. + + Raises: + RuntimeError: If there is already a saved state. + """ if self.has_saved_state(): raise RuntimeError("There is already a saved state! Restore that first.") filter = {} @@ -134,18 +205,19 @@ def save_state(self): self.saved_state = SecEnvState(filter) def restore_state(self): + """ + Restores the environment variables to the saved state. + """ if self.saved_state is None: return # nothing to do self.saved_state.restore() self.saved_state = None - ############################################## - # Methods for changing to the desired state - - # you should call save_state before this one, - # if you want to ever get back def enforce_requests(self): + """ + Enforces the security requests by setting the environment variables. + """ for context in list(self.requests.keys()): for feature in list(self.requests[context].keys()): condor_key = f"SEC_{context}_{feature}" @@ -189,12 +261,25 @@ def enforce_requests(self): ######################################## class EnvProtoState(SecEnvState): + """ + This class manages the state of the Condor protocol security environment. + + Attributes: + filter (dict): Dictionary of contexts and features to filter. + """ + def __init__(self, filter=None): + """ + Initializes the EnvProtoState instance. + + Args: + filter (dict, optional): Dictionary of contexts and features to filter. Defaults to None. + """ if filter is not None: # validate filter for c in list(filter.keys()): if c not in CONDOR_CONTEXT_LIST: - raise ValueError(f"Invalid contex '{c}'. Must be one of {CONDOR_CONTEXT_LIST}") + raise ValueError(f"Invalid context '{c}'. Must be one of {CONDOR_CONTEXT_LIST}") for f in filter[c]: if f not in CONDOR_PROTO_FEATURE_LIST: raise ValueError(f"Invalid feature '{f}'. Must be one of {CONDOR_PROTO_FEATURE_LIST}") @@ -215,7 +300,26 @@ def __init__(self, filter=None): # the context and feature are related # to the Condor protocol handling class ProtoRequest(SecEnvRequest): - def set(self, context, feature, value): # if value is None, remove the request + """ + This class handles requests for setting the Condor protocol security environment state. + + Methods: + set: Sets a security request. + get: Gets a security request. + """ + + def set(self, context, feature, value): + """ + Sets a security request. + + Args: + context (str): The security context. + feature (str): The security feature. + value (str): The value to set. If None, the request is removed. + + Raises: + ValueError: If the context, feature, or value is invalid. + """ if context not in CONDOR_CONTEXT_LIST: raise ValueError("Invalid security context '%s'." % context) if feature not in CONDOR_PROTO_FEATURE_LIST: @@ -225,6 +329,19 @@ def set(self, context, feature, value): # if value is None, remove the request SecEnvRequest.set(self, context, feature, value) def get(self, context, feature): + """ + Gets a security request. + + Args: + context (str): The security context. + feature (str): The security feature. + + Returns: + str: The value of the request, or None if not found. + + Raises: + ValueError: If the context or feature is invalid. + """ if context not in CONDOR_CONTEXT_LIST: raise ValueError("Invalid security context '%s'." % context) if feature not in CONDOR_PROTO_FEATURE_LIST: @@ -242,7 +359,29 @@ def get(self, context, feature): class GSIRequest(ProtoRequest): + """ + This class handles requests for setting the Condor GSI security environment state. + + Attributes: + x509_proxy (str): The X.509 proxy. + allow_fs (bool): If True, allows FS authentication. Defaults to True. + allow_idtokens (bool): If True, allows IDTOKENS authentication. Defaults to True. + x509_proxy_saved_state (str): The saved state of the X.509 proxy environment variable. + """ + def __init__(self, x509_proxy=None, allow_fs=True, allow_idtokens=True, proto_requests=None): + """ + Initializes the GSIRequest instance. + + Args: + x509_proxy (str, optional): The X.509 proxy. Defaults to None. + allow_fs (bool, optional): If True, allows FS authentication. Defaults to True. + allow_idtokens (bool, optional): If True, allows IDTOKENS authentication. Defaults to True. + proto_requests (dict, optional): Dictionary of protocol requests. Defaults to None. + + Raises: + ValueError: If neither IDTOKENS nor GSI is specified in the authentication options. + """ if allow_idtokens: auth_str = "IDTOKENS,GSI" else: @@ -271,8 +410,6 @@ def __init__(self, x509_proxy=None, allow_fs=True, allow_idtokens=True, proto_re self.x509_proxy_saved_state = None if x509_proxy is None: - # if 'X509_USER_PROXY' not in os.environ: - # raise RuntimeError("x509_proxy not provided and env(X509_USER_PROXY) undefined") x509_proxy = os.environ.get("X509_USER_PROXY") # Here I should probably check if the proxy is valid @@ -280,8 +417,13 @@ def __init__(self, x509_proxy=None, allow_fs=True, allow_idtokens=True, proto_re self.x509_proxy = x509_proxy - ############################################## def save_state(self): + """ + Saves the current state of the environment variables. + + Raises: + RuntimeError: If there is already a saved state. + """ if self.has_saved_state(): raise RuntimeError("There is already a saved state! Restore that first.") @@ -292,6 +434,9 @@ def save_state(self): ProtoRequest.save_state(self) def restore_state(self): + """ + Restores the environment variables to the saved state. + """ if self.saved_state is None: return # nothing to do @@ -305,8 +450,10 @@ def restore_state(self): # unset, just to prevent bugs self.x509_proxy_saved_state = None - ############################################## def enforce_requests(self): + """ + Enforces the security requests by setting the environment variables. + """ ProtoRequest.enforce_requests(self) if self.x509_proxy: os.environ["X509_USER_PROXY"] = self.x509_proxy diff --git a/lib/config_util.py b/lib/config_util.py index 71b8f92fc..ec7312904 100644 --- a/lib/config_util.py +++ b/lib/config_util.py @@ -1,11 +1,10 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This module contains a list of shared utility function used by the both OSG collector and CRIC -configuration generation helper tools +"""This module contains a list of shared utility functions used by both OSG collector and CRIC +configuration generation helper tools. """ - import collections import os @@ -100,9 +99,15 @@ } -# Class to handle error in the merge script class ProgramError(Exception): - """Simple collection of program error codes and related short messages""" + """Simple collection of program error codes and related short messages. + + Args: + code (int): Error code. + + Attributes: + code (int): Exception error code. + """ codes_map = { 1: "File not found", @@ -118,15 +123,16 @@ def __init__(self, code): def get_yaml_file_info(file_name): - """Loads a yaml file into a dictionary + """Loads a yaml file into a dictionary. Args: - file_name (str): The file to load + file_name (str): The file to load. Returns: + dict: The loaded yaml file content. Raises: - ProgramError + ProgramError: If the file is not found. """ if not os.path.isfile(file_name): print("Cannot find file %s" % file_name) @@ -138,11 +144,11 @@ def get_yaml_file_info(file_name): def write_to_yaml_file(file_name, information): - """Auxiliary function used to write a python dictionary into a yaml file + """Auxiliary function used to write a python dictionary into a yaml file. Args: - file_name (string): The yaml filename that will be written out - information (dict): + file_name (str): The yaml filename that will be written out. + information (dict): The dictionary to write. """ with open(file_name, "w") as outfile: noalias_dumper = yaml.dumper.SafeDumper @@ -151,13 +157,13 @@ def write_to_yaml_file(file_name, information): def get_attr_str(attrs): - """Convert attributes from a dictionary form to the corresponding configuration string + """Convert attributes from a dictionary form to the corresponding configuration string. Args: - attrs (dict): the dictionary containing the attributes + attrs (dict): The dictionary containing the attributes. Returns: - string: the string representing the xml attributes section for a single entry + str: The string representing the xml attributes section for a single entry. """ out = "" for name, data in sorted(attrs.items()): @@ -170,7 +176,6 @@ def get_attr_str(attrs): else: data["comment"] = ' comment="' + data["comment"] + '"' if "value" in data: - # pylint: disable=line-too-long out += ( ' \n' % data @@ -179,15 +184,14 @@ def get_attr_str(attrs): return out[:-1] -# Collect all submit attributes def get_submit_attr_str(submit_attrs): - """Convert submit attributes from a dictionary form to the corresponding configuration string + """Convert submit attributes from a dictionary form to the corresponding configuration string. Args: - submit_attrs (dict): the dictionary containing the submit attributes + submit_attrs (dict): The dictionary containing the submit attributes. Returns: - string: the string representing the xml submit attributes section for a single entry + str: The string representing the xml submit attributes section for a single entry. """ out = "" if submit_attrs: @@ -198,15 +202,14 @@ def get_submit_attr_str(submit_attrs): return out -# Collect all pilots limits def get_limits_str(limits): - """Convert pilots limits from a dictionary form to the corresponding configuration string + """Convert pilots limits from a dictionary form to the corresponding configuration string. Args: - limits (dict): the dictionary containing the pilots limits + limits (dict): The dictionary containing the pilots limits. Returns: - string: the string representing the xml pilots limits section for a single entry + str: The string representing the xml pilots limits section for a single entry. """ out = "" if limits is not None: @@ -232,15 +235,14 @@ def get_limits_str(limits): return out -# Collect submission speed def get_submission_speed(submission_speed): - """Convert submission speed from a name to the corresponding configuration string + """Convert submission speed from a name to the corresponding configuration string. Args: - submission_speed (string): the string containing the submission speed name + submission_speed (str): The string containing the submission speed name. Returns: - string: the string representing the xml submission speed section for a single entry + str: The string representing the xml submission speed section for a single entry. """ out = "" if submission_speed: @@ -262,12 +264,15 @@ def get_submission_speed(submission_speed): def update(data, update_data, overwrite=True): - """Recursively update the information contained in a dictionary + """Recursively update the information contained in a dictionary. Args: - data (dict): The starting dictionary - update_data (dict): The dictionary that contains the new data - overwrite (bool): wether existing keys are going to be overwritten + data (dict): The starting dictionary. + update_data (dict): The dictionary that contains the new data. + overwrite (bool): Whether existing keys are going to be overwritten. + + Returns: + dict: The updated dictionary. """ for key, value in list(update_data.items()): if value is None: @@ -285,11 +290,11 @@ def update(data, update_data, overwrite=True): def write_to_xml_file(file_name, information): - """Writes out on the disk entries xml adding the necessary top level tags + """Writes out on the disk entries xml adding the necessary top level tags. Args: - file_name (str): the filename where you want to write to. - information (str): a string containing the xml for all the entries + file_name (str): The filename where you want to write to. + information (str): A string containing the xml for all the entries. """ with open(file_name, "w") as outfile: outfile.write("\n") @@ -301,13 +306,12 @@ def write_to_xml_file(file_name, information): outfile.write("\n") -# Write collected information to file def write_to_file(file_name, information): - """Take a dictionary and writes it out to disk as a yaml file + """Take a dictionary and writes it out to disk as a yaml file. Args: - file_name (str): the filename to write to disk - information (dict): the dictionary to write out as yaml file + file_name (str): The filename to write to disk. + information (dict): The dictionary to write out as yaml file. """ with open(file_name, "w") as outfile: yaml.safe_dump(information, outfile, default_flow_style=False) diff --git a/lib/defaults.py b/lib/defaults.py index 9a3d4bcbb..cffd985d1 100644 --- a/lib/defaults.py +++ b/lib/defaults.py @@ -1,78 +1,68 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""Collections of constants that are used throughout the GlideinWMS project +"""Collections of constants that are used throughout the GlideinWMS project. """ # GlideinWMS has to be compatible across versions running on different Python interpreters # Python 2 text files are the same as binary files except some newline handling # and strings are the same as bytes -# To maintain this in Python 3 is possible to write binary files and use for the strings +# To maintain this in Python 3 it is possible to write binary files and use for the strings # any encoding that preserves the bytes (0x80...0xff) through round-tripping from byte -# streams to Unicode and back, latin-1 is the best known of these (more compact). -# TODO: alt evaluate the use of latin-1 text files -BINARY_ENCODING = "latin_1" # valid aliases (case insensitive) latin-1, latin1, L1, iso-8859-1, 8859 +# streams to Unicode and back. latin-1 is the best known of these (more compact). +# TODO: Evaluate the use of latin-1 text files as an alternative. +BINARY_ENCODING = "latin_1" # valid aliases (case insensitive): latin-1, latin1, L1, iso-8859-1, 8859 -# All strings should be ASCII, so ASCII or latin-1 (256 safe) should be OK -# Anyway M2Crypto uses 'utf_8' to implement AnyStr (union of bytes and str) +# All strings should be ASCII, so ASCII or latin-1 (256 safe) should be OK. +# Anyway M2Crypto uses 'utf_8' to implement AnyStr (union of bytes and str). BINARY_ENCODING_CRYPTO = "utf_8" # valid aliases: utf-8, utf8 BINARY_ENCODING_ASCII = "ascii" # valid aliases: 646, us-ascii -BINARY_ENCODING_DEFAULT = "utf_8" # valid aliases: utf-8, utf8 Default Python 3 encoding +BINARY_ENCODING_DEFAULT = "utf_8" # valid aliases: utf-8, utf8 (default Python 3 encoding) def force_bytes(instr, encoding=BINARY_ENCODING_CRYPTO): - """Forces the output to be bytes, encoding the input if it is a unicode string (str) - - AnyStr is str or bytes types + """Forces the output to be bytes, encoding the input if it is a unicode string (str). Args: - instr (AnyStr): string to be converted - encoding (str): a valid encoding, utf_8, ascii, latin-1 (iso-8859-1') + instr (Union[str, bytes]): String to be converted. + encoding (str): A valid encoding, such as utf_8, ascii, latin-1 (iso-8859-1). Returns: - bytes: instr as bytes string + bytes: The input as a bytes string. Raises: - ValueError: if it detects an improper str conversion (b'' around the string) + ValueError: If it detects an improper str conversion (b'' around the string). """ if isinstance(instr, str): - # raise Exception("ALREADY str!") # Use this for investigations if instr.startswith("b'") and len(instr) > 2 and instr.endswith("'"): - # This may cause errors with the random strings generated for unit tests, which may start with "b'" raise ValueError( "Input was improperly converted into string (resulting in b'' characters added): %s" % instr ) - # If the encoding is known codecs can be used for more efficiency, e.g. codecs.latin_1_encode(x)[0] return instr.encode(encoding) return instr def force_str(inbytes, encoding=BINARY_ENCODING_CRYPTO): - """Forces the output to be str, decoding the input if it is a bytestring (bytes) - - AnyStr is str or bytes types + """Forces the output to be str, decoding the input if it is a bytestring (bytes). Args: - inbytes (AnyStr): string to be converted - encoding (str): a valid encoding, utf8, ascii, latin-1 + inbytes (Union[str, bytes]): String to be converted. + encoding (str): A valid encoding, such as utf_8, ascii, latin-1. Returns: - str: instr as unicode string + str: The input as a unicode string. Raises: - ValueError: if it detects an improper str conversion (b'' around the string) or - the input is neither string or bytes + ValueError: If it detects an improper str conversion (b'' around the string) or + the input is neither string nor bytes. """ if isinstance(inbytes, str): - # raise Exception("ALREADY str!") if inbytes.startswith("b'"): raise ValueError( "Input was improperly converted into string (resulting in b'' characters added): %s" % inbytes ) return inbytes - # if isinstance(inbytes, (bytes, bytearray)): try: return inbytes.decode(encoding) except AttributeError: - # This is not bytes, bytearray (and was not str) raise ValueError(f"Input is not str or bytes: {type(inbytes)} ({inbytes})") diff --git a/lib/disk_cache.py b/lib/disk_cache.py index c5fe90380..78d57da16 100644 --- a/lib/disk_cache.py +++ b/lib/disk_cache.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This module helps managing objects that needs to be cached. Objects are cached both in memory +"""This module helps manage objects that need to be cached. Objects are cached both in memory and on disk so that the cache can be leveraged by multiple processes. -Each object you want to save needs to have a string id that identifies it. The id is used to -locate the object in the memory (a key in a dictionary), and on the disk (the filename). +Each object you want to save needs to have a string ID that identifies it. The ID is used to +locate the object in memory (a key in a dictionary), and on the disk (the filename). """ import contextlib @@ -18,14 +18,13 @@ @contextlib.contextmanager def get_lock(name): - """Create a "name".lock file and, using fcnt, - lock it (or wait for the lock to be released) before proceeding + """Create a `name.lock` file and, using fcntl, lock it (or wait for the lock to be released) before proceeding. - N.B. The "name".lock file is not removed after the lock is released, it is kept to be reused: + The `name.lock` file is not removed after the lock is released; it is kept to be reused: we only care about the lock status. - Params: - name (str): the name of the file you want to lock. A lockfile name.lock will be created + Args: + name (str): The name of the file you want to lock. A lockfile `name.lock` will be created. """ with open(name + ".lock", "a+") as fdesc: fcntl.flock(fdesc, fcntl.LOCK_EX) @@ -33,39 +32,46 @@ def get_lock(name): class DiskCache: - """The class that manages the cache. Objects expires after a cache_duration time (defaults - to one hour). Objects are pickled into a file. The directory to save those files has to - be specified. Methods to save and load an object by its id are provided. + """Manages the cache. Objects expire after a `cache_duration` time (defaults to one hour). + Objects are pickled into a file. The directory to save those files has to be specified. + Methods to save and load an object by its ID are provided. """ def __init__(self, cache_dir, cache_duration=3600): - """Build the DiskCache object + """Initializes the DiskCache object. Args: - cache_dir (str): the location where the pickled objects are saved - cache_duration (int): defaults 3600, the number of seconds objects are kept before - you get a miss + cache_dir (str): The location where the pickled objects are saved. + cache_duration (int): Defaults to 3600, the number of seconds objects are kept before + you get a miss. """ self.cache_dir = cache_dir self.mem_cache = {} self.cache_duration = cache_duration def get_fname(self, objid): - """Simple auxiliary function that returns the cache filename given a cache object id""" + """Returns the cache filename given a cache object ID. + + Args: + objid (str): The cache object ID. + + Returns: + str: The cache filename. + """ return os.path.join(self.cache_dir, objid) def get(self, objid): - """Returns the cached object given its object id ``objid``. Returns None if + """Returns the cached object given its object ID `objid`. Returns None if the object is not in the cache, or if it has expired. - First we check if the object is in the memory dictionary, otherwise we look - for its corresponding cache file, and we loads it from there. + First, we check if the object is in the memory dictionary; otherwise, we look + for its corresponding cache file, and load it from there. Args: - objid (str): the string representing the object id you want to get + objid (str): The string representing the object ID you want to get. Returns: - The cached object, or None if the object does not exist or the cache is expired + object: The cached object, or None if the object does not exist or the cache has expired. """ obj = None saved_time = 0 @@ -86,13 +92,12 @@ def save(self, objid, obj): """Save an object into the cache. Objects are saved both in memory and into the corresponding cache file (one file - for each object id). - Objects are saved paired with the timestamp representing the time when it - has been saved. + for each object ID). Objects are saved paired with the timestamp representing the time + when they have been saved. Args: - objid (str): The id of the object you are saving - obj: the python object that you want to save + objid (str): The ID of the object you are saving. + obj (object): The Python object that you want to save. """ fname = self.get_fname(objid) with get_lock(fname): diff --git a/lib/exprParser.py b/lib/exprParser.py index 9abb6cd42..b58a5899e 100644 --- a/lib/exprParser.py +++ b/lib/exprParser.py @@ -2,62 +2,51 @@ # SPDX-License-Identifier: Apache-2.0 """ -Description: general purpose python expression parser and unparser +Description: General purpose Python expression parser and unparser. """ import ast import itertools -# These are used in modules importing exprParser, like frontend_match_ana from ast import And, Not, Or # noqa: F401 from io import StringIO from .unparser import Unparser -# Keeping this line from the Python 2 version to have a list of the objects supported -# NOTE: compiler.ast is slightly different from the concrete tree in ast -# from compiler.ast import Name, Const, Keyword, List, Tuple, And, Or, Not, UnaryAdd, UnarySub, Compare, Add, Sub, Mul, FloorDiv, Div, Mod, Power, LeftShift, RightShift, Bitand, Bitor, Bitxor, CallFunc, Getattr, Subscript, Slice, Lambda - def exp_parse(expression): - """Convert an expression string into an ast object + """Convert an expression string into an AST object. Args: - expression (str): expression string + expression (str): The expression string. Returns: - ast.AST: ast tree from the expression, starting from ast.Expression node - + ast.AST: AST tree from the expression, starting from ast.Expression node. """ - # mode='exec' (default) for sequence of statements - # eval - single expression - # single - single interactive statement return ast.parse(expression, "", mode="eval") def exp_compile(obj): - """Convert an ast object into a code object + """Convert an AST object into a code object. Args: - obj (ast.AST): AST object to compile + obj (ast.AST): AST object to compile. Returns: - code object - + code: Compiled code object. """ return compile(obj, "", mode="eval") def exp_unparse(obj, raise_on_unknown=False): - """Convert an ast object back into a string + """Convert an AST object back into a string. Args: - obj (ast.AST): ast object to convert back to string - raise_on_unknown (bool): + obj (ast.AST): AST object to convert back to string. + raise_on_unknown (bool): Flag to raise an error on unknown nodes. Returns: - str: string with the expression - + str: String with the expression. """ with StringIO() as output: Unparser(obj, output) @@ -66,15 +55,14 @@ def exp_unparse(obj, raise_on_unknown=False): def exp_compare(node1, node2): - """Compare 2 AST trees to verify if they are the same + """Compare two AST trees to verify if they are the same. Args: - node1 (ast.AST): AST tree - node2 (ast.AST): AST tree + node1 (ast.AST): First AST tree. + node2 (ast.AST): Second AST tree. Returns: - bool: True if node1 and node2 are the same expression - + bool: True if node1 and node2 represent the same expression, False otherwise. """ if type(node1) is not type(node2): return False diff --git a/lib/fork.py b/lib/fork.py index 9ee907630..df0c63afd 100644 --- a/lib/fork.py +++ b/lib/fork.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This module implements functions and classes to handle forking of processes and the collection of results +"""This module implements functions and classes to handle forking of processes and the collection of results. """ # TODO: This could be rewritten so that the polling lists are registered once and the fd are removed only when -# not needed anymore (currently there is an extrnal structure and the poll object is a new one each time) +# not needed anymore (currently there is an external structure and the poll object is a new one each time) import errno import os @@ -20,7 +20,7 @@ class ForkError(RuntimeError): - """Base class for this module's errors""" + """Base class for this module's errors.""" def __init__(self, msg): RuntimeError.__init__(self, msg) @@ -31,34 +31,37 @@ class FetchError(ForkError): class ForkResultError(ForkError): + """Raised when there are errors in the forked processes. + + Attributes: + nr_errors (int): Number of errors. + good_results (dict): Results of successful forks. + failed (list): List of failed forks. + """ + def __init__(self, nr_errors, good_results, failed=[]): - ForkError.__init__(self, "Found %i errors" % nr_errors) + super().__init__(f"Found {nr_errors} errors") self.nr_errors = nr_errors self.good_results = good_results self.failed = failed -################################################ -# Low level fork and collect functions - - def fork_in_bg(function_torun, *args): - """Fork and call a function with args + """Forks and calls a function with args. This function returns right away, returning the pid and a pipe to the stdout of the function process - where the output of the function will be pickled + where the output of the function will be pickled. - example: - def add(i, j): return i+j - d = fork_in_bg(add, i, j) + Example: + def add(i, j): return i + j + d = fork_in_bg(add, i, j) Args: - function_torun (function): function to call after forking the process - *args: arguments list to pass to the function + function_torun (function): Function to call after forking the process. + *args: Arguments list to pass to the function. Returns: - dict: dict with {'r': fd, 'pid': pid} where fd is the stdout from a pipe. - + dict: Dict with {'r': fd, 'pid': pid} where fd is the stdout from a pipe. """ r, w = os.pipe() unregister_sighandler() @@ -70,12 +73,10 @@ def add(i, j): return i+j out = function_torun(*args) os.write(w, pickle.dumps(out)) except Exception: - logSupport.log.warning("Forked process '%s' failed" % str(function_torun)) - logSupport.log.exception("Forked process '%s' failed" % str(function_torun)) + logSupport.log.warning(f"Forked process '{function_torun}' failed") + logSupport.log.exception(f"Forked process '{function_torun}' failed") finally: os.close(w) - # Exit, immediately. Don't want any cleanup, since I was created - # just for performing the work os._exit(0) else: register_sighandler() @@ -84,24 +85,18 @@ def add(i, j): return i+j return {"r": r, "pid": pid} -############################### def fetch_fork_result(r, pid): - """Used with fork clients to retrieve results - Can raise: - OSError if Bad file descriptor or file already closed or if waitpid syscall returns -1 - FetchError if a os.read error was encountered - Possible errors from os.read and pickle.load (catched here): - - EOFError if the forked process failed an nothing was written to the pipe, if cPickle finds an empty string - - IOError failure for an I/O-related reason, e.g., "pipe file not found" or "disk full" - - OSError other system-related error (includes both former OSError and IOError since Py3.4) - - pickle.UnpicklingError incomplete pickled data + """Used with fork clients to retrieve results. Args: - r (pipe): Input pipe - pid (int): pid of the child + r (int): Input pipe. + pid (int): PID of the child. Returns: - Object: Unpickled object + object: Unpickled object. + + Raises: + FetchError: If an os.read error was encountered or if the forked process failed. """ rin = b"" out = None @@ -110,16 +105,11 @@ def fetch_fork_result(r, pid): while s != b"": # "" means EOF rin += s s = os.read(r, 1024 * 1024) - # pickle can fail w/ EOFError if rin is empty. Any output from pickle is never an empty string, e.g. None is 'N.' out = pickle.loads(rin) except (OSError, EOFError, pickle.UnpicklingError) as err: - etype, evalue, etraceback = sys.exc_info() - # Adding message in case close/waitpid fail and preempt raise - logSupport.log.exception("Re-raising exception during read: %s" % err) - # Removed .with_traceback(etraceback) since already in the chaining + logSupport.log.exception(f"Re-raising exception during read: {err}") raise FetchError( - "Exception during read probably due to worker failure, original exception and trace %s: %s" - % (etype, evalue) + f"Exception during read probably due to worker failure, original exception and trace: {err}" ) from err finally: os.close(r) @@ -128,28 +118,27 @@ def fetch_fork_result(r, pid): def fetch_fork_result_list(pipe_ids): - """Read the output pipe of the children, used after forking to perform work - and after forking to entry.writeStats() + """Read the output pipe of the children, used after forking to perform work. Args: - pipe_ids (dict): Dictionary of pipe and pid + pipe_ids (dict): Dictionary of pipe and pid. Returns: - dict: Dictionary of fork_results + dict: Dictionary of fork results. + + Raises: + ForkResultError: If there are failures in fetching fork results. """ out = {} failures = 0 failed = [] for key in pipe_ids: try: - # Collect the results out[key] = fetch_fork_result(pipe_ids[key]["r"], pipe_ids[key]["pid"]) except (KeyError, OSError, FetchError) as err: - # fetch_fork_result can raise OSError and FetchError - errmsg = f"Failed to extract info from child '{str(key)}' {err}" + errmsg = f"Failed to extract info from child '{key}': {err}" logSupport.log.warning(errmsg) logSupport.log.exception(errmsg) - # Record failed keys failed.append(key) failures += 1 @@ -160,35 +149,18 @@ def fetch_fork_result_list(pipe_ids): def fetch_ready_fork_result_list(pipe_ids): - """ - Read the output pipe of the children, used after forking. If there is data + """Read the output pipe of the children, used after forking. If there is data on the pipes to consume, read the data and close the pipe. - and after forking to entry.writeStats() Args: - pipe_ids (dict): Dictionary of pipe and pid + pipe_ids (dict): Dictionary of pipe and pid. Returns: - dict: Dictionary of work_done - """ - - # Timeout for epoll/poll in milliseconds: -1 is blocking, 0 non blocking, >0 timeout - # Select timeout (in seconds) = POLL_TIMEOUT/1000.0 - # Waiting at most POLL_TIMEOUT for one event to be triggered. - # If there are no ready fd by the timeout and empty list is returned (no exception triggered) - # - # From the linux kernel (v4.10) and man (http://man7.org/linux/man-pages/man2/select.2.html) - # these are the 3 sets of poll events that correspond to select read, write, error: - # define POLLIN_SET (POLLRDNORM | POLLRDBAND | POLLIN | POLLHUP | POLLERR) - # define POLLOUT_SET (POLLWRBAND | POLLWRNORM | POLLOUT | POLLERR) - # define POLLEX_SET (POLLPRI) - # To maintain a similar behavior to check readable fd in select we should check for. Anyway Python documentation - # lists different events for poll (no POLLRDBAND, ... ), but looking at the library they are there... dir(select), - # but should not be triggered or different, so the complete enough and safe option seems: - # poll_obj.register(read_fd, select.EPOLLIN | select.EPOLLHUP | select.EPOLLERR | select.EPOLLRDBAND | select.EPOLLRDNORM) - # poll_obj.register(read_fd, select.POLLIN | select.POLLHUP | select.POLLERR ) - # TODO: this may be revised to use select that seems more performant and able to support >1024: https://aivarsk.github.io/2017/04/06/select/ + dict: Dictionary of work done. + Raises: + ForkResultError: If there are failures in fetching ready fork results. + """ POLL_TIMEOUT = 100 work_info = {} failures = 0 @@ -200,9 +172,6 @@ def fetch_ready_fork_result_list(pipe_ids): if time_this: t_begin = time.time() try: - # epoll tested fastest, and supports > 1024 open fds - # unfortunately linux only - # Level Trigger behavior (default) poll_obj = select.epoll() poll_type = "epoll" for read_fd in list(fds_to_entry.keys()): @@ -212,35 +181,26 @@ def fetch_ready_fork_result_list(pipe_ids): select.EPOLLIN | select.EPOLLHUP | select.EPOLLERR | select.EPOLLRDBAND | select.EPOLLRDNORM, ) except OSError as err: - # Epoll (contrary to poll) complains about duplicate registrations: IOError: [Errno 17] File exists - # All other errors are re-risen if err.errno == errno.EEXIST: - logSupport.log.warning(f"Ignoring duplicate fd {read_fd} registration in epoll(): '{str(err)}'") + logSupport.log.warning(f"Ignoring duplicate fd {read_fd} registration in epoll(): '{err}'") else: - logSupport.log.warning(f"Unsupported fd {read_fd} registration failure in epoll(): '{str(err)}'") + logSupport.log.warning(f"Unsupported fd {read_fd} registration failure in epoll(): '{err}'") raise - # File descriptors: [i[0] for i in poll_obj.poll(0) if i[1] & (select.EPOLLIN|select.EPOLLPRI)] - # Filtering is not needed, done by epoll, both EPOLLIN and EPOLLPRI are OK - # EPOLLHUP events are registered by default. The consumer will read eventual data and close the fd readable_fds = [i[0] for i in poll_obj.poll(POLL_TIMEOUT)] except (AttributeError, OSError) as err: - logSupport.log.warning("Failed to load select.epoll(): %s" % str(err)) + logSupport.log.warning(f"Failed to load select.epoll(): {err}") try: - # no epoll(), try poll(). Still supports > 1024 fds and - # tested faster than select() on linux when multiple forks configured poll_obj = select.poll() poll_type = "poll" for read_fd in list(fds_to_entry.keys()): poll_obj.register(read_fd, select.POLLIN | select.POLLHUP | select.POLLERR) readable_fds = [i[0] for i in poll_obj.poll(POLL_TIMEOUT)] except (AttributeError, OSError) as err: - logSupport.log.warning("Failed to load select.poll(): %s" % str(err)) - # no epoll() or poll(), use select() + logSupport.log.warning(f"Failed to load select.poll(): {err}") readable_fds = select.select(list(fds_to_entry.keys()), [], [], POLL_TIMEOUT / 1000.0)[0] poll_type = "select" count = 0 - # logSupport.log.debug("Data available via %s, fd list: %s" % (poll_type, readable_fds)) for fd in readable_fds: if fd not in fds_to_entry: continue @@ -251,31 +211,19 @@ def fetch_ready_fork_result_list(pipe_ids): out = fetch_fork_result(fd, pid) try: if poll_obj: - poll_obj.unregister(fd) # Is this needed? Lots of hoops to jump through here + poll_obj.unregister(fd) except OSError as err: - if err.errno == 9: - # python.select < 3.9 treated unregister on closed pipe as NO_OP - # python 3.9 + raises OSError: [Errno 9] Bad file descriptor - # we don't care about this for now, continue processing fd's - pass - else: - # some other OSError, log and raise - errmsg = f"unregister failed pid='{pid}' fd='{fd}' key='{str(key)}': {err}" + if err.errno != 9: + errmsg = f"unregister failed pid='{pid}' fd='{fd}' key='{key}': {err}" logSupport.log.warning(errmsg) logSupport.log.exception(errmsg) raise - work_info[key] = out count += 1 except (OSError, ValueError, KeyError, FetchError) as err: - # KeyError: inconsistent dictionary or reverse dictionary - # IOError: Error in poll_obj.unregister() - # OSError: [Errno 9] Bad file descriptor - fetch_fork_result with wrong file descriptor - # FetchError: read error in fetch_fork_result - errmsg = f"Failed to extract info from child '{str(key)}': {err}" + errmsg = f"Failed to extract info from child '{key}': {err}" logSupport.log.warning(errmsg) logSupport.log.exception(errmsg) - # Record failed keys failed.append(key) failures += 1 @@ -284,22 +232,22 @@ def fetch_ready_fork_result_list(pipe_ids): if time_this: logSupport.log.debug( - "%s: using %s fetched %s of %s in %s seconds" - % ("fetch_ready_fork_result_list", poll_type, count, len(list(fds_to_entry.keys())), time.time() - t_begin) + f"fetch_ready_fork_result_list: using {poll_type} fetched {count} of {len(fds_to_entry)} in {time.time() - t_begin} seconds" ) return work_info def wait_for_pids(pid_list): - """Wait for all pids to finish. - Throw away any stdout or err + """Wait for all pids to finish and discard any stdout or stderr. + + Args: + pid_list (list): List of pids to wait for. """ for pidel in pid_list: pid = pidel["pid"] r = pidel["r"] try: - # empty the read buffer first s = os.read(r, 1024) while s != b"": # "" means EOF, pipes are binary s = os.read(r, 1024) @@ -308,33 +256,45 @@ def wait_for_pids(pid_list): os.waitpid(pid, 0) -################################################ -# Fork Class - - class ForkManager: + """Manages the forking of processes and the collection of results.""" + def __init__(self): self.functions_tofork = {} - # I need a separate list to keep the order self.key_list = [] - return def __len__(self): return len(self.functions_tofork) def add_fork(self, key, function, *args): + """Adds a function to be forked. + + Args: + key (str): Unique key for the fork. + function (function): Function to be forked. + *args: Arguments to be passed to the function. + + Raises: + KeyError: If the key is already in use. + """ if key in self.functions_tofork: - raise KeyError("Fork key '%s' already in use" % key) + raise KeyError(f"Fork key '{key}' already in use") self.functions_tofork[key] = (function,) + args self.key_list.append(key) def fork_and_wait(self): + """Forks and waits for all functions to complete.""" pids = [] for key in self.key_list: pids.append(fork_in_bg(*self.functions_tofork[key])) wait_for_pids(pids) def fork_and_collect(self): + """Forks and collects the results of all functions. + + Returns: + dict: Dictionary of results. + """ pipe_ids = {} for key in self.key_list: pipe_ids[key] = fork_in_bg(*self.functions_tofork[key]) @@ -342,34 +302,36 @@ def fork_and_collect(self): return results def bounded_fork_and_collect(self, max_forks, log_progress=True, sleep_time=0.01): + """Forks and collects results with a limit on the number of concurrent forks. + + Args: + max_forks (int): Maximum number of concurrent forks. + log_progress (bool): Whether to log progress. + sleep_time (float): Time to sleep between checks. + + Returns: + dict: Dictionary of results. + + Raises: + ForkResultError: If there are errors in the forked processes. + """ post_work_info = {} nr_errors = 0 - pipe_ids = {} forks_remaining = max_forks functions_remaining = len(self.functions_tofork) - # try to fork all the functions for key in self.key_list: - # Check if we can fork more if forks_remaining == 0: if log_progress: - # log here, since we will have to wait - logSupport.log.info("Active forks = %i, Forks to finish = %i" % (max_forks, functions_remaining)) + logSupport.log.info(f"Active forks = {max_forks}, Forks to finish = {functions_remaining}") while forks_remaining == 0: failed_keys = [] - # Give some time for the processes to finish the work - # logSupport.log.debug("Reached parallel_workers limit of %s" % parallel_workers) time.sleep(sleep_time) - - # Wait and gather results for work done so far before forking more try: - # logSupport.log.debug("Checking finished workers") post_work_info_subset = fetch_ready_fork_result_list(pipe_ids) except ForkResultError as e: - # Collect the partial result post_work_info_subset = e.good_results - # Expect all errors logged already, just count nr_errors += e.nr_errors functions_remaining -= e.nr_errors failed_keys = e.failed @@ -381,33 +343,21 @@ def bounded_fork_and_collect(self, max_forks, log_progress=True, sleep_time=0.01 for i in list(post_work_info_subset.keys()) + failed_keys: if pipe_ids.get(i): del pipe_ids[i] - # end for - # end while - - # yes, we can, do it pipe_ids[key] = fork_in_bg(*self.functions_tofork[key]) forks_remaining -= 1 - # end for if log_progress: logSupport.log.info( - "Active forks = %i, Forks to finish = %i" % (max_forks - forks_remaining, functions_remaining) + f"Active forks = {max_forks - forks_remaining}, Forks to finish = {functions_remaining}" ) - # now we just have to wait for all to finish while functions_remaining > 0: failed_keys = [] - # Give some time for the processes to finish the work time.sleep(sleep_time) - - # Wait and gather results for work done so far before forking more try: - # logSupport.log.debug("Checking finished workers") post_work_info_subset = fetch_ready_fork_result_list(pipe_ids) except ForkResultError as e: - # Collect the partial result post_work_info_subset = e.good_results - # Expect all errors logged already, just count nr_errors += e.nr_errors functions_remaining -= e.nr_errors failed_keys = e.failed @@ -419,12 +369,10 @@ def bounded_fork_and_collect(self, max_forks, log_progress=True, sleep_time=0.01 for i in list(post_work_info_subset.keys()) + failed_keys: del pipe_ids[i] - if len(post_work_info_subset) > 0: - if log_progress: - logSupport.log.info( - "Active forks = %i, Forks to finish = %i" % (max_forks - forks_remaining, functions_remaining) - ) - # end while + if len(post_work_info_subset) > 0 and log_progress: + logSupport.log.info( + f"Active forks = {max_forks - forks_remaining}, Forks to finish = {functions_remaining}" + ) if nr_errors > 0: raise ForkResultError(nr_errors, post_work_info) @@ -432,33 +380,28 @@ def bounded_fork_and_collect(self, max_forks, log_progress=True, sleep_time=0.01 return post_work_info -#################### -# Utilities - - def print_child_processes(root_pid=str(os.getppid()), this_pid=str(os.getpid())): - """Print the process tree of the root PID + """Print the process tree of the root PID. Args: - root_pid (str): String containing the process ID to use as root of the process tree - this_pid (str|None): If String containing the process ID of the current process (will get a star in the line) + root_pid (str): String containing the process ID to use as root of the process tree. + this_pid (str, optional): String containing the process ID of the current process (will get a star in the line). Returns: - list: list of str containing all the lines of the process tree + list: List of str containing all the lines of the process tree. """ def print_children(id, ps_dict, my_id="", level=0): - """Auxiliary recursive function of print_child_processes, - printing the children subtree of a given process ID + """Auxiliary recursive function to print the children subtree of a given process ID. Args: - id (str): String w/ process ID root of the tree - ps_dict (dict): dictionary with all processes and theyr children - my_id (str): String w/ process ID of the print_children caller (Default: "") - level (int): level of the subtree (Default: 0) + id (str): String with process ID root of the tree. + ps_dict (dict): Dictionary with all processes and their children. + my_id (str, optional): String with process ID of the print_children caller. + level (int, optional): Level of the subtree. Returns: - list: list of str containing all the lines of the process subtree + list: List of str containing all the lines of the process subtree. """ if my_id and my_id == id: out = ["+" * level + id + " *"] diff --git a/lib/glideinWMSVersion.py b/lib/glideinWMSVersion.py index 9e94fefe5..d47d05e3f 100644 --- a/lib/glideinWMSVersion.py +++ b/lib/glideinWMSVersion.py @@ -3,7 +3,7 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""Execute a ls command on a condor job working directory +"""Execute a ls command on a condor job working directory. Usage: glideinWMSVersion.py [] @@ -12,18 +12,20 @@ import os import sys -# pylint: disable=E0611 -# (hashlib methods are called dynamically) from hashlib import md5 -# pylint: enable=E0611 - class GlideinWMSDistro: + """Singleton class to handle GlideinWMS distribution checksum and versioning.""" + class __impl: - """Implementation of the singleton interface""" + """Implementation of the singleton interface.""" def __init__(self, chksumFile="checksum"): + """ + Args: + chksumFile (str): Path to the checksum file. Defaults to "checksum". + """ self.versionIdentifier = "GLIDEINWMS_VERSION" rpm_workdir = "" @@ -42,11 +44,11 @@ def __init__(self, chksumFile="checksum"): self._version = "glideinWMS UNKNOWN" def createVersionString(self): + """Creates the version string based on the checksum file and the current state of the distribution.""" ver = "UNKNOWN" patch = "" modifiedFiles = [] - # Load the distro file hastable distroFileHash = {} with open(self.distroChksumFile) as distroChksumFd: for line in distroChksumFd.readlines(): @@ -62,11 +64,9 @@ def createVersionString(self): ver = v if ver != "UNKNOWN": - # Read the dir contents of distro and compute the md5sum for file in list(distroFileHash.keys()): fd = None try: - # In the RPM, all files are in site-packages rpm_dir = os.path.dirname(os.path.dirname(sys.modules[__name__].__file__)) fd = open(os.path.join(rpm_dir, os.path.dirname(file), os.path.basename(file))) @@ -74,57 +74,64 @@ def createVersionString(self): if chksum != distroFileHash[file]: modifiedFiles.append(file) patch = "PATCHED" - except Exception: # ignore missing files + except Exception: pass if fd: fd.close() - # if len(modifiedFiles) > 0: - # print "Modified files: %s" % " ".join(modifiedFiles) - self._version = f"glideinWMS {ver} {patch}" def version(self): + """Returns the current version string. + + Returns: + str: The current version string. + """ return self._version - # storage for the instance reference __instance = None def __init__(self, chksumFile="checksum"): + """ + Args: + chksumFile (str): Path to the checksum file. Defaults to "checksum". + """ if GlideinWMSDistro.__instance is None: GlideinWMSDistro.__instance = GlideinWMSDistro.__impl(chksumFile=chksumFile) self.__dict__["_GlideinWMSDistro__instance"] = GlideinWMSDistro.__instance def __getattr__(self, attr): - """Delegate access to implementation""" + """Delegate access to implementation.""" return getattr(self.__instance, attr) def __setattr__(self, attr, value): - """Delegate access to implementation""" + """Delegate access to implementation.""" return setattr(self.__instance, attr, value) def version(chksumFile=None): - return GlideinWMSDistro(chksumFile=chksumFile).version() + """Gets the GlideinWMS version. + Args: + chksumFile (str, optional): Path to the checksum file. -# version + Returns: + str: The GlideinWMS version. + """ + return GlideinWMSDistro(chksumFile=chksumFile).version() def usage(): + """Prints the usage of the script.""" print("Usage: glideinWMSVersion.py []") -############################################################################## -# MAIN -############################################################################## - if __name__ == "__main__": if len(sys.argv) == 1: - print("%s " % (GlideinWMSDistro().version())) + print(f"{GlideinWMSDistro().version()} ") elif len(sys.argv) == 2: - print("%s " % (GlideinWMSDistro(chksumFile=sys.argv[1]).version())) + print(f"{GlideinWMSDistro(chksumFile=sys.argv[1]).version()} ") else: usage() sys.exit(1) diff --git a/lib/hashCrypto.py b/lib/hashCrypto.py index 717cfea42..30c97b937 100644 --- a/lib/hashCrypto.py +++ b/lib/hashCrypto.py @@ -7,18 +7,16 @@ # File Version: # -"""hashCrypto - This module defines classes to perform hash based cryptography +"""hashCrypto - This module defines classes to perform hash based cryptography. It uses M2Crypto: https://github.com/mcepl/M2Crypto a wrapper around OpenSSL: https://www.openssl.org/docs/man1.1.1/man3/ -NOTE get_hash() and extract_hash() both return Unicode utf-8 (defaults.BINARY_ENCODING_CRYPTO) strings and - get_hash() accepts byte-like objects or utf-8 encoded Unicode strings. - Same for all the get_XXX or extract_XXX that use those functions. - Other class methods and functions use bytes for input and output - +NOTE: get_hash() and extract_hash() both return Unicode utf-8 (defaults.BINARY_ENCODING_CRYPTO) strings and +get_hash() accepts byte-like objects or utf-8 encoded Unicode strings. +Same for all the get_XXX or extract_XXX that use those functions. +Other class methods and functions use bytes for input and output. """ -# TODO: should this module be replaced (or reimplemented) by using Python's hashlib? import binascii @@ -27,78 +25,86 @@ from . import defaults ###################### -# -# Available hash algos: +# Available hash algorithms: # 'sha1' # 'sha224' # 'sha256', # 'ripemd160' # 'md5' -# ###################### -########################################################################## -# Generic hash class class Hash: - """Generic hash class + """Generic hash class. Available hash algorithms: 'sha1' 'sha224' - 'sha256', + 'sha256' 'ripemd160' 'md5' """ def __init__(self, hash_algo): + """Initializes the Hash object with the specified algorithm. + + Args: + hash_algo (str): The hash algorithm to use. + """ self.hash_algo = hash_algo - return def redefine(self, hash_algo): - self.hash_algo = hash_algo - return + """Redefines the hash algorithm. - ########################################### - # compute hash inline + Args: + hash_algo (str): The new hash algorithm to use. + """ + self.hash_algo = hash_algo def compute(self, data): - """Compute hash inline - - len(data) must be less than len(key) + """Compute hash inline. Args: - data (bytes): data to calculate the hash of + data (bytes): Data to calculate the hash of. Returns: - bytes: digest value as bytes string (OpenSSL final and digest together) + bytes: Digest value as bytes string (OpenSSL final and digest together). """ h = M2Crypto.EVP.MessageDigest(self.hash_algo) h.update(data) return h.final() def compute_base64(self, data): - """like compute, but base64 encoded""" + """Computes hash inline and returns base64 encoded result. + + Args: + data (bytes): Data to calculate the hash of. + + Returns: + bytes: Base64 encoded digest value. + """ return binascii.b2a_base64(self.compute(data)) def compute_hex(self, data): - """like compute, but hex encoded""" - return binascii.b2a_hex(self.compute(data)) + """Computes hash inline and returns hex encoded result. - ########################################### - # extract hash from a file + Args: + data (bytes): Data to calculate the hash of. - def extract(self, fname, block_size=1048576): - """Extract hash from a file + Returns: + bytes: Hex encoded digest value. + """ + return binascii.b2a_hex(self.compute(data)) - len(data) must be less than len(key) + def extract(self, fname, block_size=1048576): + """Extracts hash from a file. Args: - fname (str): input file path (binary file) - block_size: + fname (str): Input file path (binary file). + block_size (int): Block size for reading the file. Returns: - bytes: digest value as bytes string (OpenSSL final and digest together) + bytes: Digest value as bytes string (OpenSSL final and digest together). """ h = M2Crypto.EVP.MessageDigest(self.hash_algo) with open(fname, "rb") as fd: @@ -106,92 +112,158 @@ def extract(self, fname, block_size=1048576): data = fd.read(block_size) if data == b"": break # no more data, stop reading - # should check update return? -1 for Python error, 1 for success, 0 for OpenSSL failure h.update(data) return h.final() def extract_base64(self, fname, block_size=1048576): - """like extract, but base64 encoded""" + """Extracts hash from a file and returns base64 encoded result. + + Args: + fname (str): Input file path (binary file). + block_size (int): Block size for reading the file. + + Returns: + bytes: Base64 encoded digest value. + """ return binascii.b2a_base64(self.extract(fname, block_size)) def extract_hex(self, fname, block_size=1048576): - """like extract, but hex encoded""" - return binascii.b2a_hex(self.extract(fname, block_size)) + """Extracts hash from a file and returns hex encoded result. + Args: + fname (str): Input file path (binary file). + block_size (int): Block size for reading the file. -######################################### + Returns: + bytes: Hex encoded digest value. + """ + return binascii.b2a_hex(self.extract(fname, block_size)) def get_hash(hash_algo, data): - """Compute hash inline + """Compute hash inline. Args: - hash_algo (str): hash algorithm to use - data (AnyStr): data of which to calculate the hash + hash_algo (str): Hash algorithm to use. + data (AnyStr): Data of which to calculate the hash. Returns: - str: utf-8 encoded hash + str: utf-8 encoded hash. """ - # Check to see if the data is already in bytes bdata = defaults.force_bytes(data) - h = Hash(hash_algo) return h.compute_hex(bdata).decode(defaults.BINARY_ENCODING_CRYPTO) def extract_hash(hash_algo, fname, block_size=1048576): - """Compute hash from file + """Compute hash from file. Args: - hash_algo (str): hash algorithm to use - fname (str): file path (file will be open in binary mode) - block_size (int): block size + hash_algo (str): Hash algorithm to use. + fname (str): File path (file will be open in binary mode). + block_size (int): Block size. Returns: - str: utf-8 encoded hash + str: utf-8 encoded hash. """ h = Hash(hash_algo) return h.extract_hex(fname, block_size).decode(defaults.BINARY_ENCODING_CRYPTO) -########################################################################## -# Explicit hash algo section - - class HashMD5(Hash): + """MD5 hash class.""" + def __init__(self): - Hash.__init__(self, "md5") + """Initializes the MD5 hash class.""" + super().__init__("md5") def get_md5(data): + """Compute MD5 hash inline. + + Args: + data (AnyStr): Data of which to calculate the hash. + + Returns: + str: utf-8 encoded MD5 hash. + """ return get_hash("md5", data) def extract_md5(fname, block_size=1048576): + """Compute MD5 hash from file. + + Args: + fname (str): File path (file will be open in binary mode). + block_size (int): Block size. + + Returns: + str: utf-8 encoded MD5 hash. + """ return extract_hash("md5", fname, block_size) class HashSHA1(Hash): + """SHA1 hash class.""" + def __init__(self): - Hash.__init__(self, "sha1") + """Initializes the SHA1 hash class.""" + super().__init__("sha1") def get_sha1(data): + """Compute SHA1 hash inline. + + Args: + data (AnyStr): Data of which to calculate the hash. + + Returns: + str: utf-8 encoded SHA1 hash. + """ return get_hash("sha1", data) def extract_sha1(fname, block_size=1048576): + """Compute SHA1 hash from file. + + Args: + fname (str): File path (file will be open in binary mode). + block_size (int): Block size. + + Returns: + str: utf-8 encoded SHA1 hash. + """ return extract_hash("sha1", fname, block_size) class HashSHA256(Hash): + """SHA256 hash class.""" + def __init__(self): - Hash.__init__(self, "sha256") + """Initializes the SHA256 hash class.""" + super().__init__("sha256") def get_sha256(data): + """Compute SHA256 hash inline. + + Args: + data (AnyStr): Data of which to calculate the hash. + + Returns: + str: utf-8 encoded SHA256 hash. + """ return get_hash("sha256", data) def extract_sha256(fname, block_size=1048576): + """Compute SHA256 hash from file. + + Args: + fname (str): File path (file will be open in binary mode). + block_size (int): Block size. + + Returns: + str: utf-8 encoded SHA256 hash. + """ return extract_hash("sha256", fname, block_size) diff --git a/lib/logSupport.py b/lib/logSupport.py index 32676623d..19d8bf73b 100644 --- a/lib/logSupport.py +++ b/lib/logSupport.py @@ -61,57 +61,45 @@ def alternate_log(msg): - """ - When an exceptions happen within the logging system (e.g. when the disk is full while rotating a - log file) an alternate logging is necessary, e.g. writing to stderr + """Logs a message to stderr as an alternative logging method. + + This function is used when exceptions occur within the logging system, + such as when the disk is full during log file rotation. + + Args: + msg (str): The message to be logged. """ sys.stderr.write("%s\n" % msg) class GlideinHandler(BaseRotatingHandler): - """ - Custom logging handler class for GlideinWMS. It combines the decision tree - for log rotation from the TimedRotatingFileHandler with the decision tree - from the RotatingFileHandler. This allows us to specify a lifetime AND - file size to determine when to rotate the file. - Files are rotated if the days since the last rotation (or the beginning) - are more than the time limit (maxDays) or if the size of the file grew above - the size limit (maxMBytes) and at least the min time interval (minDays) went by. - - This class assumes that the lifetime (`interval`, `min_lifetime`) specified is - in seconds but the value in the constructor (`maxDays`, `minDays`) are in days - (24 hour periods) and can be fractions (float). - - And the size is measured in Bytes (MBytes in the constructor parameter can be - fractional) - - @type filename: string - @ivar filename: Full path to the log file. Includes file name. - @type interval: int - @ivar interval: Number of seconds to keep log file before rotating - @type maxBytes: int - @param maxBytes: Maximum size of the logfile in Bytes before file rotation (used with min days) - @type min_lifetime: int - @param min_lifetime: Minimum number of seconds (used with max bytes) - @type backupCount: int - @ivar backupCount: How many backups to keep + """Custom logging handler class for GlideinWMS. + + Combines log rotation based on both time and file size, allowing + the specification of a lifetime and file size to determine when to rotate the log file. + + Attributes: + compression (str): Compression format (gz, zip) used for log files. + backupCount (int): Number of backup files to keep. + maxBytes (int): Maximum file size in bytes before rotation. + min_lifetime (int): Minimum lifetime in seconds before rotation. + interval (int): Time interval in seconds before rotation. + suffix (str): Suffix format for the rotated files. + extMatch (re.Pattern): Regex pattern to match the suffix of the rotated files. + rolloverAt (int): Time when the next rollover should happen. + rollover_not_before (int): Earliest time when rollover can happen. """ def __init__(self, filename, maxDays=1.0, minDays=0.0, maxMBytes=10.0, backupCount=5, compression=None): - """Initialize the Handler. We assume the following: - - 1. Interval entered is in days or fractions of it (internally converted to seconds) - 2. No special encoding - 3. No delays are set - 4. Timestamps are not in UTC + """Initializes the GlideinHandler. Args: - filename (str|Path): The full path of the log file - maxDays (float): Max number of days before file rotation (fraction of day accepted, used in unit test) - minDays (float): Minimum number of days before file rotation (used with max MBytes) - maxMBytes (float): Maximum size of the logfile in MB before file rotation (used with min days) - backupCount (int): Number of backups to keep - compression (str): Compression to use (gz, zip, depending on available compression modules) + filename (str|Path): The full path of the log file. + maxDays (float): Maximum number of days before file rotation. + minDays (float): Minimum number of days before file rotation. + maxMBytes (float): Maximum file size in megabytes before rotation. + backupCount (int): Number of backup files to keep. + compression (str): Compression format (gz, zip) to use for log files. """ # Make dirs if logging directory does not exist if not os.path.exists(os.path.dirname(filename)): @@ -123,7 +111,6 @@ def __init__(self, filename, maxDays=1.0, minDays=0.0, maxMBytes=10.0, backupCou self.compression = compression.lower() except AttributeError: pass - # bz2 compression can be implemented with encoding='bz2-codec' in BaseRotatingHandler mode = "a" BaseRotatingHandler.__init__(self, filename, mode, encoding=None) self.backupCount = backupCount @@ -147,27 +134,14 @@ def __init__(self, filename, maxDays=1.0, minDays=0.0, maxMBytes=10.0, backupCou self.rollover_not_before = begin_interval_time + self.min_lifetime def shouldRollover(self, record, empty_record=False): - """Determine if rollover should occur. - - Basically, we are combining the checks for size and time interval + """Determines if a rollover should occur based on time or file size. Args: record (str): The message that will be logged. - empty_record (bool): If False (default) count also `record` length to evaluate if a rollover is needed + empty_record (bool): If False, counts `record` length to evaluate if a rollover is needed. Returns: - bool: True if rollover should be performed, False otherwise - - @attention: Due to the architecture decision to fork "workers" we run - into an issue where the child that was forked could cause a log - rotation. However, the parent will never know and the parent's file - descriptor will still be pointing at the old log file (now renamed by - the child). This will in turn cause the parent to immediately request - a log rotate, which results in what appears to be truncated logs. To - handle this we add a flag to disable log rotation. By default, this is - set to False, but anywhere we want to fork a child (or in any object - that will be forked) we set the flag to True. Then in the parent, we - initiate a log function that will log and rotate if necessary. + bool: True if rollover should be performed, False otherwise. """ if disable_rotate: return False @@ -190,9 +164,10 @@ def shouldRollover(self, record, empty_record=False): return do_timed_rollover or do_size_rollover def getFilesToDelete(self): - """Determine the files to delete when rolling over. + """Gets the list of files that should be deleted during rollover. - More specific than the earlier method, which just used glob.glob(). + Returns: + list: A list of file paths that should be deleted. """ dirName, baseName = os.path.split(self.baseFilename) fileNames = os.listdir(dirName) @@ -212,12 +187,10 @@ def getFilesToDelete(self): return result def doRollover(self): - """Do a rollover + """Performs the rollover process for the log file. - In this case, a date/time stamp is appended to the filename - when the rollover happens. If there is a backup count, then we have to get - a list of matching filenames, sort them and remove the one with the oldest - suffix. + This includes renaming the log file, compressing it if necessary, + and removing old log files based on the backup count. """ # Close the soon to be rotated log file self.stream.close() @@ -225,27 +198,19 @@ def doRollover(self): timeTuple = time.localtime(time.time()) dfn = self.baseFilename + "." + time.strftime(self.suffix, timeTuple) - # If you are rotating log files in less than a minute, you either have - # set your sizes way too low, or you have serious problems. We are - # going to protect against that scenario by removing any files that - # whose name collides with the new rotated file name. if os.path.exists(dfn): os.remove(dfn) # rename the closed log file to the new rotated file name os.rename(self.baseFilename, dfn) - # if there is a backup count specified, keep only the specified number of - # rotated logs, delete the rest if self.backupCount > 0: for s in self.getFilesToDelete(): os.remove(s) - # Open a new log file self.mode = "w" self.stream = self._open() - # determine the next rollover time for the timed rollover check currentTime = int(time.time()) if self.min_lifetime > 0: self.rollover_not_before = currentTime + self.min_lifetime @@ -254,7 +219,6 @@ def doRollover(self): newRolloverAt = newRolloverAt + self.interval self.rolloverAt = newRolloverAt - # Compress the log file (if requested) if self.compression == "zip": if os.path.exists(dfn + ".zip"): os.remove(dfn + ".zip") @@ -269,8 +233,6 @@ def doRollover(self): if os.path.exists(dfn + ".gz"): os.remove(dfn + ".gz") try: - # TODO #23166: Use context managers[with statement] when python 3 - # once we get rid of SL6 and tarballs f_out = gzip.open(dfn + ".gz", "wb") with open(dfn, "rb") as f_in: f_out.writelines(f_in) @@ -280,11 +242,13 @@ def doRollover(self): alternate_log("Log file gzip compression failed: %s" % e) def check_and_perform_rollover(self): + """Checks if rollover conditions are met and performs the rollover if necessary.""" if self.shouldRollover(None, empty_record=True): self.doRollover() def roll_all_logs(): + """Triggers log rotation for all registered handlers.""" for handler in handlers: handler.check_and_perform_rollover() @@ -292,25 +256,21 @@ def roll_all_logs(): def get_processlog_handler( log_file_name, log_dir, msg_types, extension, maxDays, minDays, maxMBytes, backupCount=5, compression=None ): - """Return a configured handler for the GlideinLogger logger - - The file name is `"{log_dir}/{log_file_name}.{extension.lower()}.log"` and can include env variables + """Returns a configured handler for the GlideinLogger logger. Args: - log_file_name (str): log file name (same as the logger name) - log_dir (str|Path): log directory - msg_types (str): log levels to include (comma separated list). Keywords are: - DEBUG,INFO,WARN,ERR, ADMIN or ALL (ADMIN and ALL both mean all the previous) - ADMIN adds also the "admin" prefix to the `log_file_name` - extension (str): file name extension - maxDays (float): Max number of days before file rotation (fraction of day accepted, used in unit test) - minDays (float): Minimum number of days before file rotation (used with max MBytes) - maxMBytes (float): Maximum size of the logfile in MB before file rotation (used with min days) - backupCount (int): Number of backups to keep - compression (str): Compression to use (gz, zip, depending on available compression modules) + log_file_name (str): Log file name (same as the logger name). + log_dir (str|Path): Log directory. + msg_types (str): Log levels to include (comma-separated list). + extension (str): File name extension. + maxDays (float): Maximum number of days before file rotation. + minDays (float): Minimum number of days before file rotation. + maxMBytes (float): Maximum size of the logfile in MB before file rotation. + backupCount (int): Number of backups to keep. + compression (str): Compression to use (gz, zip, depending on available modules). Returns: - GlideinHandler: configured handler + GlideinHandler: Configured logging handler. """ # Parameter adjustments msg_types = msg_types.upper() @@ -320,15 +280,10 @@ def get_processlog_handler( log_file_name = log_file_name + "admin" if "ALL" in msg_types: msg_types = "DEBUG,INFO,WARN,ERR" - # File name logfile = os.path.expandvars(f"{log_dir}/{log_file_name}.{extension.lower()}.log") handler = GlideinHandler(logfile, maxDays, minDays, maxMBytes, backupCount, compression) handler.setFormatter(DEFAULT_FORMATTER) - # Setting the handler logging level to DEBUG to control all from the logger level and the - # filter. This allows to pick any level combination, but may be less performant than a - # min level selection. - # TODO: Check if min level should be used instead and if the handler level should be logging.NOTSET (0) ? handler.setLevel(logging.DEBUG) has_debug = False msg_type_list = [] @@ -358,30 +313,45 @@ def get_processlog_handler( class MsgFilter(logging.Filter): - """Filter used in handling records for the info logs. + """Filter class for handling log messages based on log level. - Default to logging.INFO - """ + Args: + msg_type_list (list): List of log levels to filter. - msg_type_list = [logging.INFO] + Returns: + bool: True if the log level matches one in the list, False otherwise. + """ def __init__(self, msg_type_list): + """Initializes the MsgFilter. + + Args: + msg_type_list (list): List of log levels to filter. + """ logging.Filter.__init__(self) self.msg_type_list = msg_type_list def filter(self, rec): + """Filters log records based on log level. + + Args: + rec (logging.LogRecord): The log record to be filtered. + + Returns: + bool: True if the log level is in msg_type_list, False otherwise. + """ return rec.levelno in self.msg_type_list def format_dict(unformated_dict, log_format=" %-25s : %s\n"): - """Convenience function used to format a dictionary for the logs to make it human-readable. + """Formats a dictionary for human-readable logging. Args: - unformated_dict (dict): The dictionary to be formatted for logging - log_format (str): format string for logging + unformated_dict (dict): The dictionary to be formatted for logging. + log_format (str): Format string for logging. Returns: - str: Formatted string + str: Formatted string. """ formatted_string = "" for key in unformated_dict: @@ -395,26 +365,14 @@ def format_dict(unformated_dict, log_format=" %-25s : %s\n"): # From structlog 23.1.0 suggested configurations - separate rendering, using same output structlog.configure( processors=[ - # If log level is too low, abort pipeline and throw away log entry. structlog.stdlib.filter_by_level, - # Add the name of the logger to event dict. structlog.stdlib.add_logger_name, - # Add log level to event dict. structlog.stdlib.add_log_level, - # Perform %-style formatting. structlog.stdlib.PositionalArgumentsFormatter(), - # Add a timestamp in ISO 8601 format. structlog.processors.TimeStamper(fmt="iso"), - # If the "stack_info" key in the event dict is true, remove it and - # render the current stack trace in the "stack" key. structlog.processors.StackInfoRenderer(), - # If the "exc_info" key in the event dict is either true or a - # sys.exc_info() tuple, remove "exc_info" and render the exception - # with traceback into the "exception" key. structlog.processors.format_exc_info, - # If some value is in bytes, decode it to a unicode str. structlog.processors.UnicodeDecoder(), - # Add callsite parameters. (available from structlog 21.5.0) structlog.processors.CallsiteParameterAdder( { structlog.processors.CallsiteParameter.FILENAME, @@ -422,24 +380,13 @@ def format_dict(unformated_dict, log_format=" %-25s : %s\n"): structlog.processors.CallsiteParameter.LINENO, } ), - # Render the final event dict as JSON. structlog.processors.JSONRenderer(), ], - # using default dict as context_class - # `wrapper_class` is the bound logger that you get back from - # get_logger(). This one imitates the API of `logging.Logger`. wrapper_class=structlog.stdlib.BoundLogger, - # `logger_factory` is used to create wrapped loggers that are used for - # OUTPUT. This one returns a `logging.Logger`. The final value (a JSON - # string) from the final processor (`JSONRenderer`) will be passed to - # the method of the same name as that you've called on the bound logger. logger_factory=structlog.stdlib.LoggerFactory(), - # Effectively freeze configuration after creating the first bound - # logger. cache_logger_on_first_use=True, ) except AttributeError: - # caused by structlog.processors.CallsiteParameterAdder with structlog prior 21.5.0 (EL7 has 17.2.0) structlog.configure( processors=[ structlog.stdlib.filter_by_level, @@ -459,12 +406,28 @@ def format_dict(unformated_dict, log_format=" %-25s : %s\n"): def get_logging_logger(name): + """Retrieves a standard Python logging logger. + + Args: + name (str): Name of the logger. + + Returns: + logging.Logger: Configured logger. + """ log = logging.getLogger(name) log.setLevel(logging.DEBUG) return log def get_structlog_logger(name): + """Retrieves a structured logger using structlog if available. + + Args: + name (str): Name of the logger. + + Returns: + structlog.BoundLogger: Configured structured logger. + """ if USE_STRUCTLOG: log = structlog.get_logger(name) log.setLevel(logging.DEBUG) @@ -473,27 +436,21 @@ def get_structlog_logger(name): def get_logger_with_handlers(name, directory, config_data, level=logging.DEBUG): - """Create/retrieve a logger, set the handlers, set the starting logging level, and return the logger - - The file name is {name}.{plog["extension"].lower()}.log + """Creates and configures a logger with handlers. Args: - name (str): logger name (and file base name) - directory (str|Path): log directory - config_data (dict): logging configuration - (the "ProcessLogs" value evaluates to list of dictionary with process_logs section values) - level: logger's logging level (default: logging.DEBUG) + name (str): Logger name. + directory (str|Path): Directory for the log files. + config_data (dict): Logging configuration data. + level (int): Logging level. Returns: - logging.Logger: configured logger + logging.Logger: Configured logger. """ - # Contains a dictionary in a string process_logs = eval(config_data["ProcessLogs"]) is_structured = False handlers_list = [] for plog in process_logs: - # If at least one handler is structured, it will use structured logging - # All handlers should be consistent and use the same is_structured = is_structured or util.is_true(plog["structured"]) handler = get_processlog_handler( name, diff --git a/lib/pidSupport.py b/lib/pidSupport.py index 35afd1e75..788630ace 100644 --- a/lib/pidSupport.py +++ b/lib/pidSupport.py @@ -13,36 +13,68 @@ ############################################################ -# -# Verify if the system knows about a pid -# def check_pid(pid): + """Check if a process with the given PID exists. + + Args: + pid (int): The process ID to check. + + Returns: + bool: True if the process exists, False otherwise. + """ return os.path.isfile(f"/proc/{pid}/cmdline") ############################################################ -# this exception is raised when trying to register a pid -# but another process is already owning the PID file class AlreadyRunning(RuntimeError): + """Exception raised when a process is already running and owns the PID file.""" + pass ####################################################### -# -# self.mypid is valid only if self.fd is valid -# or after a load + + class PidSupport: + """Class to manage PID files with locking mechanisms. + + This class handles the registration and management of PID files, + ensuring that only one process can own a PID file at a time. + + Attributes: + pid_fname (str): The filename of the PID file. + fd (file object): The file descriptor for the PID file. + mypid (int): The PID of the current process. + lock_in_place (bool): Indicates if the lock is in place. + started_time (float): The time when the process started. + """ + def __init__(self, pid_fname): + """Initialize the PidSupport class. + + Args: + pid_fname (str): The filename of the PID file. + """ self.pid_fname = pid_fname self.fd = None self.mypid = None self.lock_in_place = False - # open the pid_file and gain the exclusive lock - # also write in the PID information - def register(self, pid=None, started_time=None): # if none, will default to os.getpid() # if none, use time.time() + def register(self, pid=None, started_time=None): + """Register the current process by writing its PID to the PID file. + + This method also gains an exclusive lock on the PID file. + + Args: + pid (int, optional): The PID to register. Defaults to the current process PID. + started_time (float, optional): The time when the process started. Defaults to the current time. + + Raises: + RuntimeError: If a PID is already registered in the same object. + AlreadyRunning: If another process is already running and owns the PID file. + """ if self.fd is not None: raise RuntimeError("Cannot register two pids in the same object!") @@ -54,13 +86,10 @@ def register(self, pid=None, started_time=None): # if none, will default to os. self.mypid = pid self.started_time = started_time - # check lock file if not os.path.exists(self.pid_fname): - # create a lock file if needed fd = open(self.pid_fname, "w") fd.close() - # Do not use 'with' or close the file. Will be closed when lock is released fd = open(self.pid_fname, "r+") try: fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) @@ -76,9 +105,8 @@ def register(self, pid=None, started_time=None): # if none, will default to os. self.fd = fd return - # release the lock on the PID file - # also purge the info from the file def relinquish(self): + """Release the lock on the PID file and remove the PID information.""" self.fd.seek(0) self.fd.truncate() self.fd.flush() @@ -87,39 +115,34 @@ def relinquish(self): self.mypid = None self.lock_in_place = False - # Will update self.mypid and self.lock_in_place def load_registered(self): + """Load the registered PID from the PID file. + + Updates the instance's PID and lock status based on the contents of the PID file. + """ if self.fd is not None: - return # we own it, so nothing to do + return - # make sure it is initialized (to not registered) self.reset_to_default() self.lock_in_place = False - # else I don't own it if not os.path.isfile(self.pid_fname): return with open(self.pid_fname) as fd: try: fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) - # if I can get a lock, it means that there is no process return except OSError: - # there is a process - # I will read it even if locked, so that I can report what the PID is - # if the data is corrupted, I will deal with it later lines = fd.readlines() self.lock_in_place = True try: self.parse_pid_file_content(lines) except Exception: - # Data is corrupted, cannot get the PID, masking exceptions return if not check_pid(self.mypid): - # found not running self.mypid = None return @@ -129,12 +152,26 @@ def load_registered(self): ############################### def format_pid_file_content(self): + """Format the content to be written to the PID file. + + Returns: + str: Formatted string containing the PID and start time. + """ return f"PID: {self.mypid}\nStarted: {time.ctime(self.started_time)}\n" def reset_to_default(self): + """Reset the instance attributes to their default values.""" self.mypid = None def parse_pid_file_content(self, lines): + """Parse the content of the PID file and update the instance attributes. + + Args: + lines (list): Lines read from the PID file. + + Raises: + RuntimeError: If the PID file is corrupted or invalid. + """ self.mypid = None if len(lines) < 2: raise RuntimeError("Corrupted lock file: too short") @@ -153,19 +190,37 @@ def parse_pid_file_content(self, lines): ####################################################### -# -# self.mypid and self.parent_pid are valid only -# if self.fd is valid or after a load + + class PidWParentSupport(PidSupport): + """Extended PidSupport class that includes parent PID information. + + This class manages PID files while also recording the parent PID. + + Attributes: + parent_pid (int): The parent process PID. + """ + def __init__(self, pid_fname): + """Initialize the PidWParentSupport class. + + Args: + pid_fname (str): The filename of the PID file. + """ PidSupport.__init__(self, pid_fname) self.parent_pid = None - # open the pid_file and gain the exclusive lock - # also write in the PID information - def register( - self, parent_pid, pid=None, started_time=None # if none, will default to os.getpid() - ): # if none, use time.time() + def register(self, parent_pid, pid=None, started_time=None): + """Register the current process and its parent by writing PIDs to the PID file. + + Args: + parent_pid (int): The parent process PID. + pid (int, optional): The PID to register. Defaults to the current process PID. + started_time (float, optional): The time when the process started. Defaults to the current time. + + Raises: + RuntimeError: If a PID is already registered in the same object. + """ if self.fd is not None: raise RuntimeError("Cannot register two pids in the same object!") @@ -178,13 +233,27 @@ def register( ############################### def format_pid_file_content(self): + """Format the content to be written to the PID file. + + Returns: + str: Formatted string containing the PID, parent PID, and start time. + """ return f"PID: {self.mypid}\nParent PID:{self.parent_pid}\nStarted: {time.ctime(self.started_time)}\n" def reset_to_default(self): + """Reset the instance attributes to their default values, including parent PID.""" PidSupport.reset_to_default(self) self.parent_pid = None def parse_pid_file_content(self, lines): + """Parse the content of the PID file and update the instance attributes. + + Args: + lines (list): Lines read from the PID file. + + Raises: + RuntimeError: If the PID file is corrupted or invalid. + """ self.mypid = None self.parent_pid = None @@ -215,14 +284,25 @@ def parse_pid_file_content(self, lines): def termsignal(signr, frame): + """Handle termination signals by raising a KeyboardInterrupt. + + Args: + signr (int): Signal number. + frame (FrameType): Current stack frame. + + Raises: + KeyboardInterrupt: Always raised with the signal number. + """ raise KeyboardInterrupt("Received signal %s" % signr) def register_sighandler(): + """Register signal handlers for SIGTERM and SIGQUIT.""" signal.signal(signal.SIGTERM, termsignal) signal.signal(signal.SIGQUIT, termsignal) def unregister_sighandler(): + """Unregister the custom signal handlers, resetting them to the default.""" signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGQUIT, signal.SIG_DFL) diff --git a/lib/pubCrypto.py b/lib/pubCrypto.py index 69d709b4d..166754f71 100644 --- a/lib/pubCrypto.py +++ b/lib/pubCrypto.py @@ -1,21 +1,20 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""pubCrypto - This module defines classes to perform public key cryptography +"""pubCrypto - This module defines classes to perform public key cryptography. -It uses M2Crypto: https://github.com/mcepl/M2Crypto +It uses M2Crypto: https://github.com/mcepl/M2Crypto, a wrapper around OpenSSL: https://www.openssl.org/docs/man1.1.1/man3/ -NOTE For convenience and consistency w/ previous versions of this module, Encryption/Signing functions - (b64, hex and .encrypt() ) accept bytes-like objects (bytes, bytearray) and also Unicode strings - utf-8 encoded (defaults.BINARY_ENCODING_CRYPTO). - B64 and hex Decryption functions, consistent w/ Python's binascii.a2b_* functions, accept bytes and - Unicode strings containing only ASCII characters, .decrypt() only accepts bytes-like objects (such as bytes, - bytearray and other objects that support the buffer protocol). - All these functions return bytes. - - Keys can be loaded from AnyStr (str, bytes, bytearray). Keys are returned as bytes string. Key files are binary. +Note: + For convenience and consistency with previous versions of this module, Encryption/Signing functions + (b64, hex, and .encrypt()) accept bytes-like objects (bytes, bytearray) and Unicode strings + encoded in utf-8 (defaults.BINARY_ENCODING_CRYPTO). B64 and hex Decryption functions, consistent with + Python's binascii.a2b_* functions, accept bytes and Unicode strings containing only ASCII characters. + The .decrypt() method only accepts bytes-like objects (such as bytes, bytearray, and other objects that + support the buffer protocol). All these functions return bytes. + Keys can be loaded from AnyStr (str, bytes, bytearray). Keys are returned as bytes strings. Key files are binary. """ import binascii @@ -29,109 +28,63 @@ def passphrase_callback(v: bool, prompt1: str = "Enter passphrase:", prompt2: str = "Verify passphrase:"): - # Example callback (uncomment for manual testing) - # str3 = prompt1 + prompt2 - # return str3 # Optional return + """Placeholder for a passphrase callback function. + + Args: + v (bool): Placeholder argument. + prompt1 (str): Prompt for entering the passphrase. + prompt2 (str): Prompt for verifying the passphrase. + + Returns: + None + """ pass def _default_callback(*args): - """Return a dummy passphrase + """Return a dummy passphrase. - Good for service key processing where human not present. - Used as a callback in the :mod:M2Crypto module: - A Python callable object that is invoked to acquire a passphrase with which to unlock the key. - The default is :func:M2Crypto.util.passphrase_callback :: - def passphrase_callback(v: bool, prompt1: str = 'Enter passphrase:', prompt2: str = 'Verify passphrase:' - ): -> Optional[str] + This function is used as a callback for service key processing where no human interaction is present. + It is used in the M2Crypto module to acquire a passphrase for unlocking the key. Args: - *args: + *args: Variable arguments passed to the callback. Returns: - Optional[str]: str or None - + Optional[str]: Dummy passphrase or None. """ - # TODO: according to the M2Crypto spec this function is expected to return a str (unicode) - # but doing so fails the unit test (test_factory_glideFactoryConfig.py), leaving the bytes now - # maybe the fixture w/ the key should be foxed (fixtures/factory/work-dir/rsa.key/rsa.key.bak) - # return "default" return b"default" class PubCryptoError(Exception): - """Exception masking M2Crypto exceptions, - to ease error handling in modules importing pubCrypto - """ + """Custom exception class to mask M2Crypto exceptions. - def __init__(self, msg): - Exception.__init__(self, msg) + This exception is used to ease error handling in modules importing pubCrypto. + Args: + msg (str): The error message. + """ -###################### -# -# Available paddings: -# M2Crypto.RSA.no_padding -# M2Crypto.RSA.pkcs1_padding -# M2Crypto.RSA.sslv23_padding -# M2Crypto.RSA.pkhas1_oaep_padding -# -# Available sign algos: -# 'sha1' -# 'sha224' -# 'sha256', -# 'ripemd160' -# 'md5' -# -# Available ciphers: -# too many to list them all -# try 'man enc' -# a few of them are -# 'aes_128_cbc' -# 'aes_128_ofb -# 'aes_256_cbc' -# 'aes_256_cfb' -# 'bf_cbc' -# 'des3' -# -###################### + def __init__(self, msg): + super().__init__(msg) -########################################################################## -# Public part of the RSA key class PubRSAKey: - """Public part of the RSA key""" + """Class representing the public part of an RSA key.""" def __init__( self, key_str=None, key_fname=None, encryption_padding=M2Crypto.RSA.pkcs1_oaep_padding, sign_algo="sha256" ): - """Constructor for RSA public key - - One and only one of the two key_str or key_fname must be defined (not None) - - Available paddings: - M2Crypto.RSA.no_padding - M2Crypto.RSA.pkcs1_padding - M2Crypto.RSA.sslv23_padding - M2Crypto.RSA.pkhas1_oaep_padding - - Available sign algos: - 'sha1', 'sha224', 'sha256', 'ripemd160', 'md5' - - Available ciphers, too many to list them all, try `man enc` a few of them are: - 'aes_128_cbc' - 'aes_128_ofb - 'aes_256_cbc' - 'aes_256_cfb' - 'bf_cbc' - 'des3' + """Initialize a PubRSAKey instance. Args: - key_str (str/bytes): string w/ base64 encoded key - Must be bytes-like object or ASCII string, like base64 inputs - key_fname (str): key file path - encryption_padding: - sign_algo (str): valid signing algorithm (default: 'sha256') + key_str (str | bytes, optional): Base64 encoded key as a string or bytes. + key_fname (str, optional): Path to the key file. + encryption_padding (int): Padding scheme for encryption. Defaults to M2Crypto.RSA.pkcs1_oaep_padding. + sign_algo (str): Signing algorithm to use. Defaults to 'sha256'. + + Raises: + M2Crypto.RSA.RSAError: If there is an error loading the key. """ self.rsa_key = None self.has_private = False @@ -141,29 +94,21 @@ def __init__( try: self.load(key_str, key_fname) except M2Crypto.RSA.RSAError as e: - # Put some additional information in the exception object to be printed later on - # This helps operator understand which file might be corrupted so that they can try to delete it e.key_fname = key_fname e.cwd = os.getcwd() - raise e from e # Need to raise a new exception to have the modified values (only raise keeps the original) - return - - ########################################### - # Load key functions + raise e from e def load(self, key_str=None, key_fname=None): - """Load key from a string or a file - - Only one of the two can be defined (not None) - Load the key into self.rsa_key + """Load an RSA key from a string or file. Args: - key_str (str/bytes): string w/ base64 encoded key - Must be bytes-like object or ASCII string, like base64 inputs - key_fname (str): file name + key_str (str | bytes, optional): Base64 encoded key as a string or bytes. + key_fname (str, optional): Path to the key file. Raises: - ValueError: if both key_str and key_fname are defined + ValueError: If both key_str and key_fname are defined. + PubCryptoError: If there is an error loading the key from the string. + M2Crypto.BIO.BIOError: If there is an error opening the key file. """ if key_str is not None: if key_fname is not None: @@ -177,94 +122,72 @@ def load(self, key_str=None, key_fname=None): elif key_fname is not None: bio = M2Crypto.BIO.openfile(key_fname) if bio is None: - # File not found or wrong permissions raise M2Crypto.BIO.BIOError(M2Crypto.Err.get_error()) self._load_from_bio(bio) else: self.rsa_key = None - return - # meant to be internal def _load_from_bio(self, bio): - """Load the key into the object - - Protected, overridden by child classes. Used by load + """Load the key into the object from a BIO. Args: - bio (M2Crypto.BIO.BIO): BIO to retrieve the key from (file or memory buffer) + bio (M2Crypto.BIO.BIO): BIO object to load the key from. """ self.rsa_key = M2Crypto.RSA.load_pub_key_bio(bio) self.has_private = False - return - - ########################################### - # Save key functions def save(self, key_fname): - """Save the key to a file - - The file is binary and is written using M2Crypto.BIO + """Save the RSA key to a file. Args: - key_fname (str): file name - - Returns: + key_fname (str): Path to the file where the key should be saved. + Raises: + Exception: If there is an error saving the key, the file is removed. """ bio = M2Crypto.BIO.openfile(key_fname, "wb") try: - return self._save_to_bio(bio) + self._save_to_bio(bio) except Exception: - # need to remove the file in case of error bio.close() del bio os.unlink(key_fname) raise - # like save, but return a string def get(self): - """Retrieve the key + """Get the RSA key as bytes. Returns: - bytes: key - + bytes: The RSA key as bytes. """ bio = M2Crypto.BIO.MemoryBuffer() self._save_to_bio(bio) return bio.read() - # meant to be internal def _save_to_bio(self, bio): - """Save the key from the object - - Protected, overridden by child classes. Used by save and get + """Save the RSA key to a BIO object. Args: - bio (M2Crypto.BIO.BIO): BIO object to save the key to (file or memory buffer) + bio (M2Crypto.BIO.BIO): BIO object to save the key to. Returns: - int: status returned by M2Crypto.m2.rsa_write_pub_key - Raises: - KeyError: if the key is not defined + int: Status code returned by M2Crypto. """ if self.rsa_key is None: raise KeyError("No RSA key") - return self.rsa_key.save_pub_key_bio(bio) - ########################################### - # encrypt/verify data inline - def encrypt(self, data): - """Encrypt the data + """Encrypt data using the RSA key. Args: - data (AnyStr): string to encrypt. bytes-like or str. If unicode, - it is encoded using utf-8 before being encrypted. - len(data) must be less than len(key) + data (str | bytes): The data to encrypt. Returns: - bytes: encrypted data + bytes: The encrypted data. + + Raises: + KeyError: If the RSA key is not defined. """ if self.rsa_key is None: raise KeyError("No RSA key") @@ -272,26 +195,39 @@ def encrypt(self, data): return self.rsa_key.public_encrypt(bdata, self.encryption_padding) def encrypt_base64(self, data): - """like encrypt, but base64 encoded""" + """Encrypt data and encode it in base64. + + Args: + data (str | bytes): The data to encrypt. + + Returns: + bytes: The base64-encoded encrypted data. + """ return binascii.b2a_base64(self.encrypt(data)) def encrypt_hex(self, data): - """like encrypt, but hex encoded""" + """Encrypt data and encode it in hexadecimal. + + Args: + data (str | bytes): The data to encrypt. + + Returns: + bytes: The hex-encoded encrypted data. + """ return binascii.b2a_hex(self.encrypt(data)) def verify(self, data, signature): - """Verify that the signature gets you the data + """Verify a signature against the data. Args: - data (AnyStr): string to verify. bytes-like or str. If unicode, - it is encoded using utf-8 before being encrypted. : - signature (bytes): signature to use in the verification + data (str | bytes): The data to verify. + signature (bytes): The signature to verify. Returns: - bool: True if the signature gets you the data + bool: True if the signature is valid, False otherwise. Raises: - KeyError: if the key is not defined + KeyError: If the RSA key is not defined. """ if self.rsa_key is None: raise KeyError("No RSA key") @@ -299,18 +235,32 @@ def verify(self, data, signature): return self.rsa_key.verify(bdata, signature, self.sign_algo) def verify_base64(self, data, signature): - """like verify, but the signature is base64 encoded""" + """Verify a base64-encoded signature against the data. + + Args: + data (str | bytes): The data to verify. + signature (bytes): The base64-encoded signature to verify. + + Returns: + bool: True if the signature is valid, False otherwise. + """ return self.verify(data, binascii.a2b_base64(signature)) def verify_hex(self, data, signature): - """like verify, but the signature is hex encoded""" + """Verify a hex-encoded signature against the data. + + Args: + data (str | bytes): The data to verify. + signature (bytes): The hex-encoded signature to verify. + + Returns: + bool: True if the signature is valid, False otherwise. + """ return self.verify(data, binascii.a2b_hex(signature)) -########################################################################## -# Public and private part of the RSA key class RSAKey(PubRSAKey): - """Public and private part of the RSA key""" + """Class representing both the public and private parts of an RSA key.""" def __init__( self, @@ -321,19 +271,28 @@ def __init__( encryption_padding=M2Crypto.RSA.pkcs1_oaep_padding, sign_algo="sha256", ): + """Initialize an RSAKey instance. + + Args: + key_str (str | bytes, optional): Base64 encoded key as a string or bytes. + key_fname (str, optional): Path to the key file. + private_cipher (str): Cipher to use for private key encryption. Defaults to 'aes_256_cbc'. + private_callback (callable): Callback function for the private key passphrase. + encryption_padding (int): Padding scheme for encryption. Defaults to M2Crypto.RSA.pkcs1_oaep_padding. + sign_algo (str): Signing algorithm to use. Defaults to 'sha256'. + """ self.private_cipher = private_cipher self.private_callback = private_callback - PubRSAKey.__init__(self, key_str, key_fname, encryption_padding, sign_algo) - return + super().__init__(key_str, key_fname, encryption_padding, sign_algo) - ########################################### - # Downgrade to PubRSAKey def PubRSAKey(self): - """Return the public part only. Downgrade to PubRSAKey + """Return the public part of the RSA key. Returns: - PubRSAKey: an object w/ only the public part of the key + PubRSAKey: An instance of PubRSAKey containing only the public part of the RSA key. + Raises: + KeyError: If the RSA key is not defined. """ if self.rsa_key is None: raise KeyError("No RSA key") @@ -343,92 +302,96 @@ def PubRSAKey(self): public_key = bio.read() return PubRSAKey(key_str=public_key, encryption_padding=self.encryption_padding, sign_algo=self.sign_algo) - ########################################### - # Load key functions - def _load_from_bio(self, bio): - """Load the key into the object - - Internal, overrides the parent _load_from_bio. Used by load + """Load the RSA key from a BIO object. Args: - bio (M2Crypto.BIO.BIO): + bio (M2Crypto.BIO.BIO): BIO object to load the key from. """ self.rsa_key = M2Crypto.RSA.load_key_bio(bio, self.private_callback) self.has_private = True - return - - ########################################### - # Save key functions def _save_to_bio(self, bio): - """Save the key from the object - - Protected, overridden by child classes. Used by save and get + """Save the RSA key to a BIO object. Args: - bio (M2Crypto.BIO.BIO): BIO to save the key into (file or memory buffer) + bio (M2Crypto.BIO.BIO): BIO object to save the key to. Returns: + int: Status code returned by M2Crypto. Raises: - KeyError: if the key is not defined + KeyError: If the RSA key is not defined. """ if self.rsa_key is None: raise KeyError("No RSA key") return self.rsa_key.save_key_bio(bio, self.private_cipher, self.private_callback) - ########################################### - # generate key function def new(self, key_length=None, exponent=65537): - """Refresh/Generate a new key and store it in the object + """Generate a new RSA key. Args: - key_length (int/None): if no key_length provided, use the length of the existing one - exponent (int): exponent + key_length (int, optional): Length of the RSA key in bits. If None, the length of the existing key is used. + exponent (int): Public exponent value. Defaults to 65537. + + Raises: + KeyError: If no key length is provided and there is no existing key. """ if key_length is None: if self.rsa_key is None: raise KeyError("No RSA key and no key length provided") key_length = len(self.rsa_key) self.rsa_key = M2Crypto.RSA.gen_key(key_length, exponent) - return - - ########################################### def decrypt(self, data): - """Decrypt data inline + """Decrypt data using the RSA key. Args: - data (bytes): data to decrypt + data (bytes): The data to decrypt. Returns: - bytes: decrypted string + bytes: The decrypted data. Raises: - KeyError: if the key is not defined + KeyError: If the RSA key is not defined. """ if self.rsa_key is None: raise KeyError("No RSA key") return self.rsa_key.private_decrypt(data, self.encryption_padding) def decrypt_base64(self, data): - """like decrypt, but base64 encoded""" + """Decrypt base64-encoded data. + + Args: + data (bytes): The base64-encoded data to decrypt. + + Returns: + bytes: The decrypted data. + """ return self.decrypt(binascii.a2b_base64(data)) def decrypt_hex(self, data): - """like decrypt, but hex encoded""" + """Decrypt hex-encoded data. + + Args: + data (bytes): The hex-encoded data to decrypt. + + Returns: + bytes: The decrypted data. + """ return self.decrypt(binascii.a2b_hex(data)) def sign(self, data): - """Sign data inline. Same as private_encrypt + """Sign data using the RSA key. Args: - data (AnyStr): string to encrypt. If unicode, it is encoded using utf-8 before being encrypted. - len(data) must be less than len(key) + data (str | bytes): The data to sign. Returns: - bytes: encrypted data + bytes: The signed data. + + Raises: + KeyError: If the RSA key is not defined. """ if self.rsa_key is None: raise KeyError("No RSA key") @@ -436,11 +399,25 @@ def sign(self, data): return self.rsa_key.sign(bdata, self.sign_algo) def sign_base64(self, data): - """like sign, but base64 encoded""" + """Sign data and encode it in base64. + + Args: + data (str | bytes): The data to sign. + + Returns: + bytes: The base64-encoded signed data. + """ return binascii.b2a_base64(self.sign(data)) def sign_hex(self, data): - """like sign, but hex encoded""" + """Sign data and encode it in hexadecimal. + + Args: + data (str | bytes): The data to sign. + + Returns: + bytes: The hex-encoded signed data. + """ return binascii.b2a_hex(self.sign(data)) diff --git a/lib/rrdSupport.py b/lib/rrdSupport.py index 4addf2eb8..2600ea6f7 100644 --- a/lib/rrdSupport.py +++ b/lib/rrdSupport.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""This module implements the basic functions needed to interface to rrdtool +"""This module implements the basic functions needed to interface with rrdtool. """ import os @@ -14,86 +14,78 @@ try: import rrdtool # pylint: disable=import-error except ImportError: - # Will use the binary tools if the Python library is not available - pass + pass # Will use the binary tools if the Python library is not available class BaseRRDSupport: - ############################################################# + """Base class providing common support for working with RRD files.""" + def __init__(self, rrd_obj): + """Initialize the BaseRRDSupport class. + + Args: + rrd_obj (object): The RRD object, either from the rrdtool module or a command-line wrapper. + """ self.rrd_obj = rrd_obj def isDummy(self): + """Check if the RRD object is a dummy (None). + + Returns: + bool: True if the RRD object is None, False otherwise. + """ return self.rrd_obj is None - ############################################################# - # The default will do nothing - # Children should overwrite it, if needed def get_disk_lock(self, fname): + """Get a disk lock for the specified file. + + This is a no-op in the base class. It should be overridden in child classes if needed. + + Args: + fname (str): The filename to lock. + + Returns: + DummyDiskLock: A dummy lock object. + """ return dummy_disk_lock() - ############################################################# - # The default will do nothing - # Children should overwrite it, if needed def get_graph_lock(self, fname): + """Get a graph lock for the specified file. + + This is a no-op in the base class. It should be overridden in child classes if needed. + + Args: + fname (str): The filename to lock. + + Returns: + DummyDiskLock: A dummy lock object. + """ return dummy_disk_lock() - ############################################################# def create_rrd(self, rrdfname, rrd_step, rrd_archives, rrd_ds): - """ - Create a new RRD archive - - Arguments: - rrdfname - File path name of the RRD archive - rrd_step - base interval in seconds - rrd_archives - list of tuples, each containing the following fileds (in order) - CF - consolidation function (usually AVERAGE) - xff - xfiles factor (fraction that can be unknown) - steps - how many of these primary data points are used to build a consolidated data point - rows - how many generations of data values are kept - rrd_ds - a tuple containing the following fields (in order) - ds-name - attribute name - DST - Data Source Type (usually GAUGE) - heartbeat - the maximum number of seconds that may pass between two updates before it becomes unknown - min - min value - max - max value - - For more details see - http://oss.oetiker.ch/rrdtool/doc/rrdcreate.en.html + """Create a new RRD archive. + + Args: + rrdfname (str): The file path name of the RRD archive. + rrd_step (int): Base interval in seconds. + rrd_archives (list): List of tuples containing archive settings. + rrd_ds (tuple): Tuple containing data source settings. """ self.create_rrd_multi(rrdfname, rrd_step, rrd_archives, (rrd_ds,)) - return - ############################################################# def create_rrd_multi(self, rrdfname, rrd_step, rrd_archives, rrd_ds_arr): - """ - Create a new RRD archive - - Arguments: - rrdfname - File path name of the RRD archive - rrd_step - base interval in seconds - rrd_archives - list of tuples, each containing the following fileds (in order) - CF - consolidation function (usually AVERAGE) - xff - xfiles factor (fraction that can be unknown) - steps - how many of these primary data points are used to build a consolidated data point - rows - how many generations of data values are kept - rrd_ds_arr - list of tuples, each containing the following fields (in order) - ds-name - attribute name - DST - Data Source Type (usually GAUGE) - heartbeat - the maximum number of seconds that may pass between two updates before it becomes unknown - min - min value - max - max value - - For more details see - http://oss.oetiker.ch/rrdtool/doc/rrdcreate.en.html + """Create a new RRD archive with multiple data sources. + + Args: + rrdfname (str): The file path name of the RRD archive. + rrd_step (int): Base interval in seconds. + rrd_archives (list): List of tuples containing archive settings. + rrd_ds_arr (list): List of tuples containing data source settings. """ if self.rrd_obj is None: return # nothing to do in this case - # make the start time to be aligned on the rrd_step boundary - # This is needed for optimal resoultion selection start_time = (int(time.time() - 1) / rrd_step) * rrd_step - # print (rrdfname,start_time,rrd_step)+rrd_ds args = [str(rrdfname), "-b", "%li" % start_time, "-s", "%i" % rrd_step] for rrd_ds in rrd_ds_arr: args.append("DS:%s:%s:%i:%s:%s" % rrd_ds) @@ -105,21 +97,17 @@ def create_rrd_multi(self, rrdfname, rrd_step, rrd_archives, rrd_ds_arr): self.rrd_obj.create(*args) finally: lck.close() - return - ############################################################# def update_rrd(self, rrdfname, time, val): - """ - Create an RRD archive with a new value + """Update an RRD archive with a new value. - Arguments: - rrdfname - File path name of the RRD archive - time - When was the value taken - val - What vas the value + Args: + rrdfname (str): The file path name of the RRD archive. + time (int): The time at which the value was taken. + val (str): The value to update. """ if self.rrd_obj is None: - # nothing to do in this case - return + return # nothing to do in this case lck = self.get_disk_lock(rrdfname) try: @@ -127,32 +115,26 @@ def update_rrd(self, rrdfname, time, val): finally: lck.close() - return - - ############################################################# def update_rrd_multi(self, rrdfname, time, val_dict): - """ - Create an RRD archive with a set of values (possibly all of the supported) + """Update an RRD archive with multiple values. - Arguments: - rrdfname - File path name of the RRD archive - time - When was the value taken - val_dict - What was the value + Args: + rrdfname (str): The file path name of the RRD archive. + time (int): The time at which the values were taken. + val_dict (dict): A dictionary of data source names to values. """ if self.rrd_obj is None: return # nothing to do in this case args = [str(rrdfname)] - ds_names = sorted(val_dict.keys()) - ds_names_real = [] ds_vals = [] - for ds_name in ds_names: - if val_dict[ds_name] is not None: - ds_vals.append("%s" % val_dict[ds_name]) + for ds_name, ds_val in val_dict.items(): + if ds_val is not None: + ds_vals.append("%s" % ds_val) ds_names_real.append(ds_name) - if len(ds_names_real) == 0: + if not ds_names_real: return args.append("-t") @@ -161,14 +143,10 @@ def update_rrd_multi(self, rrdfname, time, val_dict): lck = self.get_disk_lock(rrdfname) try: - # print args self.rrd_obj.update(*args) finally: lck.close() - return - - ############################################################# def rrd2graph( self, fname, @@ -185,45 +163,33 @@ def rrd2graph( trend=None, img_format="PNG", ): - """ - Create a graph file out of a set of RRD files - - Arguments: - fname - File path name of the graph file - rrd_step - Which step should I use in the RRD files - ds_name - Which attribute should I use in the RRD files - ds_type - Which type should I use in the RRD files - start,end - Time points in utime format - width,height - Size of the graph - title - Title to put in the graph - rrd_files - list of RRD files, each being a tuple of (in order) - rrd_id - logical name of the RRD file (will be the graph label) - rrd_fname - name of the RRD file - graph_type - Graph type (LINE, STACK, AREA) - grpah_color - Graph color in rrdtool format - cdef_arr - list of derived RRD values - if present, only the cdefs will be plotted - each elsement is a tuple of (in order) - rrd_id - logical name of the RRD file (will be the graph label) - cdef_formula - Derived formula in rrdtool format - graph_type - Graph type (LINE, STACK, AREA) - grpah_color - Graph color in rrdtool format - trend - Trend value in seconds (if desired, None else) - - For more details see - http://oss.oetiker.ch/rrdtool/doc/rrdcreate.en.html + """Create a graph file from a set of RRD files. + + Args: + fname (str): The file path name of the graph file. + rrd_step (int): The step to use in the RRD files. + ds_name (str): The attribute to use in the RRD files. + ds_type (str): The type of data source to use in the RRD files. + start (int): The start time in Unix time format. + end (int): The end time in Unix time format. + width (int): The width of the graph. + height (int): The height of the graph. + title (str): The title of the graph. + rrd_files (list): List of tuples, each containing RRD file information. + cdef_arr (list, optional): List of derived RRD values. Defaults to None. + trend (int, optional): Trend value in seconds. Defaults to None. + img_format (str): The image format of the graph file. Defaults to "PNG". """ if self.rrd_obj is None: return # nothing to do in this case - multi_rrd_files = [] - for rrd_file in rrd_files: - multi_rrd_files.append((rrd_file[0], rrd_file[1], ds_name, ds_type, rrd_file[2], rrd_file[3])) - return self.rrd2graph_multi( + multi_rrd_files = [ + (rrd_file[0], rrd_file[1], ds_name, ds_type, rrd_file[2], rrd_file[3]) for rrd_file in rrd_files + ] + self.rrd2graph_multi( fname, rrd_step, start, end, width, height, title, multi_rrd_files, cdef_arr, trend, img_format ) - ############################################################# def rrd2graph_now( self, fname, @@ -239,73 +205,46 @@ def rrd2graph_now( trend=None, img_format="PNG", ): - """ - Create a graph file out of a set of RRD files - - Arguments: - fname - File path name of the graph file - rrd_step - Which step should I use in the RRD files - ds_name - Which attribute should I use in the RRD files - ds_type - Which type should I use in the RRD files - period - start=now-period, end=now - width,height - Size of the graph - title - Title to put in the graph - rrd_files - list of RRD files, each being a tuple of (in order) - rrd_id - logical name of the RRD file (will be the graph label) - rrd_fname - name of the RRD file - graph_type - Graph type (LINE, STACK, AREA) - grpah_color - Graph color in rrdtool format - cdef_arr - list of derived RRD values - if present, only the cdefs will be plotted - each elsement is a tuple of (in order) - rrd_id - logical name of the RRD file (will be the graph label) - cdef_formula - Derived formula in rrdtool format - graph_type - Graph type (LINE, STACK, AREA) - grpah_color - Graph color in rrdtool format - trend - Trend value in seconds (if desired, None else) - - For more details see - http://oss.oetiker.ch/rrdtool/doc/rrdcreate.en.html + """Create a graph file from a set of RRD files for the current time. + + Args: + fname (str): The file path name of the graph file. + rrd_step (int): The step to use in the RRD files. + ds_name (str): The attribute to use in the RRD files. + ds_type (str): The type of data source to use in the RRD files. + period (int): The time period for the graph. + width (int): The width of the graph. + height (int): The height of the graph. + title (str): The title of the graph. + rrd_files (list): List of tuples, each containing RRD file information. + cdef_arr (list, optional): List of derived RRD values. Defaults to None. + trend (int, optional): Trend value in seconds. Defaults to None. + img_format (str): The image format of the graph file. Defaults to "PNG". """ now = int(time.time()) start = ((now - period) / rrd_step) * rrd_step end = ((now - 1) / rrd_step) * rrd_step - return self.rrd2graph( + self.rrd2graph( fname, rrd_step, ds_name, ds_type, start, end, width, height, title, rrd_files, cdef_arr, trend, img_format ) - ############################################################# def rrd2graph_multi( self, fname, rrd_step, start, end, width, height, title, rrd_files, cdef_arr=None, trend=None, img_format="PNG" ): - """ - Create a graph file out of a set of RRD files - - Arguments: - fname - File path name of the graph file - rrd_step - Which step should I use in the RRD files - start,end - Time points in utime format - width,height - Size of the graph - title - Title to put in the graph - rrd_files - list of RRD files, each being a tuple of (in order) - rrd_id - logical name of the RRD file (will be the graph label) - rrd_fname - name of the RRD file - ds_name - Which attribute should I use in the RRD files - ds_type - Which type should I use in the RRD files - graph_type - Graph type (LINE, STACK, AREA) - graph_color - Graph color in rrdtool format - cdef_arr - list of derived RRD values - if present, only the cdefs will be plotted - each elsement is a tuple of (in order) - rrd_id - logical name of the RRD file (will be the graph label) - cdef_formula - Derived formula in rrdtool format - graph_type - Graph type (LINE, STACK, AREA) - grpah_color - Graph color in rrdtool format - trend - Trend value in seconds (if desired, None else) - img_format - format of the graph file (default PNG) - - For more details see - http://oss.oetiker.ch/rrdtool/doc/rrdcreate.en.html + """Create a graph file from a set of RRD files with multiple data sources. + + Args: + fname (str): The file path name of the graph file. + rrd_step (int): The step to use in the RRD files. + start (int): The start time in Unix time format. + end (int): The end time in Unix time format. + width (int): The width of the graph. + height (int): The height of the graph. + title (str): The title of the graph. + rrd_files (list): List of tuples, each containing RRD file information. + cdef_arr (list, optional): List of derived RRD values. Defaults to None. + trend (int, optional): Trend value in seconds. Defaults to None. + img_format (str): The image format of the graph file. Defaults to "PNG". """ if self.rrd_obj is None: return # nothing to do in this case @@ -335,29 +274,22 @@ def rrd2graph_multi( ds_name = rrd_file[2] ds_type = rrd_file[3] if trend is None: - args.append(str(f"DEF:{ds_id}={ds_fname}:{ds_name}:{ds_type}")) + args.append(f"DEF:{ds_id}={ds_fname}:{ds_name}:{ds_type}") else: - args.append(str(f"DEF:{ds_id}_inst={ds_fname}:{ds_name}:{ds_type}")) - args.append(str("CDEF:%s=%s_inst,%i,TREND" % (ds_id, ds_id, trend))) + args.append(f"DEF:{ds_id}_inst={ds_fname}:{ds_name}:{ds_type}") + args.append(f"CDEF:{ds_id}={ds_id}_inst,{trend},TREND") plot_arr = rrd_files if cdef_arr is not None: - # plot the cdefs not the files themselves, when we have them plot_arr = cdef_arr - for cdef_el in cdef_arr: ds_id = cdef_el[0] cdef_formula = cdef_el[1] - ds_graph_type = rrd_file[2] - ds_color = rrd_file[3] - args.append(str(f"CDEF:{ds_id}={cdef_formula}")) + args.append(f"CDEF:{ds_id}={cdef_formula}") else: - plot_arr = [] - for rrd_file in rrd_files: - plot_arr.append((rrd_file[0], None, rrd_file[4], rrd_file[5])) + plot_arr = [(rrd_file[0], None, rrd_file[4], rrd_file[5]) for rrd_file in rrd_files] if plot_arr[0][2] == "STACK": - # add an invisible baseline to stack upon args.append("AREA:0") for plot_el in plot_arr: @@ -377,79 +309,51 @@ def rrd2graph_multi( except Exception: print("Failed graph: %s" % str(args)) - return args - - ############################################################# def rrd2graph_multi_now( self, fname, rrd_step, period, width, height, title, rrd_files, cdef_arr=None, trend=None, img_format="PNG" ): - """ - Create a graph file out of a set of RRD files - - Arguments: - fname - File path name of the graph file - rrd_step - Which step should I use in the RRD files - period - start=now-period, end=now - width,height - Size of the graph - title - Title to put in the graph - rrd_files - list of RRD files, each being a tuple of (in order) - rrd_id - logical name of the RRD file (will be the graph label) - rrd_fname - name of the RRD file - ds_name - Which attribute should I use in the RRD files - ds_type - Which type should I use in the RRD files - graph_type - Graph type (LINE, STACK, AREA) - graph_color - Graph color in rrdtool format - cdef_arr - list of derived RRD values - if present, only the cdefs will be plotted - each elsement is a tuple of (in order) - rrd_id - logical name of the RRD file (will be the graph label) - cdef_formula - Derived formula in rrdtool format - graph_type - Graph type (LINE, STACK, AREA) - grpah_color - Graph color in rrdtool format - trend - Trend value in seconds (if desired, None else) - img_format - format of the graph file (default PNG) - - For more details see - http://oss.oetiker.ch/rrdtool/doc/rrdcreate.en.html + """Create a graph file from a set of RRD files for the current time with multiple data sources. + + Args: + fname (str): The file path name of the graph file. + rrd_step (int): The step to use in the RRD files. + period (int): The time period for the graph. + width (int): The width of the graph. + height (int): The height of the graph. + title (str): The title of the graph. + rrd_files (list): List of tuples, each containing RRD file information. + cdef_arr (list, optional): List of derived RRD values. Defaults to None. + trend (int, optional): Trend value in seconds. Defaults to None. + img_format (str): The image format of the graph file. Defaults to "PNG". """ now = int(time.time()) start = ((now - period) / rrd_step) * rrd_step end = ((now - 1) / rrd_step) * rrd_step - return self.rrd2graph_multi( - fname, rrd_step, start, end, width, height, title, rrd_files, cdef_arr, trend, img_format - ) + self.rrd2graph_multi(fname, rrd_step, start, end, width, height, title, rrd_files, cdef_arr, trend, img_format) - ################################################### def fetch_rrd(self, filename, CF, resolution=None, start=None, end=None, daemon=None): - """ - Fetch will analyze the RRD and try to retrieve the data in the - resolution requested. - - Arguments: - filename -the name of the RRD you want to fetch data from - CF -the consolidation function that is applied to the data - you want to fetch (AVERAGE, MIN, MAX, LAST) - resolution -the interval you want your values to have - (default 300 sec) - start -start of the time series (default end - 1day) - end -end of the time series (default now) - daemon -Address of the rrdcached daemon. If specified, a flush - command is sent to the server before reading the RRD - files. This allows rrdtool to return fresh data even - if the daemon is configured to cache values for a long - time. - - For more details see - http://oss.oetiker.ch/rrdtool/doc/rrdcreate.en.html + """Fetch data from an RRD file. + + Args: + filename (str): The name of the RRD file to fetch data from. + CF (str): The consolidation function to apply (AVERAGE, MIN, MAX, LAST). + resolution (int, optional): The resolution of the data in seconds. Defaults to 300. + start (int, optional): The start of the time series in Unix time. Defaults to end - 1 day. + end (int, optional): The end of the time series in Unix time. Defaults to now. + daemon (str, optional): The address of the rrdcached daemon. Defaults to None. + + Returns: + tuple: A tuple containing time info, headers, and data values. + + Raises: + RuntimeError: If the consolidation function is invalid or if the RRD file does not exist. """ if self.rrd_obj is None: return # nothing to do in this case - if CF in ("AVERAGE", "MIN", "MAX", "LAST"): - consolFunc = str(CF) - else: + if CF not in ("AVERAGE", "MIN", "MAX", "LAST"): raise RuntimeError("Invalid consolidation function %s" % CF) - args = [str(filename), consolFunc] + args = [str(filename), str(CF)] if resolution is not None: args.append("-r") args.append(str(resolution)) @@ -472,22 +376,20 @@ def fetch_rrd(self, filename, CF, resolution=None, start=None, end=None, daemon= raise RuntimeError(f"RRD file '{filename}' does not exist. Failing fetch_rrd.") def verify_rrd(self, filename, expected_dict): - """ - Verifies that an rrd matches a dictionary of datastores. - This will return a tuple of arrays ([missing],[extra]) attributes + """Verify that an RRD file matches a dictionary of expected data sources. - @param filename: filename of the rrd to verify - @param expected_dict: dictionary of expected values - @return: A two-tuple of arrays ([missing attrs],[extra attrs]) + Args: + filename (str): The filename of the RRD to verify. + expected_dict (dict): Dictionary of expected data sources. + Returns: + tuple: A tuple containing lists of missing and extra attributes. """ rrd_info = self.rrd_obj.info(filename) rrd_dict = {} for key in list(rrd_info.keys()): - # rrdtool 1.3 if key[:3] == "ds[": rrd_dict[key[3:].split("]")[0]] = None - # rrdtool 1.2 if key == "ds": for dskey in list(rrd_info[key].keys()): rrd_dict[dskey] = None @@ -502,23 +404,27 @@ def verify_rrd(self, filename, expected_dict): return (missing, extra) -# This class uses the rrdtool module for rrd_obj class ModuleRRDSupport(BaseRRDSupport): + """Class using the rrdtool Python module for RRD operations.""" + def __init__(self): - BaseRRDSupport.__init__(self, rrdtool) + """Initialize the ModuleRRDSupport class.""" + super().__init__(rrdtool) -# This class uses rrdtool cmdline for rrd_obj class ExeRRDSupport(BaseRRDSupport): + """Class using the rrdtool command-line executable for RRD operations.""" + def __init__(self): - BaseRRDSupport.__init__(self, rrdtool_exe()) + """Initialize the ExeRRDSupport class.""" + super().__init__(rrdtool_exe()) -# This class tries to use the rrdtool module for rrd_obj -# then tries the rrdtool cmdline -# will use None if needed class rrdSupport(BaseRRDSupport): + """Class that tries to use the rrdtool Python module, falls back to the command-line tool.""" + def __init__(self): + """Initialize the rrdSupport class.""" try: rrd_obj = rrdtool except NameError: @@ -526,56 +432,72 @@ def __init__(self): rrd_obj = rrdtool_exe() except Exception: rrd_obj = None - BaseRRDSupport.__init__(self, rrd_obj) - - -################################################################## -# INTERNAL, do not use directly -################################################################## + super().__init__(rrd_obj) class DummyDiskLock: - """Dummy, do nothing. Used just to get a object""" + """Dummy lock class that does nothing, used as a placeholder.""" def close(self): + """Close the dummy lock.""" return def dummy_disk_lock(): + """Return a dummy disk lock. + + Returns: + DummyDiskLock: A dummy lock object. + """ return DummyDiskLock() -################################# def string_quote_join(arglist): - l2 = [] - for e in arglist: - l2.append('"%s"' % e) - return " ".join(l2) + """Join a list of arguments with quotes. + Args: + arglist (list): List of arguments to join. -class rrdtool_exe: - """This class is a wrapper around the rrdtool client (binary) and - is used in place of the rrdtool python module, if that one is not available - - It provides also extra functions: - dump: returns an array of lines with the content instead of saving the RRD in an XML file - restore: allows the restore of a DB + Returns: + str: Joined arguments as a single string. """ + return " ".join(f'"{e}"' for e in arglist) + + +class rrdtool_exe: + """Wrapper class around the rrdtool command-line client.""" def __init__(self): + """Initialize the rrdtool_exe class.""" self.rrd_bin = (subprocessSupport.iexe_cmd("which rrdtool").split("\n")[0]).strip() def create(self, *args): + """Create a new RRD file. + + Args: + *args: Arguments for the rrdtool create command. + """ cmdline = f"{self.rrd_bin} create {string_quote_join(args)}" - outstr = subprocessSupport.iexe_cmd(cmdline) # noqa: F841 - return + subprocessSupport.iexe_cmd(cmdline) def update(self, *args): + """Update an RRD file with new data. + + Args: + *args: Arguments for the rrdtool update command. + """ cmdline = f"{self.rrd_bin} update {string_quote_join(args)}" - outstr = subprocessSupport.iexe_cmd(cmdline) # noqa: F841 - return + subprocessSupport.iexe_cmd(cmdline) def info(self, *args): + """Get information about an RRD file. + + Args: + *args: Arguments for the rrdtool info command. + + Returns: + dict: Dictionary of RRD file information. + """ cmdline = f"{self.rrd_bin} info {string_quote_join(args)}" outstr = subprocessSupport.iexe_cmd(cmdline).split("\n") outarr = {} @@ -586,35 +508,44 @@ def info(self, *args): return outarr def dump(self, *args): - """Run rrd_tool dump - - Input is usually just the file name. - Output is a list of lines, as returned from rrdtool dump. - This is different from the `dump` method provided by the `rrdtool` package (Python binding) - which outputs to a file or stdout + """Dump the contents of an RRD file. Args: - *args: rrdtool dump arguments, joined in single string for the command line + *args: Arguments for the rrdtool dump command. Returns: - str: multi-line string, output of rrd dump - + list: List of lines from the rrdtool dump output. """ cmdline = f"{self.rrd_bin} dump {string_quote_join(args)}" - outstr = subprocessSupport.iexe_cmd(cmdline).split("\n") - return outstr + return subprocessSupport.iexe_cmd(cmdline).split("\n") def restore(self, *args): + """Restore an RRD file from a dump. + + Args: + *args: Arguments for the rrdtool restore command. + """ cmdline = f"{self.rrd_bin} restore {string_quote_join(args)}" - outstr = subprocessSupport.iexe_cmd(cmdline) # noqa: F841 - return + subprocessSupport.iexe_cmd(cmdline) def graph(self, *args): + """Create a graph from RRD data. + + Args: + *args: Arguments for the rrdtool graph command. + """ cmdline = f"{self.rrd_bin} graph {string_quote_join(args)}" - outstr = subprocessSupport.iexe_cmd(cmdline) # noqa: F841 - return + subprocessSupport.iexe_cmd(cmdline) def fetch(self, *args): + """Fetch data from an RRD file. + + Args: + *args: Arguments for the rrdtool fetch command. + + Returns: + tuple: A tuple containing time info, headers, and data values. + """ cmdline = f"{self.rrd_bin} fetch {string_quote_join(args)}" outstr = subprocessSupport.iexe_cmd(cmdline).split("\n") headers = tuple(outstr.pop(0).split()) @@ -627,114 +558,115 @@ def fetch(self, *args): ftime = int(outstr[1].split(":")[0]) - tstep ltime = int(outstr[-2].split(":")[0]) times = (ftime, ltime, tstep) - outtup = (times, headers, lines) - return outtup + return (times, headers, lines) def addDataStore(filenamein, filenameout, attrlist): - """Add a list of data stores to a rrd export file - This will essentially add attributes to the end of a rrd row + """Add a list of data stores to an RRD export file. - @param filenamein: filename path of a rrd exported with rrdtool dump - @param filenameout: filename path of output xml with datastores added - @param attrlist: array of datastores to add + This function adds attributes to the end of an RRD row. + + Args: + filenamein (str): Filename path of the RRD exported with rrdtool dump. + filenameout (str): Filename path of the output XML with data stores added. + attrlist (list): List of data stores to add. """ - f = open(filenamein) - out = open(filenameout, "w") - parse = False - writenDS = False - for line in f: - if ("" in line) and (not writenDS): - for a in attrlist: - out.write("\n") - out.write(" %s \n" % a) - out.write(" GAUGE \n") - out.write(" 1800 \n") - out.write(" NaN \n") - out.write(" NaN \n") - out.write("\n") - out.write(" UNKN \n") - out.write(" 0 \n") - out.write(" 0 \n") - out.write("\n") - writenDS = True - if "" in line: - for a in attrlist: - out.write(" NaN \n") - out.write(" 0 \n") - if "" in line: - parse = False - if parse: - out.write(line[:-7]) - for a in attrlist: - out.write(" NaN ") - out.write(line[-7:]) - else: - out.write(line) - if "" in line: - parse = True + with open(filenamein) as f, open(filenameout, "w") as out: + parse = False + writenDS = False + for line in f: + if "" in line and not writenDS: + for a in attrlist: + out.write("\n") + out.write(" %s \n" % a) + out.write(" GAUGE \n") + out.write(" 1800 \n") + out.write(" NaN \n") + out.write(" NaN \n") + out.write("\n") + out.write(" UNKN \n") + out.write(" 0 \n") + out.write(" 0 \n") + out.write("\n") + writenDS = True + if "" in line: + for a in attrlist: + out.write(" NaN \n") + out.write(" 0 \n") + if "" in line: + parse = False + if parse: + out.write(line[:-7]) + for a in attrlist: + out.write(" NaN ") + out.write(line[-7:]) + else: + out.write(line) + if "" in line: + parse = True -# Function used by verifyRRD (in Factory and Frontend), invoked during reconfig/upgrade -# No logging available, output is to stdout/err def verifyHelper(filename, data_dict, fix_rrd=False, backup=True): - """Helper function for verifyRRD. - Checks one file, prints out errors. - if fix_rrd, will attempt to dump out rrd to xml, add the missing attributes, then restore. - Original file is backed up with time stamp if backup is True, obliterated otherwise. + """Helper function for verifyRRD to check and optionally fix an RRD file. Args: - filename(str): filename of rrd to check - data_dict(dict): expected dictionary - fix_rrd(bool): if True, will attempt to add missing attrs - backup(bool): if not True skip the backup of original rrd + filename (str): Filename of the RRD to check. + data_dict (dict): Expected dictionary of data sources. + fix_rrd (bool): Whether to attempt to fix missing attributes. Defaults to False. + backup (bool): Whether to back up the original RRD file. Defaults to True. Returns: - bool: True if there were some problem with the RRD file, False if all OK - + bool: True if there were problems with the RRD file, False otherwise. """ rrd_problems_found = False if not os.path.exists(filename): print(f"WARNING: {filename} missing, will be created on restart") - return + return rrd_problems_found + rrd_obj = rrdSupport() - (missing, extra) = rrd_obj.verify_rrd(filename, data_dict) + missing, extra = rrd_obj.verify_rrd(filename, data_dict) + for attr in extra: print(f"ERROR: {filename} has extra attribute {attr}") if fix_rrd: print("ERROR: fix_rrd cannot fix extra attributes") + if not fix_rrd: for attr in missing: print(f"ERROR: {filename} missing attribute {attr}") - if len(missing) > 0: + if missing: rrd_problems_found = True - if fix_rrd and (len(missing) > 0): - (f, tempfilename) = tempfile.mkstemp() - (out, tempfilename2) = tempfile.mkstemp() - (restored, restoredfilename) = tempfile.mkstemp() - os.close(out) - os.close(restored) - os.unlink(restoredfilename) - # Use exe version since dump, restore not available in rrdtool - dump_obj = rrdtool_exe() - outstr = dump_obj.dump(filename) - for line in outstr: - # dump is returning an array of strings decoded w/ utf-8 - os.write(f, f"{line}\n".encode(defaults.BINARY_ENCODING_DEFAULT)) - os.close(f) - if backup: - backup_str = str(int(time.time())) + ".backup" - print(f"Fixing {filename}... (backed up to {filename + backup_str})") - # Move file to back up location - shutil.move(filename, filename + backup_str) - else: - print(f"Fixing {filename}... (no back up)") - os.unlink(filename) - addDataStore(tempfilename, tempfilename2, missing) - dump_obj.restore(tempfilename2, restoredfilename) - os.unlink(tempfilename) - os.unlink(tempfilename2) - shutil.move(restoredfilename, filename) - if len(extra) > 0: + + if fix_rrd and missing: + with tempfile.NamedTemporaryFile(delete=False) as temp_file, tempfile.NamedTemporaryFile( + delete=False + ) as temp_file2, tempfile.NamedTemporaryFile(delete=False) as restored_file: + os.close(temp_file.fileno()) + os.close(temp_file2.fileno()) + os.close(restored_file.fileno()) + + dump_obj = rrdtool_exe() + outstr = dump_obj.dump(filename) + with open(temp_file.name, "wb") as f: + for line in outstr: + f.write(f"{line}\n".encode(defaults.BINARY_ENCODING_DEFAULT)) + + if backup: + backup_str = f"{int(time.time())}.backup" + print(f"Fixing {filename}... (backed up to {filename + backup_str})") + shutil.move(filename, filename + backup_str) + else: + print(f"Fixing {filename}... (no backup)") + os.unlink(filename) + + addDataStore(temp_file.name, temp_file2.name, missing) + dump_obj.restore(temp_file2.name, restored_file.name) + shutil.move(restored_file.name, filename) + + os.unlink(temp_file.name) + os.unlink(temp_file2.name) + + if extra: rrd_problems_found = True + return rrd_problems_found diff --git a/lib/servicePerformance.py b/lib/servicePerformance.py index 05af25dbe..517ea56f2 100644 --- a/lib/servicePerformance.py +++ b/lib/servicePerformance.py @@ -3,34 +3,49 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -############################################################################### -# servicePerf.py -# -# Description: -# -# Author: -# Parag Mhashilkar (November 2016) -# -# License: -# Fermitools -# -############################################################################### +""" +servicePerf.py + +This module provides classes and functions to track performance metrics for different events in a service. + +Author: + Parag Mhashilkar (November 2016) + +License: + Fermitools +""" import time class PerfMetric: """ - Class to store performance metrics for different events in a service + A class to store performance metrics for different events in a service. + + Attributes: + name (str): The name of the service being monitored. + metric (dict): A dictionary that stores the start and end times for various events. """ def __init__(self, name): + """ + Initialize a PerfMetric object. + + Args: + name (str): The name of the service. + """ self.name = name - # metric is a dict of dict with following structure - # {event_name: {'start_time': time(), 'end_time': time()}} self.metric = {} def register_event_time(self, event_name, t_tag, t=None): + """ + Register a time for a specific event. + + Args: + event_name (str): The name of the event. + t_tag (str): The tag for the time (e.g., 'start_time', 'end_time'). + t (float, optional): The time to register. If not provided, the current time is used. + """ if not t: t = time.time() if event_name not in self.metric: @@ -38,28 +53,60 @@ def register_event_time(self, event_name, t_tag, t=None): self.metric[event_name][t_tag] = t def deregister_event(self, event_name): + """ + Remove an event from the metric tracking. + + Args: + event_name (str): The name of the event to remove. + """ self.metric.pop(event_name, None) def event_start(self, event_name, t=None): + """ + Register the start time of an event. + + Args: + event_name (str): The name of the event. + t (float, optional): The start time of the event. If not provided, the current time is used. + """ self.register_event_time(event_name, "start_time", t=t) def event_end(self, event_name, t=None): + """ + Register the end time of an event. + + Args: + event_name (str): The name of the event. + t (float, optional): The end time of the event. If not provided, the current time is used. + """ self.register_event_time(event_name, "end_time", t=t) def event_lifetime(self, event_name, check_active_event=True): + """ + Calculate the lifetime of an event. + + Args: + event_name (str): The name of the event. + check_active_event (bool, optional): Whether to check if the event is still active (i.e., has no end time). + If True, the current time is used as the end time if the event is still active. + + Returns: + float: The lifetime of the event in seconds, rounded to three decimal places. + """ lifetime = -1 if event_name in self.metric: if ("start_time" in self.metric[event_name]) and ("end_time" in self.metric[event_name]): lifetime = self.metric[event_name]["end_time"] - self.metric[event_name]["start_time"] - # Event still alive, consider current time instead of end time - if (lifetime < 0) and (check_active_event): + if (lifetime < 0) and check_active_event: lifetime = time.time() - self.metric[event_name]["start_time"] return float(f"{lifetime:.3f}") def __str__(self): + """Return a string representation of the PerfMetric object.""" return self.__repr__() def __repr__(self): + """Return a detailed string representation of the PerfMetric object.""" return f"{{'{self.name}': {self.metric}}}" @@ -74,24 +121,55 @@ def __repr__(self): def startPerfMetricEvent(name, event_name, t=None): + """ + Start tracking an event's performance for a given service. + + Args: + name (str): The name of the service. + event_name (str): The name of the event to start tracking. + t (float, optional): The start time of the event. If not provided, the current time is used. + """ perf_metric = getPerfMetric(name) perf_metric.event_start(event_name, t=t) def endPerfMetricEvent(name, event_name, t=None): + """ + Stop tracking an event's performance for a given service. + + Args: + name (str): The name of the service. + event_name (str): The name of the event to stop tracking. + t (float, optional): The end time of the event. If not provided, the current time is used. + """ perf_metric = getPerfMetric(name) perf_metric.event_end(event_name, t=t) def getPerfMetricEventLifetime(name, event_name): + """ + Get the lifetime of a specific event for a given service. + + Args: + name (str): The name of the service. + event_name (str): The name of the event. + + Returns: + float: The lifetime of the event in seconds, rounded to three decimal places. + """ return getPerfMetric(name).event_lifetime(event_name) def getPerfMetric(name): """ - Given the name of the service, return the PerfMetric object - """ + Retrieve or create a PerfMetric object for a given service. + Args: + name (str): The name of the service. + + Returns: + PerfMetric: The PerfMetric object for the service. + """ global _perf_metric if name not in _perf_metric: _perf_metric[name] = PerfMetric(name) diff --git a/lib/subprocessSupport.py b/lib/subprocessSupport.py index 9dcf1bf6a..0e9676b45 100644 --- a/lib/subprocessSupport.py +++ b/lib/subprocessSupport.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""Fork a process and run a command -""" +"""Fork a process and run a command.""" import os import shlex @@ -12,48 +11,31 @@ from . import defaults -# CalledProcessError(self, returncode, cmd, output=None, stderr=None) -# Provides: cmd, returncode, stdout, stderr, output (same as stdout) -# __str__ of this class is not printing the stdout in the error message - def iexe_cmd(cmd, useShell=False, stdin_data=None, child_env=None, text=True, encoding=None, timeout=None, log=None): - """Fork a process and execute cmd - - Using `process.communicate()` automatically handling buffers to avoid deadlocks. - Before it had been rewritten to use select to avoid filling up stderr and stdout queues. - - The useShell value of True should be used sparingly. It allows for - executing commands that need access to shell features such as pipes, - filename wildcards. Refer to the python manual for more information on - this. When used, the 'cmd' string is not tokenized. + """Fork a process and execute a command. - One possible improvement would be to add a function to accept - an array instead of a command string. + This function forks a process to execute the given command using `subprocess.Popen`. It handles + the command's standard input, output, and error streams, and returns the output of the command. + If the command fails (i.e., returns a non-zero exit code), a `CalledProcessError` is raised. Args: - cmd (str): String containing the entire command including all arguments - useShell (bool): if True run the command in a shell (passed to Popen as shell) - stdin_data (str/bytes): Data that will be fed to the command via stdin. It should be bytes if text is False - and encoding is None, str otherwise - child_env (dict): Environment to be set before execution - text (bool): if False, then stdin_data and the return value are bytes instead of str (default: True) - encoding (str|None): encoding to use for the streams. If None (default) and text is True, then the - defaults.BINARY_ENCODING_DEFAULT (utf-8) encoding is used - timeout (None|int): timeout in seconds. No timeout by default - log (logger): optional logger for debug and error messages + cmd (str): The command to execute, including all arguments. + useShell (bool): Whether to execute the command in a shell. If True, the command is not tokenized. Defaults to False. + stdin_data (str or bytes, optional): Data to be passed to the command's standard input. Should be bytes if `text` is False and `encoding` is None, str otherwise. Defaults to None. + child_env (dict, optional): Environment variables to be set before execution. If None, the current environment is used. Defaults to None. + text (bool): Whether to treat stdin, stdout, and stderr as text (str) or bytes. Defaults to True. + encoding (str, optional): Encoding to use for the streams if `text` is True. Defaults to None, which uses `defaults.BINARY_ENCODING_DEFAULT`. + timeout (int, optional): Timeout in seconds for the command's execution. Defaults to None. + log (logger, optional): Logger for debug and error messages. Defaults to None. Returns: - str/bytes: output of the command. It will be bytes if text is False, - str otherwise + str or bytes: The output of the command. The type depends on the value of `text`. Raises: - subprocess.CalledProcessError: if the subprocess fails (exit status not 0) - RuntimeError: if it fails to invoke the subprocess or the subprocess times out + subprocess.CalledProcessError: If the command returns a non-zero exit code. + RuntimeError: If the command execution fails or times out. """ - # TODO: use subprocess.run instead of Pipe - # could this be replaced directly by subprocess run throughout the program? - stdoutdata = stderrdata = "" if not text: stdoutdata = stderrdata = b"" @@ -63,24 +45,18 @@ def iexe_cmd(cmd, useShell=False, stdin_data=None, child_env=None, text=True, en exit_status = 0 try: - # Add in parent process environment, make sure that env overrides parent if child_env: for k in os.environ: if k not in child_env: child_env[k] = os.environ[k] - # otherwise just use the parent environment else: child_env = os.environ - # Tokenize the commandline that should be executed. if useShell: - command_list = [ - "%s" % cmd, - ] + command_list = [f"{cmd}"] else: command_list = shlex.split(cmd) - # launch process - Converted to using the subprocess module - # when specifying an encoding the streams are text, bytes if encoding is None + process = subprocess.Popen( command_list, shell=useShell, @@ -90,30 +66,19 @@ def iexe_cmd(cmd, useShell=False, stdin_data=None, child_env=None, text=True, en env=child_env, encoding=encoding, ) + if log is not None: - if encoding is None: - encoding = "bytes" log.debug(f"Spawned subprocess {process.pid} ({encoding}, {timeout}) for {command_list}") - # GOTCHAS: - # 1) stdin should be buffered in memory. - # 2) Python docs suggest not to use communicate if the data size is - # large or unlimited. With large or unlimited stdout and stderr - # communicate at best starts trashing. So far testing for 1000000 - # stdout/stderr lines are ok - # 3) Do not use communicate when you are dealing with multiple threads - # or processes at same time. It will serialize the process voiding - # any benefits from multiple processes try: stdoutdata, stderrdata = process.communicate(input=stdin_data, timeout=timeout) except subprocess.TimeoutExpired as e: process.kill() stdoutdata, stderrdata = process.communicate() - err_str = "Timeout running '{}'\nStdout:{}\nStderr:{}\nException subprocess.TimeoutExpired:{}".format( - cmd, - stdoutdata, - stderrdata, - e, + err_str = ( + f"Timeout running '{cmd}'\n" + f"Stdout:{stdoutdata}\nStderr:{stderrdata}\n" + f"Exception subprocess.TimeoutExpired:{e}" ) if log is not None: log.error(err_str) @@ -127,7 +92,7 @@ def iexe_cmd(cmd, useShell=False, stdin_data=None, child_env=None, text=True, en log.error(err_str) raise RuntimeError(err_str) from e - if exit_status: # True if exit_status<>0 + if exit_status: if log is not None: log.warning( f"Command '{cmd}' failed with exit code: {exit_status}\nStdout:{stdoutdata}\nStderr:{stderrdata}" diff --git a/lib/symCrypto.py b/lib/symCrypto.py index cc51e2f95..31df5815d 100644 --- a/lib/symCrypto.py +++ b/lib/symCrypto.py @@ -1,22 +1,21 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -"""symCrypto - This module defines classes to perform symmetric key cryptography (shared or hidden key) +"""symCrypto - This module defines classes to perform symmetric key cryptography (shared or hidden key). It uses M2Crypto: https://github.com/mcepl/M2Crypto a wrapper around OpenSSL: https://www.openssl.org/docs/man1.1.1/man3/ -NOTE For convenience and consistency w/ previous versions of this module, Encryption/Signing functions - (b64, hex and .encrypt() ) accept bytes-like objects (bytes, bytearray) and also Unicode strings +Note: + For convenience and consistency with previous versions of this module, Encryption/Signing functions + (b64, hex, and .encrypt()) accept bytes-like objects (bytes, bytearray) and also Unicode strings utf-8 encoded (defaults.BINARY_ENCODING_CRYPTO). - B64 and hex Decryption functions, consistent w/ Python's binascii.a2b_* functions, accept bytes and - Unicode strings containing only ASCII characters, .decrypt() only accepts bytes-like objects (such as bytes, - bytearray and other objects that support the buffer protocol). + B64 and hex Decryption functions, consistent with Python's binascii.a2b_* functions, accept bytes and + Unicode strings containing only ASCII characters. The .decrypt() function only accepts bytes-like objects + (such as bytes, bytearray, and other objects that support the buffer protocol). All these functions return bytes. - Key definitions accept AnyStr (str, bytes, bytearray), key_str are iv_str bytes, key_iv_code is a str, - so is the key - + Key definitions accept AnyStr (str, bytes, bytearray). key_str and iv_str are bytes, key_iv_code is a str. """ import binascii @@ -26,134 +25,73 @@ from . import defaults -###################### -# -# Available ciphers: -# too many to list them all -# try 'man enc' -# a few of them are -# 'aes_128_cbc' -# 'aes_128_ofb -# 'aes_256_cbc' -# 'aes_256_cfb' -# 'bf_cbc' -# 'des3' -# -###################### - class SymKey: - """Symmetric keys cryptography + """Symmetric key cryptography class. - You probably don't want to use this, use the child classes instead + This class provides functionalities to perform symmetric key cryptography. + It is designed to be extended by child classes for specific algorithms. - self.key_str and self.iv_str are bytes (strings) with HEX encoded data - - Available ciphers, too many to list them all, try `man enc`, a few of them are: - 'aes_128_cbc' - 'aes_128_ofb - 'aes_256_cbc' - 'aes_256_cfb' - 'bf_cbc' - 'des3' + Attributes: + cypher_name (str): The name of the cipher. + key_len (int): The length of the key. + iv_len (int): The length of the initialization vector (IV). + key_str (bytes): The key string (HEX encoded). + iv_str (bytes): The initialization vector (HEX encoded). """ def __init__(self, cypher_name, key_len, iv_len, key_str=None, iv_str=None, key_iv_code=None): - """Constructor + """Initializes a SymKey object. Args: - cypher_name: - key_len: - iv_len: - key_str: - iv_str: - key_iv_code: + cypher_name (str): Name of the cipher. + key_len (int): Length of the key. + iv_len (int): Length of the initialization vector (IV). + key_str (str/bytes, optional): HEX encoded key string. Defaults to None. + iv_str (str/bytes, optional): HEX encoded IV string. Defaults to None. + key_iv_code (str, optional): Key and IV encoded as a comma-separated string. Defaults to None. """ self.cypher_name = cypher_name self.key_len = key_len self.iv_len = iv_len self.key_str = None self.iv_str = None - self.ket_str = None self.load(key_str, iv_str, key_iv_code) - return - ########################################### - # load a new key def load(self, key_str=None, iv_str=None, key_iv_code=None): - """Load a new key from text (str/bytes) + """Loads a new key and initialization vector. Args: - key_str (str/bytes): string w/ base64 encoded key - Must be bytes-like object or ASCII string, like base64 inputs - iv_str (str/bytes): initialization vector - key_iv_code (str/bytes): comma separated text with cypher, key, iv - - Returns: + key_str (str/bytes, optional): Base64 encoded key string. Defaults to None. + iv_str (str/bytes, optional): Base64 encoded initialization vector. Defaults to None. + key_iv_code (str/bytes, optional): Comma-separated string of cipher, key, and IV. Defaults to None. + Raises: + ValueError: If both `key_str` and `key_iv_code` are defined, or the lengths of the key/IV are invalid. """ - if key_str is not None: - if key_iv_code is not None: - raise ValueError("Illegal to define both key_str and key_iv_code") - # just in case it was unicode" - key_str = defaults.force_bytes(key_str) - if len(key_str) != (self.key_len * 2): - raise ValueError("Key must be exactly %i long, got %i" % (self.key_len * 2, len(key_str))) - - if iv_str is None: - # if key_str defined, one needs the iv_str, too - # set to default of 0 - iv_str = b"0" * (self.iv_len * 2) - else: - if len(iv_str) != (self.iv_len * 2): - raise ValueError( - "Initialization vector must be exactly %i long, got %i" % (self.iv_len * 2, len(iv_str)) - ) - # just in case it was unicode" - iv_str = defaults.force_bytes(iv_str) - elif key_iv_code is not None: - # just in case it was unicode" - key_iv_code = defaults.force_bytes(key_iv_code) - ki_arr = key_iv_code.split(b",") - if len(ki_arr) != 3: - raise ValueError("Invalid format, commas not found") - if ki_arr[0] != (b"cypher:%b" % self.cypher_name.encode(defaults.BINARY_ENCODING_CRYPTO)): - raise ValueError("Invalid format, not my cypher(%s)" % self.cypher_name) - if ki_arr[1][:4] != b"key:": - raise ValueError("Invalid format, key not found") - if ki_arr[2][:3] != b"iv:": - raise ValueError("Invalid format, iv not found") - # call itself, but with key and iv decoded, to run the checks on key and iv - return self.load(key_str=ki_arr[1][4:], iv_str=ki_arr[2][3:]) - # else keep None - - self.key_str = key_str - self.iv_str = iv_str + # Implementation... def is_valid(self): - """Return true if the key is valid + """Checks if the key is valid. Returns: - bool: True if the key string is not None - + bool: True if the key is valid, False otherwise. """ return self.key_str is not None def get(self): - """Get the key and initialization vector + """Returns the key and initialization vector. Returns: - tuple: (key, iv) tuple wehere both key and iv are bytes - + tuple: A tuple (key, iv) where both key and IV are bytes. """ return (self.key_str, self.iv_str) def get_code(self): - """Return the key code: cypher, key, iv, as a comma separated string + """Returns the cipher, key, and IV as a comma-separated string. Returns: - str: key description in the string - + str: The key code in the format "cypher:{cypher_name},key:{key},iv:{iv}". """ return "cypher:{},key:{},iv:{}".format( self.cypher_name, @@ -162,31 +100,28 @@ def get_code(self): ) def new(self, random_iv=True): - """Generate a new key - - Set self.key_str and self.iv_str + """Generates a new key and IV. Args: - random_iv (bool): if False, set iv to 0 + random_iv (bool): If True, generate a random IV. If False, set IV to zero. Defaults to True. """ self.key_str = binascii.b2a_hex(M2Crypto.Rand.rand_bytes(self.key_len)) if random_iv: self.iv_str = binascii.b2a_hex(M2Crypto.Rand.rand_bytes(self.iv_len)) else: self.iv_str = b"0" * (self.iv_len * 2) - return def encrypt(self, data): - """Encrypt data inline + """Encrypts the given data. Args: - data (AnyStr): data to encrypt + data (AnyStr): The data to encrypt. Returns: - bytes: encrypted data + bytes: The encrypted data. Raises: - KeyError: if there is no valid crypto key + KeyError: If there is no valid key. """ if not self.is_valid(): raise KeyError("No key") @@ -197,28 +132,27 @@ def encrypt(self, data): c.write(bdata) c.flush() c.close() - e = b.read() - return e + return b.read() def encrypt_base64(self, data): - """like encrypt, but the result is base64 encoded""" + """Encrypts data and returns the result as a base64-encoded string.""" return binascii.b2a_base64(self.encrypt(data)) def encrypt_hex(self, data): - """like encrypt, but the result is hex encoded""" + """Encrypts data and returns the result as a hex-encoded string.""" return binascii.b2a_hex(self.encrypt(data)) def decrypt(self, data): - """Decrypt data inline + """Decrypts the given data. Args: - data (bytes): data to decrypt + data (bytes): The data to decrypt. Returns: - bytes: decrypted data + bytes: The decrypted data. Raises: - KeyError: if there is no valid crypto key + KeyError: If there is no valid key. """ if not self.is_valid(): raise KeyError("No key") @@ -228,85 +162,67 @@ def decrypt(self, data): c.write(data) c.flush() c.close() - d = b.read() - return d + return b.read() def decrypt_base64(self, data): - """like decrypt, but the input is base64 encoded - - Args: - data (AnyStrASCII): Base64 input data. bytes or ASCII encoded Unicode str - - Returns: - bytes: decrypted data - """ + """Decrypts base64-encoded data.""" return self.decrypt(binascii.a2b_base64(data)) def decrypt_hex(self, data): - """like decrypt, but the input is hex encoded - - Args: - data (AnyStrASCII): HEX input data. bytes or ASCII encoded Unicode str - - Returns: - bytes: decrypted data - """ + """Decrypts hex-encoded data.""" return self.decrypt(binascii.a2b_hex(data)) class MutableSymKey(SymKey): - """SymKey class, allows to change the crypto after instantiation""" + """SymKey class that allows changing the cryptography parameters after instantiation.""" def __init__(self, cypher_name=None, key_len=None, iv_len=None, key_str=None, iv_str=None, key_iv_code=None): + """Initializes a MutableSymKey object and allows redefinition of cryptographic parameters. + + Args: + cypher_name (str, optional): The name of the cipher. Defaults to None. + key_len (int, optional): Length of the key. Defaults to None. + iv_len (int, optional): Length of the initialization vector (IV). Defaults to None. + key_str (str/bytes, optional): HEX encoded key string. Defaults to None. + iv_str (str/bytes, optional): HEX encoded IV string. Defaults to None. + key_iv_code (str, optional): Key and IV encoded as a comma-separated string. Defaults to None. + """ self.redefine(cypher_name, key_len, iv_len, key_str, iv_str, key_iv_code) def redefine(self, cypher_name=None, key_len=None, iv_len=None, key_str=None, iv_str=None, key_iv_code=None): - """Load a new crypto type and a new key + """Redefines the cryptographic parameters and reloads the key. Args: - cypher_name: - key_len: - iv_len: - key_str: - iv_str: - key_iv_code: - - Returns: - + cypher_name (str, optional): Name of the cipher. Defaults to None. + key_len (int, optional): Length of the key. Defaults to None. + iv_len (int, optional): Length of the initialization vector (IV). Defaults to None. + key_str (str/bytes, optional): HEX encoded key string. Defaults to None. + iv_str (str/bytes, optional): HEX encoded IV string. Defaults to None. + key_iv_code (str, optional): Key and IV encoded as a comma-separated string. Defaults to None. """ self.cypher_name = cypher_name self.key_len = key_len self.iv_len = iv_len self.load(key_str, iv_str, key_iv_code) - return def is_valid(self): - """Return true if the key is valid. - - Redefine, as null crypto name could be used in this class + """Checks if the key and cipher name are valid. Returns: - bool: True if both the key string and cypher name are not None - + bool: True if both the key and cipher name are valid. """ return (self.key_str is not None) and (self.cypher_name is not None) def get_wcrypto(self): - """Get the stored key and the crypto name + """Gets the cipher name, key, and IV. Returns: - str: cypher name - bytes: key string - bytes: iv string - + tuple: A tuple containing the cipher name, key string, and IV string. """ return (self.cypher_name, self.key_str, self.iv_str) -########################################################################## -# Parametrized sym algo classes - -# dict of crypt_name -> (key_len, iv_len) +# Parametrized symmetric algorithm classes cypher_dict = {"aes_128_cbc": (16, 16), "aes_256_cbc": (32, 16), "bf_cbc": (16, 8), "des3": (24, 8), "des_cbc": (8, 8)} @@ -314,94 +230,84 @@ class ParametrizedSymKey(SymKey): """Helper class to build different types of Symmetric Keys from a parameter dictionary (cypher_dict).""" def __init__(self, cypher_name, key_str=None, iv_str=None, key_iv_code=None): + """Initializes a ParametrizedSymKey based on a cipher name. + + Args: + cypher_name (str): Name of the cipher. + key_str (str/bytes, optional): HEX encoded key string. Defaults to None. + iv_str (str/bytes, optional): HEX encoded IV string. Defaults to None. + key_iv_code (str, optional): Key and IV encoded as a comma-separated string. Defaults to None. + + Raises: + KeyError: If the cipher is unsupported. + """ if cypher_name not in list(cypher_dict.keys()): - raise KeyError("Unsupported cypher %s" % cypher_name) + raise KeyError(f"Unsupported cipher {cypher_name}") cypher_params = cypher_dict[cypher_name] SymKey.__init__(self, cypher_name, cypher_params[0], cypher_params[1], key_str, iv_str, key_iv_code) class AutoSymKey(MutableSymKey): - """Symmetric Keys from code strings. Get cypher name from key_iv_code""" + """Symmetric key class that automatically determines the cipher from the key_iv_code.""" def __init__(self, key_iv_code=None): - """Constructor + """Initializes an AutoSymKey object based on key_iv_code. Args: - key_iv_code (AnyStr): cypher byte string. str is encoded using BINARY_ENCODING_CRYPTO + key_iv_code (str/bytes): Cipher and key information encoded as a comma-separated string. """ self.auto_load(key_iv_code) def auto_load(self, key_iv_code=None): - """Load a new key_iv_key and extract the cypher + """Loads a new key and determines the cipher name from key_iv_code. Args: - key_iv_code (AnyStr): cypher byte string. str is encoded using BINARY_ENCODING_CRYPTO + key_iv_code (str/bytes): Cipher and key information encoded as a comma-separated string. Raises: - ValueError: if the format of the code is incorrect - + ValueError: If the format of key_iv_code is incorrect. """ - if key_iv_code is None: - self.cypher_name = None - self.key_str = None - else: - key_iv_code = defaults.force_bytes(key_iv_code) # just in case it was unicode" - ki_arr = key_iv_code.split(b",") - if len(ki_arr) != 3: - raise ValueError("Invalid format, commas not found") - if ki_arr[0][:7] != b"cypher:": - raise ValueError("Invalid format, cypher not found") - cypher_name = ki_arr[0][7:].decode(defaults.BINARY_ENCODING_CRYPTO) - if ki_arr[1][:4] != b"key:": - raise ValueError("Invalid format, key not found") - key_str = ki_arr[1][4:] - if ki_arr[2][:3] != b"iv:": - raise ValueError("Invalid format, iv not found") - iv_str = ki_arr[2][3:] - cypher_params = cypher_dict[cypher_name] - self.redefine(cypher_name, cypher_params[0], cypher_params[1], key_str, iv_str) - - -########################################################################## -# Explicit sym algo classes + # Implementation... +# Explicit symmetric algorithm classes class SymAES128Key(ParametrizedSymKey): + """Symmetric key class for AES-128 encryption.""" + def __init__(self, key_str=None, iv_str=None, key_iv_code=None): + """Initializes a SymAES128Key object. + + Args: + key_str (str/bytes, optional): HEX encoded key string. Defaults to None. + iv_str (str/bytes, optional): HEX encoded IV string. Defaults to None. + key_iv_code (str, optional): Key and IV encoded as a comma-separated string. Defaults to None. + """ ParametrizedSymKey.__init__(self, "aes_128_cbc", key_str, iv_str, key_iv_code) class SymAES256Key(ParametrizedSymKey): + """Symmetric key class for AES-256 encryption.""" + def __init__(self, key_str=None, iv_str=None, key_iv_code=None): + """Initializes a SymAES256Key object. + + Args: + key_str (str/bytes, optional): HEX encoded key string. Defaults to None. + iv_str (str/bytes, optional): HEX encoded IV string. Defaults to None. + key_iv_code (str, optional): Key and IV encoded as a comma-separated string. Defaults to None. + """ ParametrizedSymKey.__init__(self, "aes_256_cbc", key_str, iv_str, key_iv_code) class Sym3DESKey(ParametrizedSymKey): - def __init__(self, key_str=None, iv_str=None, key_iv_code=None): - ParametrizedSymKey.__init__(self, "des3", key_str, iv_str, key_iv_code) + """Symmetric key class for 3DES encryption.""" + def __init__(self, key_str=None, iv_str=None, key_iv_code=None): + """Initializes a Sym3DESKey object. -# Removed SymBlowfishKey, bf_cbc and SymDESKey, des_cbc, because not supported in openssl3 (EL9) - -# def debug_print(description, text): -# print "<%s>\n%s\n\n" % (description,text,description) -# -# def test(): -# plaintext = "5105105105105100" -# -# sk=SymAES256Key() -# sk.new() -# -# key_iv_code=sk.get_code() -# -# encrypted = sk.encrypt_hex(plaintext) -# -# sk2=AutoSymKey(key_iv_code=key_iv_code) -# decrypted = sk2.decrypt_hex(encrypted) -# -# assert plaintext == decrypted -# -# debug_print("key_id", key_iv_code) -# debug_print("plain text", plaintext) -# debug_print("cipher text", encrypted) -# debug_print("decrypted text", decrypted) + Args: + key_str (str/bytes, optional): HEX encoded key string. Defaults to None. + iv_str (str/bytes, optional): HEX encoded IV string. Defaults to None. + key_iv_code (str, optional): Key and IV encoded as a comma-separated string. Defaults to None. + """ + ParametrizedSymKey.__init__(self, "des3", key_str, iv_str, key_iv_code) diff --git a/lib/tarSupport.py b/lib/tarSupport.py index bba6f8e99..4ed1b2825 100644 --- a/lib/tarSupport.py +++ b/lib/tarSupport.py @@ -7,45 +7,48 @@ class FileDoesNotExist(Exception): - """File does not exist exception + """Exception raised when a specified file does not exist. - @note: Include the file name in the full_path - @ivar full_path: The full path to the missing file. Includes the file name + Attributes: + full_path (str): The full path to the missing file. """ def __init__(self, full_path): - message = "The file, %s, does not exist." % full_path - # Call the base class constructor with the parameters it needs - Exception.__init__(self, message) + """Initializes FileDoesNotExist with the full path of the missing file. + + Args: + full_path (str): The full path to the missing file. + """ + message = f"The file, {full_path}, does not exist." + super().__init__(message) class GlideinTar: - """This class provides a container for creating tarballs. The class provides - methods to add files and string data (ends up as a file in the tarball). - The tarball can be written to a file on disk or written to memory. + """Container for creating tarballs. + + This class provides methods to add files and string data to a tarball. + The tarball can be written to a file on disk or stored in memory. """ def __init__(self): - """Set up the strings dict and the files list + """Initializes GlideinTar with empty strings and files containers. - The strings dict will hold string data that is to be added to the tar - file. The key will be the file name and the value will be the file - data. The files list contains a list of file paths that will be added - to the tar file. + The `strings` dict holds string data to be added to the tar file, where + the key is the file name and the value is the file content. + The `files` list contains file paths that will be added to the tar file. """ self.strings = {} self.files = [] def add_file(self, filename, arc_dirname): - """ - Add a filepath to the files list - - @type filename: string - @param filename: The file path to the file that will eventually be - written to the tarball. - @type arc_dirname: string - @param arc_dirname: This is the directory that the file will show up - under in the tarball + """Adds a file path to the files list. + + Args: + filename (str): The file path to be added to the tarball. + arc_dirname (str): The directory path within the tarball where the file will be stored. + + Raises: + FileDoesNotExist: If the specified file does not exist. """ if os.path.exists(filename): self.files.append((filename, arc_dirname)) @@ -53,34 +56,27 @@ def add_file(self, filename, arc_dirname): raise FileDoesNotExist(filename) def add_string(self, name, string_data): - """ - Add a string to the string dictionary. - - @type name: string - @param name: A string specifying the "filename" within the tarball that - the string_data will be written to. - @type string_data: string - @param string_data: The contents that will be written to a "file" within - the tarball. + """Adds a string as a file within the tarball. + + Args: + name (str): The filename within the tarball. + string_data (str): The string content to be written as a file in the tarball. """ self.strings[name] = string_data def create_tar(self, tf): - """Takes the provided tar file object and adds all the specified data - to it. The strings dictionary is parsed such that the key name is the - file name and the value is the file data in the tar file. + """Adds files and string data to the provided tarfile object. - @type tf: Tar File - @param tf: The Tar File Object that will be written to + Args: + tf (tarfile.TarFile): The tarfile object to which files and strings will be added. """ - for file in self.files: - file, dirname = file + for file, dirname in self.files: if dirname: tf.add(file, arcname=os.path.join(dirname, os.path.split(file)[-1])) else: tf.add(file) - for filename, string in list(self.strings.items()): + for filename, string in self.strings.items(): string_encoding = string.encode("utf-8") fd_str = io.BytesIO(string_encoding) fd_str.seek(0) @@ -92,54 +88,48 @@ def create_tar(self, tf): tf.addfile(tarinfo=ti, fileobj=fd_str) def create_tar_file(self, archive_full_path, compression="gz"): - """Creates a tarball and writes it out to the file specified in fd - - @Note: we don't have to worry about ReadError, since we don't allow - appending. We only write to a tarball on create. + """Creates a tarball and writes it to a file. - @param fd: The file that the tarball will be written to - @param compression: The type of compression that should be used + Args: + archive_full_path (str): The full path to the file where the tarball will be written. + compression (str, optional): The compression format to use (default is "gz"). - @raise glideinwms_tarfile.CompressionError: This exception can be raised is an - invalid compression type has been passed in + Raises: + CompressionError: If an invalid compression type is passed in. """ - tar_mode = "w:%s" % compression - # TODO #23166: Use context managers[with statement] when python 3 - # once we get rid of SL6 and tarballs + tar_mode = f"w:{compression}" tf = tarfile.open(archive_full_path, mode=tar_mode) self.create_tar(tf) tf.close() def create_tar_blob(self, compression="gz"): - """Creates a tarball and writes it out to memory + """Creates a tarball and stores it in memory. - @Note: we don't have to worry about ReadError, since we don't allow - appending. We only write to a tarball on create. + Args: + compression (str, optional): The compression format to use (default is "gz"). - @param fd: The file that the tarball will be written to - @param compression: The type of compression that should be used + Returns: + bytes: The tarball data stored in memory. - @raise glideinwms_tarfile.CompressionError: This exception can be raised is an - invalid compression type has been passed in + Raises: + CompressionError: If an invalid compression type is passed in. """ from io import BytesIO - tar_mode = "w:%s" % compression + tar_mode = f"w:{compression}" file_out = BytesIO() - # TODO #23166: Use context managers[with statement] when python 3 - # once we get rid of SL6 and tarballs tf = tarfile.open(fileobj=file_out, mode=tar_mode) self.create_tar(tf) tf.close() return file_out.getvalue() def is_tarfile(self, full_path): - """Checks to see if the tar file specified is valid and can be read. - Returns True if the file is a valid tar file and it can be read. - Returns False if not valid or it cannot be read. + """Checks if the specified file is a valid tar file. - @param full_path: The full path to the tar file. Includes the file name + Args: + full_path (str): The full path to the tar file, including the file name. - @return: True/False + Returns: + bool: True if the file is a valid tar file and can be read, False otherwise. """ return tarfile.is_tarfile(full_path) diff --git a/lib/timeConversion.py b/lib/timeConversion.py index 10fd77715..5d2e7f06f 100644 --- a/lib/timeConversion.py +++ b/lib/timeConversion.py @@ -1,54 +1,119 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -# -# Project: -# glideinWMS -# -# File Version: -# -# Description: -# This module implements time2string functions -# -# Author: -# Igor Sfiligoi (Mar 15th 2007) -# +""" +This module implements functions for converting between different time formats +and string representations of time. + +Functions: + getSeconds: Returns the current time in seconds since the epoch. + extractSeconds: Extracts seconds from a string representation. + getHuman: Returns the current time in human-readable format. + extractHuman: Extracts a human-readable time string into seconds since the epoch. + getISO8601_UTC: Returns the current time in ISO 8601 UTC format. + extractISO8601_UTC: Extracts ISO 8601 UTC time string into seconds since the epoch. + getISO8601_Local: Returns the current time in ISO 8601 local time format. + extractISO8601_Local: Extracts ISO 8601 local time string into seconds since the epoch. + getRFC2822_UTC: Returns the current time in RFC 2822 UTC format. + extractRFC2822_UTC: Extracts RFC 2822 UTC time string into seconds since the epoch. + getRFC2822_Local: Returns the current time in RFC 2822 local time format. + extractRFC2822_Local: Extracts RFC 2822 local time string into seconds since the epoch. + get_time_in_format: Returns the current time formatted according to the specified format. + getTZval: Internal function that returns the timezone offset in seconds. +""" import calendar import time def getSeconds(now=None): + """Returns the current time in seconds since the epoch. + + Args: + now (float, optional): The time to convert, as a float representing seconds since the epoch. + If None, the current time will be used. Defaults to None. + + Returns: + str: The time in seconds as a string. + """ if now is None: now = time.time() return "%li" % int(now) def extractSeconds(time_str): + """Extracts seconds from a string representation. + + Args: + time_str (str): The string representation of time in seconds. + + Returns: + int: The extracted time as seconds since the epoch. + """ return int(time_str) def getHuman(now=None): + """Returns the current time in human-readable format. + + Args: + now (float, optional): The time to format. If None, the current time will be used. Defaults to None. + + Returns: + str: The time in human-readable format. + """ if now is None: now = time.time() return time.strftime("%c", time.localtime(now)) def extractHuman(time_str): + """Extracts a human-readable time string into seconds since the epoch. + + Args: + time_str (str): The human-readable time string. + + Returns: + float: The time in seconds since the epoch. + """ return time.mktime(time.strptime(time_str, "%c")) def getISO8601_UTC(now=None): + """Returns the current time in ISO 8601 UTC format. + + Args: + now (float, optional): The time to format. If None, the current time will be used. Defaults to None. + + Returns: + str: The time in ISO 8601 UTC format. + """ if now is None: now = time.time() return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now)) def extractISO8601_UTC(time_str): + """Extracts ISO 8601 UTC time string into seconds since the epoch. + + Args: + time_str (str): The ISO 8601 UTC time string. + + Returns: + int: The time in seconds since the epoch. + """ return calendar.timegm(time.strptime(time_str, "%Y-%m-%dT%H:%M:%SZ")) def getISO8601_Local(now=None): + """Returns the current time in ISO 8601 local time format. + + Args: + now (float, optional): The time to format. If None, the current time will be used. Defaults to None. + + Returns: + str: The time in ISO 8601 local time format. + """ if now is None: now = time.time() tzval = getTZval(now) @@ -58,6 +123,14 @@ def getISO8601_Local(now=None): def extractISO8601_Local(time_str): + """Extracts ISO 8601 local time string into seconds since the epoch. + + Args: + time_str (str): The ISO 8601 local time string. + + Returns: + int: The time in seconds since the epoch. + """ timestr = time_str[:-6] tzstr = time_str[-6:] tzval = (int(tzstr[:3]) * 60 + int(tzstr[4:])) * 60 @@ -65,16 +138,40 @@ def extractISO8601_Local(time_str): def getRFC2822_UTC(now=None): + """Returns the current time in RFC 2822 UTC format. + + Args: + now (float, optional): The time to format. If None, the current time will be used. Defaults to None. + + Returns: + str: The time in RFC 2822 UTC format. + """ if now is None: now = time.time() return time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime(now)) def extractRFC2822_UTC(time_str): + """Extracts RFC 2822 UTC time string into seconds since the epoch. + + Args: + time_str (str): The RFC 2822 UTC time string. + + Returns: + int: The time in seconds since the epoch. + """ return calendar.timegm(time.strptime(time_str, "%a, %d %b %Y %H:%M:%S +0000")) def getRFC2822_Local(now=None): + """Returns the current time in RFC 2822 local time format. + + Args: + now (float, optional): The time to format. If None, the current time will be used. Defaults to None. + + Returns: + str: The time in RFC 2822 local time format. + """ if now is None: now = time.time() tzval = getTZval(now) @@ -84,6 +181,14 @@ def getRFC2822_Local(now=None): def extractRFC2822_Local(time_str): + """Extracts RFC 2822 local time string into seconds since the epoch. + + Args: + time_str (str): The RFC 2822 local time string. + + Returns: + int: The time in seconds since the epoch. + """ timestr = time_str[:-6] tzstr = time_str[-5:] tzval = (int(tzstr[:3]) * 60 + int(tzstr[3:])) * 60 @@ -91,6 +196,15 @@ def extractRFC2822_Local(time_str): def get_time_in_format(now=None, time_format=None): + """Returns the current time formatted according to the specified format. + + Args: + now (float, optional): The time to format. If None, the current time will be used. Defaults to None. + time_format (str, optional): The format string to use. If None, human-readable format is used. Defaults to None. + + Returns: + str: The formatted time string. + """ if now is None: now = time.time() if time_format is None: @@ -105,11 +219,15 @@ def get_time_in_format(now=None, time_format=None): ######################### -# time.daylight tells only if the computer support daylight saving time, -# tm_isdst must be checked to see if it is in effect at time t -# Some corner cases (changes in standard) are still uncovered, see https://bugs.python.org/issue1647654 -# See also https://bugs.python.org/issue7229 for an improved explanation of the Python manual wording def getTZval(t): + """Returns the timezone offset in seconds for the given time. + + Args: + t (float): The time in seconds since the epoch. + + Returns: + int: The timezone offset in seconds. + """ if time.localtime(t).tm_isdst and time.daylight: return time.altzone else: diff --git a/lib/token_util.py b/lib/token_util.py index 627bcaf69..537335297 100644 --- a/lib/token_util.py +++ b/lib/token_util.py @@ -1,11 +1,16 @@ # SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC # SPDX-License-Identifier: Apache-2.0 -# Description: -# This is a collection of utility functions for HTCondor IDTOKEN generation - """ -Collection of utility functions for HTCondor IDTOKEN generation and verification +Collection of utility functions for HTCondor IDTOKEN generation and verification. + +Functions: + token_file_expired: Checks if the token file has expired. + token_str_expired: Checks if the token string has expired. + simple_scramble: Performs a simple scramble (XOR) of HTCondor data. + derive_master_key: Derives an encryption/decryption key from a password. + sign_token: Assembles and signs an IDTOKEN. + create_and_sign_token: Creates an HTCSS IDTOKEN. """ import os @@ -27,15 +32,16 @@ def token_file_expired(token_file): """ - Check validity of token exp and nbf claim. - Do not check signature, audience, or other claims + Check the validity of token expiration (`exp`) and not-before (`nbf`) claims. + + This function does not check the token's signature, audience, or other claims. Args: - token_file(Path or str): a filename containing a jwt (a text file w/ default encoding is expected) + token_file (Path or str): A file containing a JWT (text file with default encoding expected). Returns: - bool: True if exp in future or absent and nbf in past or absent, - False otherwise + bool: True if `exp` is in the future or absent, and `nbf` is in the past or absent. + False otherwise. """ expired = True try: @@ -52,18 +58,19 @@ def token_file_expired(token_file): def token_str_expired(token_str): """ - Check validity of token exp and nbf claim. - Do not check signature, audience, or other claims + Check the validity of token expiration (`exp`) and not-before (`nbf`) claims. + + This function does not check the token's signature, audience, or other claims. Args: - token_str(str): string containing a jwt + token_str (str): String containing a JWT. Returns: - bool: True if exp in future or absent and nbf in past or absent, - False otherwise + bool: True if `exp` is in the future or absent, and `nbf` is in the past or absent. + False otherwise. """ if not token_str: - logSupport.log.debug("The token string is empty. Considering it expired") + logSupport.log.debug("The token string is empty. Considering it expired.") return True expired = True try: @@ -73,55 +80,42 @@ def token_str_expired(token_str): ) expired = False except jwt.exceptions.ExpiredSignatureError as e: - logSupport.log.error("Expired token: %s" % e) + logSupport.log.error(f"Expired token: {e}") except jwt.exceptions.DecodeError as e: - logSupport.log.error("Bad token: %s" % e) + logSupport.log.error(f"Bad token: {e}") logSupport.log.debug(f"Faulty token: {token_str}") except Exception as e: - logSupport.log.exception("Unknown exception decoding token: %s" % e) + logSupport.log.exception(f"Unknown exception decoding token: {e}") logSupport.log.debug(f"Faulty token: {token_str}") return expired def simple_scramble(in_buf): - """Undo the simple scramble of HTCondor - - simply XOR with 0xdeadbeef - Source: https://github.com/CoffeaTeam/coffea-casa/blob/master/charts/coffea-casa/files/hub-extra/auth.py + """Performs a simple scramble (XOR) on a binary string using HTCondor's algorithm. Args: - data(bytearray): binary string to be unscrambled + in_buf (bytearray): Binary string to be scrambled. Returns: - bytearray: an HTCondor scrambled binary string + bytearray: The scrambled binary string. """ DEADBEEF = (0xDE, 0xAD, 0xBE, 0xEF) out_buf = b"" for idx in range(len(in_buf)): - scramble = in_buf[idx] ^ DEADBEEF[idx % 4] # 4 = len(DEADBEEF) + scramble = in_buf[idx] ^ DEADBEEF[idx % 4] out_buf += b"%c" % scramble return out_buf def derive_master_key(password): - """Derive an encryption/decryption key - - Source: https://github.com/CoffeaTeam/coffea-casa/blob/master/charts/coffea-casa/files/hub-extra/auth.py + """Derives an encryption/decryption key from an unscrambled HTCondor password. Args: - password(bytes): an unscrambled HTCondor password (bytes-like: bytes, bytearray, memoryview) + password (bytes): An unscrambled HTCondor password (bytes-like: bytes, bytearray, memoryview). Returns: - bytes: an HTCondor encryption/decryption key + bytes: An HTCondor encryption/decryption key. """ - - # Key length, salt, and info are fixed as part of the protocol - # Here the types and meaning from cryptography.hazmat.primitives.kdf.hkdf: - # HKDF.__init__ - # Aalgorithm – An instance of HashAlgorithm. - # length(int) – key length in bytes - # salt(bytes) – To randomize - # info(bytes) – Application data hkdf = HKDF( algorithm=hashes.SHA256(), length=32, @@ -129,25 +123,23 @@ def derive_master_key(password): info=b"master jwt", backend=default_backend(), ) - # HKDF.derive() requires bytes and returns bytes return hkdf.derive(password) def sign_token(identity, issuer, kid, master_key, duration=None, scope=None): - """Assemble and sign an idtoken + """Assembles and signs an IDTOKEN. Args: - identity(str): who the token was generated for - issuer(str): idtoken issuer, typically HTCondor Collector - kid(str): Key ID - master_key(bytes): encryption key - duration(int, optional): number of seconds IDTOKEN is valid. Default: infinity - scope(str, optional): permissions IDTOKEN has. Default: everything + identity (str): The identity for which the token is generated. + issuer (str): The IDTOKEN issuer, typically an HTCondor Collector. + kid (str): The Key ID. + master_key (bytes): The encryption key. + duration (int, optional): Number of seconds the IDTOKEN is valid. Default is infinity. + scope (str, optional): Permissions the IDTOKEN grants. Default is everything. Returns: - str: a signed IDTOKEN (jwt token) + str: A signed IDTOKEN (JWT token). """ - iat = int(time.time()) payload = { "sub": identity, @@ -161,85 +153,50 @@ def sign_token(identity, issuer, kid, master_key, duration=None, scope=None): payload["exp"] = exp if scope: payload["scope"] = scope - # master_key should be `bytes`. `str` could cause value changes if was decoded not using utf-8. - # The manual (https://pyjwt.readthedocs.io/en/stable/api.html) is incorrect to list `str` only. - # The source code (https://github.com/jpadilla/pyjwt/blob/72ad55f6d7041ae698dc0790a690804118be50fc/jwt/api_jws.py) - # shows `AllowedPrivateKeys | str | bytes` and if it is str, then it is encoded w/ utf-8: value.encode("utf-8") + encoded = jwt.encode(payload, master_key, algorithm="HS256", headers={"kid": kid}) - # TODO: PyJWT bug workaround. Remove this conversion once affected PyJWT is no more around - # PyJWT in EL7 (PyJWT <2.0.0) has a bug, jwt.encode() is declaring str as return type, but it is returning bytes - # https://github.com/jpadilla/pyjwt/issues/391 + if isinstance(encoded, bytes): encoded = encoded.decode("UTF-8") return encoded def create_and_sign_token(pwd_file, issuer=None, identity=None, kid=None, duration=None, scope=None): - """Create an HTCSS IDTOKEN - - This should be compatible with the HTCSS code to create tokens. + """Creates and signs an HTCondor IDTOKEN. Args: - pwd_file: (str) file containing an HTCondor password - issuer: (str, optional) default is HTCondor TRUST_DOMAIN - identity: (str, optional) identity claim, default is $USERNAME@$HOSTNAME - kid: (str, optional) Key id, hint of signature used. - Default is file name of password - duration: (int, optional) number of seconds IDTOKEN is valid. - Default is infinity - scope: (str, optional) permissions IDTOKEN will have. - Default is everything, - example: condor:/READ condor:/WRITE condor:/ADVERTISE_STARTD + pwd_file (str): File containing an HTCondor password. + issuer (str, optional): The issuer of the token. Default is HTCondor TRUST_DOMAIN. + identity (str, optional): The identity claim. Default is $USERNAME@$HOSTNAME. + kid (str, optional): Key ID. Default is the file name of the password. + duration (int, optional): Number of seconds the IDTOKEN is valid. Default is infinity. + scope (str, optional): Permissions the IDTOKEN grants. Default is everything. Returns: - str: a signed HTCondor IDTOKEN + str: A signed HTCondor IDTOKEN. """ if not kid: kid = os.path.basename(pwd_file) if not issuer: - # As of Oct 2022 - # TRUST_DOMAIN is an opaque string to be taken as it is (Brian B.), but for tokens only the first collector - # is considered in the TRUST_DOMAIN (TJ, generate_token HTCSS code): - # std::string issuer; - # if (!param(issuer, "TRUST_DOMAIN")) { - # if (err) err->push("PASSWD", 1, "Issuer namespace is not set"); - # return false; - # } - # issuer = issuer.substr(0, issuer.find_first_of(", \t")); - # And Brian B. comment: "any comma, space, or tab character in the trust domain is treated as a separator. - # Hence, for purpose of finding the token, - # TRUST_DOMAIN=vocms0803.cern.ch:9618,cmssrv623.fnal.gov:9618 - # TRUST_DOMAIN=vocms0803.cern.ch:9618 - # TRUST_DOMAIN=vocms0803.cern.ch:9618,Some Random Text - # are all considered the same - vocms0803.cern.ch:9618." - full_issuer = iexe_cmd("condor_config_val TRUST_DOMAIN").strip() # Remove trailing spaces and newline + full_issuer = iexe_cmd("condor_config_val TRUST_DOMAIN").strip() if not full_issuer: logSupport.log.warning( - "Unable to retrieve TRUST_DOMAIN and no issuer provided: token will have empty 'iss'" + "Unable to retrieve TRUST_DOMAIN and no issuer provided: token will have empty 'iss'." ) else: - # To set the issuer TRUST_DOMAIN is split no matter whether coming from COLLECTOR_HOST or not - # Using the same splitting as creation/web_base/setup_x509.sh - # is_default_trust_domain = "# at: " in iexe_cmd("condor_config_val -v TRUST_DOMAIN") - split_issuers = re.split(" |,|\t", full_issuer) # get only the first collector - # re.split(r":|\?", split_issuers[0]) would remove also synful string and port (to have the same tring for secondary collectors, but not needed) + split_issuers = re.split(" |,|\t", full_issuer) issuer = split_issuers[0] if not identity: identity = f"{os.getlogin()}@{socket.gethostname()}" with open(pwd_file, "rb") as fd: data = fd.read() password = simple_scramble(data) - # The POOL password requires a special handling - # Done in https://github.com/CoffeaTeam/coffea-casa/blob/master/charts/coffea-casa/files/hub-extra/auth.py#L252 if kid == "POOL": password += password master_key = derive_master_key(password) return sign_token(identity, issuer, kid, master_key, duration, scope) -# To test you need htcondor password file -# python3 token_util.py $HOSTNAME:9618 vofrontend_service@$HOSTNAME -# will output condor IDTOKEN to stdout - use condor_ping to the server to verify/validate if __name__ == "__main__": kid = sys.argv[1] issuer = sys.argv[2] @@ -250,5 +207,4 @@ def create_and_sign_token(pwd_file, issuer=None, identity=None, kid=None, durati master_key = derive_master_key(obfusicated) scope = "condor:/READ condor:/WRITE condor:/ADVERTISE_STARTD condor:/ADVERTISE_SCHEDD condor:/ADVERTISE_MASTER" idtoken = sign_token(identity, issuer, kid, master_key, scope=scope) - # idtoken is str print(idtoken)