diff --git a/lib/rucio/core/request.py b/lib/rucio/core/request.py index e6584f7b8c..f3a61cf704 100644 --- a/lib/rucio/core/request.py +++ b/lib/rucio/core/request.py @@ -621,16 +621,17 @@ def fetch_paths(request_id, *, session: "Session"): @METRICS.time_it @transactional_session def get_and_mark_next( + rse_collection: "RseCollection", request_type, state, processed_by: Optional[str] = None, processed_at_delay: int = 600, - limit=100, - older_than=None, - rse_id=None, - activity=None, - total_workers=0, - worker_number=0, + limit: int = 100, + older_than: "Optional[datetime.datetime]" = None, + rse_id: "Optional[str]" = None, + activity: "Optional[str]" = None, + total_workers: int = 0, + worker_number: int = 0, mode_all=False, hash_variable='id', activity_shares=None, @@ -643,6 +644,7 @@ def get_and_mark_next( Retrieve the next requests matching the request type and state. Workers are balanced via hashing to reduce concurrency on database. + :param rse_collection: the RSE collection being used :param request_type: Type of the request as a string or list of strings. :param state: State of the request as a string or list of strings. :param processed_by: the daemon/executable running this query @@ -748,8 +750,8 @@ def get_and_mark_next( dst_id = res_dict['dest_rse_id'] src_id = res_dict['source_rse_id'] - res_dict['dest_rse'] = get_rse_name(rse_id=dst_id, session=session) if dst_id is not None else None - res_dict['source_rse'] = get_rse_name(rse_id=src_id, session=session) if src_id is not None else None + res_dict['dst_rse'] = rse_collection[dst_id].ensure_loaded(load_name=True) + res_dict['src_rse'] = rse_collection[src_id].ensure_loaded(load_name=True) if src_id is not None else None result.append(res_dict) else: diff --git a/lib/rucio/daemons/conveyor/finisher.py b/lib/rucio/daemons/conveyor/finisher.py index 13d7018b86..1c847d44a1 100644 --- a/lib/rucio/daemons/conveyor/finisher.py +++ b/lib/rucio/daemons/conveyor/finisher.py @@ -41,10 +41,11 @@ from rucio.core import request as request_core, replica as replica_core from rucio.core.monitor import MetricManager from rucio.core.rse import list_rses +from rucio.core.transfer import ProtocolFactory +from rucio.core.topology import Topology, ExpiringObjectCache from rucio.daemons.common import db_workqueue, ProducerConsumerDaemon from rucio.db.sqla.constants import RequestState, RequestType, ReplicaState, BadFilesStatus from rucio.db.sqla.session import transactional_session -from rucio.rse import rsemanager if TYPE_CHECKING: from rucio.daemons.common import HeartbeatHandler @@ -60,6 +61,7 @@ def _fetch_requests( db_bulk, set_last_processed_by: bool, + cached_topology, heartbeat_handler, activity, ): @@ -67,8 +69,11 @@ def _fetch_requests( logger(logging.DEBUG, 'Working on activity %s', activity) + topology = cached_topology.get() if cached_topology else Topology() + get_requests_fnc = functools.partial( request_core.get_and_mark_next, + rse_collection=topology, request_type=[RequestType.TRANSFER, RequestType.STAGEIN, RequestType.STAGEOUT], processed_by=heartbeat_handler.short_executable if set_last_processed_by else None, limit=db_bulk, @@ -100,17 +105,18 @@ def _fetch_requests( if len(reqs) < db_bulk / 2: logger(logging.INFO, "Only %s transfers, which is less than half of the bulk %s", len(reqs), db_bulk) must_sleep = True - return must_sleep, reqs + return must_sleep, (reqs, topology) def _handle_requests( - reqs, + batch, bulk, suspicious_patterns, retry_protocol_mismatches, *, logger=logging.log, ): + reqs, topology = batch if not reqs: return @@ -122,7 +128,7 @@ def _handle_requests( for chunk in chunks(reqs, bulk): try: stopwatch = Stopwatch() - _finish_requests(chunk, suspicious_patterns, retry_protocol_mismatches, logger=logger) + _finish_requests(topology, chunk, suspicious_patterns, retry_protocol_mismatches, logger=logger) METRICS.timer('handle_requests_time').observe(stopwatch.elapsed / (len(chunk) or 1)) METRICS.counter('handle_requests').inc(len(chunk)) except Exception as error: @@ -144,6 +150,7 @@ def finisher( bulk=100, db_bulk=1000, partition_wait_time=10, + cached_topology=None, total_threads=1, ): """ @@ -172,14 +179,15 @@ def finisher( def _db_producer(*, activity: str, heartbeat_handler: "HeartbeatHandler"): return _fetch_requests( db_bulk=db_bulk, + cached_topology=cached_topology, activity=activity, set_last_processed_by=not once, heartbeat_handler=heartbeat_handler, ) - def _consumer(reqs): + def _consumer(batch): return _handle_requests( - reqs=reqs, + batch=batch, bulk=bulk, suspicious_patterns=suspicious_patterns, retry_protocol_mismatches=retry_protocol_mismatches, @@ -209,17 +217,19 @@ def run(once=False, total_threads=1, sleep_time=60, activities=None, bulk=100, d if rucio.db.sqla.util.is_old_db(): raise DatabaseException('Database was not updated, daemon won\'t start') + cached_topology = ExpiringObjectCache(ttl=300, new_obj_fnc=lambda: Topology()) finisher( once=once, activities=activities, bulk=bulk, db_bulk=db_bulk, sleep_time=sleep_time, + cached_topology=cached_topology, total_threads=total_threads ) -def _finish_requests(reqs, suspicious_patterns, retry_protocol_mismatches, logger=logging.log): +def _finish_requests(topology: "Topology", reqs, suspicious_patterns, retry_protocol_mismatches, logger=logging.log): """ Used by finisher to handle terminated requests, @@ -231,7 +241,7 @@ def _finish_requests(reqs, suspicious_patterns, retry_protocol_mismatches, logge failed_during_submission = [RequestState.SUBMITTING, RequestState.SUBMISSION_FAILED, RequestState.LOST] failed_no_submission_attempts = [RequestState.NO_SOURCES, RequestState.ONLY_TAPE_SOURCES, RequestState.MISMATCH_SCHEME] undeterministic_rses = __get_undeterministic_rses(logger=logger) - rses_info, protocols = {}, {} + protocol_factory = ProtocolFactory() replicas = {} for req in reqs: try: @@ -252,14 +262,11 @@ def _finish_requests(reqs, suspicious_patterns, retry_protocol_mismatches, logge # for TAPE, replica path is needed if req['request_type'] in (RequestType.TRANSFER, RequestType.STAGEIN) and req['dest_rse_id'] in undeterministic_rses: - if req['dest_rse_id'] not in rses_info: - rses_info[req['dest_rse_id']] = rsemanager.get_rse_info(rse_id=req['dest_rse_id']) + dst_rse = topology[req['dest_rse_id']].ensure_loaded(load_info=True) pfn = req['dest_url'] scheme = urlparse(pfn).scheme - dest_rse_id_scheme = '%s_%s' % (req['dest_rse_id'], scheme) - if dest_rse_id_scheme not in protocols: - protocols[dest_rse_id_scheme] = rsemanager.create_protocol(rses_info[req['dest_rse_id']], 'write', scheme) - path = protocols[dest_rse_id_scheme].parse_pfns([pfn])[pfn]['path'] + protocol = protocol_factory.protocol(dst_rse, scheme, 'write') + path = protocol.parse_pfns([pfn])[pfn]['path'] replica['path'] = os.path.join(path, os.path.basename(pfn)) # replica should not be added to replicas until all info are filled diff --git a/lib/rucio/daemons/conveyor/poller.py b/lib/rucio/daemons/conveyor/poller.py index 5e139ac3af..2bab0760d5 100644 --- a/lib/rucio/daemons/conveyor/poller.py +++ b/lib/rucio/daemons/conveyor/poller.py @@ -40,6 +40,7 @@ from rucio.common.utils import dict_chunks from rucio.core import transfer as transfer_core, request as request_core from rucio.core.monitor import MetricManager +from rucio.core.topology import Topology, ExpiringObjectCache from rucio.daemons.common import db_workqueue, ProducerConsumerDaemon from rucio.db.sqla.constants import RequestState, RequestType from rucio.transfertool.fts3 import FTS3Transfertool @@ -63,13 +64,17 @@ def _fetch_requests( activity_shares, transfertool, filter_transfertool, + cached_topology, activity, heartbeat_handler ): worker_number, total_workers, logger = heartbeat_handler.live() logger(logging.DEBUG, 'Start to poll transfers older than %i seconds for activity %s using transfer tool: %s' % (older_than, activity, filter_transfertool)) + + topology = cached_topology.get() if cached_topology else Topology() transfs = request_core.get_and_mark_next( + rse_collection=topology, request_type=[RequestType.TRANSFER, RequestType.STAGEIN, RequestType.STAGEOUT], state=[RequestState.SUBMITTED], limit=db_bulk, @@ -151,6 +156,7 @@ def poller( partition_wait_time: int = 10, transfertool: Optional[str] = TRANSFER_TOOL, filter_transfertool: Optional[str] = FILTER_TRANSFERTOOL, + cached_topology=None, total_threads: int = 1, ): """ @@ -189,6 +195,7 @@ def _db_producer(*, activity: str, heartbeat_handler: "HeartbeatHandler"): activity_shares=activity_shares, transfertool=transfertool, filter_transfertool=filter_transfertool, + cached_topology=cached_topology, activity=activity, heartbeat_handler=heartbeat_handler, ) @@ -256,6 +263,7 @@ def run( parsed_activity_shares.update((share, int(percentage * db_bulk)) for share, percentage in parsed_activity_shares.items()) logging.info('activity shares enabled: %s' % parsed_activity_shares) + cached_topology = ExpiringObjectCache(ttl=300, new_obj_fnc=lambda: Topology()) poller( once=once, fts_bulk=fts_bulk, @@ -264,6 +272,7 @@ def run( sleep_time=sleep_time, activities=activities, activity_shares=parsed_activity_shares, + cached_topology=cached_topology, total_threads=total_threads, )