Skip to content

Commit

Permalink
Move validations around
Browse files Browse the repository at this point in the history
  • Loading branch information
abuabraham-ttd committed Dec 10, 2024
1 parent 937e7a2 commit 2b23ff0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 45 deletions.
15 changes: 5 additions & 10 deletions scripts/aws/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ def __get_current_region(self) -> str:
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch region: {e}")

def __validate_configs(self, secret):
required_keys = ["api_token", "environment", "core_base_url", "optout_base_url"]
missing_keys = [key for key in required_keys if key not in secret]
if missing_keys:
raise ConfidentialComputeMissingConfigError(missing_keys)
def __validate_ec2_specific_config(self, secret):
if "enclave_memory_mb" in secret or "enclave_cpu_count" in secret:
max_capacity = self.__get_max_capacity()
for key in ["enclave_memory_mb", "enclave_cpu_count"]:
Expand All @@ -63,12 +59,12 @@ def _get_secret(self, secret_identifier: str) -> ConfidentialComputeConfig:
try:
client = boto3.client("secretsmanager", region_name=region)
except Exception as e:
raise RuntimeError("Please specify AWS secrets as env values, or use IAM instance profile for your instance")
raise RuntimeError("Please use IAM instance profile for your instance that has permission to access Secret Manager")
try:
secret = json.loads(client.get_secret_value(SecretId=secret_identifier)["SecretString"])
self.__validate_configs(secret)
self.__validate_ec2_specific_config(secret)
return self.__add_defaults(secret)
except ClientError as e:
except ClientError as _:
raise SecretNotFoundException(f"{secret_identifier} in {region}")

@staticmethod
Expand Down Expand Up @@ -144,7 +140,7 @@ def _setup_auxiliaries(self) -> None:

def _validate_auxiliaries(self) -> None:
"""Validates auxiliary services."""
self.validate_operator_key()
self.validate_configuration()
proxy = "socks5://127.0.0.1:3306"
config_url = "http://127.0.0.1:27015/getConfig"
try:
Expand All @@ -163,7 +159,6 @@ def run_compute(self) -> None:
"""Main execution flow for confidential compute."""
self._setup_auxiliaries()
self._validate_auxiliaries()
self.validate_connectivity()
command = [
"nitro-cli", "run-enclave",
"--eif-path", "/opt/uid2operator/uid2operator.eif",
Expand Down
78 changes: 43 additions & 35 deletions scripts/confidential_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,25 @@ class ConfidentialCompute(ABC):
def __init__(self):
self.configs: ConfidentialComputeConfig = {}

def validate_environment(self):
def validate_configuration(self):
""" Validates the paramters specified through configs/secret manager ."""

def validate_operator_key():
""" Validates the operator key format and its environment alignment."""
operator_key = self.configs.get("api_token")
if not operator_key:
raise ValueError("API token is missing from the configuration.")
pattern = r"^(UID2|EUID)-.\-(I|P)-\d+-\*$"
if re.match(pattern, operator_key):
env = self.configs.get("environment", "").lower()
debug_mode = self.configs.get("debug_mode", False)
expected_env = "I" if debug_mode or env == "integ" else "P"
if operator_key.split("-")[2] != expected_env:
raise ValueError(
f"Operator key does not match the expected environment ({expected_env})."
)
return True

def validate_url(url_key, environment):
"""URL should include environment except in prod"""
if environment != "prod" and environment not in self.configs[url_key]:
Expand All @@ -33,48 +51,38 @@ def validate_url(url_key, environment):
f"{url_key} is invalid. Ensure {self.configs[url_key]} follows HTTPS, and doesn't have any path specified."
)

def validate_connectivity(self) -> None:
""" Validates that the core and opt-out URLs are accessible."""
try:
core_url = self.configs["core_base_url"]
optout_url = self.configs["optout_base_url"]
core_ip = self.__resolve_hostname(core_url)
requests.get(core_url, timeout=5)
optout_ip = self.__resolve_hostname(optout_url)
requests.get(optout_url, timeout=5)
except (requests.ConnectionError, requests.Timeout) as e:
raise Exception(
f"Failed to reach required URLs. Consider enabling {core_ip}, {optout_ip} in the egress firewall."
)
except Exception as e:
raise Exception("Failed to reach the URLs.") from e

required_keys = ["api_token", "environment", "core_base_url", "optout_base_url"]
missing_keys = [key for key in required_keys if key not in self.configs]
if missing_keys:
raise ConfidentialComputeMissingConfigError(missing_keys)

environment = self.configs["environment"]

if self.configs.get("debug_mode") and environment == "prod":
raise ValueError("Debug mode cannot be enabled in the production environment.")

validate_url("core_base_url", environment)
validate_url("optout_base_url", environment)


def validate_operator_key(self):
""" Validates the operator key format and its environment alignment."""
operator_key = self.configs.get("api_token")
if not operator_key:
raise ValueError("API token is missing from the configuration.")
pattern = r"^(UID2|EUID)-.\-(I|P)-\d+-\*$"
if re.match(pattern, operator_key):
env = self.configs.get("environment", "").lower()
debug_mode = self.configs.get("debug_mode", False)
expected_env = "I" if debug_mode or env == "integ" else "P"
if operator_key.split("-")[2] != expected_env:
raise ValueError(
f"Operator key does not match the expected environment ({expected_env})."
)
return True

def validate_connectivity(self) -> None:
""" Validates that the core and opt-out URLs are accessible."""
try:
core_url = self.configs["core_base_url"]
optout_url = self.configs["optout_base_url"]
core_ip = self.__resolve_hostname(core_url)
requests.get(core_url, timeout=5)
optout_ip = self.__resolve_hostname(optout_url)
requests.get(optout_url, timeout=5)
except (requests.ConnectionError, requests.Timeout) as e:
raise Exception(
f"Failed to reach required URLs. Consider enabling {core_ip}, {optout_ip} in the egress firewall."
)
except Exception as e:
raise Exception("Failed to reach the URLs.") from e
validate_operator_key()
validate_connectivity()


@abstractmethod
def _get_secret(self, secret_identifier: str) -> ConfidentialComputeConfig:
"""
Expand Down

0 comments on commit 2b23ff0

Please sign in to comment.