diff --git a/modules/machinery/az.py b/modules/machinery/az.py index 2074e242c3e..fdc1b7f3934 100644 --- a/modules/machinery/az.py +++ b/modules/machinery/az.py @@ -14,12 +14,13 @@ from azure.identity import CertificateCredential, ClientSecretCredential from azure.mgmt.compute import ComputeManagementClient, models from azure.mgmt.network import NetworkManagementClient + from azure.mgmt.compute.models import InstanceViewTypes from msrest.polling import LROPoller HAVE_AZURE = True except ImportError: HAVE_AZURE = False - print("Missing machinery-required libraries.") + print("Missing AZURE machinery-required libraries.") print("poetry run python -m pip install azure-identity msrest msrestazure azure-mgmt-compute azure-mgmt-network") # Cuckoo-specific imports @@ -49,7 +50,7 @@ log = logging.getLogger(__name__) # Timeout used for calls that shouldn't take longer than 5 minutes but somehow do -AZURE_TIMEOUT = 120 +AZURE_TIMEOUT = 220 # Global variable which will maintain details about each machine pool machine_pools = {} @@ -172,10 +173,63 @@ def _initialize_check(self): # Initialize the VMSSs that we will be using and not using self._set_vmss_stage() - + self._wait_for_vms_running() # Set the flag that indicates that the system is not initializing self.initializing = False + def _are_all_vms_running(self, vmss_name): + """ + Check if all VMs in a specific VMSS are in 'Running' state + """ + try: + vms = Azure._azure_api_call( + self.options.az.sandbox_resource_group, + vmss_name, + operation=self.compute_client.virtual_machine_scale_set_vms.list, + ) + + for vm in vms: + instance_view = Azure._azure_api_call( + self.options.az.sandbox_resource_group, + vmss_name, + vm.instance_id, + expand=InstanceViewTypes.instance_view, + operation=self.compute_client.virtual_machine_scale_set_vms.get, + ) + + power_state = next((status for status in instance_view.instance_view.statuses if status.code.startswith("PowerState/")), None) + if not power_state or power_state.code != "PowerState/running": + return False + + return True + except Exception as e: + log.error(f"Error checking VM status in VMSS {vmss_name}: {str(e)}") + return False + + def _wait_for_vms_running(self): + """ + Wait for all VMs in all VMSSs to reach 'Running' state before completing initialization + """ + max_wait_time = 30 * 60 # 30 minutes maximum wait time + start_time = time.time() + + while True: + all_running = True + for vmss_name in self.options.az.scale_sets: + if not self._are_all_vms_running(vmss_name): + all_running = False + break + + if all_running: + log.info("All VMs are in 'Running' state. Initialization complete.") + break + + if time.time() - start_time > max_wait_time: + raise CuckooCriticalError("Timeout waiting for all VMs to reach 'Running' state.") + + log.info("Waiting for all VMs to reach 'Running' state...") + time.sleep(30) # Wait for 30 seconds before checking again + def _get_credentials(self): """ Used to instantiate the Azure ClientSecretCredential object. @@ -482,6 +536,25 @@ def availables(self, label=None, platform=None, tags=None, arch=None, include_re label=label, platform=platform, tags=tags, arch=arch, include_reserved=include_reserved, os_version=os_version ) + def _is_vm_running(self, vmss_name, instance_id): + """ + Check if a specific VM is in 'Running' state + """ + try: + instance_view = Azure._azure_api_call( + self.options.az.sandbox_resource_group, + vmss_name, + instance_id, + expand=InstanceViewTypes.instance_view, + operation=self.compute_client.virtual_machine_scale_set_vms.get, + ) + + power_state = next((status for status in instance_view.instance_view.statuses if status.code.startswith("PowerState/")), None) + return power_state and power_state.code == "PowerState/running" + except Exception as e: + log.error(f"Error checking VM {instance_id} status in VMSS {vmss_name}: {str(e)}") + return False + def _add_machines_to_db(self, vmss_name): """ Adding machines to database that did not exist there before. @@ -553,21 +626,27 @@ def _add_machines_to_db(self, vmss_name): log.error(f"The IP '{private_ip}' is already associated with a machine in the DB. Moving on...") continue - # Add machine to DB. - # TODO: What is the point of name vs label? - self.db.add_machine( - name=vmss_vm.name, - label=vmss_vm.name, - ip=private_ip, - platform=platform, - tags=self.options.az.scale_sets[vmss_name].pool_tag, - arch=self.options.az.scale_sets[vmss_name].arch, - interface=self.options.az.interface, - snapshot=vmss_vm.storage_profile.image_reference.id, - resultserver_ip=self.options.az.resultserver_ip, - resultserver_port=self.options.az.resultserver_port, - reserved=False, - ) + # Check if the VM is in 'Running' state before adding to DB + if self._is_vm_running(vmss_name, vmss_vm.instance_id): + # Add machine to DB. + # TODO: What is the point of name vs label? + self.db.add_machine( + name=vmss_vm.name, + label=vmss_vm.name, + ip=private_ip, + platform=platform, + tags=self.options.az.scale_sets[vmss_name].pool_tag, + arch=self.options.az.scale_sets[vmss_name].arch, + interface=self.options.az.interface, + snapshot=vmss_vm.storage_profile.image_reference.id, + resultserver_ip=self.options.az.resultserver_ip, + resultserver_port=self.options.az.resultserver_port, + reserved=False, + ) + log.info(f"Added machine {vmss_vm.name} to the database.") + else: + log.warning(f"VM {vmss_vm.name} is not in 'Running' state. Skipping DB addition.") + # When we aren't initializing the system, the machine will immediately become available in DB # When we are initializing, we're going to wait for the machine to be have the Cuckoo agent all set up if self.initializing and self.options.az.wait_for_agent_before_starting: @@ -723,7 +802,7 @@ def _thr_create_vmss(self, vmss_name, vmss_image_ref, vmss_image_os): os_disk=vmss_os_disk, ) vmss_dns_settings = models.VirtualMachineScaleSetNetworkConfigurationDnsSettings( - dns_servers=self.options.az.dns_server_ips.strip().split(",") + dns_servers=[self.options.az.resultserver_ip] ) vmss_ip_config = models.VirtualMachineScaleSetIPConfiguration( name="vmss_ip_config", @@ -1059,24 +1138,74 @@ def _scale_machine_pool(self, tag, per_platform=False): log.error(repr(exc), exc_info=True) log.debug(f"Scaling {vmss_name} has completed with errors {exc!r}.") - @staticmethod - def _handle_poller_result(lro_poller_object): + + def _handle_poller_result(self, lro_poller_object): """ Provides method of handling Azure tasks that take too long to complete @param lro_poller_object: An LRO Poller Object for an Async Azure Task """ - start_time = timeit.default_timer() - # TODO: Azure disregards the timeout passed to it in most cases, unless it has a custom poller + start_time = time.time() + max_timeout = 1800 # 30 minutes maximum timeout + base_timeout = 280 # Initial timeout, same as the previous fixed timeout + + while True: + try: + # Try to get the result with a short timeout + lro_poller_result = lro_poller_object.result(timeout=30) + # time_taken = time.time() - start_time + return lro_poller_result + except Exception as e: + log.error(e) + current_time = time.time() + time_elapsed = current_time - start_time + + if time_elapsed >= max_timeout: + raise CuckooMachineError(f"The task took {round(time_elapsed)}s to complete! Exceeded maximum timeout.") + + # Check if machines are still upgrading + if self._are_machines_still_upgrading(lro_poller_object): + # If machines are still upgrading, wait longer + time.sleep(30) # Wait for 30 seconds before checking again + else: + # If machines are not upgrading, use the base timeout + if time_elapsed >= base_timeout: + raise CuckooMachineError(f"The task took {round(time_elapsed)}s to complete! Base timeout exceeded.") + time.sleep(10) # Wait for 10 seconds before checking again + + def _are_machines_still_upgrading(self, lro_poller_object): + """ + Check if any machines in the scale set are still upgrading + """ try: - lro_poller_result = lro_poller_object.result(timeout=AZURE_TIMEOUT) + # Assuming the lro_poller_object contains information about the scale set + scale_set_name = lro_poller_object.resource.name + resource_group_name = self.options.az.sandbox_resource_group + + # Get the list of VMs in the scale set + vms = self.compute_client.virtual_machine_scale_set_vms.list( + resource_group_name, + scale_set_name + ) + + for vm in vms: + # Get the instance view of the VM + instance_view = self.compute_client.virtual_machine_scale_set_vms.get( + resource_group_name, + scale_set_name, + vm.instance_id, + expand=InstanceViewTypes.instance_view + ).instance_view + + # Check if the VM is in an upgrading state + for status in instance_view.statuses: + if status.code.startswith("PowerState/") and status.code != "PowerState/running": + return True # Found a VM that's not in the running state + + return False # All VMs are in the running state except Exception as e: - raise CuckooMachineError(repr(e)) - time_taken = timeit.default_timer() - start_time - if time_taken >= AZURE_TIMEOUT: - raise CuckooMachineError(f"The task took {round(time_taken)}s to complete! Bad Azure!") - else: - return lro_poller_result - + log.error(f"Error checking machine upgrade status: {str(e)}") + return True # Assume upgrading if we can't check, to be safe + def _get_number_of_relevant_tasks(self, tag, platform=None): """ Returns the number of relevant tasks for a tag or platform