diff --git a/conf/default/cuckoo.conf.default b/conf/default/cuckoo.conf.default index f8e26b023cb..4ea6cac1dd4 100644 --- a/conf/default/cuckoo.conf.default +++ b/conf/default/cuckoo.conf.default @@ -32,11 +32,6 @@ machinery_screenshots = off scaling_semaphore = off # A configurable wait time between updating the limit value of the scaling bounded semaphore scaling_semaphore_update_timer = 10 -# Allow more than one task scheduled to be assigned at once for better scaling -# A switch to allow batch task assignment, a method that can more efficiently assign tasks to available machines -batch_scheduling = off -# The maximum number of tasks assigned to machines per batch, optimal value dependent on deployment -max_batch_count = 20 # Enable creation of memory dump of the analysis machine before shutting # down. Even if turned off, this functionality can also be enabled at diff --git a/cuckoo.py b/cuckoo.py index d649805ca2c..5ecc4017d6b 100644 --- a/cuckoo.py +++ b/cuckoo.py @@ -9,6 +9,8 @@ import sys from pathlib import Path +from lib.cuckoo.core.database import Database, init_database + if sys.version_info[:2] < (3, 8): sys.exit("You are running an incompatible version of Python, please use >= 3.8") @@ -56,6 +58,7 @@ def cuckoo_init(quiet=False, debug=False, artwork=False, test=False): check_working_directory() check_configs() create_structure() + init_database() if artwork: import time @@ -78,7 +81,8 @@ def cuckoo_init(quiet=False, debug=False, artwork=False, test=False): check_webgui_mongo() init_modules() - init_tasks() + with Database().session.begin(): + init_tasks() init_rooter() init_routing() check_tcpdump_permissions() @@ -113,7 +117,7 @@ def cuckoo_main(max_analysis_count=0): parser.add_argument("-v", "--version", action="version", version="You are running Cuckoo Sandbox {0}".format(CUCKOO_VERSION)) parser.add_argument("-a", "--artwork", help="Show artwork", action="store_true", required=False) parser.add_argument("-t", "--test", help="Test startup", action="store_true", required=False) - parser.add_argument("-m", "--max-analysis-count", help="Maximum number of analyses", type=int, required=False) + parser.add_argument("-m", "--max-analysis-count", help="Maximum number of analyses", type=int, required=False, default=0) parser.add_argument( "-s", "--stop", diff --git a/lib/cuckoo/common/abstracts.py b/lib/cuckoo/common/abstracts.py index 0e3a4dee93a..0638473fe10 100644 --- a/lib/cuckoo/common/abstracts.py +++ b/lib/cuckoo/common/abstracts.py @@ -4,7 +4,6 @@ # See the file 'docs/LICENSE' for copying permission. import datetime -import inspect import io import logging import os @@ -38,7 +37,7 @@ from lib.cuckoo.common.path_utils import path_exists, path_mkdir from lib.cuckoo.common.url_validate import url as url_validator from lib.cuckoo.common.utils import create_folder, get_memdump_path, load_categories -from lib.cuckoo.core.database import Database +from lib.cuckoo.core.database import Database, Machine, _Database try: import re2 as re @@ -107,42 +106,44 @@ class Machinery: # Default label used in machinery configuration file to supply virtual # machine name/label/vmx path. Override it if you dubbed it in another # way. - LABEL = "label" + LABEL: str = "label" + + # This must be defined in sub-classes. + module_name: str def __init__(self): - self.module_name = "" self.options = None # Database pointer. - self.db = Database() - # Machine table is cleaned to be filled from configuration file - # at each start. - self.db.clean_machines() + self.db: _Database = Database() + self.set_options(self.read_config()) + + def read_config(self) -> None: + return Config(self.module_name) - def set_options(self, options: dict): + def set_options(self, options: dict) -> None: """Set machine manager options. @param options: machine manager options dict. """ self.options = options + mmanager_opts = self.options.get(self.module_name) + if not isinstance(mmanager_opts["machines"], list): + mmanager_opts["machines"] = str(mmanager_opts["machines"]).strip().split(",") + + def initialize(self) -> None: + """Read, load, and verify machines configuration.""" + # Machine table is cleaned to be filled from configuration file + # at each start. + self.db.clean_machines() - def initialize(self, module_name): - """Read, load, and verify machines configuration. - @param module_name: module name. - """ # Load. - self._initialize(module_name) + self._initialize() # Run initialization checks. self._initialize_check() - def _initialize(self, module_name): - """Read configuration. - @param module_name: module name. - """ - self.module_name = module_name - mmanager_opts = self.options.get(module_name) - if not isinstance(mmanager_opts["machines"], list): - mmanager_opts["machines"] = str(mmanager_opts["machines"]).strip().split(",") - + def _initialize(self) -> None: + """Read configuration.""" + mmanager_opts = self.options.get(self.module_name) for machine_id in mmanager_opts["machines"]: try: machine_opts = self.options.get(machine_id.strip()) @@ -198,7 +199,7 @@ def _initialize(self, module_name): log.warning("Configuration details about machine %s are missing: %s", machine_id.strip(), e) continue - def _initialize_check(self): + def _initialize_check(self) -> None: """Runs checks against virtualization software when a machine manager is initialized. @note: in machine manager modules you may override or superclass his method. @raise CuckooMachineError: if a misconfiguration or a unkown vm state is found. @@ -208,20 +209,24 @@ def _initialize_check(self): except NotImplementedError: return + self.shutdown_running_machines(configured_vms) + self.check_screenshot_support() + + if not cfg.timeouts.vm_state: + raise CuckooCriticalError("Virtual machine state change timeout setting not found, please add it to the config file") + + def check_screenshot_support(self) -> None: # If machinery_screenshots are enabled, check the machinery supports it. - if cfg.cuckoo.machinery_screenshots: - # inspect function members available on the machinery class - cls_members = inspect.getmembers(self.__class__, predicate=inspect.isfunction) - for name, function in cls_members: - if name != Machinery.screenshot.__name__: - continue - if Machinery.screenshot == function: - msg = f"machinery {self.module_name} does not support machinery screenshots" - raise CuckooCriticalError(msg) - break - else: - raise NotImplementedError(f"missing machinery method: {Machinery.screenshot.__name__}") + if not cfg.cuckoo.machinery_screenshots: + return + + # inspect function members available on the machinery class + func = getattr(self.__class__, "screenshot") + if func == Machinery.screenshot: + msg = f"machinery {self.module_name} does not support machinery screenshots" + raise CuckooCriticalError(msg) + def shutdown_running_machines(self, configured_vms: List[str]) -> None: for machine in self.machines(): # If this machine is already in the "correct" state, then we # go on to the next machine. @@ -236,16 +241,13 @@ def _initialize_check(self): msg = f"Please update your configuration. Unable to shut '{machine.label}' down or find the machine in its proper state: {e}" raise CuckooCriticalError(msg) from e - if not cfg.timeouts.vm_state: - raise CuckooCriticalError("Virtual machine state change timeout setting not found, please add it to the config file") - def machines(self): """List virtual machines. @return: virtual machines list """ return self.db.list_machines(include_reserved=True) - def availables(self, label=None, platform=None, tags=None, arch=None, include_reserved=False, os_version=[]): + def availables(self, label=None, platform=None, tags=None, arch=None, include_reserved=False, os_version=None): """How many (relevant) machines are free. @param label: machine ID. @param platform: machine platform. @@ -257,32 +259,15 @@ def availables(self, label=None, platform=None, tags=None, arch=None, include_re label=label, platform=platform, tags=tags, arch=arch, include_reserved=include_reserved, os_version=os_version ) - def acquire(self, machine_id=None, platform=None, tags=None, arch=None, os_version=[], need_scheduled=False): - """Acquire a machine to start analysis. - @param machine_id: machine ID. - @param platform: machine platform. - @param tags: machine tags - @param arch: machine arch - @param os_version: tags to filter per OS version. Ex: winxp, win7, win10, win11 - @param need_scheduled: should the result be filtered on 'scheduled' machine status - @return: machine or None. - """ - if machine_id: - return self.db.lock_machine(label=machine_id, need_scheduled=need_scheduled) - elif platform: - return self.db.lock_machine( - platform=platform, tags=tags, arch=arch, os_version=os_version, need_scheduled=need_scheduled - ) - return self.db.lock_machine(tags=tags, arch=arch, os_version=os_version, need_scheduled=need_scheduled) - - def get_machines_scheduled(self): - return self.db.get_machines_scheduled() + def scale_pool(self, machine: Machine) -> None: + """This can be overridden in sub-classes to scale the pool of machines once one has been acquired.""" + return - def release(self, label=None): + def release(self, machine: Machine) -> Machine: """Release a machine. @param label: machine name. """ - self.db.unlock_machine(label) + return self.db.unlock_machine(machine) def running(self): """Returns running virtual machines. @@ -290,6 +275,9 @@ def running(self): """ return self.db.list_machines(locked=True) + def running_count(self): + return self.db.count_machines_running() + def screenshot(self, label, path): """Screenshot a running virtual machine. @param label: machine name @@ -302,9 +290,10 @@ def shutdown(self): """Shutdown the machine manager. Kills all alive machines. @raise CuckooMachineError: if unable to stop machine. """ - if len(self.running()) > 0: - log.info("Still %d guests still alive, shutting down...", len(self.running())) - for machine in self.running(): + running = self.running() + if len(running) > 0: + log.info("Still %d guests still alive, shutting down...", len(running)) + for machine in running: try: self.stop(machine.label) except CuckooMachineError as e: @@ -389,23 +378,12 @@ class LibVirtMachinery(Machinery): ABORTED = "abort" def __init__(self): - - if not categories_need_VM: - return - if not HAVE_LIBVIRT: raise CuckooDependencyError( "Unable to import libvirt. Ensure that you properly installed it by running: cd /opt/CAPEv2/ ; sudo -u cape poetry run extra/libvirt_installer.sh" ) - super(LibVirtMachinery, self).__init__() - - def initialize(self, module): - """Initialize machine manager module. Override default to set proper - connection string. - @param module: machine manager module - """ - super(LibVirtMachinery, self).initialize(module) + super().__init__() def _initialize_check(self): """Runs all checks when a machine manager is initialized. @@ -420,7 +398,7 @@ def _initialize_check(self): # Base checks. Also attempts to shutdown any machines which are # currently still active. - super(LibVirtMachinery, self)._initialize_check() + super()._initialize_check() def start(self, label): """Starts a virtual machine. @@ -429,14 +407,17 @@ def start(self, label): """ log.debug("Starting machine %s", label) + vm_info = self.db.view_machine_by_label(label) + if vm_info is None: + msg = f"Unable to find machine with label {label} in database." + raise CuckooMachineError(msg) + if self._status(label) != self.POWEROFF: msg = f"Trying to start a virtual machine that has not been turned off {label}" raise CuckooMachineError(msg) conn = self._connect(label) - vm_info = self.db.view_machine_by_label(label) - snapshot_list = self.vms[label].snapshotListNames(flags=0) # If a snapshot is configured try to use it. diff --git a/lib/cuckoo/common/config.py b/lib/cuckoo/common/config.py index a39f0f908aa..d60a46d0871 100644 --- a/lib/cuckoo/common/config.py +++ b/lib/cuckoo/common/config.py @@ -28,6 +28,13 @@ def parse_options(options: str) -> Dict[str, str]: class _BaseConfig: """Configuration file parser.""" + def __init__(self): + self.fullconfig = None + self.refresh() + + def _get_files_to_read(self): + raise NotImplementedError + def get(self, section): """Get options for the given section. @param section: section to fetch. @@ -42,6 +49,9 @@ def get(self, section): def get_config(self): return self.fullconfig + def refresh(self): + self.fullconfig = self._read_files(self._get_files_to_read())._sections + def _read_files(self, files: Iterable[str]): # Escape the percent signs so that ConfigParser doesn't try to do # interpolation of the value as well. @@ -60,8 +70,6 @@ def _read_files(self, files: Iterable[str]): ) raise - self.fullconfig = config._sections - for section in config.sections(): dct = Dictionary() for name, value in config.items(section): @@ -83,6 +91,8 @@ def _read_files(self, files: Iterable[str]): setattr(dct, name, value) setattr(self, section, dct) + return config + class ConfigMeta(type): """Only create one instance of a Config for each (non-analysis) config file.""" @@ -96,31 +106,35 @@ def __call__(self, fname_base: str = "cuckoo"): return self.configs[fname_base] @classmethod - def reset(cls): + def refresh(cls): """This should really only be needed for testing.""" - cls.configs.clear() + for config in cls.configs.values(): + config.refresh() class Config(_BaseConfig, metaclass=ConfigMeta): def __init__(self, fname_base: str = "cuckoo"): - files = self._get_files_to_read(fname_base) - self._read_files(files) + self._fname_base = fname_base + super().__init__() - def _get_files_to_read(self, fname_base): + def _get_files_to_read(self): # Allows test workflows to ignore custom root configs include_root_configs = "CAPE_DISABLE_ROOT_CONFIGS" not in os.environ - files = [os.path.join(CUCKOO_ROOT, "conf", "default", f"{fname_base}.conf.default")] + files = [os.path.join(CUCKOO_ROOT, "conf", "default", f"{self._fname_base}.conf.default")] if include_root_configs: - files.append(os.path.join(CUCKOO_ROOT, "conf", f"{fname_base}.conf")) - files.extend(sorted(glob.glob(os.path.join(CUCKOO_ROOT, "conf", f"{fname_base}.conf.d", "*.conf")))) - files.append(os.path.join(CUSTOM_CONF_DIR, f"{fname_base}.conf")) - files.extend(sorted(glob.glob(os.path.join(CUSTOM_CONF_DIR, f"{fname_base}.conf.d", "*.conf")))) + files.append(os.path.join(CUCKOO_ROOT, "conf", f"{self._fname_base}.conf")) + files.extend(sorted(glob.glob(os.path.join(CUCKOO_ROOT, "conf", f"{self._fname_base}.conf.d", "*.conf")))) + files.append(os.path.join(CUSTOM_CONF_DIR, f"{self._fname_base}.conf")) + files.extend(sorted(glob.glob(os.path.join(CUSTOM_CONF_DIR, f"{self._fname_base}.conf.d", "*.conf")))) if not files: - raise CuckooCriticalError(f"No {fname_base} config files could be found!") + raise CuckooCriticalError(f"No {self._fname_base} config files could be found!") return files class AnalysisConfig(_BaseConfig): def __init__(self, cfg="analysis.conf"): - files = (cfg,) - self._read_files(files) + self._cfg = cfg + super().__init__() + + def _get_files_to_read(self): + return (self._cfg,) diff --git a/lib/cuckoo/common/exceptions.py b/lib/cuckoo/common/exceptions.py index 57986b682bb..1e04be88912 100644 --- a/lib/cuckoo/common/exceptions.py +++ b/lib/cuckoo/common/exceptions.py @@ -15,6 +15,11 @@ class CuckooStartupError(CuckooCriticalError): pass +class CuckooDatabaseInitializationError(CuckooCriticalError): + def __str__(self): + return "The database has not been initialized yet. You must call init_database before attempting to use it." + + class CuckooDatabaseError(CuckooCriticalError): """Cuckoo database error.""" @@ -33,6 +38,12 @@ class CuckooOperationalError(Exception): pass +class CuckooUnserviceableTaskError(CuckooOperationalError): + """There are no machines in the pool that can service the task.""" + + pass + + class CuckooMachineError(CuckooOperationalError): """Error managing analysis machine.""" diff --git a/lib/cuckoo/common/utils.py b/lib/cuckoo/common/utils.py index a3545ccded8..4945719e886 100644 --- a/lib/cuckoo/common/utils.py +++ b/lib/cuckoo/common/utils.py @@ -5,9 +5,7 @@ import contextlib import errno import fcntl -import inspect import logging -import multiprocessing import os import random import shutil @@ -16,13 +14,12 @@ import struct import sys import tempfile -import threading import time import xmlrpc.client import zipfile from datetime import datetime from io import BytesIO -from typing import Tuple, Union +from typing import Final, List, Tuple, Union from data.family_detection_names import family_detection_names from lib.cuckoo.common import utils_dicts @@ -92,10 +89,12 @@ def arg_name_clscontext(arg_val): sanitize_len = config.cuckoo.get("sanitize_len", 32) sanitize_to_len = config.cuckoo.get("sanitize_to_len", 24) +CATEGORIES_NEEDING_VM: Final[Tuple[str]] = ("file", "url") -def load_categories(): + +def load_categories() -> Tuple[List[str], bool]: analyzing_categories = [category.strip() for category in config.cuckoo.categories.split(",")] - needs_VM = any([category in analyzing_categories for category in ("file", "url")]) + needs_VM = any(category in analyzing_categories for category in CATEGORIES_NEEDING_VM) return analyzing_categories, needs_VM @@ -126,9 +125,7 @@ def make_bytes(value: Union[str, bytes], encoding: str = "latin-1") -> bytes: def is_text_file(file_info, destination_folder, buf, file_data=False): - if any(file_type in file_info.get("type", "") for file_type in texttypes): - extracted_path = os.path.join( destination_folder, file_info.get( @@ -855,38 +852,6 @@ def default_converter(v): return v -def classlock(f): - """Classlock decorator (created for database.Database). - Used to put a lock to avoid sqlite errors. - """ - - def inner(self, *args, **kwargs): - curframe = inspect.currentframe() - calframe = inspect.getouterframes(curframe, 2) - - if calframe[1][1].endswith("database.py"): - return f(self, *args, **kwargs) - - with self._lock: - return f(self, *args, **kwargs) - - return inner - - -class SuperLock: - def __init__(self): - self.tlock = threading.Lock() - self.mlock = multiprocessing.Lock() - - def __enter__(self): - self.tlock.acquire() - self.mlock.acquire() - - def __exit__(self, type, value, traceback): - self.mlock.release() - self.tlock.release() - - def get_options(optstring: str): """Get analysis options. @return: options dict. diff --git a/lib/cuckoo/common/web_utils.py b/lib/cuckoo/common/web_utils.py index 0d3b042c4fc..a5376024cf3 100644 --- a/lib/cuckoo/common/web_utils.py +++ b/lib/cuckoo/common/web_utils.py @@ -5,12 +5,14 @@ import os import sys import tempfile +import threading import time from collections import OrderedDict from contextlib import suppress from datetime import datetime, timedelta from pathlib import Path from random import choice +from typing import Dict, List, Optional import magic import requests @@ -177,45 +179,55 @@ def my_rate_minutes(group, request): return rpm -def load_vms_exits(): - all_exits = {} - if HAVE_DIST and dist_conf.distributed.enabled: - try: - db = dist_session() - for node in db.query(Node).all(): - if hasattr(node, "exitnodes"): - for exit in node.exitnodes: - all_exits.setdefault(exit.name, []).append(node.name) - db.close() - except Exception as e: - print(e) +_all_nodes_exits: Optional[Dict[str, List[str]]] = None +_load_vms_exits_lock = threading.Lock() - return all_exits +def load_vms_exits(force=False): + global _all_nodes_exits + with _load_vms_exits_lock: + if _all_nodes_exits is not None and not force: + return _all_nodes_exits + _all_nodes_exits = {} + if HAVE_DIST and dist_conf.distributed.enabled: + try: + db = dist_session() + for node in db.query(Node).all(): + if hasattr(node, "exitnodes"): + for exit in node.exitnodes: + _all_nodes_exits.setdefault(exit.name, []).append(node.name) + db.close() + except Exception as e: + print(e) -def load_vms_tags(): - all_tags = [] - if HAVE_DIST and dist_conf.distributed.enabled: - try: - db = dist_session() - for vm in db.query(Machine).all(): - all_tags += vm.tags - all_tags = sorted(filter(None, all_tags)) - db.close() - except Exception as e: - print(e) + return _all_nodes_exits - for machine in Database().list_machines(): - all_tags += [tag.name for tag in machine.tags if tag not in all_tags] - return list(set(all_tags)) +_all_vms_tags: Optional[List[str]] = None +_load_vms_tags_lock = threading.Lock() -all_nodes_exits = load_vms_exits() -all_nodes_exits_list = list(all_nodes_exits.keys()) +def load_vms_tags(force=False): + global _all_vms_tags + with _load_vms_tags_lock: + if _all_vms_tags is not None and not force: + return _all_vms_tags + all_tags = [] + if HAVE_DIST and dist_conf.distributed.enabled: + try: + db = dist_session() + for vm in db.query(Machine).all(): + all_tags += vm.tags + all_tags = sorted(filter(None, all_tags)) + db.close() + except Exception as e: + print(e) + + for machine in Database().list_machines(include_reserved=True): + all_tags += [tag.name for tag in machine.tags if tag not in all_tags] -all_vms_tags = load_vms_tags() -all_vms_tags_str = ",".join(all_vms_tags) + _all_vms_tags = list(sorted(set(all_tags))) + return _all_vms_tags def top_asn(date_since: datetime = False, results_limit: int = 20) -> dict: @@ -401,12 +413,14 @@ def statistics(s_days: int) -> dict: details[module_name.split(".")[-1]].setdefault(name, entry) top_samples = {} - session = db.Session() added_tasks = ( - session.query(Task).join(Sample, Task.sample_id == Sample.id).filter(Task.added_on.between(date_since, date_till)).all() + db.session.query(Task).join(Sample, Task.sample_id == Sample.id).filter(Task.added_on.between(date_since, date_till)).all() ) tasks = ( - session.query(Task).join(Sample, Task.sample_id == Sample.id).filter(Task.completed_on.between(date_since, date_till)).all() + db.session.query(Task) + .join(Sample, Task.sample_id == Sample.id) + .filter(Task.completed_on.between(date_since, date_till)) + .all() ) details["total"] = len(tasks) details["average"] = f"{round(details['total'] / s_days, 2):.2f}" @@ -475,7 +489,6 @@ def statistics(s_days: int) -> dict: details["detections"] = top_detections(date_since=date_since) details["asns"] = top_asn(date_since=date_since) - session.close() return details @@ -676,9 +689,10 @@ def download_file(**kwargs): route = socks5s_random if tags: + all_vms_tags = load_vms_tags() if not all([tag.strip() in all_vms_tags for tag in tags.split(",")]): return "error", { - "error": f"Check Tags help, you have introduced incorrect tag(s). Your tags: {tags} - Supported tags: {all_vms_tags_str}" + "error": f"Check Tags help, you have introduced incorrect tag(s). Your tags: {tags} - Supported tags: {','.join(all_vms_tags)}" } elif all([tag in tags for tag in ("x64", "x86")]): return "error", {"error": "Check Tags help, you have introduced x86 and x64 tags for the same task, choose only 1"} @@ -722,13 +736,14 @@ def download_file(**kwargs): return "error", {"error": f"Error writing {kwargs['service']} storing/download file to temporary path"} # Distribute task based on route support by worker - if route and route not in ("none", "None") and all_nodes_exits_list: + all_nodes_exits = load_vms_exits() + if route and route not in ("none", "None") and all_nodes_exits: parsed_options = get_options(kwargs["options"]) node = parsed_options.get("node") if node and node not in all_nodes_exits.get(route): return "error", {"error": f"Specified worker {node} doesn't support this route: {route}"} - elif route not in all_nodes_exits_list: + elif route not in all_nodes_exits: return "error", {"error": "Specified route doesn't exist on any worker"} if not node: diff --git a/lib/cuckoo/core/analysis_manager.py b/lib/cuckoo/core/analysis_manager.py new file mode 100644 index 00000000000..d7c951b57b5 --- /dev/null +++ b/lib/cuckoo/core/analysis_manager.py @@ -0,0 +1,673 @@ +import contextlib +import functools +import logging +import os +import queue +import shutil +import threading +from typing import Any, Callable, Generator, MutableMapping, Optional, Tuple + +from lib.cuckoo.common.config import Config +from lib.cuckoo.common.constants import CUCKOO_ROOT +from lib.cuckoo.common.exceptions import ( + CuckooCriticalError, + CuckooGuestCriticalTimeout, + CuckooGuestError, + CuckooMachineError, + CuckooOperationalError, +) +from lib.cuckoo.common.integrations.parse_pe import PortableExecutable +from lib.cuckoo.common.objects import File +from lib.cuckoo.common.path_utils import path_delete, path_exists, path_mkdir +from lib.cuckoo.common.utils import convert_to_printable, create_folder, free_space_monitor, get_memdump_path +from lib.cuckoo.core.database import TASK_COMPLETED, TASK_PENDING, TASK_RUNNING, Database, Guest, Machine, Task, _Database +from lib.cuckoo.core.guest import GuestManager +from lib.cuckoo.core.machinery_manager import MachineryManager +from lib.cuckoo.core.plugins import RunAuxiliary +from lib.cuckoo.core.resultserver import ResultServer +from lib.cuckoo.core.rooter import _load_socks5_operational, rooter, vpns + +log = logging.getLogger(__name__) + +# os.listdir('/sys/class/net/') +HAVE_NETWORKIFACES = False +try: + import psutil + + network_interfaces = list(psutil.net_if_addrs().keys()) + HAVE_NETWORKIFACES = True +except ImportError: + print("Missed dependency: pip3 install psutil") + +latest_symlink_lock = threading.Lock() + + +class CuckooDeadMachine(Exception): + """Exception thrown when a machine turns dead. + + When this exception has been thrown, the analysis task will start again, + and will try to use another machine, when available. + """ + + def __init__(self, machine_name: str): + super().__init__() + self.machine_name = machine_name + + def __str__(self) -> str: + return f"{self.machine_name} is dead!" + + +def main_thread_only(func): + # Since most methods of the AnalysisManager class will be called within a child + # thread, let's decorate ones that must only be called from the main thread so + # that it's easy to differentiate between them. + @functools.wraps(func) + def inner(*args, **kwargs): + if threading.current_thread() is not threading.main_thread(): + raise AssertionError(f"{func.__name__} must only be called from the main thread") + return func(*args, **kwargs) + + return inner + + +class AnalysisLogger(logging.LoggerAdapter): + """This class will be used by AnalysisManager so that all of its log entries + will include the task ID, without having to explicitly include it in the log message. + """ + + def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]: + task_id: Optional[int] = self.extra.get("task_id") if self.extra is not None else None + if task_id is not None: + msg = f"Task #{task_id}: {msg}" + return msg, kwargs + + +class AnalysisManager(threading.Thread): + """Analysis Manager. + + This class handles the full analysis process for a given task. It takes + care of selecting the analysis machine, preparing the configuration and + interacting with the guest agent and analyzer components to launch and + complete the analysis and store, process and report its results. + """ + + def __init__( + self, + task: Task, + *, + machine: Optional[Machine] = None, + machinery_manager: Optional[MachineryManager] = None, + error_queue: Optional[queue.Queue] = None, + done_callback: Optional[Callable[["AnalysisManager"], None]] = None, + ): + """@param task: task object containing the details for the analysis.""" + super().__init__(name=f"task-{task.id}", daemon=True) + self.db: _Database = Database() + self.task = task + self.log = AnalysisLogger(log, {"task_id": self.task.id}) + self.machine = machine + self.machinery_manager = machinery_manager + self.error_queue = error_queue + self.done_callback = done_callback + self.guest: Optional[Guest] = None + self.cfg = Config() + self.aux_cfg = Config("auxiliary") + self.storage = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(self.task.id)) + self.screenshot_path = os.path.join(self.storage, "shots") + self.num_screenshots = 0 + self.binary = "" + self.interface = None + self.rt_table = None + self.route = None + self.rooter_response = "" + self.reject_segments = None + self.reject_hostports = None + + @main_thread_only + def prepare_task_and_machine_to_start(self) -> None: + """If the task doesn't use a machine, just set its state to running. + Otherwise, update the task and machine in the database so that the + task is running, the machine is locked and assigned to the task, and + create a Guest row for the analysis. + """ + self.db.set_task_status(self.task, TASK_RUNNING) + if self.machine and self.machinery_manager: + self.db.assign_machine_to_task(self.task, self.machine) + self.db.lock_machine(self.machine) + self.guest = self.db.create_guest( + self.machine, + self.machinery_manager.machinery.__class__.__name__, + self.task, + ) + + def init_storage(self): + """Initialize analysis storage folder.""" + # If the analysis storage folder already exists, we need to abort the + # analysis or previous results will be overwritten and lost. + if path_exists(self.storage): + self.log.error("Analysis results folder already exists at path '%s', analysis aborted", self.storage) + return False + + # If we're not able to create the analysis storage folder, we have to + # abort the analysis. + try: + create_folder(folder=self.storage) + except CuckooOperationalError: + self.log.error("Unable to create analysis folder %s", self.storage) + return False + + return True + + def check_file(self, sha256): + """Checks the integrity of the file to be analyzed.""" + sample = self.db.view_sample(self.task.sample_id) + + if sample and sha256 != sample.sha256: + self.log.error("Target file has been modified after submission: '%s'", convert_to_printable(self.task.target)) + return False + + return True + + def store_file(self, sha256): + """Store a copy of the file being analyzed.""" + if not path_exists(self.task.target): + self.log.error( + "The file to analyze does not exist at path '%s', analysis aborted", convert_to_printable(self.task.target) + ) + return False + + binaries_dir = os.path.join(CUCKOO_ROOT, "storage", "binaries") + self.binary = os.path.join(binaries_dir, sha256) + + if path_exists(self.binary): + self.log.info("File already exists at '%s'", self.binary) + else: + path_mkdir(binaries_dir, exist_ok=True) + # TODO: do we really need to abort the analysis in case we are not able to store a copy of the file? + try: + shutil.copy(self.task.target, self.binary) + except (IOError, shutil.Error): + self.log.error( + "Unable to store file from '%s' to '%s', analysis aborted", + self.task.target, + self.binary, + ) + return False + + try: + new_binary_path = os.path.join(self.storage, "binary") + if hasattr(os, "symlink"): + os.symlink(self.binary, new_binary_path) + else: + shutil.copy(self.binary, new_binary_path) + except (AttributeError, OSError) as e: + self.log.error("Unable to create symlink/copy from '%s' to '%s': %s", self.binary, self.storage, e) + + return True + + def screenshot_machine(self): + if not self.cfg.cuckoo.machinery_screenshots: + return + if self.machinery_manager is None or self.machine is None: + self.log.error("screenshot not possible, no machine is used for this analysis") + return + + # same format and filename approach here as VM-based screenshots + self.num_screenshots += 1 + screenshot_filename = f"{str(self.num_screenshots).rjust(4, '0')}.jpg" + screenshot_path = os.path.join(self.screenshot_path, screenshot_filename) + try: + self.machinery_manager.machinery.screenshot(self.machine.label, screenshot_path) + except Exception as err: + self.log.warning("Failed to take screenshot of %s: %s", self.machine.label, err) + self.num_screenshots -= 1 + + def build_options(self): + """Generate analysis options. + @return: options dict. + """ + options = { + "id": self.task.id, + "ip": self.machine.resultserver_ip, + "port": self.machine.resultserver_port, + "category": self.task.category, + "target": self.task.target, + "package": self.task.package, + "options": self.get_machine_specific_options(self.task.options), + "enforce_timeout": self.task.enforce_timeout, + "clock": self.task.clock, + "terminate_processes": self.cfg.cuckoo.terminate_processes, + "upload_max_size": self.cfg.resultserver.upload_max_size, + "do_upload_max_size": int(self.cfg.resultserver.do_upload_max_size), + "enable_trim": int(Config("web").general.enable_trim), + "timeout": self.task.timeout or self.cfg.timeouts.default, + } + + if self.task.category == "file": + file_obj = File(self.task.target) + options["file_name"] = file_obj.get_name() + options["file_type"] = file_obj.get_type() + # if it's a PE file, collect export information to use in more smartly determining the right package to use + options["exports"] = PortableExecutable(self.task.target).get_dll_exports() + del file_obj + + # options from auxiliary.conf + for plugin in self.aux_cfg.auxiliary_modules.keys(): + options[plugin] = self.aux_cfg.auxiliary_modules[plugin] + + return options + + def category_checks(self) -> Optional[bool]: + if self.task.category in ("file", "pcap", "static"): + sha256 = File(self.task.target).get_sha256() + # Check whether the file has been changed for some unknown reason. + # And fail this analysis if it has been modified. + if not self.check_file(sha256): + self.log.debug("check file") + return False + + # Store a copy of the original file. + if not self.store_file(sha256): + self.log.debug("store file") + return False + + if self.task.category in ("pcap", "static"): + if self.task.category == "pcap": + if hasattr(os, "symlink"): + os.symlink(self.binary, os.path.join(self.storage, "dump.pcap")) + else: + shutil.copy(self.binary, os.path.join(self.storage, "dump.pcap")) + # create the logs/files directories as + # normally the resultserver would do it + dirnames = ["logs", "files", "aux"] + for dirname in dirnames: + try: + path_mkdir(os.path.join(self.storage, dirname)) + except Exception: + self.log.debug("Failed to create folder %s", dirname) + return True + + return None + + @contextlib.contextmanager + def machine_running(self) -> Generator[None, None, None]: + assert self.machinery_manager and self.machine and self.guest + + try: + with self.db.session.begin(): + self.machinery_manager.start_machine(self.machine) + + yield + + # Take a memory dump of the machine before shutting it off. + self.dump_machine_memory() + + except (CuckooMachineError, CuckooGuestCriticalTimeout) as e: + # This machine has turned dead, so we'll throw an exception + # which informs the AnalysisManager that it should analyze + # this task again with another available machine. + self.log.exception(str(e)) + + # Remove the guest from the database, so that we can assign a + # new guest when the task is being analyzed with another machine. + with self.db.session.begin(): + self.db.guest_remove(self.guest.id) + self.db.assign_machine_to_task(self.task, None) + self.machinery_manager.machinery.delete_machine(self.machine.name) + + # Remove the analysis directory that has been created so + # far, as perform_analysis() is going to be doing that again. + shutil.rmtree(self.storage) + + raise CuckooDeadMachine(self.machine.name) from e + + with self.db.session.begin(): + try: + self.machinery_manager.stop_machine(self.machine) + except CuckooMachineError as e: + self.log.warning("Unable to stop machine %s: %s", self.machine.label, e) + # Explicitly rollback since we don't re-raise the exception. + self.db.session.rollback() + + try: + # Release the analysis machine, but only if the machine is not dead. + with self.db.session.begin(): + self.machinery_manager.machinery.release(self.machine) + except CuckooMachineError as e: + self.log.error( + "Unable to release machine %s, reason %s. You might need to restore it manually", + self.machine.label, + e, + ) + + def dump_machine_memory(self) -> None: + if not self.cfg.cuckoo.memory_dump and not self.task.memory: + return + + assert self.machinery_manager and self.machine + + try: + dump_path = get_memdump_path(self.task.id) + need_space, space_available = free_space_monitor(os.path.dirname(dump_path), return_value=True) + if need_space: + self.log.error("Not enough free disk space! Could not dump ram (Only %d MB!)", space_available) + else: + self.machinery_manager.machinery.dump_memory(self.machine.label, dump_path) + except NotImplementedError: + self.log.error("The memory dump functionality is not available for the current machine manager") + + except CuckooMachineError as e: + self.log.exception(str(e)) + + @contextlib.contextmanager + def result_server(self) -> Generator[None, None, None]: + try: + ResultServer().add_task(self.task, self.machine) + except Exception as e: + self.log.exception("Failed to add task to result-server") + if self.error_queue: + self.error_queue.put(e) + raise + try: + yield + finally: + # After all this, we can make the ResultServer forget about the + # internal state for this analysis task. + ResultServer().del_task(self.task, self.machine) + + @contextlib.contextmanager + def network_routing(self) -> Generator[None, None, None]: + self.route_network() + try: + yield + finally: + # Drop the network routing rules if any. + self.unroute_network() + + @contextlib.contextmanager + def run_auxiliary(self) -> Generator[None, None, None]: + aux = RunAuxiliary(task=self.task, machine=self.machine) + + with self.db.session.begin(): + aux.start() + + try: + yield + finally: + with self.db.session.begin(): + aux.stop() + + def run_analysis_on_guest(self) -> None: + # Generate the analysis configuration file. + options = self.build_options() + + guest_manager = GuestManager(self.machine.name, self.machine.ip, self.machine.platform, self.task.id, self) + + with self.db.session.begin(): + if Config("web").guacamole.enabled and hasattr(self.machinery, "store_vnc_port"): + self.machinery.store_vnc_port(self.machine.label, self.task.id) + options["clock"] = self.db.update_clock(self.task.id) + self.db.guest_set_status(self.task.id, "starting") + guest_manager.start_analysis(options) + if guest_manager.get_status_from_db() == "starting": + guest_manager.set_status_in_db("running") + guest_manager.wait_for_completion() + + guest_manager.set_status_in_db("stopping") + + return + + def perform_analysis(self) -> bool: + """Start analysis.""" + succeeded = False + self.socks5s = _load_socks5_operational() + + # Initialize the analysis folders. + if not self.init_storage(): + self.log.debug("Failed to initialize the analysis folder") + return False + + with self.db.session.begin(): + category_early_escape = self.category_checks() + if isinstance(category_early_escape, bool): + return category_early_escape + + # At this point, we're sure that this analysis requires a machine. + assert self.machinery_manager and self.machine and self.guest + + with self.db.session.begin(): + self.machinery_manager.scale_pool(self.machine) + + self.log.info("Starting analysis of %s '%s'", self.task.category.upper(), convert_to_printable(self.task.target)) + + with self.machine_running(), self.result_server(), self.network_routing(), self.run_auxiliary(): + try: + self.run_analysis_on_guest() + except CuckooGuestError as e: + self.log.exception(str(e)) + else: + succeeded = True + finally: + with self.db.session.begin(): + self.db.guest_stop(self.guest.id) + + return succeeded + + def launch_analysis(self) -> None: + success = False + try: + success = self.perform_analysis() + except CuckooDeadMachine: + with self.db.session.begin(): + # Put the task back in pending so that the schedule can attempt to + # choose a new machine. + self.db.set_status(self.task.id, TASK_PENDING) + raise + else: + with self.db.session.begin(): + self.db.set_status(self.task.id, TASK_COMPLETED) + self.log.info("Completed analysis %ssuccessfully.", "" if success else "un") + + self.update_latest_symlink() + + def update_latest_symlink(self): + # We make a symbolic link ("latest") which links to the latest + # analysis - this is useful for debugging purposes. This is only + # supported under systems that support symbolic links. + if not hasattr(os, "symlink"): + return + + latest = os.path.join(CUCKOO_ROOT, "storage", "analyses", "latest") + + # First we have to remove the existing symbolic link, then we have to create the new one. + # Deal with race conditions using a lock. + with latest_symlink_lock: + try: + # As per documentation, lexists() returns True for dead symbolic links. + if os.path.lexists(latest): + path_delete(latest) + + os.symlink(self.storage, latest) + except OSError as e: + self.log.warning("Error pointing latest analysis symlink: %s", e) + + def run(self): + """Run manager thread.""" + try: + self.launch_analysis() + except Exception: + self.log.exception("failure in AnalysisManager.run") + else: + self.log.info("analysis procedure completed") + finally: + if self.done_callback: + self.done_callback(self) + + def _rooter_response_check(self): + if self.rooter_response and self.rooter_response["exception"] is not None: + raise CuckooCriticalError(f"Error execution rooter command: {self.rooter_response['exception']}") + + def route_network(self): + """Enable network routing if desired.""" + # Determine the desired routing strategy (none, internet, VPN). + routing = Config("routing") + self.route = routing.routing.route + + if self.task.route: + self.route = self.task.route + + if self.route in ("none", "None", "drop", "false"): + self.interface = None + self.rt_table = None + elif self.route == "inetsim": + self.interface = routing.inetsim.interface + elif self.route == "tor": + self.interface = routing.tor.interface + elif self.route == "internet" and routing.routing.internet != "none": + self.interface = routing.routing.internet + self.rt_table = routing.routing.rt_table + if routing.routing.reject_segments != "none": + self.reject_segments = routing.routing.reject_segments + if routing.routing.reject_hostports != "none": + self.reject_hostports = str(routing.routing.reject_hostports) + elif self.route in vpns: + self.interface = vpns[self.route].interface + self.rt_table = vpns[self.route].rt_table + elif self.route in self.socks5s: + self.interface = "" + else: + self.log.warning("Unknown network routing destination specified, ignoring routing for this analysis: %s", self.route) + self.interface = None + self.rt_table = None + + # Check if the network interface is still available. If a VPN dies for + # some reason, its tunX interface will no longer be available. + if self.interface and not rooter("nic_available", self.interface): + self.log.error( + "The network interface '%s' configured for this analysis is " + "not available at the moment, switching to route=none mode", + self.interface, + ) + self.route = "none" + self.interface = None + self.rt_table = None + + if self.route == "inetsim": + self.rooter_response = rooter( + "inetsim_enable", + self.machine.ip, + str(routing.inetsim.server), + str(routing.inetsim.dnsport), + str(self.cfg.resultserver.port), + str(routing.inetsim.ports), + ) + + elif self.route == "tor": + self.rooter_response = rooter( + "socks5_enable", + self.machine.ip, + str(self.cfg.resultserver.port), + str(routing.tor.dnsport), + str(routing.tor.proxyport), + ) + + elif self.route in self.socks5s: + self.rooter_response = rooter( + "socks5_enable", + self.machine.ip, + str(self.cfg.resultserver.port), + str(self.socks5s[self.route]["dnsport"]), + str(self.socks5s[self.route]["port"]), + ) + + elif self.route in ("none", "None", "drop"): + self.rooter_response = rooter("drop_enable", self.machine.ip, str(self.cfg.resultserver.port)) + + self._rooter_response_check() + + # check if the interface is up + if HAVE_NETWORKIFACES and routing.routing.verify_interface and self.interface and self.interface not in network_interfaces: + self.log.info("Network interface {} not found, falling back to dropping network traffic", self.interface) + self.interface = None + self.rt_table = None + self.route = "drop" + + if self.interface: + self.rooter_response = rooter("forward_enable", self.machine.interface, self.interface, self.machine.ip) + self._rooter_response_check() + if self.reject_segments: + self.rooter_response = rooter( + "forward_reject_enable", self.machine.interface, self.interface, self.machine.ip, self.reject_segments + ) + self._rooter_response_check() + if self.reject_hostports: + self.rooter_response = rooter( + "hostports_reject_enable", self.machine.interface, self.machine.ip, self.reject_hostports + ) + self._rooter_response_check() + + self.log.info("Enabled route '%s'.", self.route) + + if self.rt_table: + self.rooter_response = rooter("srcroute_enable", self.rt_table, self.machine.ip) + self._rooter_response_check() + + def unroute_network(self): + routing = Config("routing") + if self.interface: + self.rooter_response = rooter("forward_disable", self.machine.interface, self.interface, self.machine.ip) + self._rooter_response_check() + if self.reject_segments: + self.rooter_response = rooter( + "forward_reject_disable", self.machine.interface, self.interface, self.machine.ip, self.reject_segments + ) + self._rooter_response_check() + if self.reject_hostports: + self.rooter_response = rooter( + "hostports_reject_disable", self.machine.interface, self.machine.ip, self.reject_hostports + ) + self._rooter_response_check() + self.log.info("Disabled route '%s'", self.route) + + if self.rt_table: + self.rooter_response = rooter("srcroute_disable", self.rt_table, self.machine.ip) + self._rooter_response_check() + + if self.route == "inetsim": + self.rooter_response = rooter( + "inetsim_disable", + self.machine.ip, + routing.inetsim.server, + str(routing.inetsim.dnsport), + str(self.cfg.resultserver.port), + str(routing.inetsim.ports), + ) + + elif self.route == "tor": + self.rooter_response = rooter( + "socks5_disable", + self.machine.ip, + str(self.cfg.resultserver.port), + str(routing.tor.dnsport), + str(routing.tor.proxyport), + ) + + elif self.route in self.socks5s: + self.rooter_response = rooter( + "socks5_disable", + self.machine.ip, + str(self.cfg.resultserver.port), + str(self.socks5s[self.route]["dnsport"]), + str(self.socks5s[self.route]["port"]), + ) + + elif self.route in ("none", "None", "drop"): + self.rooter_response = rooter("drop_disable", self.machine.ip, str(self.cfg.resultserver.port)) + + self._rooter_response_check() + + def set_machine_specific_options(self): + """This function may be used to update self.task.options based on the machine + that has been selected (self.machine). + """ + return diff --git a/lib/cuckoo/core/database.py b/lib/cuckoo/core/database.py index 0aaf4c2b8ca..c5b2b7c0fc5 100644 --- a/lib/cuckoo/core/database.py +++ b/lib/cuckoo/core/database.py @@ -12,6 +12,7 @@ import sys from contextlib import suppress from datetime import datetime, timedelta +from typing import Any, List, Optional, Union, cast # Sflock does a good filetype recon from sflock.abstracts import File as SflockFile @@ -22,11 +23,17 @@ from lib.cuckoo.common.config import Config from lib.cuckoo.common.constants import CUCKOO_ROOT from lib.cuckoo.common.demux import demux_sample -from lib.cuckoo.common.exceptions import CuckooDatabaseError, CuckooDependencyError, CuckooOperationalError +from lib.cuckoo.common.exceptions import ( + CuckooDatabaseError, + CuckooDatabaseInitializationError, + CuckooDependencyError, + CuckooOperationalError, + CuckooUnserviceableTaskError, +) from lib.cuckoo.common.integrations.parse_pe import PortableExecutable from lib.cuckoo.common.objects import PCAP, URL, File, Static from lib.cuckoo.common.path_utils import path_delete, path_exists -from lib.cuckoo.common.utils import Singleton, SuperLock, classlock, create_folder +from lib.cuckoo.common.utils import create_folder try: from sqlalchemy import ( @@ -44,14 +51,13 @@ event, func, not_, - or_, select, ) - from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError - from sqlalchemy.orm import backref, declarative_base, joinedload, relationship, sessionmaker + from sqlalchemy.exc import IntegrityError, SQLAlchemyError + from sqlalchemy.orm import Query, backref, declarative_base, joinedload, relationship, scoped_session, sessionmaker Base = declarative_base() -except ImportError: +except ImportError: # pragma: no cover raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry run pip install sqlalchemy`)") @@ -150,7 +156,6 @@ ) MACHINE_RUNNING = "running" -MACHINE_SCHEDULED = "scheduled" # Secondary table used in association Machine - Tag. machines_tags = Table( @@ -228,7 +233,7 @@ class Machine(Base): reserved = Column(Boolean(), nullable=False, default=False) def __repr__(self): - return f"" + return f"" def to_dict(self): """Converts object to dict. @@ -274,7 +279,7 @@ class Tag(Base): name = Column(String(255), nullable=False, unique=True) def __repr__(self): - return f"" + return f"" def __init__(self, name): self.name = name @@ -296,7 +301,7 @@ class Guest(Base): task_id = Column(Integer, ForeignKey("tasks.id"), nullable=False, unique=True) def __repr__(self): - return f"" + return f"" def to_dict(self): """Converts object to dict. @@ -317,11 +322,12 @@ def to_json(self): """ return json.dumps(self.to_dict()) - def __init__(self, name, label, platform, manager): + def __init__(self, name, label, platform, manager, task_id): self.name = name self.label = label self.platform = platform self.manager = manager + self.task_id = task_id class Sample(Base): @@ -347,7 +353,7 @@ class Sample(Base): ) def __repr__(self): - return f"" + return f"" def to_dict(self): """Converts object to dict. @@ -417,7 +423,7 @@ def __init__(self, message, task_id): self.task_id = task_id def __repr__(self): - return f"" + return f"" class Task(Base): @@ -538,7 +544,7 @@ def __init__(self, target=None): self.target = target def __repr__(self): - return f"" + return f"" class AlembicVersion(Base): @@ -549,7 +555,7 @@ class AlembicVersion(Base): version_num = Column(String(32), nullable=False, primary_key=True) -class Database(object, metaclass=Singleton): +class _Database: """Analysis queue database. This class handles the creation of the database user for internal queue @@ -560,7 +566,6 @@ def __init__(self, dsn=None, schema_check=True): """@param dsn: database connection string. @param schema_check: disable or enable the db schema version check """ - self._lock = SuperLock() self.cfg = conf if dsn: @@ -569,7 +574,7 @@ def __init__(self, dsn=None, schema_check=True): self._connect_database(self.cfg.database.connection) else: file_path = os.path.join(CUCKOO_ROOT, "db", "cuckoo.db") - if not path_exists(file_path): + if not path_exists(file_path): # pragma: no cover db_dir = os.path.dirname(file_path) if not path_exists(db_dir): try: @@ -589,39 +594,37 @@ def __init__(self, dsn=None, schema_check=True): # Create schema. try: Base.metadata.create_all(self.engine) - except SQLAlchemyError as e: + except SQLAlchemyError as e: # pragma: no cover raise CuckooDatabaseError(f"Unable to create or connect to database: {e}") # Get db session. - self.Session = sessionmaker(bind=self.engine) + self.session = scoped_session(sessionmaker(bind=self.engine, expire_on_commit=False)) - @event.listens_for(self.Session, "after_flush") + # There should be a better way to clean up orphans. This runs after every flush, which is crazy. + @event.listens_for(self.session, "after_flush") def delete_tag_orphans(session, ctx): session.query(Tag).filter(~Tag.tasks.any()).filter(~Tag.machines.any()).delete(synchronize_session=False) # Deal with schema versioning. # TODO: it's a little bit dirty, needs refactoring. - tmp_session = self.Session() - if not tmp_session.query(AlembicVersion).count(): - # Set database schema version. - tmp_session.add(AlembicVersion(version_num=SCHEMA_VERSION)) - try: - tmp_session.commit() - except SQLAlchemyError as e: - tmp_session.rollback() - raise CuckooDatabaseError(f"Unable to set schema version: {e}") - finally: - tmp_session.close() - else: - # Check if db version is the expected one. + with self.session() as tmp_session: last = tmp_session.query(AlembicVersion).first() - tmp_session.close() - if last.version_num != SCHEMA_VERSION and schema_check: - print( - f"DB schema version mismatch: found {last.version_num}, expected {SCHEMA_VERSION}. Try to apply all migrations" - ) - print(red("cd utils/db_migration/ && poetry run alembic upgrade head")) - sys.exit() + if last is None: + # Set database schema version. + tmp_session.add(AlembicVersion(version_num=SCHEMA_VERSION)) + try: + tmp_session.commit() + except SQLAlchemyError as e: # pragma: no cover + tmp_session.rollback() + raise CuckooDatabaseError(f"Unable to set schema version: {e}") + else: + # Check if db version is the expected one. + if last.version_num != SCHEMA_VERSION and schema_check: # pragma: no cover + print( + f"DB schema version mismatch: found {last.version_num}, expected {SCHEMA_VERSION}. Try to apply all migrations" + ) + print(red("cd utils/db_migration/ && poetry run alembic upgrade head")) + sys.exit() def __del__(self): """Disconnects pool.""" @@ -645,24 +648,24 @@ def _connect_database(self, connection_string): ) else: self.engine = create_engine(connection_string) - except ImportError as e: + except ImportError as e: # pragma: no cover lib = e.message.rsplit(maxsplit=1)[-1] raise CuckooDependencyError(f"Missing database driver, unable to import {lib} (install with `pip install {lib}`)") - def _get_or_create(self, session, model, **kwargs): + def _get_or_create(self, model, **kwargs): """Get an ORM instance or create it if not exist. @param session: SQLAlchemy session object @param model: model to query @return: row instance """ - instance = session.query(model).filter_by(**kwargs).first() + instance = self.session.query(model).filter_by(**kwargs).first() if instance: return instance else: instance = model(**kwargs) + self.session.add(instance) return instance - @classlock def drop(self): """Drop all tables.""" try: @@ -670,42 +673,29 @@ def drop(self): except SQLAlchemyError as e: raise CuckooDatabaseError(f"Unable to create or connect to database: {e}") - @classlock def clean_machines(self): """Clean old stored machines and related tables.""" # Secondary table. # TODO: this is better done via cascade delete. # self.engine.execute(machines_tags.delete()) - with self.Session() as session: - session.execute(machines_tags.delete()) - try: - session.query(Machine).delete() - session.commit() - except SQLAlchemyError as e: - log.debug("Database error cleaning machines: %s", e) - session.rollback() + self.session.execute(machines_tags.delete()) + self.session.query(Machine).delete() - @classlock def delete_machine(self, name) -> bool: """Delete a single machine entry from DB.""" - with self.Session() as session: - try: - machine = session.query(Machine).filter_by(name=name).first() - if machine: - session.delete(machine) - session.commit() - return True - else: - log.warning(f"{name} does not exist in the database.") - return False - except SQLAlchemyError as e: - log.debug("Database error deleting machine: %s", e) - session.rollback() - - @classlock - def add_machine(self, name, label, arch, ip, platform, tags, interface, snapshot, resultserver_ip, resultserver_port, reserved): + machine = self.session.query(Machine).filter_by(name=name).first() + if machine: + self.session.delete(machine) + return True + else: + log.warning(f"{name} does not exist in the database.") + return False + + def add_machine( + self, name, label, arch, ip, platform, tags, interface, snapshot, resultserver_ip, resultserver_port, reserved, locked=False + ) -> Machine: """Add a guest machine. @param name: machine id @param label: machine label @@ -719,139 +709,88 @@ def add_machine(self, name, label, arch, ip, platform, tags, interface, snapshot @param resultserver_port: port of the Result Server @param reserved: True if the machine can only be used when specifically requested """ - with self.Session() as session: - machine = Machine( - name=name, - label=label, - arch=arch, - ip=ip, - platform=platform, - interface=interface, - snapshot=snapshot, - resultserver_ip=resultserver_ip, - resultserver_port=resultserver_port, - reserved=reserved, - ) - # Deal with tags format (i.e., foo,bar,baz) - if tags: - for tag in tags.replace(" ", "").split(","): - machine.tags.append(self._get_or_create(session, Tag, name=tag)) - session.add(machine) - try: - session.commit() - except SQLAlchemyError as e: - print(e) - log.debug("Database error adding machine: %s", e) - session.rollback() + machine = Machine( + name=name, + label=label, + arch=arch, + ip=ip, + platform=platform, + interface=interface, + snapshot=snapshot, + resultserver_ip=resultserver_ip, + resultserver_port=resultserver_port, + reserved=reserved, + ) + # Deal with tags format (i.e., foo,bar,baz) + if tags: + for tag in tags.replace(" ", "").split(","): + machine.tags.append(self._get_or_create(Tag, name=tag)) + if locked: + machine.locked = True + self.session.add(machine) + return machine - @classlock def set_machine_interface(self, label, interface): - with self.Session() as session: - try: - machine = session.query(Machine).filter_by(label=label).first() - if machine is None: - log.debug("Database error setting interface: %s not found", label) - return - machine.interface = interface - session.commit() - - except SQLAlchemyError as e: - log.debug("Database error setting interface: %s", e) - session.rollback() - - @classlock - def set_vnc_port(self, task_id: int, port: int): - with self.Session() as session: - try: - task = session.query(Task).filter_by(id=task_id).first() - if task is None: - log.debug("Database error setting VPN port: For task %s", task_id) - return - if task.options: - task.options += f",vnc_port={port}" - else: - task.options = f"vnc_port={port}" - session.commit() + machine = self.session.query(Machine).filter_by(label=label).first() + if machine is None: + log.debug("Database error setting interface: %s not found", label) + return + machine.interface = interface + return machine - except SQLAlchemyError as e: - log.debug("Database error setting interface: %s", e) - session.rollback() + def set_vnc_port(self, task_id: int, port: int): + task = self.session.query(Task).filter_by(id=task_id).first() + if task is None: + log.debug("Database error setting VPN port: For task %s", task_id) + return + if task.options: + task.options += f",vnc_port={port}" + else: + task.options = f"vnc_port={port}" - @classlock def update_clock(self, task_id): - with self.Session() as session: - try: - row = session.get(Task, task_id) + row = self.session.get(Task, task_id) - if not row: - return + if not row: + return - if row.clock == datetime.utcfromtimestamp(0): - if row.category == "file": - row.clock = datetime.utcnow() + timedelta(days=self.cfg.cuckoo.daydelta) - else: - row.clock = datetime.utcnow() - session.commit() - return row.clock - except SQLAlchemyError as e: - log.debug("Database error setting clock: %s", e) - session.rollback() - - @classlock - def set_status(self, task_id, status): + if row.clock == datetime.utcfromtimestamp(0): + if row.category == "file": + row.clock = datetime.utcnow() + timedelta(days=self.cfg.cuckoo.daydelta) + else: + row.clock = datetime.utcnow() + return row.clock + + def set_task_status(self, task: Task, status) -> Task: + if status != TASK_DISTRIBUTED_COMPLETED: + task.status = status + + if status in (TASK_RUNNING, TASK_DISTRIBUTED): + task.started_on = datetime.now() + elif status in (TASK_COMPLETED, TASK_DISTRIBUTED_COMPLETED): + task.completed_on = datetime.now() + + self.session.add(task) + return task + + def set_status(self, task_id: int, status) -> Optional[Task]: """Set task status. @param task_id: task identifier @param status: status string @return: operation status """ - with self.Session() as session: - try: - row = session.get(Task, task_id) - - if not row: - return + task = self.session.get(Task, task_id) - if status != TASK_DISTRIBUTED_COMPLETED: - row.status = status - - if status in (TASK_RUNNING, TASK_DISTRIBUTED): - row.started_on = datetime.now() - elif status in (TASK_COMPLETED, TASK_DISTRIBUTED_COMPLETED): - row.completed_on = datetime.now() + if not task: + return None - session.commit() - except SQLAlchemyError as e: - log.debug("Database error setting status: %s", e) - session.rollback() + return self.set_task_status(task, status) - @classlock - def set_task_vm_and_guest_start(self, task_id, vmname, vmlabel, vmplatform, vm_id, manager): - """Set task status and logs guest start. - @param task_id: task identifier - @param vmname: virtual vm name - @param label: vm label - @param manager: vm manager - @return: guest row id - """ - with self.Session() as session: - guest = Guest(vmname, vmlabel, vmplatform, manager) - try: - guest.status = "init" - row = session.get(Task, task_id) - - if not row: - return - - row.guest = guest - row.machine = vmname - row.machine_id = vm_id - session.commit() - session.refresh(guest) - return guest.id - except SQLAlchemyError as e: - log.debug("Database error setting task vm and logging guest start: %s", e) - session.rollback() - return None + def create_guest(self, machine: Machine, manager: str, task: Task) -> Guest: + guest = Guest(machine.name, machine.label, machine.platform, manager, task.id) + guest.status = "init" + self.session.add(guest) + return guest def _package_vm_requires_check(self, package: str) -> list: """ @@ -866,204 +805,103 @@ def _task_arch_tags_helper(self, task: Task): return task_archs, task_tags - def validate_task_parameters(self, label: str, platform: str, tags: list) -> bool: - """Checks if a task is invalid based on parameters mismatch - @param label: label of the machine asked for by the task - @param platform: platform of the machine asked for by the task - @param tags: tags of task - @return: boolean indicating if a task is valid - """ - # Preventive checks. - if label and platform: - # Wrong usage. - return False - elif label and tags: - # Also wrong usage. - return False - return True - - @classlock - def is_relevant_machine_available(self, task: Task, set_status: bool = True) -> bool: - """Checks if a machine that is relevant to the given task is available - @param task: task to validate - @param set_status: boolean which indicate if the status of the task should be changed to TASK_RUNNING in the DB. - @return: boolean indicating if a relevant machine is available + def find_machine_to_service_task(self, task: Task) -> Optional[Machine]: + """Find a machine that is able to service the given task. + Returns: The Machine if an available machine was found; None if there is at least 1 machine + that *could* service it, but they are all currently in use. + Raises: CuckooUnserviceableTaskError if there are no machines in the pool that would be able + to service it. """ task_archs, task_tags = self._task_arch_tags_helper(task) os_version = self._package_vm_requires_check(task.package) - vms = self.list_machines( - locked=False, - label=task.machine, - platform=task.platform, - tags=task_tags, - arch=task_archs, - os_version=os_version, - include_scheduled=False, - ) - if len(vms) > 0: - # There are? Awesome! - if set_status: - self.set_status(task_id=task.id, status=TASK_RUNNING) - assigned = vms[0] # Take the first vm which could be assigned - self.set_machine_status(assigned.label, MACHINE_SCHEDULED) - return True - return False - @classlock - def map_tasks_to_available_machines(self, tasks: list) -> list: - """Map tasks to available_machines to schedule in batch and prevent double spending of machines - @param tasks: List of tasks to filter - @return: list of tasks that should be started by the scheduler - """ - results = [] - assigned_machines = [] - for task in tasks: - task_archs, task_tags = self._task_arch_tags_helper(task) - os_version = self._package_vm_requires_check(task.package) - machine = None - if not self.validate_task_parameters(label=task.machine, platform=task.platform, tags=task_tags): - continue - with self.Session() as session: - try: - machines = session.query(Machine).options(joinedload(Machine.tags)).filter_by(locked=False) - machines = self.filter_machines_to_task( - machines=machines, - label=task.machine, - platform=task.platform, - tags=task_tags, - archs=task_archs, - os_version=os_version, - ) - # This loop is there in order to prevent double spending of machines by filtering - # out already mapped machines - for assigned in assigned_machines: - machines = machines.filter(Machine.label.notlike(assigned.label)) - machines = machines.filter(or_(Machine.status.notlike(MACHINE_SCHEDULED), Machine.status == None)) # noqa: E711 - # Get the first free machine. - machine = machines.first() - if machine: - assigned_machines.append(machine) - self.set_status(task_id=task.id, status=TASK_RUNNING) - results.append(task) - except SQLAlchemyError as e: - log.debug("Database error batch scheduling machines: %s", e) - return [] - for assigned in assigned_machines: - self.set_machine_status(assigned.label, MACHINE_SCHEDULED) - return results - - @classlock - def is_serviceable(self, task: Task) -> bool: - """Checks if the task is serviceable. - - This method is useful when there are tasks that will never be serviced - by any of the machines available. This allows callers to decide what to - do when tasks like this are created. - - @return: boolean indicating if any machine could service the task in the future - """ - task_archs, task_tags = self._task_arch_tags_helper(task) - os_version = self._package_vm_requires_check(task.package) - vms = self.list_machines(label=task.machine, platform=task.platform, tags=task_tags, arch=task_archs, os_version=os_version) - if len(vms) > 0: - return True - return False + def get_first_machine(query: Query) -> Optional[Machine]: + # Select for update a machine, preferring one that is available and was the one that was used the + # longest time ago. This will give us a machine that can get locked or, if there are none that are + # currently available, we'll at least know that the task is serviceable. + return cast( + Optional[Machine], query.order_by(Machine.locked, Machine.locked_changed_on).with_for_update(of=Machine).first() + ) + + machines = self.session.query(Machine).options(joinedload(Machine.tags)) + filter_kwargs = { + "machines": machines, + "label": task.machine, + "platform": task.platform, + "tags": task_tags, + "archs": task_archs, + "os_version": os_version, + } + filtered_machines = self.filter_machines_to_task(include_reserved=False, **filter_kwargs) + machine = get_first_machine(filtered_machines) + if machine is None and not task.machine and task_tags: + # The task was given at least 1 tag, but there are no non-reserved machines + # that could satisfy the request. So let's see if there are any "reserved" + # machines that can satisfy it. + filtered_machines = self.filter_machines_to_task(include_reserved=True, **filter_kwargs) + machine = get_first_machine(filtered_machines) + + if machine is None: + raise CuckooUnserviceableTaskError + if machine.locked: + # There aren't any machines that can service the task NOW, but there is at least one in the pool + # that could service it once it's available. + return None + return machine - @classlock - def fetch_task(self, categories: list = []): + def fetch_task(self, categories: list = None): """Fetches a task waiting to be processed and locks it for running. @return: None or task """ - with self.Session() as session: - row = None - try: - row = ( - session.query(Task) - .filter_by(status=TASK_PENDING) - .order_by(Task.priority.desc(), Task.added_on) - # distributed cape - .filter(not_(Task.options.contains("node="))) - ) + row = ( + self.session.query(Task) + .filter_by(status=TASK_PENDING) + .order_by(Task.priority.desc(), Task.added_on) + # distributed cape + .filter(not_(Task.options.contains("node="))) + ) - if categories: - row = row.filter(Task.category.in_(categories)) - row = row.first() + if categories: + row = row.filter(Task.category.in_(categories)) + row = row.first() - if not row: - return None + if not row: + return None - self.set_status(task_id=row.id, status=TASK_RUNNING) - session.refresh(row) + self.set_status(task_id=row.id, status=TASK_RUNNING) - return row - except SQLAlchemyError as e: - log.debug("Database error fetching task: %s", e) - log.debug(red("Ensure that your database schema version is correct")) - session.rollback() + return row - @classlock def guest_get_status(self, task_id): """Log guest start. @param task_id: task id @return: guest status """ - with self.Session() as session: - try: - guest = session.query(Guest).filter_by(task_id=task_id).first() - return guest.status if guest else None - except SQLAlchemyError as e: - log.exception("Database error logging guest start: %s", e) - session.rollback() - return - - @classlock + guest = self.session.query(Guest).filter_by(task_id=task_id).first() + return guest.status if guest else None + def guest_set_status(self, task_id, status): """Log guest start. @param task_id: task identifier @param status: status """ - with self.Session() as session: - try: - guest = session.query(Guest).filter_by(task_id=task_id).first() - if guest is not None: - guest.status = status - session.commit() - session.refresh(guest) - except SQLAlchemyError as e: - log.exception("Database error logging guest start: %s", e) - session.rollback() - return None + guest = self.session.query(Guest).filter_by(task_id=task_id).first() + if guest is not None: + guest.status = status - @classlock def guest_remove(self, guest_id): """Removes a guest start entry.""" - with self.Session() as session: - try: - guest = session.get(Guest, guest_id) - session.delete(guest) - session.commit() - except SQLAlchemyError as e: - log.debug("Database error logging guest remove: %s", e) - session.rollback() - return None + guest = self.session.get(Guest, guest_id) + if guest: + self.session.delete(guest) - @classlock def guest_stop(self, guest_id): """Logs guest stop. @param guest_id: guest log entry id """ - with self.Session() as session: - try: - guest = session.get(Guest, guest_id) - if guest: - guest.shutdown_on = datetime.now() - session.commit() - except SQLAlchemyError as e: - log.debug("Database error logging guest stop: %s", e) - session.rollback() - except TypeError: - log.warning("Data inconsistency in guests table detected, it might be a crash leftover. Continue") - session.rollback() + guest = self.session.get(Guest, guest_id) + if guest: + guest.shutdown_on = datetime.now() @staticmethod def filter_machines_by_arch(machines, arch): @@ -1079,10 +917,10 @@ def filter_machines_by_arch(machines, arch): return machines def filter_machines_to_task( - self, machines: list, label=None, platform=None, tags=None, archs=None, os_version=[], include_reserved=False - ) -> list: + self, machines: Query, label=None, platform=None, tags=None, archs=None, os_version=None, include_reserved=False + ) -> Query: """Add filters to the given query based on the task - @param machines: List of machines where the filter will be applied + @param machines: Query object for the machines @param label: label of the machine(s) expected for the task @param platform: platform of the machine(s) expected for the task @param tags: tags of the machine(s) expected for the task @@ -1105,18 +943,16 @@ def filter_machines_to_task( machines = machines.filter(Machine.tags.any(Tag.name.in_(os_version))) return machines - @classlock def list_machines( self, locked=None, label=None, platform=None, - tags=[], + tags=None, arch=None, include_reserved=False, - os_version=[], - include_scheduled=True, - ): + os_version=None, + ) -> List[Machine]: """Lists virtual machines. @return: list of virtual machines """ @@ -1126,105 +962,54 @@ def list_machines( 77 | cape1 | win7 | x86 | 78 | cape2 | win10 | x64 | """ - with self.Session() as session: - try: - machines = session.query(Machine).options(joinedload(Machine.tags)) - if locked is not None and isinstance(locked, bool): - machines = machines.filter_by(locked=locked) - machines = self.filter_machines_to_task( - machines=machines, - label=label, - platform=platform, - tags=tags, - archs=arch, - os_version=os_version, - include_reserved=include_reserved, - ) - if not include_scheduled: - machines = machines.filter(or_(Machine.status.notlike(MACHINE_SCHEDULED), Machine.status == None)) # noqa: E711 - return machines.all() - except SQLAlchemyError as e: - print(e) - log.debug("Database error listing machines: %s", e) - return [] - - @classlock - def lock_machine(self, label=None, platform=None, tags=None, arch=None, os_version=[], need_scheduled=False): + machines = self.session.query(Machine).options(joinedload(Machine.tags)) + if locked is not None and isinstance(locked, bool): + machines = machines.filter_by(locked=locked) + machines = self.filter_machines_to_task( + machines=machines, + label=label, + platform=platform, + tags=tags, + archs=arch, + os_version=os_version, + include_reserved=include_reserved, + ) + return machines.all() + + def assign_machine_to_task(self, task: Task, machine: Optional[Machine]) -> Task: + if machine: + task.machine = machine.label + task.machine_id = machine.id + else: + task.machine = None + task.machine_id = None + self.session.add(task) + return task + + def lock_machine(self, machine: Machine) -> Machine: """Places a lock on a free virtual machine. - @param label: optional virtual machine label - @param platform: optional virtual machine platform - @param tags: optional tags required (list) - @param arch: optional virtual machine arch - @param os_version: tags to filter per OS version. Ex: winxp, win7, win10, win11 - @param need_scheduled: should the result be filtered on 'scheduled' machine status + @param machine: the Machine to lock @return: locked machine """ - if not self.validate_task_parameters(label=label, platform=platform, tags=tags): - return None + machine.locked = True + machine.locked_changed_on = datetime.now() + self.set_machine_status(machine, MACHINE_RUNNING) + self.session.add(machine) - with self.Session() as session: - - try: - machines = session.query(Machine) - machines = self.filter_machines_to_task( - machines=machines, label=label, platform=platform, tags=tags, archs=arch, os_version=os_version - ) - # Check if there are any machines that satisfy the - # selection requirements. - if not machines.count(): - raise CuckooOperationalError( - "No machines match selection criteria of label: '%s', platform: '%s', arch: '%s', tags: '%s'" - % (label, platform, arch, tags) - ) - if need_scheduled: - machines = machines.filter(Machine.status.like(MACHINE_SCHEDULED)) - # Get the first free machine. - machine = machines.filter_by(locked=False).first() - except SQLAlchemyError as e: - log.debug("Database error locking machine: %s", e) - return None - - if machine: - machine.locked = True - machine.locked_changed_on = datetime.now() - try: - session.commit() - session.refresh(machine) - except SQLAlchemyError as e: - log.debug("Database error locking machine: %s", e) - session.rollback() - return None - self.set_machine_status(machine.label, MACHINE_RUNNING) return machine - @classlock - def unlock_machine(self, label): - """Remove lock form a virtual machine. - @param label: virtual machine label + def unlock_machine(self, machine: Machine) -> Machine: + """Remove lock from a virtual machine. + @param machine: The Machine to unlock. @return: unlocked machine """ - with self.Session() as session: - try: - machine = session.query(Machine).filter_by(label=label).first() - except SQLAlchemyError as e: - log.debug("Database error unlocking machine: %s", e) - return None - - if machine: - machine.locked = False - machine.locked_changed_on = datetime.now() - try: - session.commit() - session.refresh(machine) - except SQLAlchemyError as e: - log.debug("Database error locking machine: %s", e) - session.rollback() - return None + machine.locked = False + machine.locked_changed_on = datetime.now() + self.session.add(machine) return machine - @classlock - def count_machines_available(self, label=None, platform=None, tags=None, arch=None, include_reserved=False, os_version=[]): + def count_machines_available(self, label=None, platform=None, tags=None, arch=None, include_reserved=False, os_version=None): """How many (relevant) virtual machines are ready for analysis. @param label: machine ID. @param platform: machine platform. @@ -1233,120 +1018,66 @@ def count_machines_available(self, label=None, platform=None, tags=None, arch=No @param include_reserved: include 'reserved' machines in the result, regardless of whether or not a 'label' was provided. @return: free virtual machines count """ - with self.Session() as session: - try: - machines = session.query(Machine).filter_by(locked=False) - machines = self.filter_machines_to_task( - machines=machines, - label=label, - platform=platform, - tags=tags, - archs=arch, - os_version=os_version, - include_reserved=include_reserved, - ) - return machines.count() - except SQLAlchemyError as e: - log.debug("Database error counting machines: %s", e) - return 0 + machines = self.session.query(Machine).filter_by(locked=False) + machines = self.filter_machines_to_task( + machines=machines, + label=label, + platform=platform, + tags=tags, + archs=arch, + os_version=os_version, + include_reserved=include_reserved, + ) + return machines.count() - @classlock - def get_available_machines(self): + def get_available_machines(self) -> List[Machine]: """Which machines are available @return: free virtual machines """ - with self.Session() as session: - try: - machines = session.query(Machine).options(joinedload(Machine.tags)).filter_by(locked=False).all() - return machines - except SQLAlchemyError as e: - log.debug("Database error getting available machines: %s", e) - return [] - - @classlock - def get_machines_scheduled(self): - with self.Session() as session: - try: - machines = session.query(Machine) - machines = machines.filter(Machine.status.like(MACHINE_SCHEDULED)) - result = machines.count() - except SQLAlchemyError as e: - log.debug("Database error getting machine scheduled: %s", e) - return 0 - return result - - @classlock - def set_machine_status(self, label, status): + machines = self.session.query(Machine).options(joinedload(Machine.tags)).filter_by(locked=False).all() + return machines + + def count_machines_running(self) -> int: + machines = self.session.query(Machine) + machines = machines.filter_by(locked=True) + return machines.count() + + def set_machine_status(self, machine_or_label: Union[str, Machine], status): """Set status for a virtual machine. @param label: virtual machine label @param status: new virtual machine status """ - with self.Session() as session: - try: - machine = session.query(Machine).filter_by(label=label).first() - except SQLAlchemyError as e: - log.debug("Database error setting machine status: %s", e) - session.close() - return - - if machine: - machine.status = status - machine.status_changed_on = datetime.now() - try: - session.commit() - session.refresh(machine) - except SQLAlchemyError as e: - log.debug("Database error setting machine status: %s", e) - session.rollback() - - @classlock - def check_machines_scheduled_timeout(self): - with self.Session() as session: - try: - machines = session.query(Machine) - machines = machines.filter(Machine.status.like(MACHINE_SCHEDULED)) - except SQLAlchemyError as e: - log.debug("Database error setting machine status: %s", e) - session.close() - return - - for machine in machines: - if machine.status_changed_on + timedelta(seconds=30) < datetime.now(): - self.set_machine_status(machine.label, MACHINE_RUNNING) - - @classlock + if isinstance(machine_or_label, str): + machine = self.session.query(Machine).filter_by(label=machine_or_label).first() + else: + machine = machine_or_label + if machine: + machine.status = status + machine.status_changed_on = datetime.now() + self.session.add(machine) + def add_error(self, message, task_id): """Add an error related to a task. @param message: error message @param task_id: ID of the related task """ - with self.Session() as session: - error = Error(message=message, task_id=task_id) - session.add(error) - try: - session.commit() - except SQLAlchemyError as e: - log.debug("Database error adding error log: %s", e) - session.rollback() + error = Error(message=message, task_id=task_id) + # Use a separate session so that, regardless of the state of a transaction going on + # outside of this function, the error will always be committed to the database. + with self.session.session_factory() as sess, sess.begin(): + sess.add(error) # The following functions are mostly used by external utils. - @classlock def register_sample(self, obj, source_url=False): - sample_id = None if isinstance(obj, (File, PCAP, Static)): - with self.Session() as session: - fileobj = File(obj.file_path) - file_type = fileobj.get_type() - file_md5 = fileobj.get_md5() - sample = None - # check if hash is known already - try: - sample = session.query(Sample).filter_by(md5=file_md5).first() - except SQLAlchemyError as e: - log.debug("Error querying sample for hash: %s", e) - - if not sample: + fileobj = File(obj.file_path) + file_type = fileobj.get_type() + file_md5 = fileobj.get_md5() + sample = None + # check if hash is known already + try: + with self.session.begin_nested(): sample = Sample( md5=file_md5, crc32=fileobj.get_crc32(), @@ -1359,30 +1090,17 @@ def register_sample(self, obj, source_url=False): # parent=sample_parent_id, source_url=source_url, ) - session.add(sample) + self.session.add(sample) + except IntegrityError: + sample = self.session.query(Sample).filter_by(md5=file_md5).first() - try: - session.commit() - except IntegrityError: - session.rollback() - try: - sample = session.query(Sample).filter_by(md5=file_md5).first() - except SQLAlchemyError as e: - log.debug("Error querying sample for hash: %s", e) - return None - except SQLAlchemyError as e: - log.debug("Database error adding task: %s", e) - return None - finally: - sample_id = sample.id - - return sample_id + return sample.id return None - @classlock def add( self, obj, + *, timeout=0, package="", options="", @@ -1433,26 +1151,19 @@ def add( @param username: username for custom auth @return: cursor or None. """ - with self.Session() as session: - - # Convert empty strings and None values to a valid int - if not timeout: - timeout = 0 - if not priority: - priority = 1 - - if isinstance(obj, (File, PCAP, Static)): - fileobj = File(obj.file_path) - file_type = fileobj.get_type() - file_md5 = fileobj.get_md5() - sample = None - # check if hash is known already - try: - sample = session.query(Sample).filter_by(md5=file_md5).first() - except SQLAlchemyError as e: - log.debug("Error querying sample for hash: %s", e) + # Convert empty strings and None values to a valid int + if not timeout: + timeout = 0 + if not priority: + priority = 1 - if not sample: + if isinstance(obj, (File, PCAP, Static)): + fileobj = File(obj.file_path) + file_type = fileobj.get_type() + file_md5 = fileobj.get_md5() + # check if hash is known already + try: + with self.session.begin_nested(): sample = Sample( md5=file_md5, crc32=fileobj.get_crc32(), @@ -1465,115 +1176,95 @@ def add( parent=sample_parent_id, source_url=source_url, ) - session.add(sample) - - try: - session.commit() - except IntegrityError: - session.rollback() - """ - try: - sample = session.query(Sample).filter_by(md5=file_md5).first() - except SQLAlchemyError as e: - log.debug("Error querying sample for hash: %s", e) - session.close() - return None - """ - except SQLAlchemyError as e: - log.debug("Database error adding task: %s", e) - session.close() - return None - - if DYNAMIC_ARCH_DETERMINATION: - # Assign architecture to task to fetch correct VM type - # This isn't 100% full proof - if "PE32+" in file_type or "64-bit" in file_type or package.endswith("_x64"): - if tags: - tags += ",x64" - else: - tags = "x64" + self.session.add(sample) + except IntegrityError: + sample = self.session.query(Sample).filter_by(md5=file_md5).first() + + if DYNAMIC_ARCH_DETERMINATION: + # Assign architecture to task to fetch correct VM type + # This isn't 100% full proof + if "PE32+" in file_type or "64-bit" in file_type or package.endswith("_x64"): + if tags: + tags += ",x64" else: - if LINUX_ENABLED and platform == "linux": - linux_arch = _get_linux_vm_tag(file_type) - if linux_arch: - if tags: - tags += f",{linux_arch}" - else: - tags = linux_arch - else: + tags = "x64" + else: + if LINUX_ENABLED and platform == "linux": + linux_arch = _get_linux_vm_tag(file_type) + if linux_arch: if tags: - tags += ",x86" + tags += f",{linux_arch}" else: - tags = "x86" + tags = linux_arch + else: + if tags: + tags += ",x86" + else: + tags = "x86" + task = Task(obj.file_path) + task.sample_id = sample.id + + if isinstance(obj, (PCAP, Static)): + # since no VM will operate on this PCAP + task.started_on = datetime.now() + + elif isinstance(obj, URL): + task = Task(obj.url) + tags = "x64,x86" + + else: + return None + + task.category = obj.__class__.__name__.lower() + task.timeout = timeout + task.package = package + task.options = options + task.priority = priority + task.custom = custom + task.machine = machine + task.platform = platform + task.memory = bool(memory) + task.enforce_timeout = enforce_timeout + task.shrike_url = shrike_url + task.shrike_msg = shrike_msg + task.shrike_sid = shrike_sid + task.shrike_refer = shrike_refer + task.parent_id = parent_id + task.tlp = tlp + task.route = route + task.cape = cape + task.tags_tasks = tags_tasks + # Deal with tags format (i.e., foo,bar,baz) + if tags: + for tag in tags.split(","): + tag_name = tag.strip() + if tag_name and tag_name not in [tag.name for tag in task.tags]: + # "Task" object is being merged into a Session along the backref cascade path for relationship "Tag.tasks"; in SQLAlchemy 2.0, this reverse cascade will not take place. + # Set cascade_backrefs to False in either the relationship() or backref() function for the 2.0 behavior; or to set globally for the whole Session, set the future=True flag + # (Background on this error at: https://sqlalche.me/e/14/s9r1) (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) + task.tags.append(self._get_or_create(Tag, name=tag_name)) + + if clock: + if isinstance(clock, str): try: - task = Task(obj.file_path) - task.sample_id = sample.id - except OperationalError: - return None - - if isinstance(obj, (PCAP, Static)): - # since no VM will operate on this PCAP - task.started_on = datetime.now() - - elif isinstance(obj, URL): - task = Task(obj.url) - tags = "x64,x86" - - task.category = obj.__class__.__name__.lower() - task.timeout = timeout - task.package = package - task.options = options - task.priority = priority - task.custom = custom - task.machine = machine - task.platform = platform - task.memory = bool(memory) - task.enforce_timeout = enforce_timeout - task.shrike_url = shrike_url - task.shrike_msg = shrike_msg - task.shrike_sid = shrike_sid - task.shrike_refer = shrike_refer - task.parent_id = parent_id - task.tlp = tlp - task.route = route - task.cape = cape - task.tags_tasks = tags_tasks - # Deal with tags format (i.e., foo,bar,baz) - if tags: - for tag in tags.split(","): - if tag.strip(): - # "Task" object is being merged into a Session along the backref cascade path for relationship "Tag.tasks"; in SQLAlchemy 2.0, this reverse cascade will not take place. - # Set cascade_backrefs to False in either the relationship() or backref() function for the 2.0 behavior; or to set globally for the whole Session, set the future=True flag - # (Background on this error at: https://sqlalche.me/e/14/s9r1) (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) - task.tags.append(self._get_or_create(session, Tag, name=tag)) - - if clock: - if isinstance(clock, str): - try: - task.clock = datetime.strptime(clock, "%m-%d-%Y %H:%M:%S") - except ValueError: - log.warning("The date you specified has an invalid format, using current timestamp") - task.clock = datetime.utcfromtimestamp(0) + task.clock = datetime.strptime(clock, "%m-%d-%Y %H:%M:%S") + except ValueError: + log.warning("The date you specified has an invalid format, using current timestamp") + task.clock = datetime.utcfromtimestamp(0) - else: - task.clock = clock else: - task.clock = datetime.utcfromtimestamp(0) - - task.user_id = user_id - task.username = username + task.clock = clock + else: + task.clock = datetime.utcfromtimestamp(0) - session.add(task) + task.user_id = user_id + task.username = username - try: - session.commit() - task_id = task.id - except SQLAlchemyError as e: - log.debug("Database error adding task: %s", e) - session.rollback() - return None + # Use a nested transaction so that we can return an ID. + with self.session.begin_nested(): + self.session.add(task) - return task_id + return task.id def add_path( self, @@ -1641,24 +1332,24 @@ def add_path( return self.add( File(file_path), - timeout, - package, - options, - priority, - custom, - machine, - platform, - tags, - memory, - enforce_timeout, - clock, - shrike_url, - shrike_msg, - shrike_sid, - shrike_refer, - parent_id, - sample_parent_id, - tlp, + timeout=timeout, + package=package, + options=options, + priority=priority, + custom=custom, + machine=machine, + platform=platform, + tags=tags, + memory=memory, + enforce_timeout=enforce_timeout, + clock=clock, + shrike_url=shrike_url, + shrike_msg=shrike_msg, + shrike_sid=shrike_sid, + shrike_refer=shrike_refer, + parent_id=parent_id, + sample_parent_id=sample_parent_id, + tlp=tlp, source_url=source_url, route=route, cape=cape, @@ -1831,7 +1522,6 @@ def demux_sample_and_add_to_db( # this is aim to return custom data, think of this as kwargs return task_ids, details - @classlock def add_pcap( self, file_path, @@ -1857,28 +1547,27 @@ def add_pcap( ): return self.add( PCAP(file_path.decode()), - timeout, - package, - options, - priority, - custom, - machine, - platform, - tags, - memory, - enforce_timeout, - clock, - shrike_url, - shrike_msg, - shrike_sid, - shrike_refer, - parent_id, - tlp, - user_id, - username, + timeout=timeout, + package=package, + options=options, + priority=priority, + custom=custom, + machine=machine, + platform=platform, + tags=tags, + memory=memory, + enforce_timeout=enforce_timeout, + clock=clock, + shrike_url=shrike_url, + shrike_msg=shrike_msg, + shrike_sid=shrike_sid, + shrike_refer=shrike_refer, + parent_id=parent_id, + tlp=tlp, + user_id=user_id, + username=username, ) - @classlock def add_static( self, file_path, @@ -1918,21 +1607,21 @@ def add_static( for file, platform in extracted_files: task_id = self.add( Static(file.decode()), - timeout, - package, - options, - priority, - custom, - machine, - platform, - tags, - memory, - enforce_timeout, - clock, - shrike_url, - shrike_msg, - shrike_sid, - shrike_refer, + timeout=timeout, + package=package, + options=options, + priority=priority, + custom=custom, + machine=machine, + platform=platform, + tags=tags, + memory=memory, + enforce_timeout=enforce_timeout, + clock=clock, + shrike_url=shrike_url, + shrike_msg=shrike_msg, + shrike_sid=shrike_sid, + shrike_refer=shrike_refer, tlp=tlp, static=static, sample_parent_id=sample_parent_id, @@ -1944,7 +1633,6 @@ def add_static( return task_ids - @classlock def add_url( self, url, @@ -2002,23 +1690,23 @@ def add_url( return self.add( URL(url), - timeout, - package, - options, - priority, - custom, - machine, - platform, - tags, - memory, - enforce_timeout, - clock, - shrike_url, - shrike_msg, - shrike_sid, - shrike_refer, - parent_id, - tlp, + timeout=timeout, + package=package, + options=options, + priority=priority, + custom=custom, + machine=machine, + platform=platform, + tags=tags, + memory=memory, + enforce_timeout=enforce_timeout, + clock=clock, + shrike_url=shrike_url, + shrike_msg=shrike_msg, + shrike_sid=shrike_sid, + shrike_refer=shrike_refer, + parent_id=parent_id, + tlp=tlp, route=route, cape=cape, tags_tasks=tags_tasks, @@ -2026,7 +1714,6 @@ def add_url( username=username, ) - @classlock def reschedule(self, task_id): """Reschedule a task. @param task_id: ID of the task to reschedule. @@ -2047,89 +1734,75 @@ def reschedule(self, task_id): add = self.add_static # Change status to recovered. - with self.Session() as session: - session.get(Task, task_id).status = TASK_RECOVERED - try: - session.commit() - except SQLAlchemyError as e: - log.debug("Database error rescheduling task: %s", e) - session.rollback() - return False - - # Normalize tags. - if task.tags: - tags = ",".join(tag.name for tag in task.tags) - else: - tags = task.tags - - def _ensure_valid_target(task): - if task.category == "url": - # URL tasks always have valid targets, return it as-is. - return task.target - - # All other task types have a "target" pointing to a temp location, - # so get a stable path "target" based on the sample hash. - paths = self.sample_path_by_hash(task.sample.sha256, task_id) - paths = [file_path for file_path in paths if path_exists(file_path)] - if not paths: - return None - - if task.category == "pcap": - # PCAP task paths are represented as bytes - return paths[0].encode() - return paths[0] - - task_target = _ensure_valid_target(task) - if not task_target: - log.warning("Unable to find valid target for task: %s", task_id) - return - - new_task_id = None - if task.category in ("file", "url"): - new_task_id = add( - task_target, - task.timeout, - task.package, - task.options, - task.priority, - task.custom, - task.machine, - task.platform, - tags, - task.memory, - task.enforce_timeout, - task.clock, - tlp=task.tlp, - route=task.route, - ) - elif task.category in ("pcap", "static"): - new_task_id = add( - task_target, - task.timeout, - task.package, - task.options, - task.priority, - task.custom, - task.machine, - task.platform, - tags, - task.memory, - task.enforce_timeout, - task.clock, - tlp=task.tlp, - ) + self.session.get(Task, task_id).status = TASK_RECOVERED - session.get(Task, task_id).custom = f"Recovery_{new_task_id}" - try: - session.commit() - except SQLAlchemyError as e: - log.debug("Database error rescheduling task: %s", e) - session.rollback() - return False + # Normalize tags. + if task.tags: + tags = ",".join(tag.name for tag in task.tags) + else: + tags = task.tags + + def _ensure_valid_target(task): + if task.category == "url": + # URL tasks always have valid targets, return it as-is. + return task.target + + # All other task types have a "target" pointing to a temp location, + # so get a stable path "target" based on the sample hash. + paths = self.sample_path_by_hash(task.sample.sha256, task_id) + paths = [file_path for file_path in paths if path_exists(file_path)] + if not paths: + return None - return new_task_id + if task.category == "pcap": + # PCAP task paths are represented as bytes + return paths[0].encode() + return paths[0] + + task_target = _ensure_valid_target(task) + if not task_target: + log.warning("Unable to find valid target for task: %s", task_id) + return + + new_task_id = None + if task.category in ("file", "url"): + new_task_id = add( + task_target, + task.timeout, + task.package, + task.options, + task.priority, + task.custom, + task.machine, + task.platform, + tags, + task.memory, + task.enforce_timeout, + task.clock, + tlp=task.tlp, + route=task.route, + ) + elif task.category in ("pcap", "static"): + new_task_id = add( + task_target, + task.timeout, + task.package, + task.options, + task.priority, + task.custom, + task.machine, + task.platform, + tags, + task.memory, + task.enforce_timeout, + task.clock, + tlp=task.tlp, + ) + + self.session.get(Task, task_id).custom = f"Recovery_{new_task_id}" + + return new_task_id - @classlock def count_matching_tasks(self, category=None, status=None, not_status=None): """Retrieve list of task. @param category: filter by category @@ -2137,81 +1810,70 @@ def count_matching_tasks(self, category=None, status=None, not_status=None): @param not_status: exclude this task status from filter @return: number of tasks. """ - with self.Session() as session: - try: - search = session.query(Task) - - if status: - search = search.filter_by(status=status) - if not_status: - search = search.filter(Task.status != not_status) - if category: - search = search.filter_by(category=category) - - tasks = search.count() - return tasks - except SQLAlchemyError as e: - log.debug("Database error counting tasks: %s", e) - return [] - - @classlock + search = self.session.query(Task) + + if status: + search = search.filter_by(status=status) + if not_status: + search = search.filter(Task.status != not_status) + if category: + search = search.filter_by(category=category) + + tasks = search.count() + return tasks + def check_file_uniq(self, sha256: str, hours: int = 0): + # TODO This function is poorly named. It returns True if a sample with the given + # sha256 already exists in the database, rather than returning True if the given + # sha256 is unique. uniq = False - with self.Session() as session: - try: - if hours and sha256: - date_since = datetime.now() - timedelta(hours=hours) - date_till = datetime.now() - uniq = ( - session.query(Task) - .join(Sample, Task.sample_id == Sample.id) - .filter(Sample.sha256 == sha256, Task.added_on.between(date_since, date_till)) - .first() - ) - else: - if not Database.find_sample(self, sha256=sha256): - uniq = False - else: - uniq = True - except SQLAlchemyError as e: - log.debug("Database error counting tasks: %s", e) + if hours and sha256: + date_since = datetime.now() - timedelta(hours=hours) + date_till = datetime.now() + uniq = ( + self.session.query(Task) + .join(Sample, Task.sample_id == Sample.id) + .filter(Sample.sha256 == sha256, Task.added_on.between(date_since, date_till)) + .first() + ) + else: + if not self.find_sample(sha256=sha256): + uniq = False + else: + uniq = True return uniq - @classlock def list_sample_parent(self, sample_id=False, task_id=False): """ Retrieve parent sample details by sample_id or task_id @param sample_id: Sample id @param task_id: Task id """ + # This function appears to only be used in one specific case, and task_id is + # the only parameter that gets passed--sample_id is never provided. + # TODO Pull sample_id as an argument. It's dead code. parent_sample = {} parent = False - with self.Session() as session: - try: - if sample_id: - parent = session.query(Sample.parent).filter(Sample.id == int(sample_id)).first() - if parent: - parent = parent[0] - elif task_id: - _, parent = ( - session.query(Task.sample_id, Sample.parent) - .join(Sample, Sample.id == Task.sample_id) - .filter(Task.id == task_id) - .first() - ) - - if parent: - parent_sample = session.query(Sample).filter(Sample.id == parent).first().to_dict() + if sample_id: # pragma: no cover + parent = self.session.query(Sample.parent).filter(Sample.id == int(sample_id)).first() + if parent: + parent = parent[0] + elif task_id: + result = ( + self.session.query(Task.sample_id, Sample.parent) + .join(Sample, Sample.id == Task.sample_id) + .filter(Task.id == task_id) + .first() + ) + if result is not None: + parent = result[1] - except SQLAlchemyError as e: - log.debug("Database error listing tasks: %s", e) - except TypeError: - pass + if parent: + parent_sample = self.session.query(Sample).filter(Sample.id == parent).first().to_dict() return parent_sample - @classlock def list_tasks( self, limit=None, @@ -2231,8 +1893,9 @@ def list_tasks( tags_tasks_like=False, task_ids=False, include_hashes=False, - user_id=False, - ): + user_id=None, + for_update=False, + ) -> List[Task]: """Retrieve list of task. @param limit: specify a limit of entries. @param details: if details about must be included @@ -2252,260 +1915,179 @@ def list_tasks( @param task_ids: list of task_id @param include_hashes: return task+samples details @param user_id: list of tasks submitted by user X + @param for_update: If True, use "SELECT FOR UPDATE" in order to create a row-level lock on the selected tasks. @return: list of tasks. """ - with self.Session() as session: - try: - # Can we remove "options(joinedload)" it is here due to next error - # sqlalchemy.orm.exc.DetachedInstanceError: Parent instance is not bound to a Session; lazy load operation of attribute 'tags' cannot proceed - # ToDo this is inefficient but it fails if we don't join. Need to fix this - search = session.query(Task).options(joinedload(Task.guest), joinedload(Task.errors), joinedload(Task.tags)) - if include_hashes: - search = search.join(Sample, Task.sample_id == Sample.id) - if status: - if "|" in status: - search = search.filter(Task.status.in_(status.split("|"))) - else: - search = search.filter(Task.status == status) - if not_status: - search = search.filter(Task.status != not_status) - if category: - search = search.filter(Task.category.in_([category] if isinstance(category, str) else category)) - if details: - search = search.options(joinedload(Task.guest), joinedload(Task.errors), joinedload(Task.tags)) - if sample_id is not None: - search = search.filter(Task.sample_id == sample_id) - if id_before is not None: - search = search.filter(Task.id < id_before) - if id_after is not None: - search = search.filter(Task.id > id_after) - if completed_after: - search = search.filter(Task.completed_on > completed_after) - if added_before: - search = search.filter(Task.added_on < added_before) - if options_like: - # Replace '*' wildcards with wildcard for sql - options_like = options_like.replace("*", "%") - search = search.filter(Task.options.like(f"%{options_like}%")) - if options_not_like: - # Replace '*' wildcards with wildcard for sql - options_not_like = options_not_like.replace("*", "%") - search = search.filter(Task.options.notlike(f"%{options_not_like}%")) - if tags_tasks_like: - search = search.filter(Task.tags_tasks.like(f"%{tags_tasks_like}%")) - if task_ids: - search = search.filter(Task.id.in_(task_ids)) - if user_id: - search = search.filter(Task.user_id == user_id) - if order_by is not None and isinstance(order_by, tuple): - search = search.order_by(*order_by) - elif order_by is not None: - search = search.order_by(order_by) - else: - search = search.order_by(Task.added_on.desc()) - - tasks = search.limit(limit).offset(offset).all() - session.expunge_all() - return tasks - except RuntimeError as e: - # RuntimeError: number of values in row (1) differ from number of column processors (62) - log.debug("Database RuntimeError error: %s", e) - except AttributeError as e: - # '_NoResultMetaData' object has no attribute '_indexes_for_keys' - log.debug("Database AttributeError error: %s", e) - except SQLAlchemyError as e: - log.debug("Database error listing tasks: %s", e) - except Exception as e: - # psycopg2.DatabaseError - log.exception(e) + tasks: List[Task] = [] + # Can we remove "options(joinedload)" it is here due to next error + # sqlalchemy.orm.exc.DetachedInstanceError: Parent instance is not bound to a Session; lazy load operation of attribute 'tags' cannot proceed + # ToDo this is inefficient but it fails if we don't join. Need to fix this + search = self.session.query(Task).options(joinedload(Task.guest), joinedload(Task.errors), joinedload(Task.tags)) + if include_hashes: # pragma: no cover + # This doesn't work, but doesn't seem to get used anywhere. + search = search.options(joinedload(Sample)) + if status: + if "|" in status: + search = search.filter(Task.status.in_(status.split("|"))) + else: + search = search.filter(Task.status == status) + if not_status: + search = search.filter(Task.status != not_status) + if category: + search = search.filter(Task.category.in_([category] if isinstance(category, str) else category)) + # We're currently always returning details. See the comment at the top of this 'try' block. + # if details: + # search = search.options(joinedload(Task.guest), joinedload(Task.errors), joinedload(Task.tags)) + if sample_id is not None: + search = search.filter(Task.sample_id == sample_id) + if id_before is not None: + search = search.filter(Task.id < id_before) + if id_after is not None: + search = search.filter(Task.id > id_after) + if completed_after: + search = search.filter(Task.completed_on > completed_after) + if added_before: + search = search.filter(Task.added_on < added_before) + if options_like: + # Replace '*' wildcards with wildcard for sql + options_like = options_like.replace("*", "%") + search = search.filter(Task.options.like(f"%{options_like}%")) + if options_not_like: + # Replace '*' wildcards with wildcard for sql + options_not_like = options_not_like.replace("*", "%") + search = search.filter(Task.options.notlike(f"%{options_not_like}%")) + if tags_tasks_like: + search = search.filter(Task.tags_tasks.like(f"%{tags_tasks_like}%")) + if task_ids: + search = search.filter(Task.id.in_(task_ids)) + if user_id is not None: + search = search.filter(Task.user_id == user_id) + if order_by is not None and isinstance(order_by, tuple): + search = search.order_by(*order_by) + elif order_by is not None: + search = search.order_by(order_by) + else: + search = search.order_by(Task.added_on.desc()) + + search = search.limit(limit).offset(offset) + if for_update: + search = search.with_for_update(of=Task) + tasks = search.all() - return [] + return tasks def minmax_tasks(self): """Find tasks minimum and maximum @return: unix timestamps of minimum and maximum """ - with self.Session() as session: - try: - _min = session.query(func.min(Task.started_on).label("min")).first() - _max = session.query(func.max(Task.completed_on).label("max")).first() - if _min and _max and _min[0] and _max[0]: - return int(_min[0].strftime("%s")), int(_max[0].strftime("%s")) - except SQLAlchemyError as e: - log.debug("Database error counting tasks: %s", e) + _min = self.session.query(func.min(Task.started_on).label("min")).first() + _max = self.session.query(func.max(Task.completed_on).label("max")).first() + if _min and _max and _min[0] and _max[0]: + return int(_min[0].strftime("%s")), int(_max[0].strftime("%s")) return 0, 0 - @classlock def get_tlp_tasks(self): """ Retrieve tasks with TLP """ - with self.Session() as session: - try: - tasks = session.query(Task).filter(Task.tlp == "true").all() - if tasks: - return [task.id for task in tasks] - else: - return [] - except SQLAlchemyError as e: - log.debug("Database error listing tasks: %s", e) - return [] + tasks = self.session.query(Task).filter(Task.tlp == "true").all() + if tasks: + return [task.id for task in tasks] + else: + return [] - @classlock def get_file_types(self): """Get sample filetypes @return: A list of all available file types """ - with self.Session() as session: - try: - unfiltered = session.query(Sample.file_type).group_by(Sample.file_type) - res = [asample[0] for asample in unfiltered.all()] - res.sort() - except SQLAlchemyError as e: - log.debug("Database error getting file_types: %s", e) - return 0 + unfiltered = self.session.query(Sample.file_type).group_by(Sample.file_type) + res = [asample[0] for asample in unfiltered.all()] + res.sort() return res - @classlock def get_tasks_status_count(self): """Count all tasks in the database @return: dict with status and number of tasks found example: {'failed_analysis': 2, 'running': 100, 'reported': 400} """ - with self.Session() as session: - try: - tasks_dict_count = session.query(Task.status, func.count(Task.status)).group_by(Task.status).all() - return dict(tasks_dict_count) - except SQLAlchemyError as e: - log.debug("Database error counting all tasks: %s", e) - - return {} + tasks_dict_count = self.session.query(Task.status, func.count(Task.status)).group_by(Task.status).all() + return dict(tasks_dict_count) - @classlock def count_tasks(self, status=None, mid=None): """Count tasks in the database @param status: apply a filter according to the task status @param mid: Machine id to filter for @return: number of tasks found """ - with self.Session() as session: - try: - unfiltered = session.query(Task) - if mid: - unfiltered = unfiltered.filter_by(machine_id=mid) - if status: - unfiltered = unfiltered.filter_by(status=status) - tasks_count = get_count(unfiltered, Task.id) - return tasks_count - except SQLAlchemyError as e: - log.debug("Database error counting tasks: %s", e) - return 0 - - @classlock - def view_task(self, task_id, details=False): + unfiltered = self.session.query(Task) + # It doesn't look like "mid" ever gets passed to this function. + if mid: # pragma: no cover + unfiltered = unfiltered.filter_by(machine_id=mid) + if status: + unfiltered = unfiltered.filter_by(status=status) + tasks_count = get_count(unfiltered, Task.id) + return tasks_count + + def view_task(self, task_id, details=False) -> Optional[Task]: """Retrieve information on a task. @param task_id: ID of the task to query. @return: details on the task. """ - with self.Session() as session: - try: - if details: - task = ( - select(Task) - .where(Task.id == task_id) - .options(joinedload(Task.guest), joinedload(Task.errors), joinedload(Task.tags), joinedload(Task.sample)) - ) - task = session.execute(task).first() - else: - query = select(Task).where(Task.id == task_id).options(joinedload(Task.tags), joinedload(Task.sample)) - task = session.execute(query).first() - if task: - task = task[0] - session.expunge(task) - return task - except SQLAlchemyError as e: - print(e) - log.debug("Database error viewing task: %s", e) - - @classlock - def add_statistics_to_task(self, task_id, details): + query = select(Task).where(Task.id == task_id) + if details: + query = query.options(joinedload(Task.guest), joinedload(Task.errors), joinedload(Task.tags), joinedload(Task.sample)) + else: + query = query.options(joinedload(Task.tags), joinedload(Task.sample)) + task = self.session.execute(query).first() + if task: + task = task[0] + + return task + + # TODO This function doesn't appear to be used. Can we pull it? + def add_statistics_to_task(self, task_id, details): # pragma: no cover """add statistic to task @param task_id: ID of the task to query. @param: details statistic. @return true of false. """ - with self.Session() as session: - try: - task = session.get(Task, task_id) - if task: - task.dropped_files = details["dropped_files"] - task.running_processes = details["running_processes"] - task.api_calls = details["api_calls"] - task.domains = details["domains"] - task.signatures_total = details["signatures_total"] - task.signatures_alert = details["signatures_alert"] - task.files_written = details["files_written"] - task.registry_keys_modified = details["registry_keys_modified"] - task.crash_issues = details["crash_issues"] - task.anti_issues = details["anti_issues"] - session.commit() - session.refresh(task) - except SQLAlchemyError as e: - log.debug("Database error deleting task: %s", e) - session.rollback() - return False + task = self.session.get(Task, task_id) + if task: + task.dropped_files = details["dropped_files"] + task.running_processes = details["running_processes"] + task.api_calls = details["api_calls"] + task.domains = details["domains"] + task.signatures_total = details["signatures_total"] + task.signatures_alert = details["signatures_alert"] + task.files_written = details["files_written"] + task.registry_keys_modified = details["registry_keys_modified"] + task.crash_issues = details["crash_issues"] + task.anti_issues = details["anti_issues"] return True - @classlock def delete_task(self, task_id): """Delete information on a task. @param task_id: ID of the task to query. @return: operation status. """ - with self.Session() as session: - try: - task = session.get(Task, task_id) - session.delete(task) - session.commit() - except SQLAlchemyError as e: - log.debug("Database error deleting task: %s", e) - session.rollback() - return False + task = self.session.get(Task, task_id) + if task is None: + return False + self.session.delete(task) return True - # classlock def delete_tasks(self, ids): - with self.Session() as session: - try: - _ = session.query(Task).filter(Task.id.in_(ids)).delete(synchronize_session=False) - except SQLAlchemyError as e: - log.debug("Database error deleting task: %s", e) - session.rollback() - return False + self.session.query(Task).filter(Task.id.in_(ids)).delete(synchronize_session=False) return True - @classlock def view_sample(self, sample_id): """Retrieve information on a sample given a sample id. @param sample_id: ID of the sample to query. @return: details on the sample used in sample: sample_id. """ - with self.Session() as session: - try: - sample = session.get(Sample, sample_id) - except AttributeError: - return None - except SQLAlchemyError as e: - log.debug("Database error viewing task: %s", e) - return None - else: - if sample: - session.expunge(sample) - - return sample + return self.session.get(Sample, sample_id) - @classlock def find_sample(self, md5=None, sha1=None, sha256=None, parent=None, task_id: int = None, sample_id: int = None): """Search samples by MD5, SHA1, or SHA256. @param md5: md5 string @@ -2517,55 +2099,45 @@ def find_sample(self, md5=None, sha1=None, sha256=None, parent=None, task_id: in @return: matches list """ sample = False - with self.Session() as session: - try: - if md5: - sample = session.query(Sample).filter_by(md5=md5).first() - elif sha1: - sample = session.query(Sample).filter_by(sha1=sha1).first() - elif sha256: - sample = session.query(Sample).filter_by(sha256=sha256).first() - elif parent: - sample = session.query(Sample).filter_by(parent=parent).all() - elif sample_id: - sample = session.query(Sample).filter_by(id=sample_id).all() - elif task_id: - sample = ( - session.query(Task) - .options(joinedload(Task.sample)) - .filter(Task.id == task_id) - .filter(Sample.id == Task.sample_id) - .all() - ) - except SQLAlchemyError as e: - log.debug("Database error searching sample: %s", e) - return None - else: - if sample: - session.expunge_all() + if md5: + sample = self.session.query(Sample).filter_by(md5=md5).first() + elif sha1: + sample = self.session.query(Sample).filter_by(sha1=sha1).first() + elif sha256: + sample = self.session.query(Sample).filter_by(sha256=sha256).first() + elif parent: + sample = self.session.query(Sample).filter_by(parent=parent).all() + elif sample_id: + sample = self.session.query(Sample).filter_by(id=sample_id).all() + elif task_id: + # If task_id is passed, then a list of Task objects is returned--not Samples. + sample = ( + self.session.query(Task) + .options(joinedload(Task.sample)) + .filter(Task.id == task_id) + .filter(Sample.id == Task.sample_id) + .all() + ) return sample - @classlock def sample_still_used(self, sample_hash: str, task_id: int): """Retrieve information if sample is used by another task(s). - @param hash: md5/sha1/sha256/sha256. + @param sample_hash: sha256. @param task_id: task_id @return: bool """ - with self.Session() as session: - db_sample = ( - session.query(Sample) - # .options(joinedload(Task.sample)) - .filter(Sample.sha256 == sample_hash) - .filter(Task.id != task_id) - .filter(Sample.id == Task.sample_id) - .filter(Task.status.in_((TASK_PENDING, TASK_RUNNING, TASK_DISTRIBUTED))) - .first() - ) - still_used = bool(db_sample) - return still_used + db_sample = ( + self.session.query(Sample) + # .options(joinedload(Task.sample)) + .filter(Sample.sha256 == sample_hash) + .filter(Task.id != task_id) + .filter(Sample.id == Task.sample_id) + .filter(Task.status.in_((TASK_PENDING, TASK_RUNNING, TASK_DISTRIBUTED))) + .first() + ) + still_used = bool(db_sample) + return still_used - @classlock def sample_path_by_hash(self, sample_hash: str = False, task_id: int = False): """Retrieve information on a sample location by given hash. @param hash: md5/sha1/sha256/sha256. @@ -2597,12 +2169,10 @@ def sample_path_by_hash(self, sample_hash: str = False, task_id: int = False): if path_exists(file_path): return [file_path] - session = False # binary also not stored in binaries, perform hash lookup if task_id and not sample_hash: - session = self.Session() db_sample = ( - session.query(Sample) + self.session.query(Sample) # .options(joinedload(Task.sample)) .filter(Task.id == task_id) .filter(Sample.id == Task.sample_id) @@ -2622,29 +2192,64 @@ def sample_path_by_hash(self, sample_hash: str = False, task_id: int = False): sample = [] # check storage/binaries if query_filter: - try: - if not session: - session = self.Session() - db_sample = session.query(Sample).filter(query_filter == sample_hash).first() - if db_sample is not None: - path = os.path.join(CUCKOO_ROOT, "storage", "binaries", db_sample.sha256) - if path_exists(path): - sample = [path] + db_sample = self.session.query(Sample).filter(query_filter == sample_hash).first() + if db_sample is not None: + path = os.path.join(CUCKOO_ROOT, "storage", "binaries", db_sample.sha256) + if path_exists(path): + sample = [path] + + if not sample: + if repconf.mongodb.enabled: + tasks = mongo_find( + "analysis", + {f"CAPE.payloads.{sizes_mongo.get(len(sample_hash), '')}": sample_hash}, + {"CAPE.payloads": 1, "_id": 0, "info.id": 1}, + ) + elif repconf.elasticsearchdb.enabled: + tasks = [ + d["_source"] + for d in es.search( + index=get_analysis_index(), + body={"query": {"match": {f"CAPE.payloads.{sizes_mongo.get(len(sample_hash), '')}": sample_hash}}}, + _source=["CAPE.payloads", "info.id"], + )["hits"]["hits"] + ] + else: + tasks = [] - if not sample: + if tasks: + for task in tasks: + for block in task.get("CAPE", {}).get("payloads", []) or []: + if block[sizes_mongo.get(len(sample_hash), "")] == sample_hash: + file_path = os.path.join( + CUCKOO_ROOT, + "storage", + "analyses", + str(task["info"]["id"]), + folders.get("CAPE"), + block["sha256"], + ) + if path_exists(file_path): + sample = [file_path] + break + if sample: + break + + for category in ("dropped", "procdump"): + # we can't filter more if query isn't sha256 if repconf.mongodb.enabled: tasks = mongo_find( "analysis", - {"CAPE.payloads.file_ref": sample_hash}, - {"CAPE.payloads": 1, "_id": 0, "info.id": 1}, + {f"{category}.{sizes_mongo.get(len(sample_hash), '')}": sample_hash}, + {category: 1, "_id": 0, "info.id": 1}, ) elif repconf.elasticsearchdb.enabled: tasks = [ d["_source"] for d in es.search( index=get_analysis_index(), - body={"query": {"match": {"CAPE.payloads.file_ref": sample_hash}}}, - _source=["CAPE.payloads", "info.id"], + body={"query": {"match": {f"{category}.{sizes_mongo.get(len(sample_hash), '')}": sample_hash}}}, + _source=["info.id", category], )["hits"]["hits"] ] else: @@ -2652,14 +2257,14 @@ def sample_path_by_hash(self, sample_hash: str = False, task_id: int = False): if tasks: for task in tasks: - for block in task.get("CAPE", {}).get("payloads", []) or []: + for block in task.get(category, []) or []: if block[sizes_mongo.get(len(sample_hash), "")] == sample_hash: file_path = os.path.join( CUCKOO_ROOT, "storage", "analyses", str(task["info"]["id"]), - folders.get("CAPE"), + folders.get(category), block["sha256"], ) if path_exists(file_path): @@ -2668,157 +2273,80 @@ def sample_path_by_hash(self, sample_hash: str = False, task_id: int = False): if sample: break - for category in ("dropped", "procdump"): - # we can't filter more if query isn't sha256 - if repconf.mongodb.enabled: - tasks = mongo_find( - "analysis", - {f"{category}.{sizes_mongo.get(len(sample_hash), '')}": sample_hash}, - {category: 1, "_id": 0, "info.id": 1}, - ) - elif repconf.elasticsearchdb.enabled: - tasks = [ - d["_source"] - for d in es.search( - index=get_analysis_index(), - body={"query": {"match": {f"{category}.{sizes_mongo.get(len(sample_hash), '')}": sample_hash}}}, - _source=["info.id", category], - )["hits"]["hits"] - ] - else: - tasks = [] - - if tasks: - for task in tasks: - for block in task.get(category, []) or []: - if block[sizes_mongo.get(len(sample_hash), "")] == sample_hash: - file_path = os.path.join( - CUCKOO_ROOT, - "storage", - "analyses", - str(task["info"]["id"]), - folders.get(category), - block["sha256"], - ) - if path_exists(file_path): - sample = [file_path] - break - if sample: - break + if not sample: + # search in temp folder if not found in binaries + db_sample = ( + self.session.query(Task).join(Sample, Task.sample_id == Sample.id).filter(query_filter == sample_hash).all() + ) - if not sample: - # search in temp folder if not found in binaries - db_sample = ( - session.query(Task).join(Sample, Task.sample_id == Sample.id).filter(query_filter == sample_hash).all() + if db_sample is not None: + samples = [_f for _f in [tmp_sample.to_dict().get("target", "") for tmp_sample in db_sample] if _f] + # hash validation and if exist + samples = [file_path for file_path in samples if path_exists(file_path)] + for path in samples: + with open(path, "rb") as f: + if sample_hash == sizes[len(sample_hash)](f.read()).hexdigest(): + sample = [path] + break + + if not sample: + # search in Suricata files folder + if repconf.mongodb.enabled: + tasks = mongo_find( + "analysis", {"suricata.files.sha256": sample_hash}, {"suricata.files.file_info.path": 1, "_id": 0} ) + elif repconf.elasticsearchdb.enabled: + tasks = [ + d["_source"] + for d in es.search( + index=get_analysis_index(), + body={"query": {"match": {"suricata.files.sha256": sample_hash}}}, + _source="suricata.files.file_info.path", + )["hits"]["hits"] + ] + else: + tasks = [] - if db_sample is not None: - samples = [_f for _f in [tmp_sample.to_dict().get("target", "") for tmp_sample in db_sample] if _f] - # hash validation and if exist - samples = [file_path for file_path in samples if path_exists(file_path)] - for path in samples: - with open(path, "rb") as f: - if sample_hash == sizes[len(sample_hash)](f.read()).hexdigest(): - sample = [path] + if tasks: + for task in tasks: + for item in task["suricata"]["files"] or []: + file_path = item.get("file_info", {}).get("path", "") + if sample_hash in file_path: + if path_exists(file_path): + sample = [file_path] break - if not sample: - # search in Suricata files folder - if repconf.mongodb.enabled: - tasks = mongo_find( - "analysis", {"suricata.files.sha256": sample_hash}, {"suricata.files.file_info.path": 1, "_id": 0} - ) - elif repconf.elasticsearchdb.enabled: - tasks = [ - d["_source"] - for d in es.search( - index=get_analysis_index(), - body={"query": {"match": {"suricata.files.sha256": sample_hash}}}, - _source="suricata.files.file_info.path", - )["hits"]["hits"] - ] - else: - tasks = [] - - if tasks: - for task in tasks: - for item in task["suricata"]["files"] or []: - file_path = item.get("file_info", {}).get("path", "") - if sample_hash in file_path: - if path_exists(file_path): - sample = [file_path] - break - - except AttributeError: - pass - except SQLAlchemyError as e: - log.debug("Database error viewing task: %s", e) - finally: - session.close() - return sample - @classlock - def count_samples(self): + def count_samples(self) -> int: """Counts the amount of samples in the database.""" - with self.Session() as session: - try: - sample_count = session.query(Sample).count() - except SQLAlchemyError as e: - log.debug("Database error counting samples: %s", e) - return 0 + sample_count = self.session.query(Sample).count() return sample_count - @classlock - def view_machine(self, name): + def view_machine(self, name) -> Optional[Machine]: """Show virtual machine. @params name: virtual machine name @return: virtual machine's details """ - with self.Session() as session: - try: - machine = session.query(Machine).options(joinedload(Machine.tags)).filter(Machine.name == name).first() - # machine = session.execute(select(Machine).filter(Machine.name == name).options(joinedload(Tag))).first() - except SQLAlchemyError as e: - log.debug("Database error viewing machine: %s", e) - return None - else: - if machine: - session.expunge(machine) + machine = self.session.query(Machine).options(joinedload(Machine.tags)).filter(Machine.name == name).first() return machine - @classlock - def view_machine_by_label(self, label): + def view_machine_by_label(self, label) -> Optional[Machine]: """Show virtual machine. @params label: virtual machine label @return: virtual machine's details """ - with self.Session() as session: - try: - machine = session.query(Machine).options(joinedload(Machine.tags)).filter(Machine.label == label).first() - except SQLAlchemyError as e: - log.debug("Database error viewing machine by label: %s", e) - return None - else: - if machine: - session.expunge(machine) + machine = self.session.query(Machine).options(joinedload(Machine.tags)).filter(Machine.label == label).first() return machine - @classlock def view_errors(self, task_id): """Get all errors related to a task. @param task_id: ID of task associated to the errors @return: list of errors. """ - with self.Session() as session: - try: - errors = session.query(Error).filter_by(task_id=task_id).all() - except SQLAlchemyError as e: - log.debug("Database error viewing errors: %s", e) - return [] + errors = self.session.query(Error).filter_by(task_id=task_id).all() return errors - @classlock def get_source_url(self, sample_id=False): """ Retrieve url from where sample was downloaded @@ -2826,37 +2354,26 @@ def get_source_url(self, sample_id=False): @param task_id: Task id """ source_url = False - with self.Session() as session: - try: - if sample_id: - source_url = session.query(Sample.source_url).filter(Sample.id == int(sample_id)).first() - if source_url: - source_url = source_url[0] - except SQLAlchemyError as e: - log.debug("Database error listing tasks: %s", e) - except TypeError: - pass + try: + if sample_id: + source_url = self.session.query(Sample.source_url).filter(Sample.id == int(sample_id)).first() + if source_url: + source_url = source_url[0] + except TypeError: + pass return source_url - @classlock def ban_user_tasks(self, user_id: int): """ Ban all tasks submitted by user_id @param user_id: user id """ - with self.Session() as session: - _ = ( - session.query(Task) - .filter(Task.user_id == user_id) - .filter(Task.status == TASK_PENDING) - .update({Task.status: TASK_BANNED}, synchronize_session=False) - ) - session.commit() - session.close() + self.session.query(Task).filter(Task.user_id == user_id).filter(Task.status == TASK_PENDING).update( + {Task.status: TASK_BANNED}, synchronize_session=False + ) - @classlock def tasks_reprocess(self, task_id: int): """common func for api and views""" task = self.view_task(task_id) @@ -2877,5 +2394,36 @@ def tasks_reprocess(self, task_id: int): }: return True, f"Task ID {task_id} cannot be reprocessed in status {task.status}", task.status + # Save the old_status, because otherwise, in the call to set_status(), + # sqlalchemy will use the cached Task object that `task` is already a reference + # to and update that in place. That would result in `task.status` in this + # function being set to TASK_COMPLETED and we don't want to return that. + old_status = task.status self.set_status(task_id, TASK_COMPLETED) - return False, "", task.status + return False, "", old_status + + +_DATABASE: Optional[_Database] = None + + +class Database: + def __getattr__(self, attr: str) -> Any: + if _DATABASE is None: + raise CuckooDatabaseInitializationError + return getattr(_DATABASE, attr) + + +def init_database(*args, exists_ok=False, **kwargs) -> _Database: + global _DATABASE + if _DATABASE is not None: + if exists_ok: + return _DATABASE + raise RuntimeError("The database has already been initialized!") + _DATABASE = _Database(*args, **kwargs) + return _DATABASE + + +def reset_database_FOR_TESTING_ONLY(): + """Used for testing.""" + global _DATABASE + _DATABASE = None diff --git a/lib/cuckoo/core/guest.py b/lib/cuckoo/core/guest.py index ccf92fd1d5b..075b80704d2 100644 --- a/lib/cuckoo/core/guest.py +++ b/lib/cuckoo/core/guest.py @@ -117,6 +117,14 @@ def get(self, method, *args, **kwargs): do_raise and r.raise_for_status() return r + def get_status_from_db(self) -> str: + with db.session.begin(): + return db.guest_get_status(self.task_id) + + def set_status_in_db(self, status: str): + with db.session.begin(): + db.guest_set_status(self.task_id, status) + def post(self, method, *args, **kwargs): """Simple wrapper around requests.post().""" url = f"http://{self.ipaddr}:{self.port}{method}" @@ -140,7 +148,7 @@ def wait_available(self): """Wait until the Virtual Machine is available for usage.""" start = timeit.default_timer() - while db.guest_get_status(self.task_id) == "starting" and self.do_run: + while self.do_run and self.get_status_from_db() == "starting": try: socket.create_connection((self.ipaddr, self.port), 1).close() break @@ -248,7 +256,7 @@ def start_analysis(self, options): # Could be beautified a bit, but basically we have to perform the # same check here as we did in wait_available(). - if db.guest_get_status(self.task_id) != "starting": + if self.get_status_from_db() != "starting": return r = self.get("/", do_raise=False) @@ -262,7 +270,7 @@ def start_analysis(self, options): r.status_code, json.dumps(dict(r.headers)), ) - db.guest_set_status(self.task_id, "failed") + self.set_status_in_db("failed") return try: @@ -276,7 +284,7 @@ def start_analysis(self, options): "go through the documentation once more and otherwise inform " "the Cuckoo Developers of your issue" ) - db.guest_set_status(self.task_id, "failed") + self.set_status_in_db("failed") return log.info("Task #%s: Guest is running CAPE Agent %s (id=%s, ip=%s)", self.task_id, version, self.vmid, self.ipaddr) @@ -302,7 +310,8 @@ def start_analysis(self, options): # Lookup file if current doesn't exist in TMP anymore alternative_path = False if not path_exists(options["target"]): - path_found = db.sample_path_by_hash(task_id=options["id"]) + with db.session.begin(): + path_found = db.sample_path_by_hash(task_id=options["id"]) if path_found: alternative_path = path_found[0] @@ -345,22 +354,20 @@ def start_analysis(self, options): self.post("/execute", data=data) def wait_for_completion(self): - count = 0 start = timeit.default_timer() - while db.guest_get_status(self.task_id) == "running" and self.do_run: + while self.do_run and self.get_status_from_db() == "running": + time.sleep(1) + if cfg.cuckoo.machinery_screenshots: if count == 0: # indicate screenshot captures have started log.info("Task #%s: Started capturing screenshots for %s", self.task_id, self.vmid) self.analysis_manager.screenshot_machine() - if count >= 5: - log.debug("Task #%s: Analysis is still running (id=%s, ip=%s)", self.task_id, self.vmid, self.ipaddr) - count = 0 - count += 1 - time.sleep(1) + if count % 5 == 0: + log.debug("Task #%s: Analysis is still running (id=%s, ip=%s)", self.task_id, self.vmid, self.ipaddr) # If the analysis hits the critical timeout, just return straight # away and try to recover the analysis results from the guest. @@ -387,7 +394,7 @@ def wait_for_completion(self): if status["status"] in ("complete", "failed"): completed_as = "completed successfully" if status["status"] == "complete" else "failed" log.info("Task #%s: Analysis %s (id=%s, ip=%s)", completed_as, self.task_id, self.vmid, self.ipaddr) - db.guest_set_status(self.task_id, "complete") + self.set_status_in_db("complete") return elif status["status"] == "exception": log.warning( @@ -397,5 +404,5 @@ def wait_for_completion(self): self.ipaddr, status["description"], ) - db.guest_set_status(self.task_id, "failed") + self.set_status_in_db("failed") return diff --git a/lib/cuckoo/core/machinery_manager.py b/lib/cuckoo/core/machinery_manager.py new file mode 100644 index 00000000000..5f36dadf2a3 --- /dev/null +++ b/lib/cuckoo/core/machinery_manager.py @@ -0,0 +1,307 @@ +import logging +import threading +import time +from time import monotonic as _time +from typing import Optional, Union + +from lib.cuckoo.common.abstracts import Machinery +from lib.cuckoo.common.config import Config +from lib.cuckoo.common.exceptions import CuckooCriticalError, CuckooMachineError +from lib.cuckoo.core.database import Database, Machine, Task, _Database +from lib.cuckoo.core.plugins import list_plugins +from lib.cuckoo.core.rooter import rooter, vpns + +log = logging.getLogger(__name__) + +routing = Config("routing") + + +class ScalingBoundedSemaphore(threading.Semaphore): + """Implements a dynamic bounded semaphore. + + A bounded semaphore checks to make sure its current value doesn't exceed its + limit value. If it does, ValueError is raised. In most situations + semaphores are used to guard resources with limited capacity. + + If the semaphore is released too many times it's a sign of a bug. If not + given, value defaults to 1. + + Like regular semaphores, bounded semaphores manage a counter representing + the number of release() calls minus the number of acquire() calls, plus a + limit value. The acquire() method blocks if necessary until it can return + without making the counter negative. If not given, value defaults to 1. + + In this version of semaphore there is an upper limit where its limit value + can never reach when it is changed. The idea behind it is that in machinery + documentation there is a limit of machines that can be available so there is + no point having it higher than that. + """ + + def __init__(self, value=1, upper_limit=1): + threading.Semaphore.__init__(self, value) + self._limit_value = value + self._upper_limit = upper_limit + + def acquire(self, blocking=True, timeout=None): + """Acquire a semaphore, decrementing the internal counter by one. + + When invoked without arguments: if the internal counter is larger than + zero on entry, decrement it by one and return immediately. If it is zero + on entry, block, waiting until some other thread has called release() to + make it larger than zero. This is done with proper interlocking so that + if multiple acquire() calls are blocked, release() will wake exactly one + of them up. The implementation may pick one at random, so the order in + which blocked threads are awakened should not be relied on. There is no + return value in this case. + + When invoked with blocking set to true, do the same thing as when called + without arguments, and return true. + + When invoked with blocking set to false, do not block. If a call without + an argument would block, return false immediately; otherwise, do the + same thing as when called without arguments, and return true. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + + """ + if not blocking and timeout is not None: + raise ValueError("Cannot specify timeout for non-blocking acquire()") + rc = False + endtime = None + with self._cond: + while self._value == 0: + if not blocking: + break + if timeout is not None: + if endtime is None: + endtime = _time() + timeout + else: + timeout = endtime - _time() + if timeout <= 0: + break + self._cond.wait(timeout) + else: + self._value -= 1 + rc = True + return rc + + __enter__ = acquire + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + + When the counter is zero on entry and another thread is waiting for it + to become larger than zero again, wake up that thread. + + If the number of releases exceeds the number of acquires, + raise a ValueError. + + """ + with self._cond: + if self._value > self._upper_limit: + raise ValueError("Semaphore released too many times") + if self._value >= self._limit_value: + self._value = self._limit_value + self._cond.notify() + return + self._value += 1 + self._cond.notify() + + def __exit__(self, t, v, tb): + self.release() + + def update_limit(self, value): + """Update the limit value for the semaphore + + This limit value is the bounded limit, and proposed limit values + are validated against the upper limit. + + """ + if 0 < value < self._upper_limit: + self._limit_value = value + if self._value > value: + self._value = value + + def check_for_starvation(self, available_count: int): + """Check for preventing starvation from coming up after updating the limit. + Take no parameter. + Return true on starvation. + """ + if self._value == 0 and available_count == self._limit_value: + self._value = self._limit_value + return True + # Resync of the lock value + if abs(self._value - available_count) > 0: + self._value = available_count + return True + return False + + +MachineryLockType = Union[threading.Lock, threading.BoundedSemaphore, ScalingBoundedSemaphore] + + +class MachineryManager: + def __init__(self): + self.cfg = Config() + self.db: _Database = Database() + self.machinery_name: str = self.cfg.cuckoo.machinery + self.machinery: Machinery = self.create_machinery() + self.pool_scaling_lock = threading.Lock() + if self.machinery.module_name != self.machinery_name: + raise CuckooCriticalError( + f"Incorrect machinery module was imported. " + f"Should've been {self.machinery_name} but was {self.machinery.module_name}" + ) + log.info( + "Using %s with max_machines_count=%d", + self, + self.cfg.cuckoo.max_machines_count, + ) + self.machine_lock: MachineryLockType = self.create_machine_lock() + + def __str__(self): + return f"{self.__class__.__name__}[{self.machinery_name}]" + + def create_machine_lock(self) -> MachineryLockType: + retval: MachineryLockType = threading.Lock() + + # You set this value if you are using a machinery that is NOT auto-scaling + max_vmstartup_count = self.cfg.cuckoo.max_vmstartup_count + if max_vmstartup_count: + # The BoundedSemaphore is used to prevent CPU starvation when starting up multiple VMs + log.info("max_vmstartup_count for BoundedSemaphore = %d", max_vmstartup_count) + retval = threading.BoundedSemaphore(max_vmstartup_count) + + # You set this value if you are using a machinery that IS auto-scaling + elif self.cfg.cuckoo.scaling_semaphore: + # If the user wants to use the scaling bounded semaphore, check what machinery is specified, and then + # grab the required configuration key for setting the upper limit + machinery_opts = self.machinery.options.get(self.machinery_name) + machines_limit: int = 0 + if self.machinery_name == "az": + machines_limit = machinery_opts.get("total_machines_limit") + elif self.machinery_name == "aws": + machines_limit = machinery_opts.get("dynamic_machines_limit") + if machines_limit: + # The ScalingBoundedSemaphore is used to keep feeding available machines from the pending tasks queue + log.info("upper limit for ScalingBoundedSemaphore = %d", machines_limit) + retval = ScalingBoundedSemaphore(value=len(machinery_opts["machines"]), upper_limit=machines_limit) + else: + log.warning( + "scaling_semaphore is set but the %s machinery does not set the machines limit. Ignoring scaling semaphore.", + self.machinery_name, + ) + + return retval + + @staticmethod + def create_machinery() -> Machinery: + # Get registered class name. Only one machine manager is imported, + # therefore there should be only one class in the list. + plugin = list_plugins("machinery")[0] + machinery: Machinery = plugin() + + return machinery + + def find_machine_to_service_task(self, task: Task) -> Optional[Machine]: + machine = self.db.find_machine_to_service_task(task) + if machine: + log.info( + "Task #%s: found useable machine %s (arch=%s, platform=%s)", + task.id, + machine.name, + machine.arch, + machine.platform, + ) + else: + log.debug( + "Task #%s: no machine available yet for task requiring machine '%s', platform '%s' or tags '%s'.", + task.id, + task.machine, + task.platform, + task.tags, + ) + + return machine + + def initialize_machinery(self) -> None: + """Initialize the machines in the database and initialize routing for them.""" + try: + self.machinery.initialize() + except CuckooMachineError as e: + raise CuckooCriticalError("Error initializing machines") from e + + # At this point all the available machines should have been identified + # and added to the list. If none were found, Cuckoo needs to abort the + # execution. + available_machines = list(self.machinery.machines()) + if not len(available_machines): + raise CuckooCriticalError("No machines available") + else: + log.info("Loaded %d machine%s", len(available_machines), "s" if len(available_machines) != 1 else "") + + if len(available_machines) > 1 and self.db.engine.name == "sqlite": + log.warning( + "As you've configured CAPE to execute parallel analyses, we recommend you to switch to a PostgreSQL database as SQLite might cause some issues" + ) + + # Drop all existing packet forwarding rules for each VM. Just in case + # Cuckoo was terminated for some reason and various forwarding rules + # have thus not been dropped yet. + for machine in available_machines: + rooter("inetsim_disable", machine.ip) + if not machine.interface: + log.info( + "Unable to determine the network interface for VM with name %s, Cape will not be able to give it " + "full internet access or route it through a VPN! Please define a default network interface for the " + "machinery or define a network interface for each VM", + machine.name, + ) + continue + + # Drop forwarding rule to each VPN. + for vpn in vpns.values(): + rooter("forward_disable", machine.interface, vpn.interface, machine.ip) + + # Drop forwarding rule to the internet / dirty line. + if routing.routing.internet != "none": + rooter("forward_disable", machine.interface, routing.routing.internet, machine.ip) + + threading.Thread(target=self.thr_maintain_scaling_bounded_semaphore, daemon=True) + + def running_machines_max_reached(self) -> bool: + """Return true if we've reached the maximum number of running machines.""" + return 0 < self.cfg.cuckoo.max_machines_count <= self.machinery.running_count() + + def scale_pool(self, machine: Machine) -> None: + """For machinery backends that support auto-scaling, make sure that enough machines + are spun up. For other types of machinery, this is basically a noop. This is called + from the AnalysisManager (i.e. child) thread, so we use a lock to make sure that + it doesn't get called multiple times simultaneously. We don't want to call it from + the main thread as that would block the scheduler while machines are spun up. + Note that the az machinery maintains its own thread to monitor to size of the pool. + """ + with self.pool_scaling_lock: + self.machinery.scale_pool(machine) + + def start_machine(self, machine: Machine) -> None: + with self.machine_lock: + self.machinery.start(machine.label) + + def stop_machine(self, machine: Machine) -> None: + self.machinery.stop(machine.label) + + def thr_maintain_scaling_bounded_semaphore(self) -> None: + """Maintain the limit of the ScalingBoundedSemaphore if one is being used.""" + if not isinstance(self.machine_lock, ScalingBoundedSemaphore) or not self.cfg.cuckoo.scaling_semaphore_update_timer: + return + + while True: + with self.db.session.begin(): + # Here be dragons! Making these calls on the ScalingBoundedSemaphore is not + # thread safe. + self.machine_lock.update_limit(len(self.machinery.machines())) + self.machine_lock.check_for_starvation(self.machinery.availables(include_reserved=True)) + time.sleep(self.cfg.cuckoo.scaling_semaphore_update_timer) diff --git a/lib/cuckoo/core/scheduler.py b/lib/cuckoo/core/scheduler.py index 59cdad280b7..719880c057b 100644 --- a/lib/cuckoo/core/scheduler.py +++ b/lib/cuckoo/core/scheduler.py @@ -2,60 +2,27 @@ # This file is part of Cuckoo Sandbox - http://www.cuckoosandbox.org # See the file 'docs/LICENSE' for copying permission. +import contextlib import enum import logging import os import queue -import shutil import signal import threading import time from collections import defaultdict -from time import monotonic as _time +from typing import DefaultDict, List, Optional, Tuple from lib.cuckoo.common.config import Config from lib.cuckoo.common.constants import CUCKOO_ROOT -from lib.cuckoo.common.exceptions import ( - CuckooCriticalError, - CuckooGuestCriticalTimeout, - CuckooGuestError, - CuckooMachineError, - CuckooOperationalError, -) -from lib.cuckoo.common.integrations.parse_pe import PortableExecutable -from lib.cuckoo.common.objects import File -from lib.cuckoo.common.path_utils import path_delete, path_exists, path_mkdir -from lib.cuckoo.common.utils import convert_to_printable, create_folder, free_space_monitor, get_memdump_path, load_categories -from lib.cuckoo.core.database import TASK_COMPLETED, TASK_FAILED_ANALYSIS, TASK_PENDING, Database, Task -from lib.cuckoo.core.guest import GuestManager -from lib.cuckoo.core.log import task_log_stop -from lib.cuckoo.core.plugins import RunAuxiliary, list_plugins -from lib.cuckoo.core.resultserver import ResultServer -from lib.cuckoo.core.rooter import _load_socks5_operational, rooter, vpns - -# os.listdir('/sys/class/net/') -HAVE_NETWORKIFACES = False -try: - import psutil - - network_interfaces = list(psutil.net_if_addrs().keys()) - HAVE_NETWORKIFACES = True -except ImportError: - print("Missed dependency: pip3 install psutil") +from lib.cuckoo.common.exceptions import CuckooUnserviceableTaskError +from lib.cuckoo.common.utils import CATEGORIES_NEEDING_VM, free_space_monitor, load_categories +from lib.cuckoo.core.analysis_manager import AnalysisManager +from lib.cuckoo.core.database import TASK_FAILED_ANALYSIS, TASK_PENDING, Database, Machine, Task, _Database +from lib.cuckoo.core.machinery_manager import MachineryManager log = logging.getLogger(__name__) -machinery = None -machine_lock = None -latest_symlink_lock = threading.Lock() -routing = Config("routing") -web_cfg = Config("web") -enable_trim = int(web_cfg.general.enable_trim) -expose_vnc_port = web_cfg.guacamole.enabled - -active_analysis_count = 0 -active_analysis_count_lock = threading.Lock() - class LoopState(enum.IntEnum): """Enum that represents the state of the main scheduler loop.""" @@ -66,796 +33,245 @@ class LoopState(enum.IntEnum): INACTIVE = 4 -class ScalingBoundedSemaphore(threading.Semaphore): - """Implements a dynamic bounded semaphore. - - A bounded semaphore checks to make sure its current value doesn't exceed its - limit value. If it does, ValueError is raised. In most situations - semaphores are used to guard resources with limited capacity. - - If the semaphore is released too many times it's a sign of a bug. If not - given, value defaults to 1. - - Like regular semaphores, bounded semaphores manage a counter representing - the number of release() calls minus the number of acquire() calls, plus a - limit value. The acquire() method blocks if necessary until it can return - without making the counter negative. If not given, value defaults to 1. - - In this version of semaphore there is an upper limit where its limit value - can never reach when it is changed. The idea behind it is that in machinery - documentation there is a limit of machines that can be available so there is - no point having it higher than that. - """ - - def __init__(self, value=1, upper_limit=1): - threading.Semaphore.__init__(self, value) - self._limit_value = value - self._upper_limit = upper_limit - - def acquire(self, blocking=True, timeout=None): - """Acquire a semaphore, decrementing the internal counter by one. - - When invoked without arguments: if the internal counter is larger than - zero on entry, decrement it by one and return immediately. If it is zero - on entry, block, waiting until some other thread has called release() to - make it larger than zero. This is done with proper interlocking so that - if multiple acquire() calls are blocked, release() will wake exactly one - of them up. The implementation may pick one at random, so the order in - which blocked threads are awakened should not be relied on. There is no - return value in this case. - - When invoked with blocking set to true, do the same thing as when called - without arguments, and return true. - - When invoked with blocking set to false, do not block. If a call without - an argument would block, return false immediately; otherwise, do the - same thing as when called without arguments, and return true. - - When invoked with a timeout other than None, it will block for at - most timeout seconds. If acquire does not complete successfully in - that interval, return false. Return true otherwise. - - """ - if not blocking and timeout is not None: - raise ValueError("Cannot specify timeout for non-blocking acquire()") - rc = False - endtime = None - with self._cond: - while self._value == 0: - if not blocking: - break - if timeout is not None: - if endtime is None: - endtime = _time() + timeout - else: - timeout = endtime - _time() - if timeout <= 0: - break - self._cond.wait(timeout) - else: - self._value -= 1 - rc = True - return rc +class SchedulerCycleDelay(enum.IntEnum): + SUCCESS = 0 + NO_PENDING_TASKS = 1 + MAX_MACHINES_RUNNING = 1 + SCHEDULER_PAUSED = 5 + FAILURE = 5 + LOW_DISK_SPACE = 30 - __enter__ = acquire - - def release(self): - """Release a semaphore, incrementing the internal counter by one. - - When the counter is zero on entry and another thread is waiting for it - to become larger than zero again, wake up that thread. - - If the number of releases exceeds the number of acquires, - raise a ValueError. - - """ - with self._cond: - if self._value > self._upper_limit: - raise ValueError("Semaphore released too many times") - if self._value >= self._limit_value: - self._value = self._limit_value - self._cond.notify() - return - self._value += 1 - self._cond.notify() - - def __exit__(self, t, v, tb): - self.release() - - def update_limit(self, value): - """Update the limit value for the semaphore - - This limit value is the bounded limit, and proposed limit values - are validated against the upper limit. - - """ - if value < self._upper_limit and value > 0: - self._limit_value = value - if self._value > value: - self._value = value - - def check_for_starvation(self, available_count: int): - """Check for preventing starvation from coming up after updating the limit. - Take no parameter. - Return true on starvation. - """ - if self._value == 0 and available_count == self._limit_value: - self._value = self._limit_value - return True - # Resync of the lock value - if abs(self._value - available_count) > 0: - self._value = available_count - return True - return False - - -class CuckooDeadMachine(Exception): - """Exception thrown when a machine turns dead. - - When this exception has been thrown, the analysis task will start again, - and will try to use another machine, when available. - """ - pass - - -class AnalysisManager(threading.Thread): - """Analysis Manager. +class Scheduler: + """Tasks Scheduler. - This class handles the full analysis process for a given task. It takes - care of selecting the analysis machine, preparing the configuration and - interacting with the guest agent and analyzer components to launch and - complete the analysis and store, process and report its results. + This class is responsible for the main execution loop of the tool. It + prepares the analysis machines and keep waiting and loading for new + analysis tasks. + Whenever a new task is available, it launches AnalysisManager which will + take care of running the full analysis process and operating with the + assigned analysis machine. """ - def __init__(self, task, error_queue): - """@param task: task object containing the details for the analysis.""" - threading.Thread.__init__(self) - self.task = task - self.errors = error_queue + def __init__(self, maxcount=0): + self.loop_state = LoopState.INACTIVE self.cfg = Config() - self.aux_cfg = Config("auxiliary") - self.storage = "" - self.screenshot_path = "" - self.num_screenshots = 0 - self.binary = "" - self.machine = None - self.db = Database() - self.interface = None - self.rt_table = None - self.route = None - self.rooter_response = "" - self.reject_segments = None - self.reject_hostports = None + self.db: _Database = Database() + self.max_analysis_count: int = maxcount or self.cfg.cuckoo.max_analysis_count + self.analysis_threads_lock = threading.Lock() + self.total_analysis_count: int = 0 + self.analysis_threads: List[AnalysisManager] = [] + self.analyzing_categories, categories_need_VM = load_categories() + self.machinery_manager = MachineryManager() if categories_need_VM else None + log.info("Creating scheduler with max_analysis_count=%s", self.max_analysis_count or "unlimited") + + @property + def active_analysis_count(self) -> int: + with self.analysis_threads_lock: + return len(self.analysis_threads) + + def analysis_finished(self, analysis_manager: AnalysisManager): + with self.analysis_threads_lock: + try: + self.analysis_threads.remove(analysis_manager) + except ValueError: + pass - def init_storage(self): - """Initialize analysis storage folder.""" - self.storage = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(self.task.id)) - self.screenshot_path = os.path.join(self.storage, "shots") + def do_main_loop_work(self, error_queue: queue.Queue) -> SchedulerCycleDelay: + """Return the number of seconds to sleep after returning.""" + if self.loop_state == LoopState.STOPPING: + # This blocks the main loop until the analyses are finished. + self.wait_for_running_analyses_to_finish() + self.loop_state = LoopState.INACTIVE + return SchedulerCycleDelay.SUCCESS - # If the analysis storage folder already exists, we need to abort the - # analysis or previous results will be overwritten and lost. - if path_exists(self.storage): - log.error("Task #%s: Analysis results folder already exists at path '%s', analysis aborted", self.task.id, self.storage) - return False + if self.loop_state == LoopState.PAUSED: + log.debug("scheduler is paused, send '%s' to process %d to resume", signal.SIGUSR2, os.getpid()) + return SchedulerCycleDelay.SCHEDULER_PAUSED - # If we're not able to create the analysis storage folder, we have to - # abort the analysis. - try: - create_folder(folder=self.storage) - except CuckooOperationalError: - log.error("Task #%s: Unable to create analysis folder %s", self.task.id, self.storage) - return False + if 0 < self.max_analysis_count <= self.total_analysis_count: + log.info("Maximum analysis count has been reached, shutting down.") + self.stop() + return SchedulerCycleDelay.SUCCESS - return True + if self.is_short_on_disk_space(): + return SchedulerCycleDelay.LOW_DISK_SPACE - def check_file(self, sha256): - """Checks the integrity of the file to be analyzed.""" - sample = self.db.view_sample(self.task.sample_id) + analysis_manager: Optional[AnalysisManager] = None + with self.db.session.begin(): + if self.machinery_manager and self.machinery_manager.running_machines_max_reached(): + return SchedulerCycleDelay.MAX_MACHINES_RUNNING - if sample and sha256 != sample.sha256: - log.error( - "Task #%s: Target file has been modified after submission: '%s'", - self.task.id, - convert_to_printable(self.task.target), + try: + task, machine = self.find_next_serviceable_task() + except Exception: + log.exception("Failed to find next serviceable task") + # Explicitly call rollback since we're not re-raising the exception and letting the + # begin() context manager handle rolling back the transaction. + self.db.session.rollback() + return SchedulerCycleDelay.FAILURE + + if task is None: + # There are no pending tasks so try again in 1 second. + return SchedulerCycleDelay.NO_PENDING_TASKS + + log.info("Task #%s: Processing task", task.id) + self.total_analysis_count += 1 + analysis_manager = AnalysisManager( + task, + machine=machine, + machinery_manager=self.machinery_manager, + error_queue=error_queue, + done_callback=self.analysis_finished, ) - return False + analysis_manager.prepare_task_and_machine_to_start() + self.db.session.expunge_all() - return True + with self.analysis_threads_lock: + self.analysis_threads.append(analysis_manager) + analysis_manager.start() - def store_file(self, sha256): - """Store a copy of the file being analyzed.""" - if not path_exists(self.task.target): - log.error( - "Task #%s: The file to analyze does not exist at path '%s', analysis aborted", - self.task.id, - convert_to_printable(self.task.target), - ) - return False + return SchedulerCycleDelay.SUCCESS - self.binary = os.path.join(CUCKOO_ROOT, "storage", "binaries", sha256) + def find_next_serviceable_task(self) -> Tuple[Optional[Task], Optional[Machine]]: + task: Optional[Task] = None + machine: Optional[Machine] = None - if path_exists(self.binary): - log.info("Task #%s: File already exists at '%s'", self.task.id, self.binary) + if self.machinery_manager: + task, machine = self.find_pending_task_to_service() else: - # TODO: do we really need to abort the analysis in case we are not able to store a copy of the file? - try: - shutil.copy(self.task.target, self.binary) - except (IOError, shutil.Error): - log.error( - "Task #%s: Unable to store file from '%s' to '%s', analysis aborted", - self.task.id, - self.task.target, - self.binary, - ) - return False - - try: - new_binary_path = os.path.join(self.storage, "binary") - if hasattr(os, "symlink"): - os.symlink(self.binary, new_binary_path) - else: - shutil.copy(self.binary, new_binary_path) - except (AttributeError, OSError) as e: - log.error("Task #%s: Unable to create symlink/copy from '%s' to '%s': %s", self.task.id, self.binary, self.storage, e) - - return True - - def screenshot_machine(self): - if not self.cfg.cuckoo.machinery_screenshots: - return - if self.machine is None: - log.error("Task #%s: screenshot not possible, no machine acquired yet", self.task.id) - return - - # same format and filename approach here as VM-based screenshots - self.num_screenshots += 1 - screenshot_filename = f"{str(self.num_screenshots).rjust(4, '0')}.jpg" - screenshot_path = os.path.join(self.screenshot_path, screenshot_filename) - machinery.screenshot(self.machine.label, screenshot_path) - - def acquire_machine(self): - """Acquire an analysis machine from the pool of available ones.""" - machine = None - orphan = False - # Start a loop to acquire a machine to run the analysis on. - while True: - machine_lock.acquire() - - # If the user specified a specific machine ID, a platform to be - # used or machine tags acquire the machine accordingly. - task_archs, task_tags = self.db._task_arch_tags_helper(self.task) - os_version = self.db._package_vm_requires_check(self.task.package) + task = self.find_pending_task_not_requiring_machinery() + + return task, machine + + def find_pending_task_not_requiring_machinery(self) -> Optional[Task]: + # This function must only be called when we're configured to not process any tasks + # that require machinery. + assert not self.machinery_manager + + task: Optional[Task] = None + tasks = self.db.list_tasks( + category=[category for category in self.analyzing_categories if category not in CATEGORIES_NEEDING_VM], + status=TASK_PENDING, + order_by=(Task.priority.desc(), Task.added_on), + options_not_like="node=", + limit=1, + for_update=True, + ) + if tasks: + task = tasks[0] + return task + + def find_pending_task_to_service(self) -> Tuple[Optional[Task], Optional[Machine]]: + # This function must only be called when we have the ability to use machinery. + assert self.machinery_manager + + task: Optional[Task] = None + machine: Optional[Machine] = None + # Get the list of all pending tasks in the order that they should be processed. + for task_candidate in self.db.list_tasks( + status=TASK_PENDING, + order_by=(Task.priority.desc(), Task.added_on), + options_not_like="node=", + for_update=True, + ): + if task_candidate.category not in CATEGORIES_NEEDING_VM: + # This task can definitely be processed because it doesn't need a machine. + task = task_candidate + break - # In some cases it's possible that we enter this loop without having any available machines. We should make sure this is not - # such case, or the analysis task will fail completely. - if not machinery.availables( - label=self.task.machine, platform=self.task.platform, tags=task_tags, arch=task_archs, os_version=os_version - ): - machine_lock.release() - log.debug( - "Task #%s: no machine available yet for machine '%s', platform '%s' or tags '%s'.", - self.task.id, - self.task.machine, - self.task.platform, - self.task.tags, - ) - time.sleep(1) + try: + machine = self.machinery_manager.find_machine_to_service_task(task_candidate) + except CuckooUnserviceableTaskError: + if self.cfg.cuckoo.fail_unserviceable: + log.info("Task #%s: Failing unserviceable task", task_candidate.id) + self.db.set_status(task_candidate.id, TASK_FAILED_ANALYSIS) + else: + log.info("Task #%s: Unserviceable task", task_candidate.id) continue - if self.cfg.cuckoo.batch_scheduling and not orphan: - machine = machinery.acquire( - machine_id=self.task.machine, - platform=self.task.platform, - tags=task_tags, - arch=task_archs, - os_version=os_version, - need_scheduled=True, - ) - else: - machine = machinery.acquire( - machine_id=self.task.machine, - platform=self.task.platform, - tags=task_tags, - arch=task_archs, - os_version=os_version, - need_scheduled=True, - ) - # If no machine is available at this moment, wait for one second and try again. - if not machine: - machine_lock.release() - log.debug( - "Task #%s: no machine available yet for machine '%s', platform '%s' or tags '%s'.", - self.task.id, - self.task.machine, - self.task.platform, - self.task.tags, - ) - time.sleep(1) - orphan = True - else: - log.info( - "Task #%s: acquired machine %s (label=%s, arch=%s, platform=%s)", - self.task.id, - machine.name, - machine.label, - machine.arch, - machine.platform, - ) + if machine: + task = task_candidate break - self.machine = machine - - def build_options(self): - """Generate analysis options. - @return: options dict. - """ - options = { - "id": self.task.id, - "ip": self.machine.resultserver_ip, - "port": self.machine.resultserver_port, - "category": self.task.category, - "target": self.task.target, - "package": self.task.package, - "options": self.task.options, - "enforce_timeout": self.task.enforce_timeout, - "clock": self.task.clock, - "terminate_processes": self.cfg.cuckoo.terminate_processes, - "upload_max_size": self.cfg.resultserver.upload_max_size, - "do_upload_max_size": int(self.cfg.resultserver.do_upload_max_size), - "enable_trim": enable_trim, - "timeout": self.task.timeout or self.cfg.timeouts.default, - } + return task, machine - if self.task.category == "file": - file_obj = File(self.task.target) - options["file_name"] = file_obj.get_name() - options["file_type"] = file_obj.get_type() - # if it's a PE file, collect export information to use in more smartly determining the right package to use - options["exports"] = PortableExecutable(self.task.target).get_dll_exports() - del file_obj + def get_available_machine_stats(self) -> DefaultDict[str, int]: + available_machine_stats = defaultdict(int) + for machine in self.db.get_available_machines(): + for tag in machine.tags: + if tag: + available_machine_stats[tag.name] += 1 + if machine.platform: + available_machine_stats[machine.platform] += 1 - # options from auxiliary.conf - for plugin in self.aux_cfg.auxiliary_modules.keys(): - options[plugin] = self.aux_cfg.auxiliary_modules[plugin] + return available_machine_stats - return options + def get_locked_machine_stats(self) -> DefaultDict[str, int]: + locked_machine_stats = defaultdict(int) + for machine in self.db.list_machines(locked=True): + for tag in machine.tags: + if tag: + locked_machine_stats[tag.name] += 1 + if machine.platform: + locked_machine_stats[machine.platform] += 1 - def category_checks(self): - if self.task.category in ("file", "pcap", "static"): - sha256 = File(self.task.target).get_sha256() - # Check whether the file has been changed for some unknown reason. - # And fail this analysis if it has been modified. - if not self.check_file(sha256): - log.debug("check file") - return False + return locked_machine_stats - # Store a copy of the original file. - if not self.store_file(sha256): - log.debug("store file") - return False + def get_pending_task_stats(self) -> DefaultDict[str, int]: + pending_task_stats = defaultdict(int) + for task in self.db.list_tasks(status=TASK_PENDING): + for tag in task.tags: + if tag: + pending_task_stats[tag.name] += 1 + if task.platform: + pending_task_stats[task.platform] += 1 + if task.machine: + pending_task_stats[task.machine] += 1 - if self.task.category in ("pcap", "static"): - if self.task.category == "pcap": - if hasattr(os, "symlink"): - os.symlink(self.binary, os.path.join(self.storage, "dump.pcap")) - else: - shutil.copy(self.binary, os.path.join(self.storage, "dump.pcap")) - # create the logs/files directories as - # normally the resultserver would do it - dirnames = ["logs", "files", "aux"] - for dirname in dirnames: - try: - path_mkdir(os.path.join(self.storage, dirname)) - except Exception: - log.debug("Failed to create folder %s", dirname) - return True + return pending_task_stats - def launch_analysis(self): - """Start analysis.""" - global active_analysis_count - succeeded = False - dead_machine = False - self.socks5s = _load_socks5_operational() - aux = False - # Initialize the analysis folders. - if not self.init_storage(): - log.debug("Failed to initialize the analysis folder") + def is_short_on_disk_space(self): + """If not enough free disk space is available, then we print an + error message and wait another round. This check is ignored + when the freespace configuration variable is set to zero. + """ + if not self.cfg.cuckoo.freespace: return False - category_early_escape = self.category_checks() - if isinstance(category_early_escape, bool): - return category_early_escape - - log.info( - "Task #%s: Starting analysis of %s '%s'", - self.task.id, - self.task.category.upper(), - convert_to_printable(self.task.target), - ) - - # Acquire analysis machine. - try: - self.acquire_machine() - guest_log = self.db.set_task_vm_and_guest_start( - self.task.id, - self.machine.name, - self.machine.label, - self.task.platform, - self.machine.id, - machinery.__class__.__name__, + # Resolve the full base path to the analysis folder, just in + # case somebody decides to make a symbolic link out of it. + dir_path = os.path.join(CUCKOO_ROOT, "storage", "analyses") + need_space, space_available = free_space_monitor(dir_path, return_value=True, analysis=True) + if need_space: + log.error( + "Not enough free disk space! (Only %d MB!). You can change limits it in cuckoo.conf -> freespace", space_available ) - # At this point we can tell the ResultServer about it. - except CuckooOperationalError as e: - machine_lock.release() - log.error("Task #%s: Cannot acquire machine: %s", self.task.id, e, exc_info=True) - return False + return need_space + @contextlib.contextmanager + def loop_signals(self): + signals_to_handle = (signal.SIGHUP, signal.SIGTERM, signal.SIGUSR1, signal.SIGUSR2) + for sig in signals_to_handle: + signal.signal(sig, self.signal_handler) try: - unlocked = False - - # Mark the selected analysis machine in the database as started. - # Start the machine. - machinery.start(self.machine.label) - - # By the time start returns it will have fully started the Virtual - # Machine. We can now safely release the machine lock. - machine_lock.release() - unlocked = True - - # Generate the analysis configuration file. - options = self.build_options() - - if expose_vnc_port and hasattr(machinery, "store_vnc_port"): - machinery.store_vnc_port(self.machine.label, self.task.id) - - try: - ResultServer().add_task(self.task, self.machine) - except Exception as e: - machinery.release(self.machine.label) - log.exception(e, exc_info=True) - self.errors.put(e) - - aux = RunAuxiliary(task=self.task, machine=self.machine) - - # Enable network routing. - self.route_network() - - aux.start() - - # Initialize the guest manager. - guest = GuestManager(self.machine.name, self.machine.ip, self.machine.platform, self.task.id, self) - - options["clock"] = self.db.update_clock(self.task.id) - self.db.guest_set_status(self.task.id, "starting") - # Start the analysis. - guest.start_analysis(options) - if self.db.guest_get_status(self.task.id) == "starting": - self.db.guest_set_status(self.task.id, "running") - guest.wait_for_completion() - self.db.guest_set_status(self.task.id, "stopping") - succeeded = True - except (CuckooMachineError, CuckooGuestCriticalTimeout) as e: - if not unlocked: - machine_lock.release() - log.error(str(e), extra={"task_id": self.task.id}, exc_info=True) - dead_machine = True - except CuckooGuestError as e: - if not unlocked: - machine_lock.release() - log.error(str(e), extra={"task_id": self.task.id}, exc_info=True) + yield finally: - # Stop Auxiliary modules. - if aux: - aux.stop() - - # Take a memory dump of the machine before shutting it off. - if self.cfg.cuckoo.memory_dump or self.task.memory: - try: - dump_path = get_memdump_path(self.task.id) - need_space, space_available = free_space_monitor(os.path.dirname(dump_path), return_value=True) - if need_space: - log.error("Not enough free disk space! Could not dump ram (Only %d MB!)", space_available) - else: - machinery.dump_memory(self.machine.label, dump_path) - except NotImplementedError: - log.error("The memory dump functionality is not available for the current machine manager") - - except CuckooMachineError as e: - log.error(e, exc_info=True) - - try: - # Stop the analysis machine. - machinery.stop(self.machine.label) - - except CuckooMachineError as e: - log.warning("Task #%s: Unable to stop machine %s: %s", self.task.id, self.machine.label, e) - - # Mark the machine in the database as stopped. Unless this machine - # has been marked as dead, we just keep it as "started" in the - # database so it'll not be used later on in this session. - self.db.guest_stop(guest_log) - - # After all this, we can make the ResultServer forget about the - # internal state for this analysis task. - ResultServer().del_task(self.task, self.machine) - - # Drop the network routing rules if any. - self.unroute_network() - - if dead_machine: - # Remove the guest from the database, so that we can assign a - # new guest when the task is being analyzed with another machine. - self.db.guest_remove(guest_log) - machinery.delete_machine(self.machine.name) - - # Remove the analysis directory that has been created so - # far, as launch_analysis() is going to be doing that again. - shutil.rmtree(self.storage) - - # This machine has turned dead, so we throw an exception here - # which informs the AnalysisManager that it should analyze - # this task again with another available machine. - raise CuckooDeadMachine() + for sig in signals_to_handle: + signal.signal(sig, signal.SIG_DFL) - try: - # Release the analysis machine. But only if the machine has not turned dead yet. - machinery.release(self.machine.label) - - except CuckooMachineError as e: - log.error( - "Task #%s: Unable to release machine %s, reason %s. You might need to restore it manually", - self.task.id, - self.machine.label, - e, - ) - - return succeeded - - def run(self): - """Run manager thread.""" - global active_analysis_count - active_analysis_count_lock.acquire() - active_analysis_count += 1 - active_analysis_count_lock.release() - try: - while True: - try: - success = self.launch_analysis() - except CuckooDeadMachine as e: - log.exception(e) - continue - - break - - self.db.set_status(self.task.id, TASK_COMPLETED) - - # If the task is still available in the database, update our task - # variable with what's in the database, as otherwise we're missing - # out on the status and completed_on change. This would then in - # turn thrown an exception in the analysisinfo processing module. - self.task = self.db.view_task(self.task.id) or self.task - - log.debug("Task #%s: Released database task with status %s", self.task.id, success) - - # We make a symbolic link ("latest") which links to the latest - # analysis - this is useful for debugging purposes. This is only - # supported under systems that support symbolic links. - if hasattr(os, "symlink"): - latest = os.path.join(CUCKOO_ROOT, "storage", "analyses", "latest") - - # First we have to remove the existing symbolic link, then we have to create the new one. - # Deal with race conditions using a lock. - latest_symlink_lock.acquire() - try: - # As per documentation, lexists() returns True for dead symbolic links. - if os.path.lexists(latest): - path_delete(latest) - - os.symlink(self.storage, latest) - except OSError as e: - log.warning("Task #%s: Error pointing latest analysis symlink: %s", self.task.id, e) - finally: - latest_symlink_lock.release() - - log.info("Task #%s: analysis procedure completed", self.task.id) - except Exception as e: - log.exception("Task #%s: Failure in AnalysisManager.run: %s", self.task.id, e) - finally: - self.db.set_status(self.task.id, TASK_COMPLETED) - task_log_stop(self.task.id) - active_analysis_count_lock.acquire() - active_analysis_count -= 1 - active_analysis_count_lock.release() - - def _rooter_response_check(self): - if self.rooter_response and self.rooter_response["exception"] is not None: - raise CuckooCriticalError(f"Error execution rooter command: {self.rooter_response['exception']}") - - def route_network(self): - """Enable network routing if desired.""" - # Determine the desired routing strategy (none, internet, VPN). - self.route = routing.routing.route - - if self.task.route: - self.route = self.task.route - - if self.route in ("none", "None", "drop", "false"): - self.interface = None - self.rt_table = None - elif self.route == "inetsim": - self.interface = routing.inetsim.interface - elif self.route == "tor": - self.interface = routing.tor.interface - elif self.route == "internet" and routing.routing.internet != "none": - self.interface = routing.routing.internet - self.rt_table = routing.routing.rt_table - if routing.routing.reject_segments != "none": - self.reject_segments = routing.routing.reject_segments - if routing.routing.reject_hostports != "none": - self.reject_hostports = str(routing.routing.reject_hostports) - elif self.route in vpns: - self.interface = vpns[self.route].interface - self.rt_table = vpns[self.route].rt_table - elif self.route in self.socks5s: - self.interface = "" - else: - log.warning("Unknown network routing destination specified, ignoring routing for this analysis: %s", self.route) - self.interface = None - self.rt_table = None - - # Check if the network interface is still available. If a VPN dies for - # some reason, its tunX interface will no longer be available. - if self.interface and not rooter("nic_available", self.interface): - log.error( - "The network interface '%s' configured for this analysis is " - "not available at the moment, switching to route=none mode", - self.interface, - ) - self.route = "none" - self.interface = None - self.rt_table = None - - if self.route == "inetsim": - self.rooter_response = rooter( - "inetsim_enable", - self.machine.ip, - str(routing.inetsim.server), - str(routing.inetsim.dnsport), - str(self.cfg.resultserver.port), - str(routing.inetsim.ports), - ) - - elif self.route == "tor": - self.rooter_response = rooter( - "socks5_enable", - self.machine.ip, - str(self.cfg.resultserver.port), - str(routing.tor.dnsport), - str(routing.tor.proxyport), - ) - - elif self.route in self.socks5s: - self.rooter_response = rooter( - "socks5_enable", - self.machine.ip, - str(self.cfg.resultserver.port), - str(self.socks5s[self.route]["dnsport"]), - str(self.socks5s[self.route]["port"]), - ) - - elif self.route in ("none", "None", "drop"): - self.rooter_response = rooter("drop_enable", self.machine.ip, str(self.cfg.resultserver.port)) - - self._rooter_response_check() - - # check if the interface is up - if HAVE_NETWORKIFACES and routing.routing.verify_interface and self.interface and self.interface not in network_interfaces: - log.info("Network interface {} not found, falling back to dropping network traffic", self.interface) - self.interface = None - self.rt_table = None - self.route = "drop" - - if self.interface: - self.rooter_response = rooter("forward_enable", self.machine.interface, self.interface, self.machine.ip) - self._rooter_response_check() - if self.reject_segments: - self.rooter_response = rooter( - "forward_reject_enable", self.machine.interface, self.interface, self.machine.ip, self.reject_segments - ) - self._rooter_response_check() - if self.reject_hostports: - self.rooter_response = rooter( - "hostports_reject_enable", self.machine.interface, self.machine.ip, self.reject_hostports - ) - self._rooter_response_check() - - log.info("Enabled route '%s'.", self.route) - - if self.rt_table: - self.rooter_response = rooter("srcroute_enable", self.rt_table, self.machine.ip) - self._rooter_response_check() - - def unroute_network(self): - if self.interface: - self.rooter_response = rooter("forward_disable", self.machine.interface, self.interface, self.machine.ip) - self._rooter_response_check() - if self.reject_segments: - self.rooter_response = rooter( - "forward_reject_disable", self.machine.interface, self.interface, self.machine.ip, self.reject_segments - ) - self._rooter_response_check() - if self.reject_hostports: - self.rooter_response = rooter( - "hostports_reject_disable", self.machine.interface, self.machine.ip, self.reject_hostports - ) - self._rooter_response_check() - log.info("Disabled route '%s'", self.route) - - if self.rt_table: - self.rooter_response = rooter("srcroute_disable", self.rt_table, self.machine.ip) - self._rooter_response_check() - - if self.route == "inetsim": - self.rooter_response = rooter( - "inetsim_disable", - self.machine.ip, - routing.inetsim.server, - str(routing.inetsim.dnsport), - str(self.cfg.resultserver.port), - str(routing.inetsim.ports), - ) - - elif self.route == "tor": - self.rooter_response = rooter( - "socks5_disable", - self.machine.ip, - str(self.cfg.resultserver.port), - str(routing.tor.dnsport), - str(routing.tor.proxyport), - ) - - elif self.route in self.socks5s: - self.rooter_response = rooter( - "socks5_disable", - self.machine.ip, - str(self.cfg.resultserver.port), - str(self.socks5s[self.route]["dnsport"]), - str(self.socks5s[self.route]["port"]), - ) - - elif self.route in ("none", "None", "drop"): - self.rooter_response = rooter("drop_disable", self.machine.ip, str(self.cfg.resultserver.port)) - - self._rooter_response_check() - - -class Scheduler: - """Tasks Scheduler. - - This class is responsible for the main execution loop of the tool. It - prepares the analysis machines and keep waiting and loading for new - analysis tasks. - Whenever a new task is available, it launches AnalysisManager which will - take care of running the full analysis process and operating with the - assigned analysis machine. - """ - - def __init__(self, maxcount=None): - self.loop_state = LoopState.INACTIVE - self.cfg = Config() - self.db = Database() - self.maxcount = maxcount - self.total_analysis_count = 0 - self.analyzing_categories, self.categories_need_VM = load_categories() - self.analysis_threads = [] + def shutdown_machinery(self): + """Shutdown machine manager (used to kill machines that still alive).""" + if self.machinery_manager: + with self.db.session.begin(): + self.machinery_manager.machinery.shutdown() def signal_handler(self, signum, frame): """Scheduler signal handler""" sig = signal.Signals(signum) if sig in (signal.SIGHUP, signal.SIGTERM): log.info("received signal '%s', waiting for remaining analysis to finish before stopping", sig.name) - self.loop_state = LoopState.STOPPING + self.stop() elif sig == signal.SIGUSR1: log.info("received signal '%s', pausing new detonations, running detonations will continue until completion", sig.name) self.loop_state = LoopState.PAUSED @@ -865,345 +281,91 @@ def signal_handler(self, signum, frame): else: log.info("received signal '%s', nothing to do", sig.name) - def initialize(self): - """Initialize the machine manager.""" - global machinery, machine_lock - - machinery_name = self.cfg.cuckoo.machinery - if not self.categories_need_VM: - return - - # Get registered class name. Only one machine manager is imported, - # therefore there should be only one class in the list. - plugin = list_plugins("machinery")[0] - # Initialize the machine manager. - machinery = plugin() - - # Provide a dictionary with the configuration options to the - # machine manager instance. - machinery.set_options(Config(machinery_name)) - - # Initialize the machine manager. - try: - machinery.initialize(machinery_name) - except CuckooMachineError as e: - raise CuckooCriticalError(f"Error initializing machines: {e}") - # If the user wants to use the scaling bounded semaphore, check what machinery is specified, and then - # grab the required configuration key for setting the upper limit - if self.cfg.cuckoo.scaling_semaphore: - machinery_opts = machinery.options.get(machinery_name) - if machinery_name == "az": - machines_limit = machinery_opts.get("total_machines_limit") - elif machinery_name == "aws": - machines_limit = machinery_opts.get("dynamic_machines_limit") - # You set this value if you are using a machinery that is NOT auto-scaling - max_vmstartup_count = self.cfg.cuckoo.max_vmstartup_count - if max_vmstartup_count: - # The BoundedSemaphore is used to prevent CPU starvation when starting up multiple VMs - machine_lock = threading.BoundedSemaphore(max_vmstartup_count) - # You set this value if you are using a machinery that IS auto-scaling - elif self.cfg.cuckoo.scaling_semaphore and machines_limit: - # The ScalingBoundedSemaphore is used to keep feeding available machines from the pending tasks queue - machine_lock = ScalingBoundedSemaphore(value=len(machinery.machines()), upper_limit=machines_limit) - else: - machine_lock = threading.Lock() - - log.info( - 'Using "%s" machine manager with max_analysis_count=%d, max_machines_count=%d, and max_vmstartup_count=%d', - machinery_name, - self.cfg.cuckoo.max_analysis_count, - self.cfg.cuckoo.max_machines_count, - self.cfg.cuckoo.max_vmstartup_count, - ) - - # At this point all the available machines should have been identified - # and added to the list. If none were found, Cuckoo needs to abort the - # execution. - - if not len(machinery.machines()): - raise CuckooCriticalError("No machines available") - else: - log.info("Loaded %d machine/s", len(machinery.machines())) - - if len(machinery.machines()) > 1 and self.db.engine.name == "sqlite": - log.warning( - "As you've configured CAPE to execute parallelanalyses, we recommend you to switch to a PostgreSQL database as SQLite might cause some issues" - ) - - # Drop all existing packet forwarding rules for each VM. Just in case - # Cuckoo was terminated for some reason and various forwarding rules - # have thus not been dropped yet. - for machine in machinery.machines(): - if not machine.interface: - log.info( - "Unable to determine the network interface for VM with name %s, Cuckoo will not be able to give it " - "full internet access or route it through a VPN! Please define a default network interface for the " - "machinery or define a network interface for each VM", - machine.name, - ) - continue - - # Drop forwarding rule to each VPN. - for vpn in vpns.values(): - rooter("forward_disable", machine.interface, vpn.interface, machine.ip) - - # Drop forwarding rule to the internet / dirty line. - if routing.routing.internet != "none": - rooter("forward_disable", machine.interface, routing.routing.internet, machine.ip) - - def stop(self): - """Set loop state to stopping.""" - self.loop_state = LoopState.STOPPING - - def shutdown_machinery(self): - """Shutdown machine manager (used to kill machines that still alive).""" - if self.categories_need_VM: - machinery.shutdown() - def start(self): """Start scheduler.""" - self.initialize() - - log.info("Waiting for analysis tasks") - - # Handle interrupts - for _signal in [signal.SIGHUP, signal.SIGTERM, signal.SIGUSR1, signal.SIGUSR2]: - signal.signal(_signal, self.signal_handler) + if self.machinery_manager: + with self.db.session.begin(): + self.machinery_manager.initialize_machinery() # Message queue with threads to transmit exceptions (used as IPC). - errors = queue.Queue() - - # Command-line overrides the configuration file. - if self.maxcount is None: - self.maxcount = self.cfg.cuckoo.max_analysis_count + error_queue = queue.Queue() # Start the logger which grabs database information if self.cfg.cuckoo.periodic_log: - self._thr_periodic_log() - # Update timer for semaphore limit value if enabled - if self.cfg.cuckoo.scaling_semaphore and not self.cfg.cuckoo.max_vmstartup_count: - # Note that this variable only exists under these conditions - scaling_semaphore_timer = time.time() - - if self.cfg.cuckoo.batch_scheduling: - max_batch_scheduling_count = ( - self.cfg.cuckoo.max_batch_count if self.cfg.cuckoo.max_batch_count and self.cfg.cuckoo.max_batch_count > 1 else 5 - ) - # This loop runs forever. + threading.Thread(target=self.thr_periodic_log, name="periodic_log", daemon=True).start() - self.loop_state = LoopState.RUNNING - while self.loop_state in (LoopState.RUNNING, LoopState.PAUSED, LoopState.STOPPING): - if self.loop_state == LoopState.STOPPING: - # Wait for analyses to finish before stopping - while self.analysis_threads: - thread = self.analysis_threads.pop() - log.debug("Waiting for analysis PID %d", thread.native_id) - thread.join() - break - if self.loop_state == LoopState.PAUSED: - log.debug("scheduler is paused, send '%s' to process %d to resume", signal.SIGUSR2, os.getpid()) - time.sleep(5) - continue - # Update scaling bounded semaphore limit value, if enabled, based on the number of machines - # Wait until the machine lock is not locked. This is only the case - # when all machines are fully running, rather than "about to start" - # or "still busy starting". This way we won't have race conditions - # with finding out there are no available machines in the analysis - # manager or having two analyses pick the same machine. + with self.loop_signals(): + log.info("Waiting for analysis tasks") + self.loop_state = LoopState.RUNNING + try: + while self.loop_state in (LoopState.RUNNING, LoopState.PAUSED, LoopState.STOPPING): + sleep_time = self.do_main_loop_work(error_queue) + time.sleep(sleep_time) + try: + raise error_queue.get(block=False) + except queue.Empty: + pass + finally: + self.loop_state = LoopState.INACTIVE - # Update semaphore limit value if enabled based on the number of machines - if self.cfg.cuckoo.scaling_semaphore and not self.cfg.cuckoo.max_vmstartup_count: - # Every x seconds, update the semaphore limit. This requires a database call to machinery.availables(), - # hence waiting a bit between calls - if scaling_semaphore_timer + int(self.cfg.cuckoo.scaling_semaphore_update_timer) < time.time(): - machine_lock.update_limit(len(machinery.machines())) - # Prevent full starvation, very unlikely to ever happen. - machine_lock.check_for_starvation(machinery.availables()) - # Note that this variable only exists under these conditions - scaling_semaphore_timer = time.time() + def stop(self): + """Set loop state to stopping.""" + self.loop_state = LoopState.STOPPING - if self.categories_need_VM: - if not machine_lock.acquire(False): - continue - machine_lock.release() + def thr_periodic_log(self, oneshot=False): + # Ordinarily, this is the entry-point for a child thread. The oneshot parameter makes + # it easier for testing. + if not log.isEnabledFor(logging.DEBUG): + # The only purpose of this function is to log a debug message, so if debug + # logging is disabled, don't bother making all the database queries every 10 + # seconds--just return. + return - # If not enough free disk space is available, then we print an - # error message and wait another round (this check is ignored - # when the freespace configuration variable is set to zero). - if self.cfg.cuckoo.freespace: - # Resolve the full base path to the analysis folder, just in - # case somebody decides to make a symbolic link out of it. - dir_path = os.path.join(CUCKOO_ROOT, "storage", "analyses") - need_space, space_available = free_space_monitor(dir_path, return_value=True, analysis=True) - if need_space: - log.error( - "Not enough free disk space! (Only %d MB!). You can change limits it in cuckoo.conf -> freespace", - space_available, + while True: + # Since we know we'll be logging the resulting message, just use f-strings + # because they're faster and easier to read than using %s/%d and params to + # log.debug(). + msgs = [f"# Active analysis: {self.active_analysis_count}"] + + with self.db.session.begin(): + pending_task_count = self.db.count_tasks(status=TASK_PENDING) + pending_task_stats = self.get_pending_task_stats() + msgs.extend( + ( + f"# Pending Tasks: {pending_task_count}", + f"# Specific Pending Tasks: {dict(pending_task_stats)}", ) - continue - - # Have we limited the number of concurrently executing machines? - if self.cfg.cuckoo.max_machines_count > 0 and self.categories_need_VM: - # Are too many running? - if len(machinery.running()) >= self.cfg.cuckoo.max_machines_count: - continue - - # If no machines are available, it's pointless to fetch for pending tasks. Loop over. - # But if we analyze pcaps/static only it's fine - if self.categories_need_VM and not machinery.availables(include_reserved=True): - continue - # Exits if max_analysis_count is defined in the configuration - # file and has been reached. - if self.maxcount and self.total_analysis_count >= self.maxcount: - if active_analysis_count <= 0: - log.info("Maximum analysis count has been reached, shutting down.") - self.stop() - else: - if self.cfg.cuckoo.batch_scheduling: - tasks_to_create = [] - if self.categories_need_VM: - # First things first, are there pending tasks? - if not self.db.count_tasks(status=TASK_PENDING): - continue - # There are? Great, let's get them, ordered by priority and then oldest to newest - tasks_with_relevant_machine_available = [] - for task in self.db.list_tasks( - status=TASK_PENDING, order_by=(Task.priority.desc(), Task.added_on), options_not_like="node=" - ): - # Can this task ever be serviced? - if not self.db.is_serviceable(task): - if self.cfg.cuckoo.fail_unserviceable: - log.debug("Task #%s: Failing unserviceable task", task.id) - self.db.set_status(task.id, TASK_FAILED_ANALYSIS) - continue - log.debug("Task #%s: Unserviceable task", task.id) - if self.db.is_relevant_machine_available(task=task, set_status=False): - tasks_with_relevant_machine_available.append(task) - # The batching number is the number of tasks that will be considered to mapping to machines for starting - # Max_batch_scheduling_count is referring to the batch_scheduling config however this number - # is the maximum and capped for each usage by the number of locks available which refer to - # the number of expected available machines. - batching_number = ( - max_batch_scheduling_count if machine_lock._value > max_batch_scheduling_count else machine_lock._value + ) + if self.machinery_manager: + available_machine_count = self.db.count_machines_available() + available_machine_stats = self.get_available_machine_stats() + locked_machine_count = len(self.db.list_machines(locked=True)) + locked_machine_stats = self.get_locked_machine_stats() + total_machine_count = len(self.db.list_machines()) + msgs.extend( + ( + f"# Available Machines: {available_machine_count}", + f"# Available Specific Machines: {dict(available_machine_stats)}", + f"# Locked Machines: {locked_machine_count}", + f"# Specific Locked Machines: {dict(locked_machine_stats)}", + f"# Total Machines: {total_machine_count}", ) - if len(tasks_with_relevant_machine_available) > batching_number: - tasks_with_relevant_machine_available = tasks_with_relevant_machine_available[:batching_number] - tasks_to_create = self.db.map_tasks_to_available_machines(tasks_with_relevant_machine_available) - else: - tasks_to_create = [] - while True: - task = self.db.fetch_task(self.analyzing_categories) - if not task: - break - else: - tasks_to_create.append(task) - for task in tasks_to_create: - task = self.db.view_task(task.id) - log.debug("Task #%s: Processing task", task.id) - self.total_analysis_count += 1 - # Initialize and start the analysis manager. - analysis = AnalysisManager(task, errors) - analysis.daemon = True - analysis.start() - self.analysis_threads.append(analysis) - # We only want to keep track of active threads - self.analysis_threads = [t for t in self.analysis_threads if t.is_alive()] - else: - if self.categories_need_VM: - # First things first, are there pending tasks? - if not self.db.count_tasks(status=TASK_PENDING): - continue - relevant_machine_is_available = False - # There are? Great, let's get them, ordered by priority and then oldest to newest - for task in self.db.list_tasks( - status=TASK_PENDING, order_by=(Task.priority.desc(), Task.added_on), options_not_like="node=" - ): - # Can this task ever be serviced? - if not self.db.is_serviceable(task): - if self.cfg.cuckoo.fail_unserviceable: - log.debug("Task #%s: Failing unserviceable task", task.id) - self.db.set_status(task.id, TASK_FAILED_ANALYSIS) - continue - log.debug("Task #%s: Unserviceable task", task.id) - relevant_machine_is_available = self.db.is_relevant_machine_available(task) - if relevant_machine_is_available: - break - if not relevant_machine_is_available: - task = None - else: - task = self.db.view_task(task.id) - else: - task = self.db.fetch_task(self.analyzing_categories) - if task: - log.debug("Task #%s: Processing task", task.id) - self.total_analysis_count += 1 - # Initialize and start the analysis manager. - analysis = AnalysisManager(task, errors) - analysis.daemon = True - analysis.start() - self.analysis_threads.append(analysis) - # We only want to keep track of active threads - self.analysis_threads = [t for t in self.analysis_threads if t.is_alive()] + ) + if self.cfg.cuckoo.scaling_semaphore: + lock_value = ( + f"{self.machinery_manager.machine_lock._value}/{self.machinery_manager.machine_lock._limit_value}" + ) + msgs.append(f"# Lock value: {lock_value}") + log.debug("; ".join(msgs)) - # Deal with errors. - try: - raise errors.get(block=False) - except queue.Empty: - pass - self.loop_state = LoopState.INACTIVE + if oneshot: + break - def _thr_periodic_log(self): - global active_analysis_count - specific_available_machine_counts = defaultdict(int) - for machine in self.db.get_available_machines(): - for tag in machine.tags: - if tag: - specific_available_machine_counts[tag.name] += 1 - if machine.platform: - specific_available_machine_counts[machine.platform] += 1 - specific_pending_task_counts = defaultdict(int) - for task in self.db.list_tasks(status=TASK_PENDING): - for tag in task.tags: - if tag: - specific_pending_task_counts[tag.name] += 1 - if task.platform: - specific_pending_task_counts[task.platform] += 1 - if task.machine: - specific_pending_task_counts[task.machine] += 1 - specific_locked_machine_counts = defaultdict(int) - for machine in self.db.list_machines(locked=True): - for tag in machine.tags: - if tag: - specific_locked_machine_counts[tag.name] += 1 - if machine.platform: - specific_locked_machine_counts[machine.platform] += 1 - if self.cfg.cuckoo.scaling_semaphore: - number_of_machine_scheduled = machinery.get_machines_scheduled() - log.debug( - "# Pending Tasks: %d; # Specific Pending Tasks: %s; # Available Machines: %d; # Available Specific Machines: %s; # Locked Machines: %d; # Specific Locked Machines: %s; # Total Machines: %d; Lock value: %d/%d; # Active analysis: %d; # Machines scheduled: %d", - self.db.count_tasks(status=TASK_PENDING), - dict(specific_pending_task_counts), - self.db.count_machines_available(), - dict(specific_available_machine_counts), - len(self.db.list_machines(locked=True)), - dict(specific_locked_machine_counts), - len(self.db.list_machines()), - machine_lock._value, - machine_lock._limit_value, - active_analysis_count, - number_of_machine_scheduled, - ) - else: - log.debug( - "# Pending Tasks: %d; # Specific Pending Tasks: %s; # Available Machines: %d; # Available Specific Machines: %s; # Locked Machines: %d; # Specific Locked Machines: %s; # Total Machines: %d", - self.db.count_tasks(status=TASK_PENDING), - dict(specific_pending_task_counts), - self.db.count_machines_available(), - dict(specific_available_machine_counts), - len(self.db.list_machines(locked=True)), - dict(specific_locked_machine_counts), - len(self.db.list_machines()), - ) - thr = threading.Timer(10, self._thr_periodic_log) - thr.daemon = True - thr.start() + time.sleep(10) + + def wait_for_running_analyses_to_finish(self) -> None: + log.info("Waiting for running analyses to finish.") + while self.analysis_threads: + thread = self.analysis_threads.pop() + log.debug("Waiting for analysis thread (%r)", thread) + thread.join() diff --git a/lib/cuckoo/core/startup.py b/lib/cuckoo/core/startup.py index 5b179bfb827..01240c300ff 100644 --- a/lib/cuckoo/core/startup.py +++ b/lib/cuckoo/core/startup.py @@ -152,17 +152,6 @@ def create_structure(): ) -class DatabaseHandler(logging.Handler): - """Logging to database handler. - Used to log errors related to tasks in database. - """ - - def emit(self, record): - if hasattr(record, "task_id"): - db = Database() - db.add_error(record.msg, int(record.task_id)) - - class ConsoleHandler(logging.StreamHandler): """Logging to console handler.""" diff --git a/modules/machinery/aws.py b/modules/machinery/aws.py index a8709eeac2b..a9471d43641 100644 --- a/modules/machinery/aws.py +++ b/modules/machinery/aws.py @@ -2,6 +2,8 @@ import sys import time +from lib.cuckoo.core.database import Machine + try: import boto3 except ImportError: @@ -30,8 +32,7 @@ class AWS(Machinery): AUTOSCALE_CUCKOO = "AUTOSCALE_CUCKOO" - def __init__(self): - super(AWS, self).__init__() + module_name = "aws" """override Machinery method""" @@ -176,13 +177,11 @@ def _allocate_new_machine(self): """override Machinery method""" - def acquire(self, machine_id=None, platform=None, tags=None, arch=None, os_version=None, need_scheduled=False): + def scale_pool(self, machine: Machine) -> None: """ override Machinery method to utilize the auto scale option """ - base_class_return_value = super(AWS, self).acquire(machine_id, platform, tags, need_scheduled=need_scheduled) self._start_or_create_machines() # prepare another machine - return base_class_return_value def _start_or_create_machines(self): """ @@ -295,14 +294,15 @@ def stop(self, label): """override Machinery method""" - def release(self, label=None): + def release(self, machine: Machine) -> Machine: """ we override it to have the ability to run start_or_create_machines() after unlocking the last machine Release a machine. @param label: machine label. """ - super(AWS, self).release(label) + retval = super(AWS, self).release(machine) self._start_or_create_machines() + return retval def _create_instance(self, tags): """ diff --git a/modules/machinery/az.py b/modules/machinery/az.py index e7c02fa7012..2074e242c3e 100644 --- a/modules/machinery/az.py +++ b/modules/machinery/az.py @@ -88,6 +88,7 @@ class Azure(Machinery): + module_name = "az" # Resource tag that indicates auto-scaling. AUTO_SCALE_CAPE_KEY = "AUTO_SCALE_CAPE" @@ -103,15 +104,14 @@ class Azure(Machinery): WINDOWS_PLATFORM = "windows" LINUX_PLATFORM = "linux" - def _initialize(self, module_name): + def _initialize(self): """ Overloading abstracts.py:_initialize() Read configuration. @param module_name: module name @raise CuckooDependencyError: if there is a problem with the dependencies call """ - self.module_name = module_name - mmanager_opts = self.options.get(module_name) + mmanager_opts = self.options.get(self.module_name) if not isinstance(mmanager_opts["scale_sets"], list): mmanager_opts["scale_sets"] = mmanager_opts["scale_sets"].strip().split(",") @@ -150,7 +150,6 @@ def _initialize_check(self): """ Overloading abstracts.py:_initialize_check() Running checks against Azure that the configuration is correct. - @param module_name: module name, currently not used be required @raise CuckooDependencyError: if there is a problem with the dependencies call """ if not HAVE_AZURE: @@ -483,31 +482,6 @@ def availables(self, label=None, platform=None, tags=None, arch=None, include_re label=label, platform=platform, tags=tags, arch=arch, include_reserved=include_reserved, os_version=os_version ) - def acquire(self, machine_id=None, platform=None, tags=None, arch=None, os_version=[], need_scheduled=False): - """ - Overloading abstracts.py:acquire() to utilize the auto-scale option. - @param machine_id: the name of the machine to be acquired - @param platform: the platform of the machine's operating system to be acquired - @param tags: any tags that are associated with the machine to be acquired - @param arch: the architecture of the operating system - @return: dict representing machine object from DB - """ - base_class_return_value = super(Azure, self).acquire( - machine_id=machine_id, platform=platform, tags=tags, arch=arch, os_version=os_version, need_scheduled=need_scheduled - ) - if base_class_return_value and base_class_return_value.name: - vmss_name, _ = base_class_return_value.name.split("_") - - # Get the VMSS name by the tag - if not machine_pools[vmss_name]["is_scaling"]: - # Start it and forget about it - threading.Thread( - target=self._thr_scale_machine_pool, - args=(self.options.az.scale_sets[vmss_name].pool_tag, True if platform else False), - ).start() - - return base_class_return_value - def _add_machines_to_db(self, vmss_name): """ Adding machines to database that did not exist there before. @@ -810,7 +784,8 @@ def _thr_create_vmss(self, vmss_name, vmss_image_ref, vmss_image_os): "is_scaling_down": False, "wait": False, } - self._add_machines_to_db(vmss_name) + with self.db.session.begin(): + self._add_machines_to_db(vmss_name) def _thr_reimage_vmss(self, vmss_name): """ @@ -840,7 +815,8 @@ def _thr_reimage_vmss(self, vmss_name): else: log.error(repr(e), exc_info=True) raise - self._add_machines_to_db(vmss_name) + with self.db.session.begin(): + self._add_machines_to_db(vmss_name) def _thr_scale_machine_pool(self, tag, per_platform=False): """ @@ -849,6 +825,10 @@ def _thr_scale_machine_pool(self, tag, per_platform=False): @param per_platform: A boolean flag indicating that we should scale machine pools "per platform" vs. "per tag" @return: Ends method call """ + with self.db.session.begin(): + return self._scale_machine_pool(tag, per_platform=per_platform) + + def _scale_machine_pool(self, tag, per_platform=False): global machine_pools, is_platform_scaling, current_vmss_operations platform = None diff --git a/modules/machinery/esx.py b/modules/machinery/esx.py index 23289cd2f92..4e348b5c5fb 100644 --- a/modules/machinery/esx.py +++ b/modules/machinery/esx.py @@ -12,6 +12,8 @@ class ESX(LibVirtMachinery): """Virtualization layer for ESXi/ESX based on python-libvirt.""" + module_name = "esx" + def _initialize_check(self): """Runs all checks when a machine manager is initialized. @raise CuckooMachineError: if configuration is invalid diff --git a/modules/machinery/kvm.py b/modules/machinery/kvm.py index f167279a34f..d3b3637c807 100644 --- a/modules/machinery/kvm.py +++ b/modules/machinery/kvm.py @@ -11,6 +11,8 @@ class KVM(LibVirtMachinery): """Virtualization layer for KVM based on python-libvirt.""" + module_name = "kvm" + def _initialize_check(self): """Runs all checks when a machine manager is initialized. @raise CuckooMachineError: if configuration is invalid diff --git a/modules/machinery/multi.py b/modules/machinery/multi.py index 6c8a7f34163..333a412105a 100644 --- a/modules/machinery/multi.py +++ b/modules/machinery/multi.py @@ -28,6 +28,8 @@ def import_plugin(name): class MultiMachinery(Machinery): + module_name = "multi" + LABEL = "mm_label" _machineries = {} diff --git a/modules/machinery/physical.py b/modules/machinery/physical.py index 267862b5017..495fe906b8e 100644 --- a/modules/machinery/physical.py +++ b/modules/machinery/physical.py @@ -27,6 +27,8 @@ class Physical(Machinery): """Manage physical sandboxes.""" + module_name = "physical" + # Physical machine states. RUNNING = "running" STOPPED = "stopped" diff --git a/modules/machinery/proxmox.py b/modules/machinery/proxmox.py index 6f981c6f366..b4387d2a86a 100644 --- a/modules/machinery/proxmox.py +++ b/modules/machinery/proxmox.py @@ -25,8 +25,10 @@ class Proxmox(Machinery): """Manage Proxmox sandboxes.""" + module_name = "proxmox" + def __init__(self): - super(Proxmox, self).__init__() + super().__init__() self.timeout = int(cfg.timeouts.vm_state) def _initialize_check(self): diff --git a/modules/machinery/qemu.py b/modules/machinery/qemu.py index 55e72b032ea..6061c1a7e48 100644 --- a/modules/machinery/qemu.py +++ b/modules/machinery/qemu.py @@ -323,13 +323,15 @@ class QEMU(Machinery): """Virtualization layer for QEMU (non-KVM).""" + module_name = "qemu" + # VM states. RUNNING = "running" STOPPED = "stopped" ERROR = "machete" def __init__(self): - super(QEMU, self).__init__() + super().__init__() self.state = {} def _initialize_check(self): diff --git a/modules/machinery/virtualbox.py b/modules/machinery/virtualbox.py index fe7bb016562..71c343ef3e5 100644 --- a/modules/machinery/virtualbox.py +++ b/modules/machinery/virtualbox.py @@ -23,6 +23,8 @@ class VirtualBox(Machinery): """Virtualization layer for VirtualBox.""" + module_name = "virtualbox" + # VM states. SAVED = "saved" RUNNING = "running" diff --git a/modules/machinery/vmware.py b/modules/machinery/vmware.py index ba3c7555a64..222f1ed1a12 100644 --- a/modules/machinery/vmware.py +++ b/modules/machinery/vmware.py @@ -20,6 +20,8 @@ class VMware(Machinery): """Virtualization layer for VMware Workstation using vmrun utility.""" + module_name = "vmware" + LABEL = "vmx_path" def _initialize_check(self): diff --git a/modules/machinery/vmwarerest.py b/modules/machinery/vmwarerest.py index 37c180ffcca..9131b4f413b 100644 --- a/modules/machinery/vmwarerest.py +++ b/modules/machinery/vmwarerest.py @@ -18,6 +18,8 @@ class VMwareREST(Machinery): """Virtualization layer for remote VMware REST Server.""" + module_name = "vmwarerest" + LABEL = "id" def _initialize_check(self): diff --git a/modules/machinery/vmwareserver.py b/modules/machinery/vmwareserver.py index 1b6aeee180d..7473b2be241 100644 --- a/modules/machinery/vmwareserver.py +++ b/modules/machinery/vmwareserver.py @@ -16,6 +16,8 @@ class VMwareServer(Machinery): """Virtualization layer for remote VMware Workstation Server using vmrun utility.""" + module_name = "vmwareserver" + LABEL = "vmx_path" def _initialize_check(self): diff --git a/modules/machinery/vsphere.py b/modules/machinery/vsphere.py index 44d1484eb04..600d19499f4 100644 --- a/modules/machinery/vsphere.py +++ b/modules/machinery/vsphere.py @@ -31,6 +31,8 @@ class vSphere(Machinery): """vSphere/ESXi machinery class based on pyVmomi Python SDK.""" + module_name = "vsphere" + # VM states RUNNING = "poweredOn" POWEROFF = "poweredOff" @@ -41,13 +43,13 @@ def __init__(self): if not HAVE_PYVMOMI: raise CuckooDependencyError("Couldn't import pyVmomi. Please install using 'pip3 install --upgrade pyvmomi'") - super(vSphere, self).__init__() + super().__init__() - def _initialize(self, module_name): + def _initialize(self): """Read configuration. @param module_name: module name. """ - super(vSphere, self)._initialize(module_name) + super(vSphere, self)._initialize() # Initialize random number generator random.seed() diff --git a/poetry.lock b/poetry.lock index c54d2466256..951416b8d4a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1057,6 +1057,20 @@ wcwidth = "0.2.13" build = ["build (==1.0.3)", "pyinstaller (==6.3.0)", "setuptools (==69.0.3)"] dev = ["black (==24.1.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.1.17)", "flake8-comprehensions (==3.14.0)", "flake8-copyright (==0.2.4)", "flake8-encodings (==0.5.1)", "flake8-logging-format (==0.9.0)", "flake8-no-implicit-concat (==0.3.5)", "flake8-print (==5.0.0)", "flake8-simplify (==0.21.0)", "flake8-todos (==0.3.0)", "flake8-use-pathlib (==0.3.0)", "isort (==5.13.2)", "mypy (==1.8.0)", "mypy-protobuf (==3.5.0)", "pre-commit (==3.5.0)", "psutil (==5.9.2)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-instafail (==0.5.0)", "pytest-sugar (==0.9.7)", "requests (==2.31.0)", "ruff (==0.1.14)", "stix2 (==3.0.1)", "types-PyYAML (==6.0.8)", "types-backports (==0.1.3)", "types-colorama (==0.4.15.11)", "types-protobuf (==4.23.0.3)", "types-psutil (==5.8.23)", "types-requests (==2.31.0.20240125)", "types-tabulate (==0.9.0.20240106)", "types-termcolor (==1.1.4)"] +[[package]] +name = "freezegun" +version = "1.4.0" +description = "Let your Python tests travel through time" +optional = false +python-versions = ">=3.7" +files = [ + {file = "freezegun-1.4.0-py3-none-any.whl", hash = "sha256:55e0fc3c84ebf0a96a5aa23ff8b53d70246479e9a68863f1fcac5a3e52f19dd6"}, + {file = "freezegun-1.4.0.tar.gz", hash = "sha256:10939b0ba0ff5adaecf3b06a5c2f73071d9678e507c5eaedb23c761d56ac774b"}, +] + +[package.dependencies] +python-dateutil = ">=2.7" + [[package]] name = "func-timeout" version = "4.3.5" @@ -2814,6 +2828,21 @@ pytest = ">=5.4.0" docs = ["sphinx", "sphinx-rtd-theme"] testing = ["Django", "django-configurations (>=2.0)"] +[[package]] +name = "pytest-freezer" +version = "0.4.8" +description = "Pytest plugin providing a fixture interface for spulec/freezegun" +optional = false +python-versions = ">= 3.6" +files = [ + {file = "pytest_freezer-0.4.8-py3-none-any.whl", hash = "sha256:644ce7ddb8ba52b92a1df0a80a699bad2b93514c55cf92e9f2517b68ebe74814"}, + {file = "pytest_freezer-0.4.8.tar.gz", hash = "sha256:8ee2f724b3ff3540523fa355958a22e6f4c1c819928b78a7a183ae4248ce6ee6"}, +] + +[package.dependencies] +freezegun = ">=1.0" +pytest = ">=3.6" + [[package]] name = "pytest-mock" version = "3.7.0" @@ -4243,4 +4272,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "908d9df7f4820e03a05f2bba26a896dd37ec500b7e1f451bc7735d00b5a91ac3" +content-hash = "6f52771a4b3ee2e76b9208a7ce70d04e3417839c08a22305d126fb0cffed14e6" diff --git a/pyproject.toml b/pyproject.toml index d4b15280de3..19fad93fb5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ pytest-django = "4.5.2" pytest_asyncio = "0.18.3" pytest-xdist = "3.0.2" pytest-asyncio = "0.18.3" +pytest-freezer = "0.4.8" tenacity = "8.1.0" httpretty = "^1.1.4" func-timeout = "^4.3.5" diff --git a/tests/conftest.py b/tests/conftest.py index 6bcc2418018..c31ef56f9fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,21 +3,36 @@ import pytest import lib.cuckoo.common.config +import lib.cuckoo.core.analysis_manager import lib.cuckoo.core.database from lib.cuckoo.common.config import ConfigMeta +from lib.cuckoo.core.database import Database, init_database, reset_database_FOR_TESTING_ONLY + + +@pytest.fixture +def db(): + reset_database_FOR_TESTING_ONLY() + try: + init_database(dsn="sqlite://") + retval = Database() + retval.engine.echo = True + yield retval + finally: + reset_database_FOR_TESTING_ONLY() @pytest.fixture def tmp_cuckoo_root(monkeypatch, tmp_path): monkeypatch.setattr(lib.cuckoo.core.database, "CUCKOO_ROOT", str(tmp_path)) + monkeypatch.setattr(lib.cuckoo.core.analysis_manager, "CUCKOO_ROOT", str(tmp_path)) yield tmp_path @pytest.fixture(autouse=True) def custom_conf_path(request, monkeypatch, tmp_cuckoo_root): - ConfigMeta.reset() monkeypatch.setenv("CAPE_DISABLE_ROOT_CONFIGS", "1") path: pathlib.Path = tmp_cuckoo_root / "custom" / "conf" path.mkdir(mode=0o755, parents=True) monkeypatch.setattr(lib.cuckoo.common.config, "CUSTOM_CONF_DIR", str(path)) + ConfigMeta.refresh() yield path diff --git a/tests/test_abstracts.py b/tests/test_abstracts.py index b096b9a34f1..59649c07e31 100644 --- a/tests/test_abstracts.py +++ b/tests/test_abstracts.py @@ -66,9 +66,15 @@ def test_not_implemented_run(self, rep): rep.run() +@pytest.mark.usefixtures("db") class TestScreenshotMachinery: def test_missing_screenshot_method(self): class MockMachinery(abstracts.Machinery): + module_name = "mock" + + def read_config(self): + return {"mock": {"machines": []}} + def _list(self): pass @@ -77,13 +83,17 @@ def _list(self): mock_cuckoo_cfg.machinery_screenshots = True with patch.object(abstracts.cfg, "cuckoo", mock_cuckoo_cfg): - with patch("lib.cuckoo.common.abstracts.inspect.getmembers", return_value=[]): - # calling _initialize_check() raises without any methods - with pytest.raises(NotImplementedError): - tm._initialize_check() + # calling _initialize_check() raises without any methods + with pytest.raises(CuckooCriticalError): + tm._initialize_check() def test_machinery_missing_screenshot_support(self): class MockMachinery(abstracts.Machinery): + module_name = "mock" + + def read_config(self): + return {"mock": {"machines": []}} + def _list(self): pass @@ -98,6 +108,11 @@ def _list(self): def test_machinery_screenshot_support(self): class MockMachinery(abstracts.Machinery): + module_name = "mock" + + def read_config(self): + return {"mock": {"machines": []}} + def _list(self): pass diff --git a/tests/test_analysis_manager.py b/tests/test_analysis_manager.py new file mode 100644 index 00000000000..68c5c8915e2 --- /dev/null +++ b/tests/test_analysis_manager.py @@ -0,0 +1,480 @@ +import datetime +import os +import pathlib +from typing import Generator + +import pytest +from pytest_mock import MockerFixture + +from lib.cuckoo.common.abstracts import Machinery +from lib.cuckoo.common.config import ConfigMeta +from lib.cuckoo.core.analysis_manager import AnalysisManager +from lib.cuckoo.core.database import TASK_RUNNING, Guest, Machine, Task, _Database +from lib.cuckoo.core.machinery_manager import MachineryManager +from lib.cuckoo.core.scheduler import Scheduler + + +class MockMachinery(Machinery): + module_name = "mock" + + def read_config(self): + return { + "mock": { + "machines": "name0", + }, + "name0": { + "label": "label0", + "platform": "windows", + "arch": "x64", + "ip": "1.2.3.4", + }, + } + + def _list(self): + return ["name0"] + + +@pytest.fixture +def machinery() -> Generator[MockMachinery, None, None]: + yield MockMachinery() + + +@pytest.mark.usefixtures("db") +@pytest.fixture +def machinery_manager( + custom_conf_path: pathlib.Path, monkeypatch, machinery: MockMachinery +) -> Generator[MachineryManager, None, None]: + confd_path = custom_conf_path / "cuckoo.conf.d" + confd_path.mkdir(0o755, parents=True, exist_ok=True) + with open(confd_path / "machinery_manager.conf", "wt") as fil: + print("[cuckoo]", file=fil) + print(f"machinery = {MockMachinery.module_name}", file=fil) + ConfigMeta.refresh() + monkeypatch.setattr(MachineryManager, "create_machinery", lambda self: machinery) + yield MachineryManager() + + +@pytest.mark.usefixtures("db") +@pytest.fixture +def scheduler(): + return Scheduler() + + +@pytest.fixture +def task(db: _Database, tmp_path) -> Generator[Task, None, None]: + sample_path = tmp_path / "sample.py" + with open(sample_path, "wt") as fil: + print("#!/usr/bin/env python\nprint('hello world')", file=fil) + with db.session.begin(): + db.add_path(str(sample_path)) + task = db.list_tasks()[0] + db.session.expunge_all() + + yield task + + +@pytest.fixture +def machine(db: _Database) -> Generator[Machine, None, None]: + with db.session.begin(): + machine = db.add_machine( + name="name0", + label="label0", + arch="x64", + ip="1.2.3.4", + platform="windows", + tags="tag1,x64", + interface="int0", + snapshot="snap0", + resultserver_ip="5.6.7.8", + resultserver_port="2043", + reserved=False, + ) + db.session.expunge_all() + yield machine + + +def get_test_object_path(relpath: str): + result = pathlib.Path(__file__).absolute().parent / relpath + if not result.exists(): + pytest.skip("Required data file is not present") + return result + + +@pytest.mark.usefixtures("db") +class TestAnalysisManager: + def test_init(self, task: Task): + mgr = AnalysisManager(task=task) + + assert mgr.cfg.cuckoo == { + "categories": "static, pcap, url, file", + "freespace": 50000, + "delete_original": False, + "tmppath": "/tmp", + "terminate_processes": False, + "memory_dump": False, + "delete_bin_copy": False, + "max_machines_count": 10, + "reschedule": False, + "rooter": "/tmp/cuckoo-rooter", + "machinery": "kvm", + "machinery_screenshots": False, + "delete_archive": True, + "max_vmstartup_count": 5, + "daydelta": 0, + "max_analysis_count": 0, + "max_len": 196, + "sanitize_len": 32, + "sanitize_to_len": 24, + "scaling_semaphore": False, + "scaling_semaphore_update_timer": 10, + "freespace_processing": 15000, + "periodic_log": False, + "fail_unserviceable": True, + } + + assert mgr.task.id == task.id + + def test_logger(self, task: Task, caplog: pytest.LogCaptureFixture): + mgr = AnalysisManager(task=task) + mgr.log.info("Test") + assert any((record.message == f"Task #{task.id}: Test") for record in caplog.records) + + def test_prepare_task_and_machine_to_start_no_machinery(self, db: _Database, task: Task): + mgr = AnalysisManager(task=task) + assert task.status != TASK_RUNNING + with db.session.begin(): + mgr.prepare_task_and_machine_to_start() + with db.session.begin(): + db.session.refresh(task) + assert task.status == TASK_RUNNING + + def test_prepare_task_and_machine_to_start_with_machinery( + self, db: _Database, task: Task, machine: Machine, machinery_manager: MachineryManager + ): + mgr = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + with db.session.begin(): + mgr.prepare_task_and_machine_to_start() + with db.session.begin(): + db.session.refresh(task) + db.session.refresh(machine) + guest: Guest = db.session.query(Guest).first() + assert task.status == TASK_RUNNING + assert task.machine == machine.label + assert task.machine_id == machine.id + assert machine.locked + assert guest is not None + assert guest.name == machine.name + assert guest.label == machine.label + assert guest.manager == "MockMachinery" + assert guest.task_id == task.id + + def test_init_storage(self, task: Task, tmp_cuckoo_root: pathlib.Path): + analysis_man = AnalysisManager(task=task) + assert analysis_man.init_storage() is True + assert (tmp_cuckoo_root / "storage" / "analyses" / str(task.id)).exists() + + def test_init_storage_already_exists(self, task: Task, tmp_cuckoo_root: pathlib.Path, caplog: pytest.LogCaptureFixture): + analysis_man = AnalysisManager(task=task) + (tmp_cuckoo_root / "storage" / "analyses" / str(task.id)).mkdir(parents=True) + assert analysis_man.init_storage() is False + assert "already exists at path" in caplog.text + + def test_init_storage_other_error(self, task: Task, mocker: MockerFixture, caplog: pytest.LogCaptureFixture): + mocker.patch("lib.cuckoo.common.path_utils.Path.mkdir", side_effect=OSError) + analysis_man = AnalysisManager(task=task) + assert analysis_man.init_storage() is False + assert "Unable to create analysis folder" in caplog.text + + def test_check_file(self, task: Task, mocker: MockerFixture): + class mock_sample: + sha256 = "e3b" + + analysis_man = AnalysisManager(task=task) + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) + assert analysis_man.check_file("e3b") is True + + def test_check_file_err(self, task: Task, mocker: MockerFixture): + class mock_sample: + sha256 = "different_sha_256" + + analysis_man = AnalysisManager(task=task) + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) + assert analysis_man.check_file("e3b") is False + + def test_store_file(self, task: Task, tmp_cuckoo_root: pathlib.Path): + analysis_man = AnalysisManager(task=task) + analysis_man.init_storage() + assert analysis_man.store_file(sha256="e3") is True + assert (tmp_cuckoo_root / "storage" / "binaries" / "e3").exists() + binary_symlink = tmp_cuckoo_root / "storage" / "analyses" / str(task.id) / "binary" + assert binary_symlink.is_symlink() + assert os.readlink(binary_symlink) == str(tmp_cuckoo_root / "storage" / "binaries" / "e3") + + def test_store_file_no_dir(self, task: Task, mocker: MockerFixture, caplog: pytest.LogCaptureFixture): + analysis_man = AnalysisManager(task=task) + analysis_man.init_storage() + mocker.patch("lib.cuckoo.core.analysis_manager.shutil.copy", side_effect=IOError) + assert analysis_man.store_file(sha256="e3be3b") is False + assert "Unable to store file" in caplog.text + + def test_store_file_wrong_path(self, task: Task, caplog: pytest.LogCaptureFixture): + task.target = "idontexist" + analysis_man = AnalysisManager(task=task) + analysis_man.init_storage() + analysis_man.store_file(sha256="e3be3b") is False + assert "analysis aborted" in caplog.text + + def test_store_file_binary_already_exists(self, task: Task, tmp_cuckoo_root: pathlib.Path, caplog: pytest.LogCaptureFixture): + analysis_man = AnalysisManager(task=task) + analysis_man.init_storage() + bin_path = tmp_cuckoo_root / "storage" / "binaries" / "sha256" + bin_path.parent.mkdir() + bin_path.touch() + assert analysis_man.store_file(sha256="sha256") is True + assert "File already exists" in caplog.text + assert os.readlink(tmp_cuckoo_root / "storage" / "analyses" / str(task.id) / "binary") == str(bin_path) + + def test_screenshot_machine( + self, + task: Task, + machine: Machine, + machinery_manager: MachineryManager, + tmp_cuckoo_root: pathlib.Path, + custom_conf_path: pathlib.Path, + monkeypatch, + ): + screenshot_called = False + with open(custom_conf_path / "cuckoo.conf", "wt") as fil: + print("[cuckoo]\nmachinery_screenshots = on", file=fil) + ConfigMeta.refresh() + + def screenshot(self2, label, path): + nonlocal screenshot_called + screenshot_called = True + assert label == machine.label + assert path == str(tmp_cuckoo_root / "storage" / "analyses" / str(task.id) / "shots" / "0001.jpg") + + monkeypatch.setattr(MockMachinery, "screenshot", screenshot) + + analysis_man = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + analysis_man.init_storage() + analysis_man.screenshot_machine() + assert screenshot_called + + def test_screenshot_machine_disabled( + self, task: Task, machine: Machine, machinery_manager: MachineryManager, custom_conf_path: pathlib.Path, monkeypatch + ): + def screenshot(self2, label, path): + raise RuntimeError("This should not get called.") + + monkeypatch.setattr(MockMachinery, "screenshot", screenshot) + + analysis_man = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + analysis_man.init_storage() + analysis_man.screenshot_machine() + + def test_screenshot_machine_no_machine( + self, task: Task, custom_conf_path: pathlib.Path, monkeypatch, caplog: pytest.LogCaptureFixture + ): + with open(custom_conf_path / "cuckoo.conf", "wt") as fil: + print("[cuckoo]\nmachinery_screenshots = on", file=fil) + ConfigMeta.refresh() + + def screenshot(self2, label, path): + raise RuntimeError("This should not get called.") + + monkeypatch.setattr(MockMachinery, "screenshot", screenshot) + + analysis_man = AnalysisManager(task=task) + analysis_man.init_storage() + analysis_man.screenshot_machine() + assert "no machine is used" in caplog.text + + def test_build_options( + self, db: _Database, tmp_path: pathlib.Path, task: Task, machine: Machine, machinery_manager: MachineryManager + ): + with db.session.begin(): + task = db.session.merge(task) + task.package = "foo" + task.options = "foo=bar" + task.enforce_timeout = 1 + task.clock = datetime.datetime.strptime("01-01-2099 09:01:01", "%m-%d-%Y %H:%M:%S") + task.timeout = 10 + + analysis_man = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + opts = analysis_man.build_options() + assert opts == { + "amsi": False, + "browser": True, + "category": "file", + "clock": datetime.datetime(2099, 1, 1, 9, 1, 1), + "curtain": False, + "digisig": True, + "disguise": True, + "do_upload_max_size": 0, + "during_script": False, + "enable_trim": 0, + "enforce_timeout": 1, + "evtx": False, + "exports": "", + "file_name": "sample.py", + "file_pickup": False, + "file_type": "Python script, ASCII text executable", + "human_linux": False, + "human_windows": True, + "id": task.id, + "ip": "5.6.7.8", + "options": "foo=bar", + "package": "foo", + "permissions": False, + "port": "2043", + "pre_script": False, + "procmon": False, + "recentfiles": False, + "screenshots_linux": True, + "screenshots_windows": True, + "stap": False, + "stds_view": True, + "sysmon_linux": False, + "sysmon_windows": False, + "target": str(tmp_path / "sample.py"), + "terminate_processes": False, + "timeout": 10, + "tlsdump": True, + "upload_max_size": 100000000, + "usage": False, + "windows_static_route": False, + } + + def test_build_options_pe( + self, db: _Database, tmp_path: pathlib.Path, task: Task, machine: Machine, machinery_manager: MachineryManager + ): + sample_location = get_test_object_path(pathlib.Path("test_samples_sources/HelloWorld32.exe")) + with db.session.begin(): + task = db.session.merge(task) + task.package = "file" + task.enforce_timeout = 1 + task.clock = datetime.datetime.strptime("01-01-2099 09:01:01", "%m-%d-%Y %H:%M:%S") + task.timeout = 10 + task.target = str(sample_location) + + analysis_man = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + opts = analysis_man.build_options() + assert opts == { + "amsi": False, + "browser": True, + "category": "file", + "clock": datetime.datetime(2099, 1, 1, 9, 1, 1), + "curtain": False, + "digisig": True, + "disguise": True, + "do_upload_max_size": 0, + "during_script": False, + "enable_trim": 0, + "enforce_timeout": 1, + "evtx": False, + "exports": "", + "file_name": "HelloWorld32.exe", + "file_pickup": False, + "file_type": "PE32 executable (console) Intel 80386, for MS Windows", + "human_linux": False, + "human_windows": True, + "id": task.id, + "ip": "5.6.7.8", + "options": "", + "package": "file", + "permissions": False, + "port": "2043", + "pre_script": False, + "procmon": False, + "recentfiles": False, + "screenshots_linux": True, + "screenshots_windows": True, + "stap": False, + "stds_view": True, + "sysmon_linux": False, + "sysmon_windows": False, + "target": str(sample_location), + "terminate_processes": False, + "timeout": 10, + "tlsdump": True, + "upload_max_size": 100000000, + "usage": False, + "windows_static_route": False, + } + + def test_category_checks( + self, db: _Database, task: Task, machine: Machine, machinery_manager: MachineryManager, mocker: MockerFixture + ): + sample_sha256 = "05ec45d230d2a55b059f0ba7a0f0b4085f8ab6a73c1ffa33c7693f5ef48e22e5" + + class mock_sample: + sha256 = sample_sha256 + + sample_location = get_test_object_path(pathlib.Path("test_samples") / sample_sha256) + with db.session.begin(): + task = db.session.merge(task) + task.target = str(sample_location) + + analysis_man = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + assert analysis_man.init_storage() is True + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) + + assert analysis_man.category_checks() is None + + def test_category_checks_modified_file( + self, db: _Database, task: Task, machine: Machine, machinery_manager: MachineryManager, mocker: MockerFixture + ): + sample_sha256 = "05ec45d230d2a55b059f0ba7a0f0b4085f8ab6a73c1ffa33c7693f5ef48e22e5" + + class mock_sample: + sha256 = "123" + + sample_location = get_test_object_path(pathlib.Path("test_samples") / sample_sha256) + with db.session.begin(): + task = db.session.merge(task) + task.target = str(sample_location) + + analysis_man = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + assert analysis_man.init_storage() is True + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) + + assert analysis_man.category_checks() is False + + def test_category_checks_no_store_file( + self, db: _Database, task: Task, machine: Machine, machinery_manager: MachineryManager, mocker: MockerFixture + ): + sample_sha256 = "05ec45d230d2a55b059f0ba7a0f0b4085f8ab6a73c1ffa33c7693f5ef48e22e5" + + class mock_sample: + sha256 = sample_sha256 + + sample_location = get_test_object_path(pathlib.Path("test_samples") / sample_sha256) + with db.session.begin(): + task = db.session.merge(task) + task.target = str(sample_location) + + analysis_man = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + assert analysis_man.init_storage() is True + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) + mocker.patch("lib.cuckoo.core.scheduler.AnalysisManager.store_file", return_value=False) + assert analysis_man.category_checks() is False + + def test_category_checks_pcap( + self, db: _Database, task: Task, machine: Machine, machinery_manager: MachineryManager, mocker: MockerFixture + ): + sample_sha256 = "05ec45d230d2a55b059f0ba7a0f0b4085f8ab6a73c1ffa33c7693f5ef48e22e5" + + class mock_sample: + sha256 = sample_sha256 + + sample_location = get_test_object_path(pathlib.Path("test_samples") / sample_sha256) + with db.session.begin(): + task = db.session.merge(task) + task.target = str(sample_location) + task.category = "pcap" + + analysis_man = AnalysisManager(task=task, machine=machine, machinery_manager=machinery_manager) + assert analysis_man.init_storage() is True + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) + assert analysis_man.category_checks() is True diff --git a/tests/test_config.py b/tests/test_config.py index d4a9a48aafc..807abed0091 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ import pytest -from lib.cuckoo.common.config import AnalysisConfig, Config +from lib.cuckoo.common.config import AnalysisConfig, Config, ConfigMeta from lib.cuckoo.common.exceptions import CuckooOperationalError from lib.cuckoo.common.path_utils import path_write_file @@ -73,6 +73,7 @@ def test_option_override(self, custom_conf_path): """, ) config = Config("cuckoo") + ConfigMeta.refresh() # This was overridden in the custom config. assert config.get("cuckoo")["machinery_screenshots"] is True @@ -107,6 +108,7 @@ def test_subdirs(self, custom_conf_path): url = {url} """, ) + ConfigMeta.refresh() config = Config("api") # This is set to 'no' in the default api.conf and not overridden. assert config.get("api")["token_auth_enabled"] is False @@ -142,6 +144,7 @@ def test_environment_interpolation(self, custom_conf_path, monkeypatch): custom_secret = "MyReallySecretKeyWithAPercent(%)InIt" monkeypatch.setenv("DLINTELKEY", custom_secret) config = Config("processing") + ConfigMeta.refresh() section = config.get("virustotal") # Inherited from default config assert section.enabled is True @@ -164,4 +167,5 @@ def test_missing_environment_interpolation(self, custom_conf_path, monkeypatch): ) with pytest.raises(configparser.InterpolationMissingOptionError): + ConfigMeta.refresh() _ = Config("processing") diff --git a/tests/test_database.py b/tests/test_database.py index d44b67eb7b9..362e0e0a246 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -3,21 +3,48 @@ # See the file 'docs/LICENSE' for copying permission. import base64 +import dataclasses +import datetime +import hashlib import os +import pathlib import shutil from tempfile import NamedTemporaryFile import pytest -from sqlalchemy import delete, inspect from sqlalchemy.exc import SQLAlchemyError -from lib.cuckoo.common.exceptions import CuckooOperationalError +from lib.cuckoo.common.exceptions import CuckooUnserviceableTaskError from lib.cuckoo.common.path_utils import path_mkdir from lib.cuckoo.common.utils import store_temp_file -from lib.cuckoo.core.database import Database, Error, Machine, Tag, Task, machines_tags, tasks_tags +from lib.cuckoo.core import database +from lib.cuckoo.core.database import ( + TASK_BANNED, + TASK_COMPLETED, + TASK_PENDING, + TASK_REPORTED, + TASK_RUNNING, + Error, + Guest, + Machine, + Sample, + Tag, + Task, + _Database, + machines_tags, +) -@pytest.fixture(autouse=True) +@dataclasses.dataclass +class StorageLayout: + tmp_path: pathlib.Path + storage: str + binary_storage: str + analyses_storage: str + tmpdir: str + + +@pytest.fixture def storage(tmp_path, request): storage = tmp_path / "storage" binaries = storage / "binaries" @@ -26,89 +53,102 @@ def storage(tmp_path, request): analyses.mkdir(mode=0o755, parents=True) tmpdir = tmp_path / "tmp" tmpdir.mkdir(mode=0o755, parents=True) - request.instance.tmp_path = tmp_path - request.instance.storage = str(storage) - request.instance.binary_storage = str(binaries) - request.instance.analyses_storage = str(analyses) - request.instance.tmpdir = str(tmpdir) + yield StorageLayout( + tmp_path=tmp_path, + storage=str(storage), + binary_storage=str(binaries), + analyses_storage=str(analyses), + tmpdir=str(tmpdir), + ) + shutil.rmtree(str(tmp_path)) + + +@pytest.fixture +def temp_filename(storage: StorageLayout): + with NamedTemporaryFile(mode="w+", delete=False, dir=storage.storage) as f: + f.write("hehe") + yield f.name + + +@pytest.fixture +def temp_pcap(temp_filename: str, storage: StorageLayout): + pcap_header_base64 = b"1MOyoQIABAAAAAAAAAAAAAAABAABAAAA" + pcap_bytes = base64.b64decode(pcap_header_base64) + yield store_temp_file(pcap_bytes, "%s.pcap" % temp_filename, storage.tmpdir.encode()) -@pytest.mark.usefixtures("tmp_cuckoo_root") +@pytest.mark.usefixtures("tmp_cuckoo_root", "storage") class TestDatabaseEngine: """Test database stuff.""" URI = None - def setup_method(self, method): - with NamedTemporaryFile(mode="w+", delete=False, dir=self.storage) as f: - f.write("hehe") - self.temp_filename = f.name - pcap_header_base64 = b"1MOyoQIABAAAAAAAAAAAAAAABAABAAAA" - pcap_bytes = base64.b64decode(pcap_header_base64) - self.temp_pcap = store_temp_file(pcap_bytes, "%s.pcap" % f.name, self.tmpdir.encode()) - self.d = Database(dsn="sqlite://") - # self.d.connect(dsn=self.URI) - self.session = self.d.Session() - # This need to be done before each tests as sticky tags have been found to corrupt results - inspector = inspect(self.d.engine) - if inspector.get_table_names(): - stmt = delete(Machine) - stmt2 = delete(Task) - stmt3 = delete(machines_tags) - stmt4 = delete(tasks_tags) - stmt5 = delete(Tag) - stmt6 = delete(Error) - self.session.execute(stmt) - self.session.execute(stmt2) - self.session.execute(stmt3) - self.session.execute(stmt4) - self.session.execute(stmt5) - self.session.execute(stmt6) - self.session.commit() - - def teardown_method(self): - del self.d - shutil.rmtree(str(self.tmp_path)) - - def add_url(self, url, priority=1, status="pending"): - task_id = self.d.add_url(url, priority=priority) - self.d.set_status(task_id, status) - return task_id - - def test_add_tasks(self): + def add_machine(self, db: _Database, **kwargs) -> Machine: + dflt = dict( + name="name0", + label="label0", + ip="1.2.3.0", + platform="windows", + tags="tag1,x64", + interface="int0", + snapshot="snap0", + resultserver_ip="5.6.7.8", + resultserver_port="2043", + arch="x64", + reserved=False, + ) + dflt.update(kwargs) + return db.add_machine(**dflt) + + def test_add_tasks(self, db: _Database, temp_filename: str): # Add task. - count = self.session.query(Task).count() - self.d.add_path(self.temp_filename) - assert self.session.query(Task).count() == count + 1 + with db.session.begin(): + assert db.session.query(Task).count() == 0 + with db.session.begin(): + db.add_path(temp_filename) + with db.session.begin(): + assert db.session.query(Task).count() == 1 # Add url. - self.d.add_url("http://foo.bar") - assert self.session.query(Task).count() == count + 2 - - def test_error_exists(self): - task_id = self.add_url("http://google.com/") - self.d.add_error("A" * 1024, task_id) - assert len(self.d.view_errors(task_id)) == 1 - self.d.add_error("A" * 1024, task_id) - assert len(self.d.view_errors(task_id)) == 2 - - def test_long_error(self): - self.add_url("http://google.com/") - self.d.add_error("A" * 1024, 1) - err = self.d.view_errors(1) - assert err and len(err[0].message) == 1024 - - def test_task_set_options(self): - assert self.d.add_path(self.temp_filename, options={"foo": "bar"}) is None - t1 = self.d.add_path(self.temp_filename, options="foo=bar") - assert self.d.view_task(t1).options == "foo=bar" - - def test_task_tags_str(self): - t1 = self.d.add_path(self.temp_filename, tags="foo,,bar") - t2 = self.d.add_path(self.temp_filename, tags="boo,,far") - - t1_tag_list = [str(x.name) for x in list(self.d.view_task(t1).tags)] - t2_tag_list = [str(x.name) for x in list(self.d.view_task(t2).tags)] + with db.session.begin(): + db.add_url("http://foo.bar") + with db.session.begin(): + assert db.session.query(Task).count() == 2 + + def test_error_exists(self, db: _Database): + err_msg = "A" * 1024 + with db.session.begin(): + task_id = db.add_url("http://google.com/") + db.add_error(err_msg, task_id) + with db.session.begin(): + errs = db.view_errors(task_id) + assert len(errs) == 1 + assert errs[0].message == err_msg + with db.session.begin(): + db.add_error(err_msg, task_id) + with db.session.begin(): + assert len(db.view_errors(task_id)) == 2 + + def test_task_set_options(self, db: _Database, temp_filename: str): + with pytest.raises(SQLAlchemyError): + with db.session.begin(): + # Make sure options passed in as a dict are not allowed. + db.add_path(temp_filename, options={"foo": "bar"}) + + with db.session.begin(): + t1 = db.add_path(temp_filename, options="foo=bar") + + with db.session.begin(): + assert db.view_task(t1).options == "foo=bar" + + def test_task_tags_str(self, db: _Database, temp_filename: str): + with db.session.begin(): + t1 = db.add_path(temp_filename, tags="foo,,bar") + t2 = db.add_path(temp_filename, tags="boo,,far") + + with db.session.begin(): + t1_tag_list = [str(x.name) for x in list(db.view_task(t1).tags)] + t2_tag_list = [str(x.name) for x in list(db.view_task(t2).tags)] t1_tag_list.sort() t2_tag_list.sort() @@ -116,518 +156,884 @@ def test_task_tags_str(self): assert t1_tag_list == ["bar", "foo", "x86"] assert t2_tag_list == ["boo", "far", "x86"] - def test_reschedule_file(self): - count = self.session.query(Task).count() - task_id = self.d.add_path(self.temp_filename) - assert self.session.query(Task).count() == count + 1 - task = self.d.view_task(task_id) - assert task is not None + def test_truncate_error_msg(self, monkeypatch): + monkeypatch.setattr(Error, "MAX_LENGTH", 20) + err = Error("abcdefghijklmnopqrstuvwxyz", 1) + assert err.message == "abcdefgh...rstuvwxyz" + + def test_reschedule_file(self, db: _Database, temp_filename: str, storage: StorageLayout): + with db.session.begin(): + task_id = db.add_path(temp_filename) + with db.session.begin(): + assert db.session.query(Task).count() == 1 + task = db.view_task(task_id) + assert task is not None + db.session.expunge(task) + assert task.category == "file" # write a real sample to storage - task_path = os.path.join(self.analyses_storage, str(task.id)) + task_path = os.path.join(storage.analyses_storage, str(task.id)) path_mkdir(task_path) - shutil.copy(self.temp_filename, os.path.join(task_path, "binary")) + shutil.copy(temp_filename, os.path.join(task_path, "binary")) - new_task_id = self.d.reschedule(task_id) + with db.session.begin(): + new_task_id = db.reschedule(task_id) assert new_task_id is not None - new_task = self.d.view_task(new_task_id) + + with db.session.begin(): + new_task = db.view_task(new_task_id) assert new_task.category == "file" - def test_reschedule_static(self): - count = self.session.query(Task).count() - task_ids = self.d.add_static(self.temp_filename) + def test_reschedule_static(self, db: _Database, temp_filename: str, storage: StorageLayout): + with db.session.begin(): + task_ids = db.add_static(temp_filename) assert len(task_ids) == 1 task_id = task_ids[0] - assert self.session.query(Task).count() == count + 1 - task = self.d.view_task(task_id) - assert task is not None + with db.session.begin(): + assert db.session.query(Task).count() == 1 + task = db.view_task(task_id) + assert task is not None + db.session.expunge_all() assert task.category == "static" # write a real sample to storage - static_path = os.path.join(self.binary_storage, task.sample.sha256) - shutil.copy(self.temp_filename, static_path) + static_path = os.path.join(storage.binary_storage, task.sample.sha256) + shutil.copy(temp_filename, static_path) - new_task_id = self.d.reschedule(task_id) - assert new_task_id is not None - new_task = self.d.view_task(new_task_id[0]) - assert new_task.category == "static" - - def test_reschedule_pcap(self): - count = self.session.query(Task).count() - task_id = self.d.add_pcap(self.temp_pcap) - assert self.session.query(Task).count() == count + 1 - task = self.d.view_task(task_id) - assert task is not None + with db.session.begin(): + new_task_id = db.reschedule(task_id) + assert new_task_id is not None + with db.session.begin(): + new_task = db.view_task(new_task_id[0]) + assert new_task.category == "static" + + def test_reschedule_pcap(self, db: _Database, temp_pcap: str, storage: StorageLayout): + with db.session.begin(): + task_id = db.add_pcap(temp_pcap) + with db.session.begin(): + assert db.session.query(Task).count() == 1 + task = db.view_task(task_id) + assert task is not None + db.session.expunge_all() assert task.category == "pcap" # write a real sample to storage - pcap_path = os.path.join(self.binary_storage, task.sample.sha256) - shutil.copy(self.temp_pcap, pcap_path) + pcap_path = os.path.join(storage.binary_storage, task.sample.sha256) + shutil.copy(temp_pcap, pcap_path) # reschedule the PCAP task - new_task_id = self.d.reschedule(task_id) + with db.session.begin(): + new_task_id = db.reschedule(task_id) assert new_task_id is not None - new_task = self.d.view_task(new_task_id) - assert new_task.category == "pcap" + with db.session.begin(): + new_task = db.view_task(new_task_id) + assert new_task.category == "pcap" - def test_reschedule_url(self): + def test_reschedule_url(self, db: _Database): # add a URL task - count = self.session.query(Task).count() - task_id = self.d.add_url("test_reschedule_url") - assert self.session.query(Task).count() == count + 1 - task = self.d.view_task(task_id) - assert task is not None - assert task.category == "url" + with db.session.begin(): + task_id = db.add_url("test_reschedule_url") + with db.session.begin(): + assert db.session.query(Task).count() == 1 + task = db.view_task(task_id) + assert task is not None + assert task.category == "url" # reschedule the URL task - new_task_id = self.d.reschedule(task_id) + with db.session.begin(): + new_task_id = db.reschedule(task_id) assert new_task_id is not None - new_task = self.d.view_task(new_task_id) - assert new_task.category == "url" - - def test_add_machine(self): - self.d.add_machine( - name="name1", - label="label1", - ip="1.2.3.4", - platform="windows", - tags="tag1 tag2", - interface="int0", - snapshot="snap0", - resultserver_ip="5.6.7.8", - resultserver_port=2043, - arch="x64", - reserved=False, - ) - self.d.add_machine( - name="name2", - label="label2", - ip="1.2.3.4", - platform="windows", - tags="tag1 tag2", - interface="int0", - snapshot="snap0", - resultserver_ip="5.6.7.8", - resultserver_port=2043, - arch="x64", - reserved=True, - ) - m1 = self.d.view_machine("name1") - m2 = self.d.view_machine("name2") - - assert m1.to_dict() == { - "status": None, - "locked": False, - "name": "name1", - "resultserver_ip": "5.6.7.8", - "ip": "1.2.3.4", - "tags": ["tag1tag2"], - "label": "label1", - "locked_changed_on": None, - "platform": "windows", - "snapshot": "snap0", - "interface": "int0", - "status_changed_on": None, - "id": 1, - "resultserver_port": "2043", - "arch": "x64", - "reserved": False, - } - - assert m2.to_dict() == { - "id": 2, - "interface": "int0", - "ip": "1.2.3.4", - "label": "label2", - "locked": False, - "locked_changed_on": None, - "name": "name2", - "platform": "windows", - "resultserver_ip": "5.6.7.8", - "resultserver_port": "2043", - "snapshot": "snap0", - "status": None, - "status_changed_on": None, - "tags": ["tag1tag2"], - "arch": "x64", - "reserved": True, - } + with db.session.begin(): + new_task = db.view_task(new_task_id) + assert new_task.category == "url" + + def test_add_machine(self, db: _Database): + with db.session.begin(): + self.add_machine(db, name="name1", label="label1", tags="tag1 tag2,tag3") + self.add_machine(db, name="name2", label="label2", tags="tag1 tag2", reserved=True) + with db.session.begin(): + m1 = db.view_machine("name1") + m2 = db.view_machine("name2") + + assert m1.to_dict() == { + "status": None, + "locked": False, + "name": "name1", + "resultserver_ip": "5.6.7.8", + "ip": "1.2.3.0", + "tags": ["tag1tag2", "tag3"], + "label": "label1", + "locked_changed_on": None, + "platform": "windows", + "snapshot": "snap0", + "interface": "int0", + "status_changed_on": None, + "id": 1, + "resultserver_port": "2043", + "arch": "x64", + "reserved": False, + } + + assert m2.to_dict() == { + "id": 2, + "interface": "int0", + "ip": "1.2.3.0", + "label": "label2", + "locked": False, + "locked_changed_on": None, + "name": "name2", + "platform": "windows", + "resultserver_ip": "5.6.7.8", + "resultserver_port": "2043", + "snapshot": "snap0", + "status": None, + "status_changed_on": None, + "tags": ["tag1tag2"], + "arch": "x64", + "reserved": True, + } + + def test_find_machine_to_service_task_tags_reserved(self, db: _Database): + with db.session.begin(): + self.add_machine(db, name="name0", label="label0", tags="tag1,x64", reserved=False) + self.add_machine(db, name="name1", label="label1", tags="tag1,x64", reserved=True) + self.add_machine(db, name="name2", label="label2", tags="tag1,tag2,x64", reserved=True) + self.add_machine(db, name="name3", label="label3", tags="tag3,x64", locked=True) + self.add_machine(db, name="name4", label="label4", tags="tag4,x64", locked=True, reserved=True) + task1 = Task() + task1.tags = [Tag("tag1")] + task2 = Task() + task2.tags = [Tag("tag2")] + task3 = Task() + task3.machine = "label0" + task4 = Task() + task4.machine = "label1" + task5 = Task() + task5.tags = [Tag("tag3")] + task6 = Task() + task6.tags = [Tag("tag4")] + task7 = Task() + task7.tags = [Tag("idontexist")] + with db.session.begin(): + # A task with a tag. An unreserved, unlocked machine exists. + assert db.find_machine_to_service_task(task1).name == "name0" + # A task with a tag. A reserved, unlocked machine exists. + assert db.find_machine_to_service_task(task2).name == "name2" + # A task with a requested machine that is unreserved and unlocked. + assert db.find_machine_to_service_task(task3).name == "name0" + # A task with a requested machine that is reserved and unlocked. + assert db.find_machine_to_service_task(task4).name == "name1" + # A task with a tag. An unreserved, locked machine exists. + assert db.find_machine_to_service_task(task5) is None + # A task with a tag. A reserved, locked machine exists. + assert db.find_machine_to_service_task(task6) is None + # A task with a tag that doesn't match any machines. + with pytest.raises(CuckooUnserviceableTaskError): + db.find_machine_to_service_task(task7) + + def test_clean_machines(self, db: _Database): + """Add a couple machines and make sure that clean_machines removes them and their tags.""" + with db.session.begin(): + for i, tags in ((1, "tag1"), (2, None)): + self.add_machine( + db, + name=f"name{i}", + label=f"label{i}", + ip=f"1.2.3.{i}", + tags=tags, + ) + with db.session.begin(): + db.clean_machines() + + with db.session.begin(): + assert db.session.query(Machine).count() == 0 + assert db.session.query(Tag).count() == 1 + assert db.session.query(machines_tags).count() == 0 + + def test_delete_machine(self, db: _Database): + machines = [] + with db.session.begin(): + for i, tags in ((1, "tag1"), (2, None)): + machines.append(f"name{i}") + self.add_machine( + db, + name=machines[-1], + label=f"label{i}", + ip=f"1.2.3.{i}", + tags=tags, + ) + with db.session.begin(): + assert db.delete_machine(machines[0]) + assert db.session.query(Machine).count() == 1 + # Attempt to delete the same machine. + assert not db.delete_machine(machines[0]) + assert db.session.query(Machine).count() == 1 + assert db.delete_machine(machines[1]) + assert db.session.query(Machine).count() == 0 + + def test_set_machine_interface(self, db: _Database): + intf = "newintf" + with db.session.begin(): + self.add_machine(db) + assert db.set_machine_interface("label0", intf) is not None + assert db.set_machine_interface("idontexist", intf) is None + + with db.session.begin(): + assert db.session.query(Machine).filter_by(label="label0").one().interface == intf + + def test_set_vnc_port(self, db: _Database): + with db.session.begin(): + id1 = db.add_url("http://foo.bar") + id2 = db.add_url("http://foo.bar", options="nomonitor=1") + with db.session.begin(): + db.set_vnc_port(id1, 6001) + db.set_vnc_port(id2, 6002) + # Make sure that it doesn't fail if giving a task ID that doesn't exist. + db.set_vnc_port(id2 + 1, 6003) + with db.session.begin(): + t1 = db.session.query(Task).filter_by(id=id1).first() + assert t1.options == "vnc_port=6001" + t2 = db.session.query(Task).filter_by(id=id2).first() + assert t2.options == "nomonitor=1,vnc_port=6002" + + def test_update_clock_file(self, db: _Database, temp_filename: str, monkeypatch, freezer): + with db.session.begin(): + # This task ID doesn't exist. + assert db.update_clock(1) is None + + task_id = db.add_path(temp_filename) + now = datetime.datetime.utcnow() + monkeypatch.setattr(db.cfg.cuckoo, "daydelta", 1) + new_clock = now + datetime.timedelta(days=1) + assert db.update_clock(task_id) == new_clock + with db.session.begin(): + assert db.session.query(Task).one().clock == new_clock + + def test_update_clock_url(self, db: _Database, monkeypatch, freezer): + with db.session.begin(): + task_id = db.add_url("https://www.google.com") + now = datetime.datetime.utcnow() + monkeypatch.setattr(database.datetime, "utcnow", lambda: now) + # URL's are unaffected by the daydelta setting. + monkeypatch.setattr(db.cfg.cuckoo, "daydelta", 1) + assert db.update_clock(task_id) == now + with db.session.begin(): + assert db.session.query(Task).one().clock == now + + def test_set_status(self, db: _Database, freezer): + with db.session.begin(): + assert db.set_status(1, TASK_COMPLETED) is None + task_id = db.add_url("https://www.google.com") + with db.session.begin(): + task = db.session.query(Task).filter_by(id=task_id).one() + assert task.started_on is None + assert task.completed_on is None + now = datetime.datetime.utcnow() + freezer.move_to(now) + db.set_status(task_id, TASK_RUNNING) + task = db.session.query(Task).filter_by(id=task_id).one() + assert task.status == TASK_RUNNING + assert task.started_on == now + assert task.completed_on is None + + new_now = now + datetime.timedelta(seconds=1) + freezer.move_to(new_now) + db.set_status(task_id, TASK_COMPLETED) + task = db.session.query(Task).filter_by(id=task_id).one() + assert task.status == TASK_COMPLETED + assert task.started_on == now + assert task.completed_on == new_now + + def test_create_guest(self, db: _Database): + with db.session.begin(): + machine = self.add_machine(db) + task_id = db.add_url("http://foo.bar") + with db.session.begin(): + task = db.session.query(Task).filter_by(id=task_id).first() + guest = db.create_guest(machine, "kvm", task) + assert guest.name == "name0" + assert guest.label == "label0" + assert guest.manager == "kvm" + assert guest.task_id == task_id + assert guest.status == "init" + with db.session.begin(): + assert guest == db.session.query(Guest).first() @pytest.mark.parametrize( - "reserved,requested_platform,requested_machine,is_serviceable", + "kwargs,expected_machines", ( - (False, "windows", None, True), - (False, "linux", None, False), - (False, "windows", "label", True), - (False, "linux", "label", False), - (True, "windows", None, False), - (True, "linux", None, False), - (True, "windows", "label", True), - (True, "linux", "label", False), + ({"locked": True}, {"n2"}), + ({"locked": False}, {"n1", "n4", "n5", "n6"}), + # Make sure providing a label overrides "include_reserved" + ({"label": "l3"}, {"n3"}), + ({"label": "foo"}, set()), + ({"platform": "windows"}, {"n1", "n2", "n5", "n6"}), + ({"platform": "osx"}, set()), + ({"tags": ["tag1"]}, {"n1", "n2", "n4", "n6"}), + ({"tags": ["foo"]}, set()), + ({"arch": ["x86"]}, {"n1", "n2", "n4", "n5", "n6"}), + ({"arch": ["x64"]}, {"n1", "n2", "n4", "n5"}), + ({"arch": ["xy"]}, set()), + ({"os_version": ["win10"]}, {"n5"}), + ({"os_version": ["winxp"]}, set()), + ({"include_reserved": True}, {"n1", "n2", "n3", "n4", "n5", "n6"}), ), ) - def test_serviceability(self, reserved, requested_platform, requested_machine, is_serviceable): - self.d.add_machine( - name="win10-x64-1", - label="label", - ip="1.2.3.4", - platform="windows", - tags="tag1", - interface="int0", - snapshot="snap0", - resultserver_ip="5.6.7.8", - resultserver_port=2043, - arch="x64", - reserved=reserved, - ) - task = Task() - task.platform = requested_platform - task.machine = requested_machine - task.tags = [Tag("tag1")] - # tasks matching the available machines are serviceable - assert self.d.is_serviceable(task) is is_serviceable + def test_list_machines(self, db: _Database, kwargs, expected_machines): + with db.session.begin(): + self.add_machine(db, name="n1", label="l1") + m = self.add_machine(db, name="n2", label="l2") + m.locked = True + self.add_machine(db, name="n3", label="l3", reserved=True) + self.add_machine(db, name="n4", label="l4", platform="linux") + self.add_machine(db, name="n5", label="l5", tags="win10") + self.add_machine(db, name="n6", label="l6", arch="x86") + with db.session.begin(): + actual_machines = [machine.name for machine in db.list_machines(**kwargs)] + if kwargs == {"arch": ["x86"]}: + # This is the only parameter that causes the returned value to be in any + # guaranteed order. + assert actual_machines[0] == "n6" + actual_machines_set = set(actual_machines) + assert actual_machines_set == expected_machines + + def test_assign_machine_to_task(self, db: _Database): + with db.session.begin(): + t1 = db.add_url("http://one.com") + t2 = db.add_url("http://two.com") + m1 = self.add_machine(db) + with db.session.begin(): + task1 = db.session.get(Task, t1) + task2 = db.session.get(Task, t2) + db.assign_machine_to_task(task1, m1) + db.assign_machine_to_task(task2, None) + with db.session.begin(): + task1 = db.session.get(Task, t1) + task2 = db.session.get(Task, t2) + assert task1.machine == "label0" + assert task1.machine_id == m1.id + assert task2.machine is None + assert task2.machine_id is None + + def test_lock_machine(self, db: _Database, freezer): + with db.session.begin(): + m1 = self.add_machine(db) + with db.session.begin(): + db.lock_machine(m1) + with db.session.begin(): + m1 = db.session.get(Machine, m1.id) + assert m1.locked + assert m1.locked_changed_on == datetime.datetime.now() + assert m1.status == "running" + freezer.move_to(datetime.datetime.now() + datetime.timedelta(minutes=5)) + with db.session.begin(): + assert db.count_machines_running() == 1 + db.unlock_machine(m1) + with db.session.begin(): + m1 = db.session.get(Machine, m1.id) + assert not m1.locked + assert m1.locked_changed_on == datetime.datetime.now() + with db.session.begin(): + assert db.count_machines_running() == 0 @pytest.mark.parametrize( - "task_instructions,machine_instructions,expected_results", - # @param task_instructions : list of tasks to be created, each tuple represent the tag to associate to tasks and the number of such tasks to create - # @param machine_instructions : list of machines to be created, each collections represent the parameters to associate to machines and the number of such machines to create - # @param expected_results : dictionary of expected tasks to be mapped to machines numbered by their tags + "kwargs,expected_retval", ( - # No tasks, no machines - ( - [], - [], - {}, - ), - # Assign 10 tasks with the same tag to 10 available machines with that tag - ( - [("tag1", 10)], - [("windows", "x64", "tag1", 10)], - {"tag1": 10}, - ), - # Assign 10 tasks (8 with one tag, 2 with another) to 8 available machines with that tag and 2 available machines with the other tag - ( - [("tag1", 8), ("tag2", 2)], - [("windows", "x64", "tag1,", 8), ("windows", "x86", "tag2,", 2)], - {"tag1": 8, "tag2": 2}, - ), - # Assign 43 tasks total containing a variety of tags to 40/80 available machines with the first tag, 2/2 available machines with the second tag and 1/2 available machines with the third tag - ( - [("tag1", 40), ("tag2", 2), ("tag3", 1)], - [("windows", "x64", "tag1", 80), ("windows", "x86", "tag2", 2), ("linux", "x64", "tag3", 2)], - {"tag1": 40, "tag2": 2, "tag3": 1}, - ), + ({"machine": "l1"}, None), # The specified machine is in use. + ({"machine": "l2"}, "n2"), # The specified machine is not in use. + ({"machine": "l3"}, "n3"), # The specific machine is reserved but not in use. + ({"machine": "foo"}, CuckooUnserviceableTaskError), # No such machine exists. + ({"platform": "windows"}, "n2"), + ({"platform": "osx"}, CuckooUnserviceableTaskError), + ({"tags": "tag1"}, "n2"), + ({"tags": "foo"}, CuckooUnserviceableTaskError), + ({"tags": "x64"}, "n2"), + ({"tags": ""}, "n2"), + ({"tags": "arm"}, CuckooUnserviceableTaskError), + # msix requires a machine with the win10 or win11 tag. + ({"package": "msix"}, CuckooUnserviceableTaskError), + ({"package": "foo"}, "n2"), ), ) - def test_map_tasks_to_available_machines(self, task_instructions, machine_instructions, expected_results): - tasks = [] - machines = [] - cleanup_tasks = [] - - # Parsing machine instructions - for machine_instruction in machine_instructions: - platform, archs, tags, num_of_machines = machine_instruction - for i in range(num_of_machines): - machine_name = str(platform) + str(archs) + str(i) - self.d.add_machine( - name=machine_name, - label=machine_name, - ip="1.2.3.4", - platform=platform, - tags=tags, - interface="int0", - snapshot="snap0", - resultserver_ip="5.6.7.8", - resultserver_port=2043, - arch=archs, - reserved=False, - ) - machines.append((machine_name, tags)) - for machine, machine_tag in machines: - self.d.set_machine_status(machine, "running") - # Parsing tasks instructions - for task_instruction in task_instructions: - tags, num_of_tasks = task_instruction - for i in range(num_of_tasks): - task_id = "Sample_%s_%s" % (tags, i) - with open(task_id, "w") as f: - f.write(task_id) - cleanup_tasks.append(task_id) - task = self.d.add_path(file_path=task_id, tags=tags) - task = self.d.view_task(task) - tasks.append(task) - - # Parse the expected results - total_task_to_be_assigned = 0 - total_task_to_be_assigned = sum(expected_results.values()) - - total_task_assigned = 0 - results = [] - for tag in expected_results.keys(): - results.append([tag, 0]) - relevant_tasks = self.d.map_tasks_to_available_machines(tasks) - for task in relevant_tasks: - tags = [tag.name for tag in task.tags] - for result in results: - if result[0] == tags[0]: - result[1] += 1 - break - total_task_assigned += len(relevant_tasks) - - # Cleanup - for file in cleanup_tasks: - os.remove(file) - - # Test results - assert total_task_assigned == total_task_to_be_assigned - for tag in expected_results.keys(): - for result in results: - if tag == result[0]: - assert expected_results[tag] == result[1] + def test_find_machine_to_service_task(self, db: _Database, temp_filename: str, kwargs, expected_retval): + with db.session.begin(): + self.add_machine(db, name="n1", label="l1", locked=True) + self.add_machine(db, name="n2", label="l2", tags="tag1,x64") + self.add_machine(db, name="n3", label="l3", reserved=True) + + task_id = db.add_path(temp_filename, **kwargs) + with db.session.begin(): + task = db.session.get(Task, task_id) + if isinstance(expected_retval, type) and issubclass(expected_retval, Exception): + with pytest.raises(expected_retval): + db.find_machine_to_service_task(task) + else: + result = db.find_machine_to_service_task(task) + if expected_retval is None: + assert result is None + else: + assert result.name == expected_retval @pytest.mark.parametrize( - "task,machine,expected_results", - # @param task : dictionary describing the task to be created - # @param machine : dictionary describing the machine to be created - # @param expected_results : list of expected locked machines after attempting the test in the format of {number_of_expected_available_machines, should_raise_exception, should_be_locked} + "categories,expected_task", ( - # Generic task with no contrains - ( - {"label": "task0", "machine": None, "platform": None, "tags": None, "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1", - "locked": False, - }, - (0, False, True), - ), - # Suitable task which is going to be locking this machine - ( - {"label": "task1", "machine": None, "platform": "windows", "tags": "tag1", "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1", - "locked": False, - }, - (0, False, True), - ), - # Suitable task which is going to be locking this machine from the label - ( - {"label": "task2", "machine": "machine1", "platform": None, "tags": None, "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1", - "locked": False, - }, - (0, False, True), - ), - # Unsuitable task which is going to make the function fail the locking (label + platform) - ( - {"label": "task3", "machine": "machine1", "platform": "windows", "tags": None, "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1", - "locked": False, - }, - (1, False, False), - ), - # Unsuitable task which is going to make the function fail the locking (label + tags) - ( - {"label": "task4", "machine": "machine1", "platform": None, "tags": "tag1", "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1", - "locked": False, - }, - (1, False, False), - ), - # Unsuitable task which is going to make the function fail the locking (label + platform + tags) - ( - {"label": "task5", "machine": "machine1", "platform": "windows", "tags": "tag1", "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1", - "locked": False, - }, - (1, False, False), - ), - # Suitable task which is going to fail locking the machine as the machine is already locked - ( - {"label": "task6", "machine": None, "platform": "windows", "tags": "tag1", "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1", - "locked": True, - }, - (0, False, True), - ), - # Suitable task which is going to fail locking the machine because the machine is reserved - ( - {"label": "task7", "machine": None, "platform": "windows", "tags": "tag1", "os_version": None}, - { - "label": "machine1", - "reserved": True, - "platform": "windows", - "arch": "x64", - "tags": "tag1", - "locked": False, - }, - (1, True, False), - ), - # Suitable task which is going to not locked the machine as it is not compatible (tags) - ( - {"label": "task8", "machine": None, "platform": "windows", "tags": "tag1", "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag2", - "locked": False, - }, - (1, True, False), - ), - # Suitable task which is going to not locked the machine as it is not compatible (platform) - ( - {"label": "task9", "machine": None, "platform": "windows", "tags": "tag1", "os_version": None}, - { - "label": "machine1", - "reserved": False, - "platform": "linux", - "arch": "x64", - "tags": "tag1", - "locked": False, - }, - (1, True, False), - ), - # Suitable task which is going to not locked the machine as it is not compatible (os_version) - ( - {"label": "task10", "machine": None, "platform": "windows", "tags": "tag1", "os_version": ["win10"]}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1,win7", - "locked": False, - }, - (1, True, False), - ), - # Suitable task which is going to be locking the machine as the os_version is compatible (os_version) - ( - {"label": "task11", "machine": None, "platform": "windows", "tags": "tag1", "os_version": ["win10"]}, - { - "label": "machine1", - "reserved": False, - "platform": "windows", - "arch": "x64", - "tags": "tag1,win10", - "locked": False, - }, - (0, False, True), - ), + (None, "t3"), + (["url"], "t3"), + (["file"], "t4"), + (["other"], None), ), ) - def test_lock_machine(self, task, machine, expected_results): - if machine["tags"] is not None: - machine_name = str(machine["label"]) + "_" + str(machine["tags"].replace(",", "_")) - else: - machine_name = str(machine["label"]) - self.d.add_machine( - name=machine_name, - label=machine["label"], - ip="1.2.3.4", - platform=machine["platform"], - tags=machine["tags"], - arch=machine["arch"], - interface="int0", - snapshot="snap0", - resultserver_ip="5.6.7.8", - resultserver_port=2043, - reserved=machine["reserved"], - ) - if machine["locked"]: - try: - queried_machine = self.session.query(Machine).filter_by(label=machine["label"]) - if queried_machine: - queried_machine.locked = True - try: - self.session.commit() - self.session.refresh(machine) - except SQLAlchemyError: - self.session.rollback() - pass - except SQLAlchemyError: - pass - task_id = "Sample_%s_%s" % (task["label"], task["tags"]) - with open(task_id, "w") as f: - f.write(task_id) - queried_task = self.d.add_path( - file_path=task_id, - machine=task["machine"], - platform=task["platform"], - tags=task["tags"], + def test_fetch_task(self, db: _Database, temp_filename, categories, expected_task): + with db.session.begin(): + tasks = dict( + t1=db.add_url("https://www.google.com"), + t2=db.add_url("https://www.google.com"), + t3=db.add_url("https://www.google.com", priority=2), + t4=db.add_path(temp_filename), + ) + db.set_status(tasks["t2"], TASK_RUNNING) + with db.session.begin(): + task = db.fetch_task(categories) + if expected_task is None: + assert task is None + else: + assert task.id == tasks[expected_task] + assert task.status == TASK_RUNNING + + def test_guest(self, db: _Database, freezer): + with db.session.begin(): + machine = self.add_machine(db) + task_id = db.add_url("http://foo.bar") + task = db.session.query(Task).filter_by(id=task_id).first() + guest = db.create_guest(machine, "kvm", task) + with db.session.begin(): + db.guest_set_status(task_id, "completed") + # Make sure it doesn't fall over when given a task that doesn't exist. + db.guest_set_status(task_id + 1, "completed") + with db.session.begin(): + guest_id = guest.id + assert db.session.query(Guest).first().status == "completed" + assert db.guest_get_status(task_id) == "completed" + assert db.guest_get_status(task_id + 1) is None + db.guest_stop(guest_id) + with db.session.begin(): + assert db.session.query(Guest).first().shutdown_on == datetime.datetime.now() + db.guest_stop(guest_id + 1) + db.guest_remove(guest_id) + with db.session.begin(): + assert db.session.query(Guest).first() is None + db.guest_remove(guest_id + 1) + + @pytest.mark.parametrize( + "kwargs,expected_retval", + ( + ({"label": "l1"}, 0), + ({"label": "l2"}, 1), + ({"label": "l3"}, 1), + ({"label": "foo"}, 0), + ({"platform": "windows"}, 2), + ({"platform": "osx"}, 0), + ({"tags": ["tag1"]}, 2), + ({"tags": ["foo"]}, 0), + ({"arch": ["x64"]}, 1), + ({"arch": ["x86"]}, 2), + ({"arch": ["arm"]}, 0), + # msix requires a machine with the win10 or win11 tag. + ({"os_version": ["win10"]}, 1), + ({"os_version": ["foo"]}, 0), + ({"include_reserved": True}, 3), + ), + ) + def test_count_machines_available(self, db: _Database, kwargs, expected_retval): + with db.session.begin(): + m = self.add_machine(db, name="n1", label="l1") + m.locked = True + self.add_machine(db, name="n2", label="l2", tags="tag1,x64") + self.add_machine(db, name="n3", label="l3", reserved=True) + self.add_machine(db, name="n4", label="l4", tags="tag1,win10", arch="x86") + with db.session.begin(): + assert db.count_machines_available(**kwargs) == expected_retval + + def test_get_available_machines(self, db: _Database): + with db.session.begin(): + m = self.add_machine(db, name="n1", label="l1") + m.locked = True + self.add_machine(db, name="n2", label="l2", tags="tag1,x64") + self.add_machine(db, name="n3", label="l3", reserved=True) + with db.session.begin(): + assert set(m.label for m in db.get_available_machines()) == {"l2", "l3"} + + def test_set_machine_status(self, db: _Database, freezer): + with db.session.begin(): + self.add_machine(db, name="n1", label="l1") + self.add_machine(db, name="n2", label="l2") + with db.session.begin(): + db.set_machine_status("l2", "running") + with db.session.begin(): + machine = db.session.query(Machine).filter_by(label="l2").one() + assert machine.status == "running" + assert machine.status_changed_on == datetime.datetime.now() + + machine = db.session.query(Machine).filter_by(label="l1").one() + assert machine.status != "running" + + @pytest.mark.parametrize( + "kwargs,expected_count", + ( + ({}, 3), + ({"category": "url"}, 2), + ({"category": "foo"}, 0), + ({"status": "running"}, 1), + ({"status": "foo"}, 0), + ({"not_status": "running"}, 2), + ({"not_status": "foo"}, 3), + ({"status": "running", "not_status": "running"}, 0), + ), + ) + def test_count_matching_tasks(self, db: _Database, temp_filename, kwargs, expected_count): + with db.session.begin(): + db.add_path(temp_filename) + db.add_url("https://www.google.com") + t3 = db.add_url("https://www.bing.com") + db.set_status(t3, "running") + with db.session.begin(): + assert db.count_matching_tasks(**kwargs) == expected_count + + def test_check_file_uniq(self, db: _Database, temp_filename, freezer): + with db.session.begin(): + assert not db.check_file_uniq("a") + db.add_path(temp_filename) + with db.session.begin(): + with open(temp_filename, "rb") as fil: + sha256 = hashlib.sha256(fil.read()).hexdigest() + assert db.check_file_uniq(sha256) + freezer.move_to(datetime.datetime.now() + datetime.timedelta(hours=2)) + assert not db.check_file_uniq(sha256, hours=1) + + def test_list_sample_parent(self, db: _Database, temp_filename): + dct = dict( + md5="md5", + crc32="crc32", + sha1="sha1", + sha256="sha256", + sha512="sha512", + file_size=100, + file_type="file_type", + ssdeep="ssdeep", + source_url="source_url", ) - queried_task = self.d.view_task(queried_task) - queried_task_archs = [tag.name for tag in queried_task.tags if tag.name in ("x86", "x64")] - queried_task_tags = [tag.name for tag in queried_task.tags if tag.name not in queried_task_archs] - number_of_expected_available_machines, should_raise_exception, should_be_locked = expected_results - if should_raise_exception: - with pytest.raises(CuckooOperationalError): - returned_machine = self.d.lock_machine( - label=queried_task.machine, - platform=queried_task.platform, - tags=queried_task_tags, - arch=queried_task_archs, - os_version=task["os_version"], + with db.session.begin(): + with db.session.begin_nested(): + sample = Sample(**dct) + db.session.add(sample) + sample_id = sample.id + task_id = db.add_path(temp_filename) + sample2 = db.session.query(Sample).filter(Sample.id != sample.id).one() + sample2.parent = sample_id + + with db.session.begin(): + exp_val = dict(**dct, parent=None, id=sample_id) + assert db.list_sample_parent(task_id=task_id) == exp_val + assert db.list_sample_parent(task_id=task_id + 1) == {} + + def test_list_tasks(self, db: _Database, temp_filename, freezer): + with db.session.begin(): + t1 = db.add_path(temp_filename, options="minhook=1") + t2 = db.add_url("https://2.com", tags_tasks="tag1") + t3 = db.add_url("https://3.com", user_id=5) + start = datetime.datetime.now() + with db.session.begin(): + + def get_ids(**kwargs): + return [t.id for t in db.list_tasks(**kwargs)] + + assert get_ids(limit=1) == [t3] + assert get_ids(category="url") == [t3, t2] + assert get_ids(offset=1) == [t2, t1] + with db.session.begin_nested() as nested: + now = start + datetime.timedelta(minutes=1) + freezer.move_to(now) + db.set_status(t2, TASK_COMPLETED) + db.session.query(Task).get(t1).added_on = start + db.session.query(Task).get(t2).added_on = start + datetime.timedelta(seconds=1) + db.session.query(Task).get(t3).added_on = now + assert get_ids(status=TASK_COMPLETED) == [t2] + assert get_ids(not_status=TASK_COMPLETED) == [t3, t1] + assert get_ids(completed_after=start) == [t2] + assert get_ids(order_by=(Task.completed_on, Task.id)) == [t1, t3, t2] + assert get_ids(order_by=(Task.id)) == [t1, t2, t3] + assert get_ids(added_before=now) == [t2, t1] + nested.rollback() + assert get_ids(sample_id=1) == [t1] + assert get_ids(id_before=t3) == [t2, t1] + assert get_ids(id_after=t2) == [t3] + assert get_ids(options_like="minhook") == [t1] + assert get_ids(options_not_like="minhook") == [t3, t2] + assert get_ids(tags_tasks_like="1") == [t2] + assert get_ids(task_ids=(t1, t2)) == [t2, t1] + assert get_ids(task_ids=(t3 + 1,)) == [] + assert get_ids(user_id=5) == [t3] + assert get_ids(user_id=0) == [t2, t1] + + def test_minmax_tasks(self, db: _Database, freezer): + with db.session.begin(): + assert db.minmax_tasks() == (0, 0) + + start_time = datetime.datetime.now() + with db.session.begin(): + t1 = db.add_url("https://1.com") + t2 = db.add_url("https://2.com") + t3 = db.add_url("https://3.com") + t4 = db.add_url("https://4.com") + _t5 = db.add_url("https://5.com") + t2_started = start_time + freezer.move_to(t2_started) + db.set_status(t2, TASK_RUNNING) + freezer.move_to(start_time + datetime.timedelta(minutes=1)) + db.set_status(t1, TASK_RUNNING) + freezer.move_to(start_time + datetime.timedelta(minutes=2)) + db.set_status(t3, TASK_RUNNING) + freezer.move_to(start_time + datetime.timedelta(minutes=3)) + db.set_status(t4, TASK_RUNNING) + # t5 has not started + + freezer.move_to(start_time + datetime.timedelta(minutes=4)) + db.set_status(t1, TASK_COMPLETED) + # t2 is still running + freezer.move_to(start_time + datetime.timedelta(minutes=5)) + db.set_status(t4, TASK_COMPLETED) + t3_completed = start_time + datetime.timedelta(minutes=6) + freezer.move_to(t3_completed) + db.set_status(t3, TASK_COMPLETED) + with db.session.begin(): + assert db.minmax_tasks() == (int(t2_started.timestamp()), int(t3_completed.timestamp())) + + def test_get_tlp_tasks(self, db: _Database): + with db.session.begin(): + db.add_url("https://1.com") + with db.session.begin(): + assert db.get_tlp_tasks() == [] + with db.session.begin(): + t2 = db.add_url("https://2.com", tlp="true") + with db.session.begin(): + assert db.get_tlp_tasks() == [t2] + + def test_get_file_types(self, db: _Database, temp_filename): + with db.session.begin(): + assert db.get_file_types() == [] + with db.session.begin(): + for i in range(2): + db.session.add( + Sample( + md5=f"md5_{i}", + sha1=f"sha1_{i}", + crc32=f"crc32_{i}", + sha256=f"sha256_{i}", + sha512=f"sha512_{i}", + file_size=100 + i, + file_type=f"file_type_{i}", + ) ) - assert returned_machine is None - else: - returned_machine = self.d.lock_machine( - label=queried_task.machine, - platform=queried_task.platform, - tags=queried_task_tags, - arch=queried_task_archs, - os_version=task["os_version"], - ) - output_machine = self.d.list_machines() - if output_machine and returned_machine is not None: - output_machine = output_machine[0] - # Normalizing the output in order to remove the joined tags in one of the output - output_machine.__dict__.pop("tags", None) - output_machine.__dict__.pop("_sa_instance_state", None) - returned_machine.__dict__.pop("_sa_instance_state", None) - output_machine.__dict__.pop("status", None) - output_machine.__dict__.pop("status_changed_on", None) - returned_machine.__dict__.pop("status", None) - returned_machine.__dict__.pop("status_changed_on", None) - assert output_machine.locked == should_be_locked - assert returned_machine.__dict__ == output_machine.__dict__ - # cleanup - os.remove(task_id) - assert len(self.d.get_available_machines()) == number_of_expected_available_machines + with db.session.begin(): + assert db.get_file_types() == ["file_type_0", "file_type_1"] + + def test_get_tasks_status_count(self, db: _Database): + with db.session.begin(): + assert db.get_tasks_status_count() == {} + with db.session.begin(): + _t1 = db.add_url("https://1.com") + t2 = db.add_url("https://2.com") + t3 = db.add_url("https://3.com") + db.set_status(t2, TASK_RUNNING) + db.set_status(t3, TASK_RUNNING) + with db.session.begin(): + assert db.get_tasks_status_count() == { + TASK_PENDING: 1, + TASK_RUNNING: 2, + } + + def test_count_tasks(self, db: _Database): + with db.session.begin(): + assert db.count_tasks() == 0 + with db.session.begin(): + _t1 = db.add_url("https://1.com") + t2 = db.add_url("https://2.com") + t3 = db.add_url("https://3.com") + db.set_status(t2, TASK_RUNNING) + db.set_status(t3, TASK_RUNNING) + with db.session.begin(): + assert db.count_tasks() == 3 + assert db.count_tasks(status=TASK_RUNNING) == 2 + assert db.count_tasks(status=TASK_COMPLETED) == 0 + + def test_delete_task(self, db: _Database, temp_filename): + with db.session.begin(): + t1 = db.add_url("https://1.com") + t2 = db.add_path(temp_filename, tags="x86") + with db.session.begin(): + db.delete_task(t2) + with db.session.begin(): + tasks = db.session.query(Task).all() + assert len(tasks) == 1 + assert tasks[0].id == t1 + assert not db.delete_task(t2) + + def test_delete_tasks(self, db: _Database, temp_filename): + with db.session.begin(): + t1 = db.add_url("https://1.com") + t2 = db.add_path(temp_filename, tags="x86") + t3 = db.add_url("https://3.com") + with db.session.begin(): + assert db.delete_tasks([]) + assert db.delete_tasks([t1, t2, t3 + 1]) + tasks = db.session.query(Task).all() + assert len(tasks) == 1 + assert tasks[0].id == t3 + assert db.delete_tasks([t1, t2]) + tasks = db.session.query(Task).all() + assert len(tasks) == 1 + assert tasks[0].id == t3 + + def test_view_sample(self, db: _Database): + with db.session.begin(): + samples = [] + for i in range(2): + samples.append( + Sample( + md5=f"md5_{i}", + sha1=f"sha1_{i}", + crc32=f"crc32_{i}", + sha256=f"sha256_{i}", + sha512=f"sha512_{i}", + file_size=100 + i, + file_type=f"file_type_{i}", + ) + ) + with db.session.begin_nested(): + db.session.add(samples[-1]) + db.session.expunge(samples[-1]) + with db.session.begin(): + assert db.view_sample(samples[-1].id).to_dict() == samples[-1].to_dict() + assert db.view_sample(samples[-1].id + 1) is None + + def test_find_sample(self, db: _Database, temp_filename): + with db.session.begin(): + samples = [] + parent_id = None + for i in range(2): + sample = Sample( + md5=f"md5_{i}", + sha1=f"sha1_{i}", + crc32=f"crc32_{i}", + sha256=f"sha256_{i}", + sha512=f"sha512_{i}", + file_size=100 + i, + file_type=f"file_type_{i}", + parent=parent_id, + ) + with db.session.begin_nested(): + db.session.add(sample) + parent_id = sample.id + samples.append(sample.id) + t1 = db.add_path(temp_filename) + with open(temp_filename, "rb") as fil: + sha256 = hashlib.sha256(fil.read()).hexdigest() + task_sample = db.session.query(Sample).filter_by(sha256=sha256).one().id + with db.session.begin(): + assert db.find_sample() is False + assert db.find_sample(md5="md5_1").id == samples[1] + assert db.find_sample(sha1="sha1_1").id == samples[1] + assert db.find_sample(sha256="sha256_0").id == samples[0] + assert [s.id for s in db.find_sample(parent=samples[0])] == samples[1:] + assert [s.id for s in db.find_sample(parent=samples[1])] == [] + # When a task_id is passed, find_sample returns Task objects instead of Sample objects. + assert [t.sample.id for t in db.find_sample(task_id=t1)] == [task_sample] + assert [s.id for s in db.find_sample(sample_id=samples[1])] == [samples[1]] + + def test_sample_still_used(self, db: _Database, temp_filename): + with db.session.begin(): + t1 = db.add_path(temp_filename) + with open(temp_filename, "rb") as fil: + sha256 = hashlib.sha256(fil.read()).hexdigest() + with db.session.begin(): + # No other tasks are associated with this sample. + assert not db.sample_still_used(sha256, t1) + with db.session.begin(): + # Add another task for the sample. + t2 = db.add_path(temp_filename) + with db.session.begin(): + # So now it IS still being used. + assert db.sample_still_used(sha256, t1) + with db.session.begin(): + # Mark the second task as completed... + db.set_status(t2, TASK_COMPLETED) + with db.session.begin(): + # So it is no longer "used". + assert not db.sample_still_used(sha256, t1) + + def test_count_samples(self, db: _Database, temp_filename): + with db.session.begin(): + assert db.count_samples() == 0 + db.add_path(temp_filename) + with db.session.begin(): + assert db.count_samples() == 1 + + def test_view_machine_by_label(self, db: _Database): + with db.session.begin(): + m0 = self.add_machine(db, name="name0", label="label0") + self.add_machine(db, name="name1", label="label1") + db.session.refresh(m0) + db.session.expunge_all() + with db.session.begin(): + assert db.view_machine_by_label("foo") is None + m0_dict = db.session.query(Machine).get(m0.id).to_dict() + assert db.view_machine_by_label("label0").to_dict() == m0_dict + + def test_get_source_url(self, db: _Database, temp_filename): + with db.session.begin(): + assert db.get_source_url() is False + assert db.get_source_url(1) is None + db.add_path(temp_filename) + with open(temp_filename, "a") as fil: + fil.write("a") + db.add_path(temp_filename) + url = "https://badguys.com" + db.session.query(Sample).get(1).source_url = url + with db.session.begin(): + assert db.get_source_url(1) == url + assert db.get_source_url(2) is None + + def test_ban_user_tasks(self, db: _Database): + with db.session.begin(): + t1 = db.add_url("https://1.com", user_id=0) + t2 = db.add_url("https://2.com", user_id=1) + t3 = db.add_url("https://3.com", user_id=1) + t4 = db.add_url("https://3.com", user_id=1) + db.set_status(t4, TASK_COMPLETED) + with db.session.begin(): + db.ban_user_tasks(1) + assert db.session.query(Task).get(t1).status == TASK_PENDING + assert db.session.query(Task).get(t2).status == TASK_BANNED + assert db.session.query(Task).get(t3).status == TASK_BANNED + assert db.session.query(Task).get(t4).status == TASK_COMPLETED + + def test_tasks_reprocess(self, db: _Database): + with db.session.begin(): + err, _msg, old_status = db.tasks_reprocess(1) + assert err is True + assert old_status == "" + t1 = db.add_url("https://1.com") + with db.session.begin(): + err, _msg, old_status = db.tasks_reprocess(t1) + assert err is True + assert old_status == TASK_PENDING + db.set_status(t1, TASK_REPORTED) + with db.session.begin(): + err, _msg, old_status = db.tasks_reprocess(t1) + assert err is False + assert old_status == TASK_REPORTED + assert db.session.query(Task).get(t1).status == TASK_COMPLETED @pytest.mark.parametrize( "task,machines,expected_result", @@ -972,33 +1378,34 @@ def test_lock_machine(self, task, machine, expected_results): ), ), ) - def test_filter_machines_to_task(self, task, machines, expected_result): - for machine in machines: - machine_name = ( - str(machine["label"]) + str(machine["platform"]) + str(machine["arch"]) + str(task["label"].replace("task", "")) - ) - self.d.add_machine( - name=machine_name, - label=machine["label"], - ip="1.2.3.4", - platform=machine["platform"], - tags=machine["tags"], - interface="int0", - snapshot="snap0", - resultserver_ip="5.6.7.8", - resultserver_port=2043, - arch=machine["arch"], - reserved=machine["reserved"], - ) + def test_filter_machines_to_task(self, task, machines, expected_result, db: _Database): + with db.session.begin(): + for machine in machines: + machine_name = ( + str(machine["label"]) + str(machine["platform"]) + str(machine["arch"]) + str(task["label"].replace("task", "")) + ) + db.add_machine( + name=machine_name, + label=machine["label"], + ip="1.2.3.4", + platform=machine["platform"], + tags=machine["tags"], + interface="int0", + snapshot="snap0", + resultserver_ip="5.6.7.8", + resultserver_port="2043", + arch=machine["arch"], + reserved=machine["reserved"], + ) if task["tags"] is not None: task_archs = [tag for tag in task["tags"].split(",") if tag in ("x86", "x64")] task_tags = [tag for tag in task["tags"].split(",") if tag not in task_archs] else: task_archs = None task_tags = None - with self.session as session: - created_machines = session.query(Machine) - output_machines = self.d.filter_machines_to_task( + with db.session.begin(): + created_machines = db.session.query(Machine) + output_machines = db.filter_machines_to_task( machines=created_machines, label=task["machine"], platform=task["platform"], @@ -1011,29 +1418,3 @@ def test_filter_machines_to_task(self, task, machines, expected_result): assert len(output_machines) == expected_result else: assert output_machines.count() == expected_result - - @pytest.mark.parametrize( - "task,expected_result", - # @param task : dictionary describing the task to be validated - # @param expected_result : expected_result of the function tested - ( - # No parameters - ({"label": None, "platform": None, "tags": None}, True), - # Only label - ({"label": "task1", "platform": None, "tags": None}, True), - # Only platform - ({"label": None, "platform": "windows", "tags": None}, True), - # Only tags - ({"label": None, "platform": None, "tags": "tag1"}, True), - # Label and platform - ({"label": "task1", "platform": "windows", "tags": None}, False), - # Label and tags - ({"label": "task1", "platform": None, "tags": "tag1"}, False), - # Platform and tags - ({"label": None, "platform": "windows", "tags": "tag1"}, True), - # Label, Platform and tags - ({"label": "task1", "platform": "windows", "tags": "tag1"}, False), - ), - ) - def test_validate_task_parameters(self, task, expected_result): - assert self.d.validate_task_parameters(label=task["label"], platform=task["platform"], tags=task["tags"]) == expected_result diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 9108d3d1f3b..1fc772db0e5 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -162,20 +162,22 @@ def test_init_storage_other_error(self, clean_init_storage, mocker, caplog): assert analysis_man.init_storage() is False assert "Unable to create analysis folder" in caplog.text + @pytest.mark.usefixtures("db") def test_check_file(self, mocker): class mock_sample: sha256 = "e3b" analysis_man = AnalysisManager(task=mock_task(), error_queue=queue.Queue()) - mocker.patch("lib.cuckoo.core.database.Database.view_sample", return_value=mock_sample()) + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) assert analysis_man.check_file("e3b") is True + @pytest.mark.usefixtures("db") def test_check_file_err(self, mocker): class mock_sample: sha256 = "f3b" analysis_man = AnalysisManager(task=mock_task(), error_queue=queue.Queue()) - mocker.patch("lib.cuckoo.core.database.Database.view_sample", return_value=mock_sample()) + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) assert analysis_man.check_file("e3b") is False @pytest.mark.skip(reason="TODO") @@ -208,6 +210,7 @@ def test_store_file_symlink_err(self, symlink, caplog): analysis_man.store_file(sha256="e3be3b") assert "Unable to create symlink/copy" in caplog.text + @pytest.mark.usefixtures("db") def test_acquire_machine(self, setup_machinery, setup_machine_lock): class mock_machinery: def availables(self, label, platform, tags, arch, os_version): @@ -486,7 +489,7 @@ class mock_sample: analysis_man = AnalysisManager(task=mock_task_cat, error_queue=queue.Queue()) assert analysis_man.init_storage() is True - mocker.patch("lib.cuckoo.core.scheduler.Database.view_sample", return_value=mock_sample()) + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) assert analysis_man.category_checks() is None @@ -497,7 +500,7 @@ class mock_sample: analysis_man = AnalysisManager(task=mock_task(), error_queue=queue.Queue()) assert analysis_man.init_storage() is True - mocker.patch("lib.cuckoo.core.scheduler.Database.view_sample", return_value=mock_sample()) + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) assert analysis_man.category_checks() is False @@ -511,7 +514,7 @@ class mock_sample: mock_task_cat.target = sample_location analysis_man = AnalysisManager(task=mock_task_cat, error_queue=queue.Queue()) assert analysis_man.init_storage() is True - mocker.patch("lib.cuckoo.core.scheduler.Database.view_sample", return_value=mock_sample()) + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) mocker.patch("lib.cuckoo.core.scheduler.AnalysisManager.store_file", return_value=False) assert analysis_man.category_checks() is False @@ -528,6 +531,6 @@ class mock_sample: analysis_man = AnalysisManager(task=mock_task_cat, error_queue=queue.Queue()) assert analysis_man.init_storage() is True - mocker.patch("lib.cuckoo.core.scheduler.Database.view_sample", return_value=mock_sample()) + mocker.patch("lib.cuckoo.core.database._Database.view_sample", return_value=mock_sample()) assert analysis_man.category_checks() is True diff --git a/tests/web/test_apiv2.py b/tests/web/test_apiv2.py index 389b7815976..796eefa6949 100644 --- a/tests/web/test_apiv2.py +++ b/tests/web/test_apiv2.py @@ -1,8 +1,10 @@ +import pathlib from unittest.mock import patch import pytest from django.test import SimpleTestCase +from lib.cuckoo.common.config import ConfigMeta from lib.cuckoo.core.database import ( TASK_BANNED, TASK_COMPLETED, @@ -19,13 +21,19 @@ ) -@pytest.mark.usefixtures("tmp_cuckoo_root") -class ReprocessTask(SimpleTestCase): +@pytest.fixture +def taskreprocess_enabled(custom_conf_path: pathlib.Path): + with open(custom_conf_path / "api.conf", "wt") as fil: + print("[taskreprocess]\nenabled = yes", file=fil) + ConfigMeta.refresh() + yield + +@pytest.mark.usefixtures("db", "tmp_cuckoo_root") +class ReprocessTask(SimpleTestCase): taskprocess_config = "lib.cuckoo.common.web_utils.apiconf.taskreprocess" """API configuration to patch in each test case.""" - @patch.dict(taskprocess_config, {"enabled": False}) def test_api_disabled(self): response = self.client.get("/apiv2/tasks/reprocess/1/") assert response.status_code == 200 @@ -33,9 +41,9 @@ def test_api_disabled(self): json_body = {"error": True, "error_value": "Task Reprocess API is Disabled"} assert response.json() == json_body - @patch.dict(taskprocess_config, {"enabled": True}) + @pytest.mark.usefixtures("taskreprocess_enabled") def test_task_does_not_exist(self): - patch_target = "lib.cuckoo.core.database.Database.view_task" + patch_target = "lib.cuckoo.core.database._Database.view_task" with patch(patch_target, return_value=None): response = self.client.get("/apiv2/tasks/reprocess/1/") assert response.status_code == 200 @@ -43,11 +51,11 @@ def test_task_does_not_exist(self): json_body = {"error": True, "error_value": "Task ID does not exist in the database"} assert response.json() == json_body - @patch.dict(taskprocess_config, {"enabled": True}) + @pytest.mark.usefixtures("taskreprocess_enabled") def test_can_reprocess(self): task = Task() valid_status = (TASK_REPORTED, TASK_RECOVERED, TASK_FAILED_PROCESSING, TASK_FAILED_REPORTING) - patch_target = "lib.cuckoo.core.database.Database.view_task" + patch_target = "lib.cuckoo.core.database._Database.view_task" with patch(patch_target, return_value=task): for status in valid_status: expected_response = {"error": False, "data": f"Task ID 1 with status {status} marked for reprocessing"} @@ -58,7 +66,7 @@ def test_can_reprocess(self): assert response.headers["content-type"] == "application/json" assert response.json() == expected_response - @patch.dict(taskprocess_config, {"enabled": True}) + @pytest.mark.usefixtures("taskreprocess_enabled") def test_cant_reprocess(self): task = Task() invalid_status = ( @@ -70,7 +78,7 @@ def test_cant_reprocess(self): TASK_RUNNING, TASK_DISTRIBUTED, ) - patch_target = "lib.cuckoo.core.database.Database.view_task" + patch_target = "lib.cuckoo.core.database._Database.view_task" with patch(patch_target, return_value=task): for status in invalid_status: expected_response = {"error": True, "error_value": f"Task ID 1 cannot be reprocessed in status {status}"} diff --git a/utils/cleaners.py b/utils/cleaners.py index 927fcbf9f58..a456edce1ae 100644 --- a/utils/cleaners.py +++ b/utils/cleaners.py @@ -9,6 +9,7 @@ sys.path.append(CUCKOO_ROOT) from lib.cuckoo.common.cleaners_utils import execute_cleanup +from lib.cuckoo.core.database import init_database if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -93,4 +94,5 @@ "-bt", "--before-time", help="Manage all pending jobs before N hours.", action="store", required=False, type=int ) args = parser.parse_args() + init_database() execute_cleanup(vars(args)) diff --git a/utils/db_migration/env.py b/utils/db_migration/env.py index 74eb881a210..5bd4dd969e2 100644 --- a/utils/db_migration/env.py +++ b/utils/db_migration/env.py @@ -21,10 +21,10 @@ curdir = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(curdir, "..", "..")) -from lib.cuckoo.core.database import Base, Database +from lib.cuckoo.core.database import Base, _Database # Get database connection string from cuckoo configuration. -url = Database(schema_check=False).engine.url.__to_string__(hide_password=False) +url = _Database(schema_check=False).engine.url.__to_string__(hide_password=False) target_metadata = Base.metadata diff --git a/utils/db_migration/versions/from_1_1_to_1_2-extend_file_type.py b/utils/db_migration/versions/from_1_1_to_1_2-extend_file_type.py index 18c03332837..7f5577cdce7 100644 --- a/utils/db_migration/versions/from_1_1_to_1_2-extend_file_type.py +++ b/utils/db_migration/versions/from_1_1_to_1_2-extend_file_type.py @@ -55,7 +55,7 @@ def _perform(upgrade): "postgresql": "tasks_sample_id_fkey", } - fkey = fkey_name.get(db.Database(schema_check=False).engine.name) + fkey = fkey_name.get(db._Database(schema_check=False).engine.name) # First drop the foreign key. if fkey: diff --git a/utils/dist.py b/utils/dist.py index 41b659e8369..979cb45f1f9 100644 --- a/utils/dist.py +++ b/utils/dist.py @@ -52,6 +52,7 @@ Database, ) from lib.cuckoo.core.database import Task as MD_Task +from lib.cuckoo.core.database import init_database dist_conf = Config("distributed") main_server_name = dist_conf.distributed.get("main_server_name", "master") @@ -1554,6 +1555,7 @@ def init_logging(debug=False): args = p.parse_args() log = init_logging(args.debug) + init_database() if args.enable_clean: cron_cleaner(args.clean_hours) @@ -1595,6 +1597,7 @@ def init_logging(debug=False): app.run(host=args.host, port=args.port, debug=args.debug, use_reloader=False) else: + init_database(exists_ok=True) app = create_app(database_connection=dist_conf.distributed.db) # this allows run it with gunicorn/uwsgi diff --git a/utils/process.py b/utils/process.py index b1087929d77..23583fe6749 100644 --- a/utils/process.py +++ b/utils/process.py @@ -37,7 +37,7 @@ from lib.cuckoo.common.constants import CUCKOO_ROOT from lib.cuckoo.common.path_utils import path_delete, path_exists, path_mkdir from lib.cuckoo.common.utils import free_space_monitor, get_options -from lib.cuckoo.core.database import TASK_COMPLETED, TASK_FAILED_PROCESSING, TASK_REPORTED, Database, Task +from lib.cuckoo.core.database import TASK_COMPLETED, TASK_FAILED_PROCESSING, TASK_REPORTED, Database, Task, init_database from lib.cuckoo.core.plugins import RunProcessing, RunReporting, RunSignatures from lib.cuckoo.core.startup import ConsoleHandler, check_linux_dist, init_modules @@ -119,7 +119,8 @@ def process( if memory_debugging: gc.collect() log.info("(2) GC object counts: %d, %d", len(gc.get_objects()), len(gc.garbage)) - RunProcessing(task=task_dict, results=results).run() + with db.session.begin(): + RunProcessing(task=task_dict, results=results).run() if memory_debugging: gc.collect() log.info("(3) GC object counts: %d, %d", len(gc.get_objects()), len(gc.garbage)) @@ -136,7 +137,8 @@ def process( reprocess = report RunReporting(task=task.to_dict(), results=results, reprocess=reprocess).run() - Database().set_status(task_id, TASK_REPORTED) + with db.session.begin(): + db.set_status(task_id, TASK_REPORTED) if auto: # Is ok to delete original file, but we need to lookup on delete_bin_copy if no more pendings tasks @@ -145,8 +147,11 @@ def process( if cfg.cuckoo.delete_bin_copy: copy_path = os.path.join(CUCKOO_ROOT, "storage", "binaries", sample_sha256) - if path_exists(copy_path) and not db.sample_still_used(sample_sha256, task_id): - path_delete(copy_path) + if path_exists(copy_path): + with db.session.begin(): + is_still_used = db.sample_still_used(sample_sha256, task_id) + if not is_still_used: + path_delete(copy_path) if memory_debugging: gc.collect() @@ -162,6 +167,8 @@ def process( def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN) + # See https://docs.sqlalchemy.org/en/14/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork + db.engine.dispose(close=False) def get_formatter_fmt(task_id=None, main_task_id=None): @@ -181,7 +188,6 @@ def set_formatter_fmt(task_id=None, main_task_id=None): def init_logging(tid=0, debug=False): - # Pyattck creates root logger which we don't want. So we must use this dirty hack to remove it # If basicConfig was already called by something and had a StreamHandler added, # replace it with a ConsoleHandler. @@ -248,18 +254,19 @@ def init_logging(tid=0, debug=False): def processing_finished(future): task_id = pending_future_map.get(future) - try: - _ = future.result() - log.info("Reports generation completed") - except TimeoutError as error: - log.error("Processing Timeout %s. Function: %s", error, error.args[1]) - Database().set_status(task_id, TASK_FAILED_PROCESSING) - except pebble.ProcessExpired as error: - log.error("Exception when processing task: %s", error, exc_info=True) - Database().set_status(task_id, TASK_FAILED_PROCESSING) - except Exception as error: - log.error("Exception when processing task: %s", error, exc_info=True) - Database().set_status(task_id, TASK_FAILED_PROCESSING) + with db.session.begin(): + try: + _ = future.result() + log.info("Reports generation completed") + except TimeoutError as error: + log.error("Processing Timeout %s. Function: %s", error, error.args[1]) + db.set_status(task_id, TASK_FAILED_PROCESSING) + except pebble.ProcessExpired as error: + log.error("Exception when processing task: %s", error, exc_info=True) + db.set_status(task_id, TASK_FAILED_PROCESSING) + except Exception as error: + log.error("Exception when processing task: %s", error, exc_info=True) + db.set_status(task_id, TASK_FAILED_PROCESSING) pending_future_map.pop(future) pending_task_id_map.pop(task_id) @@ -280,7 +287,6 @@ def autoprocess( with pebble.ProcessPool(max_workers=parallel, max_tasks=maxtasksperchild, initializer=init_worker) as pool: # CAUTION - big ugly loop ahead. while count < maxcount or not maxcount: - # If not enough free disk space is available, then we print an # error message and wait another round (this check is ignored # when the freespace configuration variable is set to zero). @@ -294,10 +300,14 @@ def autoprocess( if len(pending_task_id_map) >= parallel: time.sleep(5) continue - if failed_processing: - tasks = db.list_tasks(status=TASK_FAILED_PROCESSING, limit=parallel, order_by=Task.completed_on.asc()) - else: - tasks = db.list_tasks(status=TASK_COMPLETED, limit=parallel, order_by=Task.completed_on.asc()) + with db.session.begin(): + if failed_processing: + tasks = db.list_tasks(status=TASK_FAILED_PROCESSING, limit=parallel, order_by=Task.completed_on.asc()) + else: + tasks = db.list_tasks(status=TASK_COMPLETED, limit=parallel, order_by=Task.completed_on.asc()) + # Make sure the tasks are available as normal objects after the transaction ends, so that + # sqlalchemy doesn't auto-initiate a new transaction the next time they are accessed. + db.session.expunge_all() added = False # For loop to add only one, nice. (reason is that we shouldn't overshoot maxcount) for task in tasks: @@ -308,9 +318,10 @@ def autoprocess( log.info("Processing analysis data for Task #%d", task.id) sample_hash = "" if task.category != "url": - sample = db.view_sample(task.sample_id) - if sample: - sample_hash = sample.sha256 + with db.session.begin(): + sample = db.view_sample(task.sample_id) + if sample: + sample_hash = sample.sha256 args = task.target, sample_hash kwargs = dict(report=True, auto=True, task=task, memory_debugging=memory_debugging, debug=debug) @@ -352,7 +363,6 @@ def autoprocess( def _load_report(task_id: int): - if repconf.mongodb.enabled: analysis = mongo_find_one("analysis", {"info.id": task_id}, sort=[("_id", -1)]) for process in analysis.get("behavior", {}).get("processes", []): @@ -459,6 +469,7 @@ def main(): ) args = parser.parse_args() + init_database() init_modules() if args.id == "auto": autoprocess( @@ -477,18 +488,22 @@ def main(): if not path_exists(os.path.join(CUCKOO_ROOT, "storage", "analyses", str(num))): print(red(f"\n[{num}] Analysis folder doesn't exist anymore\n")) continue - task = Database().view_task(num) - if task is None: - task = {"id": num, "target": None} - print("Task not in database") - else: - # Add sample lookup as we point to sample from TMP. Case when delete_original=on - if not path_exists(task.target): - samples = Database().sample_path_by_hash(task_id=task.id) - for sample in samples: - if path_exists(sample): - task.__setattr__("target", sample) - break + with db.session.begin(): + task = db.view_task(num) + if task is None: + task = {"id": num, "target": None} + print("Task not in database") + else: + # Add sample lookup as we point to sample from TMP. Case when delete_original=on + if not path_exists(task.target): + samples = db.sample_path_by_hash(task_id=task.id) + for sample in samples: + if path_exists(sample): + task.__setattr__("target", sample) + break + # Make sure that SQLAlchemy doesn't auto-begin a new transaction the next time + # these objects are accessed. + db.session.expunge_all() if args.signatures: report = False diff --git a/utils/sample_path.py b/utils/sample_path.py index 2e11c9e2b85..519e51c14f3 100644 --- a/utils/sample_path.py +++ b/utils/sample_path.py @@ -7,17 +7,17 @@ from lib.cuckoo.common.config import Config from lib.cuckoo.common.path_utils import path_exists -from lib.cuckoo.core.database import Database +from lib.cuckoo.core.database import Database, init_database repconf = Config("reporting") if "__main__" == __name__: - parser = argparse.ArgumentParser() parser.add_argument("--hash", help="Hash to lookup", default=None, action="store", required=False) parser.add_argument("--id", help="Get hash by sample_id from task", default=None, action="store", required=False) args = parser.parse_args() + init_database() paths = Database().sample_path_by_hash(sample_hash=args.hash, task_id=args.id) if paths: paths = [path for path in paths if path_exists(path)] diff --git a/utils/submit.py b/utils/submit.py index a2ed86863f0..d3cab9e76c1 100644 --- a/utils/submit.py +++ b/utils/submit.py @@ -23,7 +23,7 @@ from lib.cuckoo.common.objects import File from lib.cuckoo.common.path_utils import path_exists from lib.cuckoo.common.utils import sanitize_filename, store_temp_file, to_unicode -from lib.cuckoo.core.database import Database +from lib.cuckoo.core.database import Database, init_database from lib.cuckoo.core.startup import check_user_permissions check_user_permissions(os.getenv("CAPE_AS_ROOT", False)) @@ -131,6 +131,7 @@ def main(): if args.quiet: logging.disable(logging.WARNING) + init_database() db = Database() target = to_unicode(args.target) diff --git a/web/apiv2/views.py b/web/apiv2/views.py index 253db68005c..5ebc3c4abbf 100644 --- a/web/apiv2/views.py +++ b/web/apiv2/views.py @@ -53,7 +53,7 @@ statistics, validate_task, ) -from lib.cuckoo.core.database import TASK_COMPLETED, TASK_RECOVERED, TASK_RUNNING, Database, Task +from lib.cuckoo.core.database import TASK_RECOVERED, TASK_RUNNING, Database, Task, _Database from lib.cuckoo.core.rooter import _load_socks5_operational, vpns try: @@ -99,7 +99,7 @@ es_as_db = True es = elastic_handler -db = Database() +db: _Database = Database() # Conditional decorator for web authentication @@ -1019,7 +1019,6 @@ def tasks_reprocess(request, task_id): if error: return Response({"error": True, "error_value": msg}) - db.set_status(task_id, TASK_COMPLETED) return Response({"error": error, "data": f"Task ID {task_id} with status {task_status} marked for reprocessing"}) diff --git a/web/submission/views.py b/web/submission/views.py index a41b78bfff7..777c6e31376 100644 --- a/web/submission/views.py +++ b/web/submission/views.py @@ -23,12 +23,12 @@ from lib.cuckoo.common.saztopcap import saz_to_pcap from lib.cuckoo.common.utils import get_options, get_user_filename, sanitize_filename, store_temp_file from lib.cuckoo.common.web_utils import ( - all_nodes_exits_list, - all_vms_tags, download_file, download_from_bazaar, download_from_vt, get_file_content, + load_vms_exits, + load_vms_tags, parse_request_arguments, perform_search, process_new_dlnexec_task, @@ -122,7 +122,6 @@ def get_platform(magic): def index(request, task_id=None, resubmit_hash=None): remote_console = False if request.method == "POST": - ( static, package, @@ -530,6 +529,8 @@ def index(request, task_id=None, resubmit_hash=None): enabledconf["pre_script"] = web_conf.pre_script.enabled enabledconf["during_script"] = web_conf.during_script.enabled + all_vms_tags = load_vms_tags() + if all_vms_tags: enabledconf["tags"] = True @@ -618,9 +619,9 @@ def index(request, task_id=None, resubmit_hash=None): "tor": routing.tor.enabled, "config": enabledconf, "resubmit": resubmit_hash, - "tags": sorted(list(set(all_vms_tags))), + "tags": all_vms_tags, "existent_tasks": existent_tasks, - "all_exitnodes": all_nodes_exits_list, + "all_exitnodes": list(sorted(load_vms_exits())), }, ) diff --git a/web/web/middleware/__init__.py b/web/web/middleware/__init__.py new file mode 100644 index 00000000000..45e59dff342 --- /dev/null +++ b/web/web/middleware/__init__.py @@ -0,0 +1,2 @@ +from .custom_auth import CustomAuth # noqa +from .db_transaction import DBTransactionMiddleware # noqa diff --git a/web/web/middleware.py b/web/web/middleware/custom_auth.py similarity index 100% rename from web/web/middleware.py rename to web/web/middleware/custom_auth.py diff --git a/web/web/middleware/db_transaction.py b/web/web/middleware/db_transaction.py new file mode 100644 index 00000000000..c4700ee9fe2 --- /dev/null +++ b/web/web/middleware/db_transaction.py @@ -0,0 +1,10 @@ +from lib.cuckoo.core.database import Database + + +class DBTransactionMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + with Database().session.begin(): + return self.get_response(request) diff --git a/web/web/settings.py b/web/web/settings.py index 61d2684c284..97ead65cdce 100644 --- a/web/web/settings.py +++ b/web/web/settings.py @@ -218,8 +218,9 @@ #'web.middleware.ExceptionMiddleware', "django.contrib.auth.middleware.AuthenticationMiddleware", # 'django_otp.middleware.OTPMiddleware', - # in case you want custom auth, place logic in web/web/middleware.py + # in case you want custom auth, place logic in web/web/middleware/custom_auth.py # "web.middleware.CustomAuth", + "web.middleware.DBTransactionMiddleware", ] OTP_TOTP_ISSUER = "CAPE Sandbox" @@ -507,3 +508,7 @@ except NameError: with suppress(ImportError): from .local_settings import * # noqa: F403 + +from lib.cuckoo.core.database import init_database + +init_database()