Skip to content

Commit

Permalink
Update HTEX.Interchange to listen on a single interface (#2828)
Browse files Browse the repository at this point in the history
This PR updates the HTEX.Interchange to listen for connections from the managers only on a specific interface rather than the current default of binding to all interfaces. Binding to all interfaces is generally frowned upon on login nodes, when all we need is to allow connections from the internal network.

Here are the changes:

Pass HighThroughputExecutor.address to Interchange.interchange_address
Interchange will bind only to the interchange_address if it is specified instead of the binding to zmq:* or 0.0.0.0 which is the current default.
Adding tests for the Interchange
Please note that configs which specify HTEX(address="localhost") or similar where a non IPv4 address will now fail
  • Loading branch information
yadudoc authored Jul 19, 2023
1 parent 3fc9293 commit c39700b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 60 deletions.
8 changes: 5 additions & 3 deletions parsl/executors/high_throughput/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
address : string
An address to connect to the main Parsl process which is reachable from the network in which
workers will be running. This can be either a hostname as returned by ``hostname`` or an
IP address. Most login nodes on clusters have several network interfaces available, only
some of which can be reached from the compute nodes.
workers will be running. This field expects an IPv4 address (xxx.xxx.xxx.xxx).
Most login nodes on clusters have several network interfaces available, only some of which
can be reached from the compute nodes. This field can be used to limit the executor to listen
only on a specific interface, and limiting connections to the internal network.
By default, the executor will attempt to enumerate and connect through all possible addresses.
Setting an address here overrides the default behavior.
default=None
Expand Down Expand Up @@ -470,6 +471,7 @@ def _start_local_interchange_process(self):
kwargs={"client_ports": (self.outgoing_q.port,
self.incoming_q.port,
self.command_client.port),
"interchange_address": self.address,
"worker_ports": self.worker_ports,
"worker_port_range": self.worker_port_range,
"hub_address": self.hub_address,
Expand Down
72 changes: 15 additions & 57 deletions parsl/executors/high_throughput/interchange.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
import argparse
import zmq
import os
import sys
Expand All @@ -14,7 +13,7 @@
import threading
import json

from typing import cast, Any, Dict, Set
from typing import cast, Any, Dict, Set, Optional

from parsl.utils import setproctitle
from parsl.version import VERSION as PARSL_VERSION
Expand All @@ -29,6 +28,9 @@
HEARTBEAT_CODE = (2 ** 32) - 1
PKL_HEARTBEAT_CODE = pickle.dumps((2 ** 32) - 1)

LOGGER_NAME = "interchange"
logger = logging.getLogger(LOGGER_NAME)


class ManagerLost(Exception):
''' Task lost due to manager loss. Manager is considered lost when multiple heartbeats
Expand Down Expand Up @@ -66,7 +68,7 @@ class Interchange:
"""
def __init__(self,
client_address="127.0.0.1",
interchange_address="127.0.0.1",
interchange_address: Optional[str] = None,
client_ports=(50055, 50056, 50057),
worker_ports=None,
worker_port_range=(54000, 55000),
Expand All @@ -83,8 +85,9 @@ def __init__(self,
client_address : str
The ip address at which the parsl client can be reached. Default: "127.0.0.1"
interchange_address : str
The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1"
interchange_address : Optional str
If specified the interchange will only listen on this address for connections from workers
else, it binds to all addresses.
client_ports : triple(int, int, int)
The ports at which the client can be reached
Expand Down Expand Up @@ -125,7 +128,7 @@ def __init__(self,
logger.debug("Initializing Interchange process")

self.client_address = client_address
self.interchange_address = interchange_address
self.interchange_address: str = interchange_address or "*"
self.poll_period = poll_period

logger.info("Attempting connection to client at {} on ports: {},{},{}".format(
Expand Down Expand Up @@ -160,14 +163,14 @@ def __init__(self,
self.worker_task_port = self.worker_ports[0]
self.worker_result_port = self.worker_ports[1]

self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port))
self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port))
self.task_outgoing.bind(f"tcp://{self.interchange_address}:{self.worker_task_port}")
self.results_incoming.bind(f"tcp://{self.interchange_address}:{self.worker_result_port}")

else:
self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*',
self.worker_task_port = self.task_outgoing.bind_to_random_port(f"tcp://{self.interchange_address}",
min_port=worker_port_range[0],
max_port=worker_port_range[1], max_tries=100)
self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*',
self.worker_result_port = self.results_incoming.bind_to_random_port(f"tcp://{self.interchange_address}",
min_port=worker_port_range[0],
max_port=worker_port_range[1], max_tries=100)

Expand Down Expand Up @@ -574,16 +577,14 @@ def expire_bad_managers(self, interesting_managers, hub_channel):
interesting_managers.remove(manager_id)


def start_file_logger(filename, name='interchange', level=logging.DEBUG, format_string=None):
def start_file_logger(filename, level=logging.DEBUG, format_string=None):
"""Add a stream log handler.
Parameters
---------
filename: string
Name of the file to write logs to. Required.
name: string
Logger name. Default="parsl.executors.interchange"
level: logging.LEVEL
Set the logging level. Default=logging.DEBUG
- format_string (string): Set the format string
Expand All @@ -598,7 +599,7 @@ def start_file_logger(filename, name='interchange', level=logging.DEBUG, format_
format_string = "%(asctime)s.%(msecs)03d %(name)s:%(lineno)d %(processName)s(%(process)d) %(threadName)s %(funcName)s [%(levelname)s] %(message)s"

global logger
logger = logging.getLogger(name)
logger = logging.getLogger(LOGGER_NAME)
logger.setLevel(level)
handler = logging.FileHandler(filename)
handler.setLevel(level)
Expand All @@ -619,46 +620,3 @@ def starter(comm_q, *args, **kwargs):
comm_q.put((ic.worker_task_port,
ic.worker_result_port))
ic.start()


if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument("-c", "--client_address",
help="Client address")
parser.add_argument("-l", "--logdir", default="parsl_worker_logs",
help="Parsl worker log directory")
parser.add_argument("-t", "--task_url",
help="REQUIRED: ZMQ url for receiving tasks")
parser.add_argument("-r", "--result_url",
help="REQUIRED: ZMQ url for posting results")
parser.add_argument("-p", "--poll_period",
help="REQUIRED: poll period used for main thread")
parser.add_argument("--worker_ports", default=None,
help="OPTIONAL, pair of workers ports to listen on, eg --worker_ports=50001,50005")
parser.add_argument("-d", "--debug", action='store_true',
help="Count of apps to launch")

args = parser.parse_args()

# Setup logging
global logger
format_string = "%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s"

logger = logging.getLogger("interchange")
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setLevel('DEBUG' if args.debug is True else 'INFO')
formatter = logging.Formatter(format_string, datefmt='%Y-%m-%d %H:%M:%S')
handler.setFormatter(formatter)
logger.addHandler(handler)

logger.debug("Starting Interchange")

optionals = {}

if args.worker_ports:
optionals['worker_ports'] = [int(i) for i in args.worker_ports.split(',')]

ic = Interchange(**optionals)
ic.start()
Empty file.
46 changes: 46 additions & 0 deletions parsl/tests/test_htex/test_htex_zmq_binding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import logging

import psutil
import pytest
import zmq

from parsl.executors.high_throughput.interchange import Interchange


def test_interchange_binding_no_address():
ix = Interchange()
assert ix.interchange_address == "*"


def test_interchange_binding_with_address():
# Using loopback address
address = "127.0.0.1"
ix = Interchange(interchange_address=address)
assert ix.interchange_address == address


def test_interchange_binding_with_non_ipv4_address():
# Confirm that a ipv4 address is required
address = "localhost"
with pytest.raises(zmq.error.ZMQError):
Interchange(interchange_address=address)


def test_interchange_binding_bad_address():
""" Confirm that we raise a ZMQError when a bad address is supplied"""
address = "550.0.0.0"
with pytest.raises(zmq.error.ZMQError):
Interchange(interchange_address=address)


def test_limited_interface_binding():
""" When address is specified the worker_port would be bound to it rather than to 0.0.0.0"""
address = "127.0.0.1"
ix = Interchange(interchange_address=address)
ix.worker_result_port
proc = psutil.Process()
conns = proc.connections(kind="tcp")

matched_conns = [conn for conn in conns if conn.laddr.port == ix.worker_result_port]
assert len(matched_conns) == 1
assert matched_conns[0].laddr.ip == address

0 comments on commit c39700b

Please sign in to comment.