diff --git a/experiment/measurer/datatypes.py b/experiment/measurer/datatypes.py new file mode 100644 index 000000000..21415b336 --- /dev/null +++ b/experiment/measurer/datatypes.py @@ -0,0 +1,21 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for common data types shared under the measurer module.""" +import collections + +SnapshotMeasureRequest = collections.namedtuple( + 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle']) + +RetryRequest = collections.namedtuple( + 'RetryRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle']) diff --git a/experiment/measurer/measure_manager.py b/experiment/measurer/measure_manager.py index f10e556c3..288148401 100644 --- a/experiment/measurer/measure_manager.py +++ b/experiment/measurer/measure_manager.py @@ -44,20 +44,20 @@ from database import models from experiment.build import build_utils from experiment.measurer import coverage_utils +from experiment.measurer import measure_worker from experiment.measurer import run_coverage from experiment.measurer import run_crashes from experiment import scheduler +import experiment.measurer.datatypes as measurer_datatypes logger = logs.Logger() -SnapshotMeasureRequest = collections.namedtuple( - 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle']) - NUM_RETRIES = 3 RETRY_DELAY = 3 FAIL_WAIT_SECONDS = 30 SNAPSHOT_QUEUE_GET_TIMEOUT = 1 SNAPSHOTS_BATCH_SAVE_SIZE = 100 +MEASUREMENT_LOOP_WAIT = 10 def exists_in_experiment_filestore(path: pathlib.Path) -> bool: @@ -75,10 +75,9 @@ def measure_main(experiment_config): experiment = experiment_config['experiment'] max_total_time = experiment_config['max_total_time'] measurers_cpus = experiment_config['measurers_cpus'] - runners_cpus = experiment_config['runners_cpus'] region_coverage = experiment_config['region_coverage'] - measure_loop(experiment, max_total_time, measurers_cpus, runners_cpus, - region_coverage) + measure_manager_loop(experiment, max_total_time, measurers_cpus, + region_coverage) # Clean up resources. gc.collect() @@ -104,18 +103,7 @@ def measure_loop(experiment: str, """Continuously measure trials for |experiment|.""" logger.info('Start measure_loop.') - pool_args = () - if measurers_cpus is not None and runners_cpus is not None: - local_experiment = experiment_utils.is_local_experiment() - if local_experiment: - cores_queue = multiprocessing.Queue() - logger.info('Scheduling measurers from core %d to %d.', - runners_cpus, runners_cpus + measurers_cpus - 1) - for cpu in range(runners_cpus, runners_cpus + measurers_cpus): - cores_queue.put(cpu) - pool_args = (measurers_cpus, _process_init, (cores_queue,)) - else: - pool_args = (measurers_cpus,) + pool_args = get_pool_args(measurers_cpus, runners_cpus) with multiprocessing.Pool( *pool_args) as pool, multiprocessing.Manager() as manager: @@ -256,12 +244,13 @@ def _query_unmeasured_trials(experiment: str): def _get_unmeasured_first_snapshots( - experiment: str) -> List[SnapshotMeasureRequest]: + experiment: str) -> List[measurer_datatypes.SnapshotMeasureRequest]: """Returns a list of unmeasured SnapshotMeasureRequests that are the first snapshot for their trial. The trials are trials in |experiment|.""" trials_without_snapshots = _query_unmeasured_trials(experiment) return [ - SnapshotMeasureRequest(trial.fuzzer, trial.benchmark, trial.id, 0) + measurer_datatypes.SnapshotMeasureRequest(trial.fuzzer, trial.benchmark, + trial.id, 0) for trial in trials_without_snapshots ] @@ -289,7 +278,8 @@ def _query_measured_latest_snapshots(experiment: str): def _get_unmeasured_next_snapshots( - experiment: str, max_cycle: int) -> List[SnapshotMeasureRequest]: + experiment: str, + max_cycle: int) -> List[measurer_datatypes.SnapshotMeasureRequest]: """Returns a list of the latest unmeasured SnapshotMeasureRequests of trials in |experiment| that have been measured at least once in |experiment|. |max_total_time| is used to determine if a trial has another @@ -305,16 +295,15 @@ def _get_unmeasured_next_snapshots( if next_cycle > max_cycle: continue - snapshot_with_cycle = SnapshotMeasureRequest(snapshot.fuzzer, - snapshot.benchmark, - snapshot.trial_id, - next_cycle) + snapshot_with_cycle = measurer_datatypes.SnapshotMeasureRequest( + snapshot.fuzzer, snapshot.benchmark, snapshot.trial_id, next_cycle) next_snapshots.append(snapshot_with_cycle) return next_snapshots -def get_unmeasured_snapshots(experiment: str, - max_cycle: int) -> List[SnapshotMeasureRequest]: +def get_unmeasured_snapshots( + experiment: str, + max_cycle: int) -> List[measurer_datatypes.SnapshotMeasureRequest]: """Returns a list of SnapshotMeasureRequests that need to be measured (assuming they have been saved already).""" # Measure the first snapshot of every started trial without any measured @@ -683,6 +672,134 @@ def initialize_logs(): }) +def consume_snapshots_from_response_queue( + response_queue, queued_snapshots) -> List[models.Snapshot]: + """Consume response_queue, allows retry objects to retried, and + return all measured snapshots in a list.""" + measured_snapshots = [] + while True: + try: + response_object = response_queue.get_nowait() + if isinstance(response_object, measurer_datatypes.RetryRequest): + # Need to retry measurement task, will remove identifier from + # the set so task can be retried in next loop iteration. + snapshot_identifier = (response_object.trial_id, + response_object.cycle) + queued_snapshots.remove(snapshot_identifier) + logger.info('Reescheduling task for trial %s and cycle %s', + response_object.trial_id, response_object.cycle) + elif isinstance(response_object, models.Snapshot): + measured_snapshots.append(response_object) + else: + logger.error('Type of response object not mapped! %s', + type(response_object)) + except queue.Empty: + break + return measured_snapshots + + +def measure_manager_inner_loop(experiment: str, max_cycle: int, request_queue, + response_queue, queued_snapshots): + """Reads from database to determine which snapshots needs measuring. Write + measurements tasks to request queue, get results from response queue, and + write measured snapshots to database. Returns False if there's no more + snapshots left to be measured""" + initialize_logs() + # Read database to determine which snapshots needs measuring. + unmeasured_snapshots = get_unmeasured_snapshots(experiment, max_cycle) + logger.info('Retrieved %d unmeasured snapshots from measure manager', + len(unmeasured_snapshots)) + # When there are no more snapshots left to be measured, should break loop. + if not unmeasured_snapshots: + return False + + # Write measurements requests to request queue + for unmeasured_snapshot in unmeasured_snapshots: + # No need to insert fuzzer and benchmark info here as it's redundant + # (Can be retrieved through trial_id). + unmeasured_snapshot_identifier = (unmeasured_snapshot.trial_id, + unmeasured_snapshot.cycle) + # Checking if snapshot already was queued so workers will not repeat + # measurement for same snapshot + if unmeasured_snapshot_identifier not in queued_snapshots: + request_queue.put(unmeasured_snapshot) + queued_snapshots.add(unmeasured_snapshot_identifier) + + # Read results from response queue. + measured_snapshots = consume_snapshots_from_response_queue( + response_queue, queued_snapshots) + logger.info('Retrieved %d measured snapshots from response queue', + len(measured_snapshots)) + + # Save measured snapshots to database. + if measured_snapshots: + db_utils.add_all(measured_snapshots) + + return True + + +def get_pool_args(measurers_cpus, runners_cpus): + """Return pool args based on measurer cpus and runner cpus arguments.""" + if measurers_cpus is None or runners_cpus is None: + return () + + local_experiment = experiment_utils.is_local_experiment() + if not local_experiment: + return (measurers_cpus,) + + cores_queue = multiprocessing.Queue() + logger.info('Scheduling measurers from core %d to %d.', runners_cpus, + runners_cpus + measurers_cpus - 1) + for cpu in range(runners_cpus, runners_cpus + measurers_cpus): + cores_queue.put(cpu) + return (measurers_cpus, _process_init, (cores_queue,)) + + +def measure_manager_loop(experiment: str, + max_total_time: int, + measurers_cpus=None, + region_coverage=False): # pylint: disable=too-many-locals + """Measure manager loop. Creates request and response queues, request + measurements tasks from workers, retrieve measurement results from response + queue and writes measured snapshots in database.""" + logger.info('Starting measure manager loop.') + if not measurers_cpus: + measurers_cpus = multiprocessing.cpu_count() + logger.info('Number of measurer CPUs not passed as argument. using %d', + measurers_cpus) + with multiprocessing.Pool() as pool, multiprocessing.Manager() as manager: + logger.info('Setting up coverage binaries') + set_up_coverage_binaries(pool, experiment) + request_queue = manager.Queue() + response_queue = manager.Queue() + + config = { + 'request_queue': request_queue, + 'response_queue': response_queue, + 'region_coverage': region_coverage, + } + local_measure_worker = measure_worker.LocalMeasureWorker(config) + + # Since each worker is going to be in an infinite loop, we dont need + # result return. Workers' life scope will end automatically when there + # are no more snapshots left to measure. + logger.info('Starting measure worker loop for %d workers', + measurers_cpus) + for _ in range(measurers_cpus): + _result = pool.apply_async(local_measure_worker.measure_worker_loop) + + max_cycle = _time_to_cycle(max_total_time) + queued_snapshots = set() + while not scheduler.all_trials_ended(experiment): + continue_inner_loop = measure_manager_inner_loop( + experiment, max_cycle, request_queue, response_queue, + queued_snapshots) + if not continue_inner_loop: + break + time.sleep(MEASUREMENT_LOOP_WAIT) + logger.info('All trials ended. Ending measure manager loop') + + def main(): """Measure the experiment.""" initialize_logs() diff --git a/experiment/measurer/measure_worker.py b/experiment/measurer/measure_worker.py new file mode 100644 index 000000000..cfa033d06 --- /dev/null +++ b/experiment/measurer/measure_worker.py @@ -0,0 +1,88 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for measurer workers logic.""" +import time +from typing import Dict, Optional +from common import logs +from database.models import Snapshot +import experiment.measurer.datatypes as measurer_datatypes +from experiment.measurer import measure_manager + +MEASUREMENT_TIMEOUT = 1 +logger = logs.Logger() # pylint: disable=invalid-name + + +class BaseMeasureWorker: + """Base class for measure worker. Encapsulates core methods that will be + implemented for Local and Google Cloud measure workers.""" + + def __init__(self, config: Dict): + self.request_queue = config['request_queue'] + self.response_queue = config['response_queue'] + self.region_coverage = config['region_coverage'] + + def get_task_from_request_queue(self): + """"Get task from request queue""" + raise NotImplementedError + + def put_result_in_response_queue(self, measured_snapshot, request): + """Save measurement result in response queue, for the measure manager to + retrieve""" + raise NotImplementedError + + def measure_worker_loop(self): + """Periodically retrieves request from request queue, measure it, and + put result in response queue""" + logs.initialize(default_extras={ + 'component': 'measurer', + 'subcomponent': 'worker', + }) + logger.info('Starting one measure worker loop') + while True: + # 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', + # 'cycle'] + request = self.get_task_from_request_queue() + logger.info( + 'Measurer worker: Got request %s %s %d %d from request queue', + request.fuzzer, request.benchmark, request.trial_id, + request.cycle) + measured_snapshot = measure_manager.measure_snapshot_coverage( + request.fuzzer, request.benchmark, request.trial_id, + request.cycle, self.region_coverage) + self.put_result_in_response_queue(measured_snapshot, request) + time.sleep(MEASUREMENT_TIMEOUT) + + +class LocalMeasureWorker(BaseMeasureWorker): + """Class that holds implementations of core methods for running a measure + worker locally.""" + + def get_task_from_request_queue( + self) -> measurer_datatypes.SnapshotMeasureRequest: + """Get item from request multiprocessing queue, block if necessary until + an item is available""" + request = self.request_queue.get(block=True) + return request + + def put_result_in_response_queue( + self, measured_snapshot: Optional[Snapshot], + request: measurer_datatypes.SnapshotMeasureRequest): + if measured_snapshot: + logger.info('Put measured snapshot in response_queue') + self.response_queue.put(measured_snapshot) + else: + retry_request = measurer_datatypes.RetryRequest( + request.fuzzer, request.benchmark, request.trial_id, + request.cycle) + self.response_queue.put(retry_request) diff --git a/experiment/measurer/test_measure_manager.py b/experiment/measurer/test_measure_manager.py index 69e6400a6..7b6521869 100644 --- a/experiment/measurer/test_measure_manager.py +++ b/experiment/measurer/test_measure_manager.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for measure_manager.py.""" - import os import shutil from unittest import mock @@ -27,6 +26,7 @@ from experiment.build import build_utils from experiment.measurer import measure_manager from test_libs import utils as test_utils +import experiment.measurer.datatypes as measurer_datatypes TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), 'test_data') @@ -174,7 +174,7 @@ def test_measure_trial_coverage(mocked_measure_snapshot_coverage, mocked_queue, """Tests that measure_trial_coverage works as expected.""" min_cycle = 1 max_cycle = 10 - measure_request = measure_manager.SnapshotMeasureRequest( + measure_request = measurer_datatypes.SnapshotMeasureRequest( FUZZER, BENCHMARK, TRIAL_NUM, min_cycle) measure_manager.measure_trial_coverage(measure_request, max_cycle, mocked_queue(), False) @@ -409,3 +409,123 @@ def test_path_exists_in_experiment_filestore(mocked_execute, environ): mocked_execute.assert_called_with( ['gsutil', 'ls', 'gs://cloud-bucket/example-experiment'], expect_zero=False) + + +def test_consume_unmapped_type_from_response_queue(): + """Tests the scenario where an unmapped type is retrieved from the response + queue. This scenario is not expected to happen, so in this case no snapshots + are returned.""" + # Use normal queue here as multiprocessing queue gives flaky tests. + response_queue = queue.Queue() + response_queue.put('unexpected string') + snapshots = measure_manager.consume_snapshots_from_response_queue( + response_queue, set()) + assert not snapshots + + +def test_consume_retry_type_from_response_queue(): + """Tests the scenario where a retry object is retrieved from the + response queue. In this scenario, we want to remove the snapshot identifier + from the queued_snapshots set, as this allows the measurement task to be + retried in the future.""" + # Use normal queue here as multiprocessing queue gives flaky tests. + response_queue = queue.Queue() + retry_request_object = measurer_datatypes.RetryRequest( + 'fuzzer', 'benchmark', TRIAL_NUM, CYCLE) + snapshot_identifier = (TRIAL_NUM, CYCLE) + response_queue.put(retry_request_object) + queued_snapshots_set = set([snapshot_identifier]) + snapshots = measure_manager.consume_snapshots_from_response_queue( + response_queue, queued_snapshots_set) + assert not snapshots + assert len(queued_snapshots_set) == 0 + + +def test_consume_snapshot_type_from_response_queue(): + """Tests the scenario where a measured snapshot is retrieved from the + response queue. In this scenario, we want to return the snapshot in the + function.""" + # Use normal queue here as multiprocessing queue gives flaky tests. + response_queue = queue.Queue() + snapshot_identifier = (TRIAL_NUM, CYCLE) + queued_snapshots_set = set([snapshot_identifier]) + measured_snapshot = models.Snapshot(trial_id=TRIAL_NUM) + response_queue.put(measured_snapshot) + assert response_queue.qsize() == 1 + snapshots = measure_manager.consume_snapshots_from_response_queue( + response_queue, queued_snapshots_set) + assert len(snapshots) == 1 + + +@mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') +def test_measure_manager_inner_loop_break_condition( + mocked_get_unmeasured_snapshots): + """Tests that the measure manager inner loop returns False when there's no + more snapshots left to be measured.""" + # Empty list means no more snapshots left to be measured. + mocked_get_unmeasured_snapshots.return_value = [] + request_queue = queue.Queue() + response_queue = queue.Queue() + continue_inner_loop = measure_manager.measure_manager_inner_loop( + 'experiment', 1, request_queue, response_queue, set()) + assert not continue_inner_loop + + +@mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') +@mock.patch( + 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') +def test_measure_manager_inner_loop_writes_to_request_queue( + mocked_consume_snapshots_from_response_queue, + mocked_get_unmeasured_snapshots): + """Tests that the measure manager inner loop is writing measurement tasks to + request queue.""" + mocked_get_unmeasured_snapshots.return_value = [ + measurer_datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 0, 0) + ] + mocked_consume_snapshots_from_response_queue.return_value = [] + request_queue = queue.Queue() + response_queue = queue.Queue() + measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, + response_queue, set()) + assert request_queue.qsize() == 1 + + +@mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') +@mock.patch( + 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') +@mock.patch('database.utils.add_all') +def test_measure_manager_inner_loop_dont_write_to_db( + mocked_add_all, mocked_consume_snapshots_from_response_queue, + mocked_get_unmeasured_snapshots): + """Tests that the measure manager inner loop does not call add_all to write + to the database, when there are no measured snapshots to be written.""" + mocked_get_unmeasured_snapshots.return_value = [ + measurer_datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 0, 0) + ] + request_queue = queue.Queue() + response_queue = queue.Queue() + mocked_consume_snapshots_from_response_queue.return_value = [] + measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, + response_queue, set()) + mocked_add_all.not_called() + + +@mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') +@mock.patch( + 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') +@mock.patch('database.utils.add_all') +def test_measure_manager_inner_loop_writes_to_db( + mocked_add_all, mocked_consume_snapshots_from_response_queue, + mocked_get_unmeasured_snapshots): + """Tests that the measure manager inner loop calls add_all to write + to the database, when there are measured snapshots to be written.""" + mocked_get_unmeasured_snapshots.return_value = [ + measurer_datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 0, 0) + ] + request_queue = queue.Queue() + response_queue = queue.Queue() + snapshot_model = models.Snapshot(trial_id=1) + mocked_consume_snapshots_from_response_queue.return_value = [snapshot_model] + measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, + response_queue, set()) + mocked_add_all.assert_called_with([snapshot_model]) diff --git a/experiment/measurer/test_measure_worker.py b/experiment/measurer/test_measure_worker.py new file mode 100644 index 000000000..4e5bd7b05 --- /dev/null +++ b/experiment/measurer/test_measure_worker.py @@ -0,0 +1,57 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for measure_worker.py.""" +import multiprocessing +import pytest + +from database.models import Snapshot +from experiment.measurer import measure_worker +import experiment.measurer.datatypes as measurer_datatypes + + +@pytest.fixture +def local_measure_worker(): + """Fixture for instantiating a local measure worker object""" + request_queue = multiprocessing.Queue() + response_queue = multiprocessing.Queue() + region_coverage = False + config = { + 'request_queue': request_queue, + 'response_queue': response_queue, + 'region_coverage': region_coverage + } + return measure_worker.LocalMeasureWorker(config) + + +def test_put_snapshot_in_response_queue(local_measure_worker): # pylint: disable=redefined-outer-name + """Tests the scenario where measure_snapshot is not None, so snapshot is put + in response_queue""" + request = measurer_datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', + 1, 0) + snapshot = Snapshot(trial_id=1) + local_measure_worker.put_result_in_response_queue(snapshot, request) + response_queue = local_measure_worker.response_queue + assert response_queue.qsize() == 1 + assert isinstance(response_queue.get(), Snapshot) + + +def test_put_retry_in_response_queue(local_measure_worker): # pylint: disable=redefined-outer-name + """Tests the scenario where measure_snapshot is None, so task needs to be + retried""" + request = measurer_datatypes.RetryRequest('fuzzer', 'benchmark', 1, 0) + snapshot = None + local_measure_worker.put_result_in_response_queue(snapshot, request) + response_queue = local_measure_worker.response_queue + assert response_queue.qsize() == 1 + assert isinstance(response_queue.get(), measurer_datatypes.RetryRequest)