Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic timeout for Azure operations in machinery module #2233

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 160 additions & 31 deletions modules/machinery/az.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from azure.identity import CertificateCredential, ClientSecretCredential
from azure.mgmt.compute import ComputeManagementClient, models
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.compute.models import InstanceViewTypes
from msrest.polling import LROPoller

HAVE_AZURE = True
except ImportError:
HAVE_AZURE = False
print("Missing machinery-required libraries.")
print("Missing AZURE machinery-required libraries.")
print("poetry run python -m pip install azure-identity msrest msrestazure azure-mgmt-compute azure-mgmt-network")

# Cuckoo-specific imports
Expand Down Expand Up @@ -49,7 +50,7 @@
log = logging.getLogger(__name__)

# Timeout used for calls that shouldn't take longer than 5 minutes but somehow do
AZURE_TIMEOUT = 120
AZURE_TIMEOUT = 220

# Global variable which will maintain details about each machine pool
machine_pools = {}
Expand Down Expand Up @@ -172,10 +173,63 @@ def _initialize_check(self):

# Initialize the VMSSs that we will be using and not using
self._set_vmss_stage()

self._wait_for_vms_running()
# Set the flag that indicates that the system is not initializing
self.initializing = False

def _are_all_vms_running(self, vmss_name):
"""
Check if all VMs in a specific VMSS are in 'Running' state
"""
try:
vms = Azure._azure_api_call(
self.options.az.sandbox_resource_group,
vmss_name,
operation=self.compute_client.virtual_machine_scale_set_vms.list,
)

for vm in vms:
instance_view = Azure._azure_api_call(
self.options.az.sandbox_resource_group,
vmss_name,
vm.instance_id,
expand=InstanceViewTypes.instance_view,
operation=self.compute_client.virtual_machine_scale_set_vms.get,
)

power_state = next((status for status in instance_view.instance_view.statuses if status.code.startswith("PowerState/")), None)
if not power_state or power_state.code != "PowerState/running":
return False

return True
except Exception as e:
log.error(f"Error checking VM status in VMSS {vmss_name}: {str(e)}")
return False

def _wait_for_vms_running(self):
"""
Wait for all VMs in all VMSSs to reach 'Running' state before completing initialization
"""
max_wait_time = 30 * 60 # 30 minutes maximum wait time
start_time = time.time()

while True:
all_running = True
for vmss_name in self.options.az.scale_sets:
if not self._are_all_vms_running(vmss_name):
all_running = False
break

if all_running:
log.info("All VMs are in 'Running' state. Initialization complete.")
break

if time.time() - start_time > max_wait_time:
raise CuckooCriticalError("Timeout waiting for all VMs to reach 'Running' state.")

log.info("Waiting for all VMs to reach 'Running' state...")
time.sleep(30) # Wait for 30 seconds before checking again

def _get_credentials(self):
"""
Used to instantiate the Azure ClientSecretCredential object.
Expand Down Expand Up @@ -482,6 +536,25 @@ def availables(self, label=None, platform=None, tags=None, arch=None, include_re
label=label, platform=platform, tags=tags, arch=arch, include_reserved=include_reserved, os_version=os_version
)

def _is_vm_running(self, vmss_name, instance_id):
"""
Check if a specific VM is in 'Running' state
"""
try:
instance_view = Azure._azure_api_call(
self.options.az.sandbox_resource_group,
vmss_name,
instance_id,
expand=InstanceViewTypes.instance_view,
operation=self.compute_client.virtual_machine_scale_set_vms.get,
)

power_state = next((status for status in instance_view.instance_view.statuses if status.code.startswith("PowerState/")), None)
return power_state and power_state.code == "PowerState/running"
except Exception as e:
log.error(f"Error checking VM {instance_id} status in VMSS {vmss_name}: {str(e)}")
return False

def _add_machines_to_db(self, vmss_name):
"""
Adding machines to database that did not exist there before.
Expand Down Expand Up @@ -553,21 +626,27 @@ def _add_machines_to_db(self, vmss_name):
log.error(f"The IP '{private_ip}' is already associated with a machine in the DB. Moving on...")
continue

# Add machine to DB.
# TODO: What is the point of name vs label?
self.db.add_machine(
name=vmss_vm.name,
label=vmss_vm.name,
ip=private_ip,
platform=platform,
tags=self.options.az.scale_sets[vmss_name].pool_tag,
arch=self.options.az.scale_sets[vmss_name].arch,
interface=self.options.az.interface,
snapshot=vmss_vm.storage_profile.image_reference.id,
resultserver_ip=self.options.az.resultserver_ip,
resultserver_port=self.options.az.resultserver_port,
reserved=False,
)
# Check if the VM is in 'Running' state before adding to DB
if self._is_vm_running(vmss_name, vmss_vm.instance_id):
# Add machine to DB.
# TODO: What is the point of name vs label?
self.db.add_machine(
name=vmss_vm.name,
label=vmss_vm.name,
ip=private_ip,
platform=platform,
tags=self.options.az.scale_sets[vmss_name].pool_tag,
arch=self.options.az.scale_sets[vmss_name].arch,
interface=self.options.az.interface,
snapshot=vmss_vm.storage_profile.image_reference.id,
resultserver_ip=self.options.az.resultserver_ip,
resultserver_port=self.options.az.resultserver_port,
reserved=False,
)
log.info(f"Added machine {vmss_vm.name} to the database.")
else:
log.warning(f"VM {vmss_vm.name} is not in 'Running' state. Skipping DB addition.")

# When we aren't initializing the system, the machine will immediately become available in DB
# When we are initializing, we're going to wait for the machine to be have the Cuckoo agent all set up
if self.initializing and self.options.az.wait_for_agent_before_starting:
Expand Down Expand Up @@ -723,7 +802,7 @@ def _thr_create_vmss(self, vmss_name, vmss_image_ref, vmss_image_os):
os_disk=vmss_os_disk,
)
vmss_dns_settings = models.VirtualMachineScaleSetNetworkConfigurationDnsSettings(
dns_servers=self.options.az.dns_server_ips.strip().split(",")
dns_servers=[self.options.az.resultserver_ip]
)
vmss_ip_config = models.VirtualMachineScaleSetIPConfiguration(
name="vmss_ip_config",
Expand Down Expand Up @@ -1059,24 +1138,74 @@ def _scale_machine_pool(self, tag, per_platform=False):
log.error(repr(exc), exc_info=True)
log.debug(f"Scaling {vmss_name} has completed with errors {exc!r}.")

@staticmethod
def _handle_poller_result(lro_poller_object):

def _handle_poller_result(self, lro_poller_object):
"""
Provides method of handling Azure tasks that take too long to complete
@param lro_poller_object: An LRO Poller Object for an Async Azure Task
"""
start_time = timeit.default_timer()
# TODO: Azure disregards the timeout passed to it in most cases, unless it has a custom poller
start_time = time.time()
max_timeout = 1800 # 30 minutes maximum timeout
base_timeout = 280 # Initial timeout, same as the previous fixed timeout

while True:
try:
# Try to get the result with a short timeout
lro_poller_result = lro_poller_object.result(timeout=30)
# time_taken = time.time() - start_time
return lro_poller_result
except Exception as e:
log.error(e)
current_time = time.time()
time_elapsed = current_time - start_time

if time_elapsed >= max_timeout:
raise CuckooMachineError(f"The task took {round(time_elapsed)}s to complete! Exceeded maximum timeout.")

# Check if machines are still upgrading
if self._are_machines_still_upgrading(lro_poller_object):
# If machines are still upgrading, wait longer
time.sleep(30) # Wait for 30 seconds before checking again
else:
# If machines are not upgrading, use the base timeout
if time_elapsed >= base_timeout:
raise CuckooMachineError(f"The task took {round(time_elapsed)}s to complete! Base timeout exceeded.")
time.sleep(10) # Wait for 10 seconds before checking again

def _are_machines_still_upgrading(self, lro_poller_object):
"""
Check if any machines in the scale set are still upgrading
"""
try:
lro_poller_result = lro_poller_object.result(timeout=AZURE_TIMEOUT)
# Assuming the lro_poller_object contains information about the scale set
scale_set_name = lro_poller_object.resource.name
resource_group_name = self.options.az.sandbox_resource_group

# Get the list of VMs in the scale set
vms = self.compute_client.virtual_machine_scale_set_vms.list(
resource_group_name,
scale_set_name
)

for vm in vms:
# Get the instance view of the VM
instance_view = self.compute_client.virtual_machine_scale_set_vms.get(
resource_group_name,
scale_set_name,
vm.instance_id,
expand=InstanceViewTypes.instance_view
).instance_view

# Check if the VM is in an upgrading state
for status in instance_view.statuses:
if status.code.startswith("PowerState/") and status.code != "PowerState/running":
return True # Found a VM that's not in the running state

return False # All VMs are in the running state
except Exception as e:
raise CuckooMachineError(repr(e))
time_taken = timeit.default_timer() - start_time
if time_taken >= AZURE_TIMEOUT:
raise CuckooMachineError(f"The task took {round(time_taken)}s to complete! Bad Azure!")
else:
return lro_poller_result

log.error(f"Error checking machine upgrade status: {str(e)}")
return True # Assume upgrading if we can't check, to be safe

def _get_number_of_relevant_tasks(self, tag, platform=None):
"""
Returns the number of relevant tasks for a tag or platform
Expand Down