Skip to content

Commit

Permalink
Fix shutdown on Ctrl+C for Python source stages (#1839)
Browse files Browse the repository at this point in the history
* Ensure that Python sources which run indefinitely actually shutdown when requested by the pipline (ex: when the user hits Ctrl+C).
* The `KafkaSourceStage` already handled this properly, move the `stop` method from this class into the `SingleOutputSource`.
* Add `request_stop` and `is_stop_requested` methods to `SingleOutputSource`.
* Update all existing source stages which run indefinitely.
* Add new `should_stop_fn` constructor argument to `DirectoryWatcher`, sources which use the watcher (`AutoencoderSourceStage` & `AppShieldSourceStage`) pass in their `is_stop_requested` method allowing the watcher to know when to shut down.
* Move the RSS source generator method from the RSS source module into the `RSSController`, and update the `RSSSourceStage` to use the `RSSController` directly rather than the RSS module. This avoids the problem where modules require all config values to be JSON serializable, preventing a callback function from being passed to a module.

Closes #1837

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1839
  • Loading branch information
dagardner-nv authored Aug 14, 2024
1 parent f33776e commit 5fdbcb9
Show file tree
Hide file tree
Showing 17 changed files with 170 additions and 153 deletions.
25 changes: 8 additions & 17 deletions docs/source/developer_guide/guides/2_real_world_phishing.md
Original file line number Diff line number Diff line change
Expand Up @@ -761,20 +761,20 @@ def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject:
return builder.make_source(self.unique_name, self.source_generator)
```

The `source_generator` method is where most of the RabbitMQ-specific code exists. When we have a message that we wish to emit into the pipeline, we simply `yield` it.
The `source_generator` method is where most of the RabbitMQ-specific code exists. When we have a message that we wish to emit into the pipeline, we simply `yield` it. We continue this process until the `is_stop_requested()` method returns `True`.

```python
def source_generator(self) -> collections.abc.Iterator[MessageMeta]:
try:
while not self._stop_requested:
(method_frame, header_frame, body) = self._channel.basic_get(self._queue_name)
while not self.is_stop_requested():
(method_frame, _, body) = self._channel.basic_get(self._queue_name)
if method_frame is not None:
try:
buffer = StringIO(body.decode("utf-8"))
df = cudf.io.read_json(buffer, orient='records', lines=True)
yield MessageMeta(df=df)
except Exception as ex:
logger.exception("Error occurred converting RabbitMQ message to Dataframe: {}".format(ex))
logger.exception("Error occurred converting RabbitMQ message to Dataframe: %s", ex)
finally:
self._channel.basic_ack(method_frame.delivery_tag)
else:
Expand Down Expand Up @@ -824,11 +824,11 @@ class RabbitMQSourceStage(PreallocatorMixin, SingleOutputSource):
Hostname or IP of the RabbitMQ server.
exchange : str
Name of the RabbitMQ exchange to connect to.
exchange_type : str
exchange_type : str, optional
RabbitMQ exchange type; defaults to `fanout`.
queue_name : str
queue_name : str, optional
Name of the queue to listen to. If left blank, RabbitMQ will generate a random queue name bound to the exchange.
poll_interval : str
poll_interval : str, optional
Amount of time between polling RabbitMQ for new messages
"""

Expand All @@ -854,9 +854,6 @@ class RabbitMQSourceStage(PreallocatorMixin, SingleOutputSource):

self._poll_interval = pd.Timedelta(poll_interval)

# Flag to indicate whether or not we should stop
self._stop_requested = False

@property
def name(self) -> str:
return "from-rabbitmq"
Expand All @@ -867,18 +864,12 @@ class RabbitMQSourceStage(PreallocatorMixin, SingleOutputSource):
def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(MessageMeta)

def stop(self):
# Indicate we need to stop
self._stop_requested = True

return super().stop()

def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject:
return builder.make_source(self.unique_name, self.source_generator)

def source_generator(self) -> collections.abc.Iterator[MessageMeta]:
try:
while not self._stop_requested:
while not self.is_stop_requested():
(method_frame, _, body) = self._channel.basic_get(self._queue_name)
if method_frame is not None:
try:
Expand Down
5 changes: 1 addition & 4 deletions docs/source/developer_guide/guides/4_source_cpp_stage.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,10 @@ def __init__(self,
self._exchange_type = exchange_type
self._queue_name = queue_name

self._connection = None
self._connection: pika.BlockingConnection = None
self._channel = None

self._poll_interval = pd.Timedelta(poll_interval)

# Flag to indicate whether or not we should stop
self._stop_requested = False
```
```python
def connect(self):
Expand Down
11 changes: 1 addition & 10 deletions examples/developer_guide/2_2_rabbitmq/rabbitmq_source_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def __init__(self,

self._poll_interval = pd.Timedelta(poll_interval)

# Flag to indicate whether or not we should stop
self._stop_requested = False

@property
def name(self) -> str:
return "from-rabbitmq"
Expand All @@ -90,18 +87,12 @@ def supports_cpp_node(self) -> bool:
def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(MessageMeta)

def stop(self):
# Indicate we need to stop
self._stop_requested = True

return super().stop()

def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject:
return builder.make_source(self.unique_name, self.source_generator)

def source_generator(self) -> collections.abc.Iterator[MessageMeta]:
try:
while not self._stop_requested:
while not self.is_stop_requested():
(method_frame, _, body) = self._channel.basic_get(self._queue_name)
if method_frame is not None:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,11 @@ def __init__(self,
self._exchange_type = exchange_type
self._queue_name = queue_name

self._connection = None
self._connection: pika.BlockingConnection = None
self._channel = None

self._poll_interval = pd.Timedelta(poll_interval)

# Flag to indicate whether or not we should stop
self._stop_requested = False

@property
def name(self) -> str:
return "from-rabbitmq"
Expand Down Expand Up @@ -117,7 +114,7 @@ def connect(self):

def source_generator(self):
try:
while not self._stop_requested:
while not self.is_stop_requested():
(method_frame, _, body) = self._channel.basic_get(self._queue_name)
if method_frame is not None:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _polling_generate_frames_fsspec(self) -> typing.Iterable[fsspec.core.OpenFil
curr_time = time.monotonic()
next_update_epoch = curr_time

while (True):
while (not self.is_stop_requested()):
# Before doing any work, find the next update epoch after the current time
while (next_update_epoch <= curr_time):
# Only ever add `self._watch_interval` to next_update_epoch so all updates are at repeating intervals
Expand Down
71 changes: 70 additions & 1 deletion python/morpheus/morpheus/controllers/rss_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@
import logging
import os
import time
from collections.abc import Callable
from collections.abc import Iterable
from dataclasses import asdict
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
from urllib.parse import urlparse

import requests
import requests_cache

import cudf

from morpheus.messages import MessageMeta

logger = logging.getLogger(__name__)

IMPORT_EXCEPTION = None
Expand Down Expand Up @@ -72,6 +78,12 @@ class RSSController:
Request timeout in secs to fetch the feed.
strip_markup : bool, optional, default = False
When true, strip HTML & XML markup from the from the content, summary and title fields.
stop_after: int, default = 0
Stops ingesting after emitting `stop_after` records (rows in the dataframe). Useful for testing. Disabled if `0`
interval_secs : float, optional, default = 600
Interval in seconds between fetching new feed items.
should_stop_fn: Callable[[], bool]
Function that returns a boolean indicating if the watcher should stop processing files.
"""

# Fields which may contain HTML or XML content
Expand All @@ -89,7 +101,10 @@ def __init__(self,
cache_dir: str = "./.cache/http",
cooldown_interval: int = 600,
request_timeout: float = 2.0,
strip_markup: bool = False):
strip_markup: bool = False,
stop_after: int = 0,
interval_secs: float = 600,
should_stop_fn: Callable[[], bool] = None):
if IMPORT_EXCEPTION is not None:
raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION

Expand All @@ -104,6 +119,11 @@ def __init__(self,
self._request_timeout = request_timeout
self._strip_markup = strip_markup

if should_stop_fn is None:
self._should_stop_fn = lambda: False
else:
self._should_stop_fn = should_stop_fn

# Validate feed_input
for f in self._feed_input:
if not RSSController.is_url(f) and not os.path.exists(f):
Expand All @@ -113,7 +133,14 @@ def __init__(self,
# If feed_input is URL. Runs indefinitely
run_indefinitely = any(RSSController.is_url(f) for f in self._feed_input)

if (stop_after > 0 and run_indefinitely):
raise ValueError("Cannot set both `stop_after` and `run_indefinitely` to True.")

self._stop_after = stop_after
self._run_indefinitely = run_indefinitely
self._interval_secs = interval_secs
self._interval_td = timedelta(seconds=self._interval_secs)

self._enable_cache = enable_cache

if enable_cache:
Expand Down Expand Up @@ -381,3 +408,45 @@ def is_url(cls, feed_input: str) -> bool:
return parsed_url.scheme != '' and parsed_url.netloc != ''
except Exception:
return False

def feed_generator(self) -> Iterable[MessageMeta]:
"""
Fetch RSS feed entries and yield as MessageMeta object.
"""
stop_requested = False
records_emitted = 0

while (not stop_requested and not self._should_stop_fn()):
try:
for df in self.fetch_dataframes():
df_size = len(df)

if logger.isEnabledFor(logging.DEBUG):
logger.info("Received %d new entries...", df_size)
logger.info("Emitted %d records so far.", records_emitted)

yield MessageMeta(df=df)

records_emitted += df_size

if (0 < self._stop_after <= records_emitted):
stop_requested = True
logger.info("Stop limit reached... preparing to halt the source.")
break

except Exception as exc:
if not self.run_indefinitely:
logger.error("Failed either in the process of fetching or processing entries: %s.", exc)
raise
logger.error("Failed either in the process of fetching or processing entries: %s.", exc)

if not self.run_indefinitely:
stop_requested = True
continue

logger.info("Waiting for %d seconds before fetching again...", self._interval_secs)
sleep_until = datetime.now() + self._interval_td
while (datetime.now() < sleep_until and not self._should_stop_fn()):
time.sleep(1)

logger.info("RSS source exhausted, stopping.")
51 changes: 5 additions & 46 deletions python/morpheus/morpheus/modules/input/rss_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.

import logging
import time

import mrc
from pydantic import ValidationError

from morpheus.controllers.rss_controller import RSSController
from morpheus.messages import MessageMeta
from morpheus.modules.schemas.rss_source_schema import RSSSourceSchema
from morpheus.utils.module_utils import ModuleLoaderFactory
from morpheus.utils.module_utils import register_module
Expand Down Expand Up @@ -57,6 +55,7 @@ def _rss_source(builder: mrc.Builder):

module_config = builder.get_current_module_config()
rss_config = module_config.get("rss_source", {})

try:
validated_config = RSSSourceSchema(**rss_config)
except ValidationError as e:
Expand All @@ -74,50 +73,10 @@ def _rss_source(builder: mrc.Builder):
cache_dir=validated_config.cache_dir,
cooldown_interval=validated_config.cooldown_interval_sec,
request_timeout=validated_config.request_timeout_sec,
strip_markup=validated_config.strip_markup)

stop_requested = False

def fetch_feeds() -> MessageMeta:
"""
Fetch RSS feed entries and yield as MessageMeta object.
"""
nonlocal stop_requested
records_emitted = 0

while (not stop_requested):
try:
for df in controller.fetch_dataframes():
df_size = len(df)

if logger.isEnabledFor(logging.DEBUG):
logger.info("Received %d new entries...", df_size)
logger.info("Emitted %d records so far.", records_emitted)

yield MessageMeta(df=df)

records_emitted += df_size

if (0 < validated_config.stop_after_rec <= records_emitted):
stop_requested = True
logger.info("Stop limit reached... preparing to halt the source.")
break

except Exception as exc:
if not controller.run_indefinitely:
logger.error("Failed either in the process of fetching or processing entries: %s.", exc)
raise
logger.error("Failed either in the process of fetching or processing entries: %s.", exc)

if not controller.run_indefinitely:
stop_requested = True
continue

logger.info("Waiting for %d seconds before fetching again...", validated_config.interval_sec)
time.sleep(validated_config.interval_sec)

logger.info("RSS source exhausted, stopping.")
strip_markup=validated_config.strip_markup,
stop_after=validated_config.stop_after_rec,
interval_secs=validated_config.interval_sec)

node = builder.make_source("fetch_feeds", fetch_feeds)
node = builder.make_source("fetch_feeds", controller.feed_generator)

builder.register_module_output("output", node)
32 changes: 32 additions & 0 deletions python/morpheus/morpheus/pipeline/single_output_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def __init__(self, c: Config):

self._create_ports(0, 1)

# Flag to indicate if we need to stop, subclasses should check this value periodically, typically at the start
# of a polling loop
self._stop_requested = False

# pylint: disable=unused-argument
def _post_build_single(self, builder: mrc.Builder, out_node: mrc.SegmentObject) -> mrc.SegmentObject:
return out_node
Expand Down Expand Up @@ -74,3 +78,31 @@ def _post_build(self, builder: mrc.Builder, out_ports_nodes: list[mrc.SegmentObj
logger.info("Added source: %s\n └─> %s", self, pretty_print_type_name(self.output_ports[0].output_type))

return [ret_val]

def stop(self):
"""
This method is invoked by the pipeline whenever there is an unexpected shutdown.
Subclasses should override this method to perform any necessary cleanup operations.
"""

# Indicate we need to stop
self.request_stop()

return super().stop()

def request_stop(self):
"""
Request the source to stop processing data.
"""
self._stop_requested = True

def is_stop_requested(self) -> bool:
"""
Returns `True` if a stop has been requested.
Returns
-------
bool:
True if a stop has been requested, False otherwise.
"""
return self._stop_requested
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def __init__(self,
sort_glob=sort_glob,
recursive=recursive,
queue_max_size=queue_max_size,
batch_timeout=batch_timeout)
batch_timeout=batch_timeout,
should_stop_fn=self.is_stop_requested)

@property
def name(self) -> str:
Expand Down
Loading

0 comments on commit 5fdbcb9

Please sign in to comment.