Skip to content

Commit

Permalink
Transfers: use topology in poller and finisher
Browse files Browse the repository at this point in the history
  • Loading branch information
Radu Carpa committed Nov 10, 2023
1 parent 4380d02 commit 666f0f0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
18 changes: 10 additions & 8 deletions lib/rucio/core/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,16 +621,17 @@ def fetch_paths(request_id, *, session: "Session"):
@METRICS.time_it
@transactional_session
def get_and_mark_next(
rse_collection: "RseCollection",
request_type,
state,
processed_by: Optional[str] = None,
processed_at_delay: int = 600,
limit=100,
older_than=None,
rse_id=None,
activity=None,
total_workers=0,
worker_number=0,
limit: int = 100,
older_than: "Optional[datetime.datetime]" = None,
rse_id: "Optional[str]" = None,
activity: "Optional[str]" = None,
total_workers: int = 0,
worker_number: int = 0,
mode_all=False,
hash_variable='id',
activity_shares=None,
Expand All @@ -643,6 +644,7 @@ def get_and_mark_next(
Retrieve the next requests matching the request type and state.
Workers are balanced via hashing to reduce concurrency on database.
:param rse_collection: the RSE collection being used
:param request_type: Type of the request as a string or list of strings.
:param state: State of the request as a string or list of strings.
:param processed_by: the daemon/executable running this query
Expand Down Expand Up @@ -748,8 +750,8 @@ def get_and_mark_next(

dst_id = res_dict['dest_rse_id']
src_id = res_dict['source_rse_id']
res_dict['dest_rse'] = get_rse_name(rse_id=dst_id, session=session) if dst_id is not None else None
res_dict['source_rse'] = get_rse_name(rse_id=src_id, session=session) if src_id is not None else None
res_dict['dst_rse'] = rse_collection[dst_id].ensure_loaded(load_name=True)
res_dict['src_rse'] = rse_collection[src_id].ensure_loaded(load_name=True) if src_id is not None else None

result.append(res_dict)
else:
Expand Down
35 changes: 21 additions & 14 deletions lib/rucio/daemons/conveyor/finisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@
from rucio.core import request as request_core, replica as replica_core
from rucio.core.monitor import MetricManager
from rucio.core.rse import list_rses
from rucio.core.transfer import ProtocolFactory
from rucio.core.topology import Topology, ExpiringObjectCache
from rucio.daemons.common import db_workqueue, ProducerConsumerDaemon
from rucio.db.sqla.constants import RequestState, RequestType, ReplicaState, BadFilesStatus
from rucio.db.sqla.session import transactional_session
from rucio.rse import rsemanager

if TYPE_CHECKING:
from rucio.daemons.common import HeartbeatHandler
Expand All @@ -60,15 +61,19 @@
def _fetch_requests(
db_bulk,
set_last_processed_by: bool,
cached_topology,
heartbeat_handler,
activity,
):
worker_number, total_workers, logger = heartbeat_handler.live()

logger(logging.DEBUG, 'Working on activity %s', activity)

topology = cached_topology.get() if cached_topology else Topology()

get_requests_fnc = functools.partial(
request_core.get_and_mark_next,
rse_collection=topology,
request_type=[RequestType.TRANSFER, RequestType.STAGEIN, RequestType.STAGEOUT],
processed_by=heartbeat_handler.short_executable if set_last_processed_by else None,
limit=db_bulk,
Expand Down Expand Up @@ -100,17 +105,18 @@ def _fetch_requests(
if len(reqs) < db_bulk / 2:
logger(logging.INFO, "Only %s transfers, which is less than half of the bulk %s", len(reqs), db_bulk)
must_sleep = True
return must_sleep, reqs
return must_sleep, (reqs, topology)


def _handle_requests(
reqs,
batch,
bulk,
suspicious_patterns,
retry_protocol_mismatches,
*,
logger=logging.log,
):
reqs, topology = batch
if not reqs:
return

Expand All @@ -122,7 +128,7 @@ def _handle_requests(
for chunk in chunks(reqs, bulk):
try:
stopwatch = Stopwatch()
_finish_requests(chunk, suspicious_patterns, retry_protocol_mismatches, logger=logger)
_finish_requests(topology, chunk, suspicious_patterns, retry_protocol_mismatches, logger=logger)
METRICS.timer('handle_requests_time').observe(stopwatch.elapsed / (len(chunk) or 1))
METRICS.counter('handle_requests').inc(len(chunk))
except Exception as error:
Expand All @@ -144,6 +150,7 @@ def finisher(
bulk=100,
db_bulk=1000,
partition_wait_time=10,
cached_topology=None,
total_threads=1,
):
"""
Expand Down Expand Up @@ -172,14 +179,15 @@ def finisher(
def _db_producer(*, activity: str, heartbeat_handler: "HeartbeatHandler"):
return _fetch_requests(
db_bulk=db_bulk,
cached_topology=cached_topology,
activity=activity,
set_last_processed_by=not once,
heartbeat_handler=heartbeat_handler,
)

def _consumer(reqs):
def _consumer(batch):
return _handle_requests(
reqs=reqs,
batch=batch,
bulk=bulk,
suspicious_patterns=suspicious_patterns,
retry_protocol_mismatches=retry_protocol_mismatches,
Expand Down Expand Up @@ -209,17 +217,19 @@ def run(once=False, total_threads=1, sleep_time=60, activities=None, bulk=100, d
if rucio.db.sqla.util.is_old_db():
raise DatabaseException('Database was not updated, daemon won\'t start')

cached_topology = ExpiringObjectCache(ttl=300, new_obj_fnc=lambda: Topology())
finisher(
once=once,
activities=activities,
bulk=bulk,
db_bulk=db_bulk,
sleep_time=sleep_time,
cached_topology=cached_topology,
total_threads=total_threads
)


def _finish_requests(reqs, suspicious_patterns, retry_protocol_mismatches, logger=logging.log):
def _finish_requests(topology: "Topology", reqs, suspicious_patterns, retry_protocol_mismatches, logger=logging.log):
"""
Used by finisher to handle terminated requests,
Expand All @@ -231,7 +241,7 @@ def _finish_requests(reqs, suspicious_patterns, retry_protocol_mismatches, logge
failed_during_submission = [RequestState.SUBMITTING, RequestState.SUBMISSION_FAILED, RequestState.LOST]
failed_no_submission_attempts = [RequestState.NO_SOURCES, RequestState.ONLY_TAPE_SOURCES, RequestState.MISMATCH_SCHEME]
undeterministic_rses = __get_undeterministic_rses(logger=logger)
rses_info, protocols = {}, {}
protocol_factory = ProtocolFactory()
replicas = {}
for req in reqs:
try:
Expand All @@ -252,14 +262,11 @@ def _finish_requests(reqs, suspicious_patterns, retry_protocol_mismatches, logge

# for TAPE, replica path is needed
if req['request_type'] in (RequestType.TRANSFER, RequestType.STAGEIN) and req['dest_rse_id'] in undeterministic_rses:
if req['dest_rse_id'] not in rses_info:
rses_info[req['dest_rse_id']] = rsemanager.get_rse_info(rse_id=req['dest_rse_id'])
dst_rse = topology[req['dest_rse_id']].ensure_loaded(load_info=True)
pfn = req['dest_url']
scheme = urlparse(pfn).scheme
dest_rse_id_scheme = '%s_%s' % (req['dest_rse_id'], scheme)
if dest_rse_id_scheme not in protocols:
protocols[dest_rse_id_scheme] = rsemanager.create_protocol(rses_info[req['dest_rse_id']], 'write', scheme)
path = protocols[dest_rse_id_scheme].parse_pfns([pfn])[pfn]['path']
protocol = protocol_factory.protocol(dst_rse, scheme, 'write')
path = protocol.parse_pfns([pfn])[pfn]['path']
replica['path'] = os.path.join(path, os.path.basename(pfn))

# replica should not be added to replicas until all info are filled
Expand Down
9 changes: 9 additions & 0 deletions lib/rucio/daemons/conveyor/poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from rucio.common.utils import dict_chunks
from rucio.core import transfer as transfer_core, request as request_core
from rucio.core.monitor import MetricManager
from rucio.core.topology import Topology, ExpiringObjectCache
from rucio.daemons.common import db_workqueue, ProducerConsumerDaemon
from rucio.db.sqla.constants import RequestState, RequestType
from rucio.transfertool.fts3 import FTS3Transfertool
Expand All @@ -63,13 +64,17 @@ def _fetch_requests(
activity_shares,
transfertool,
filter_transfertool,
cached_topology,
activity,
heartbeat_handler
):
worker_number, total_workers, logger = heartbeat_handler.live()

logger(logging.DEBUG, 'Start to poll transfers older than %i seconds for activity %s using transfer tool: %s' % (older_than, activity, filter_transfertool))

topology = cached_topology.get() if cached_topology else Topology()
transfs = request_core.get_and_mark_next(
rse_collection=topology,
request_type=[RequestType.TRANSFER, RequestType.STAGEIN, RequestType.STAGEOUT],
state=[RequestState.SUBMITTED],
limit=db_bulk,
Expand Down Expand Up @@ -151,6 +156,7 @@ def poller(
partition_wait_time: int = 10,
transfertool: Optional[str] = TRANSFER_TOOL,
filter_transfertool: Optional[str] = FILTER_TRANSFERTOOL,
cached_topology=None,
total_threads: int = 1,
):
"""
Expand Down Expand Up @@ -189,6 +195,7 @@ def _db_producer(*, activity: str, heartbeat_handler: "HeartbeatHandler"):
activity_shares=activity_shares,
transfertool=transfertool,
filter_transfertool=filter_transfertool,
cached_topology=cached_topology,
activity=activity,
heartbeat_handler=heartbeat_handler,
)
Expand Down Expand Up @@ -256,6 +263,7 @@ def run(
parsed_activity_shares.update((share, int(percentage * db_bulk)) for share, percentage in parsed_activity_shares.items())
logging.info('activity shares enabled: %s' % parsed_activity_shares)

cached_topology = ExpiringObjectCache(ttl=300, new_obj_fnc=lambda: Topology())
poller(
once=once,
fts_bulk=fts_bulk,
Expand All @@ -264,6 +272,7 @@ def run(
sleep_time=sleep_time,
activities=activities,
activity_shares=parsed_activity_shares,
cached_topology=cached_topology,
total_threads=total_threads,
)

Expand Down

0 comments on commit 666f0f0

Please sign in to comment.