Skip to content

Commit

Permalink
[aws][fix] Handle custom command arguments correctly (#1302)
Browse files Browse the repository at this point in the history
[aws][fix] Handle custom command arguments correctly
  • Loading branch information
aquamatthias authored Nov 22, 2022
1 parent 80d0c78 commit 78045c6
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 46 deletions.
38 changes: 32 additions & 6 deletions plugins/aws/resoto_plugin_aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from resotolib.graph import Graph
from resotolib.logger import log, setup_logger
from resotolib.types import JsonElement, Json
from resotolib.utils import log_runtime, NoExitArgumentParser, chunks
from resotolib.utils import log_runtime, NoExitArgumentParser
from .collector import AwsAccountCollector
from .configuration import AwsConfig
from .resource.base import AwsAccount, AwsResource
Expand Down Expand Up @@ -175,6 +175,36 @@ def adjust_shape(o: str, shape: Optional[Shape]) -> Any:
# map and structure types are currently not supported
raise ValueError(f"Cannot convert {o} to {shape}")

def coerce_args(fn_args: List[str], om: OperationModel) -> Dict[str, Any]:
members: Dict[str, Shape] = om.input_shape.members if isinstance(om.input_shape, StructureShape) else {}
param_name: Optional[str] = None
param_shape: Optional[Shape] = None
arg_val: Dict[str, Any] = {}
for arg in fn_args:
if arg.startswith("--"):
name = pascalcase(arg.removeprefix("--"))
param_name = name
param_shape = members.get(name)
bool_value = True
if param_shape is None and arg.startswith("--no-"):
name = name[2:]
param_name = name
param_shape = members.get(name)
bool_value = False
if param_shape is None:
raise ValueError(f"AWS: Unknown parameter {arg}")
if param_shape.name == "Boolean" or param_shape.type_name == "Boolean":
arg_val[name] = bool_value
param_shape = None
param_name = None
elif param_name is not None:
arg_val[param_name] = adjust_shape(arg, param_shape)
param_name = None
param_shape = None
else:
raise ValueError(f"AWS: Unexpected argument {arg}")
return arg_val

def create_client() -> AwsClient:
role = p.role or cfg.role
region = p.region or (cfg.region[0] if cfg.region else None)
Expand All @@ -189,11 +219,7 @@ def create_client() -> AwsClient:
service_model = client.service_model(p.service)
op: OperationModel = service_model.operation_model(pascalcase(p.operation))
output_shape = op.output_shape.type_name
members: Dict[str, Shape] = op.input_shape.members if isinstance(op.input_shape, StructureShape) else {}
func_args = {}
for arg, arg_value in chunks(remaining, 2):
name = pascalcase(arg.removeprefix("--"))
func_args[name] = adjust_shape(arg_value, members.get(name))
func_args = coerce_args(remaining, op)

result: List[Json] = client.call_single(p.service, p.operation, None, **func_args) # type: ignore
# Remove the "ResponseMetadata" from the result
Expand Down
2 changes: 0 additions & 2 deletions plugins/aws/resoto_plugin_aws/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def call_single(
result.extend(child)
elif child is not None:
result.append(child)
else:
raise AttributeError(f"Expected result under key '{result_name}'")
log.debug(f"[Aws] called service={aws_service} action={action}{arg_info}: {len(result)} results.")
return result
else:
Expand Down
14 changes: 7 additions & 7 deletions plugins/aws/resoto_plugin_aws/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def collect(self) -> None:
# The shared executor is used to parallelize the collection of resources "as fast as possible"
# It should only be used in scenarios, where it is safe to do so.
# This executor is shared between all regions.
shared_queue = ExecutorQueue(executor, self.account.name)
shared_queue = ExecutorQueue(executor, self.account.safe_name)
shared_queue.submit_work(self.update_account)
global_builder = GraphBuilder(
self.graph, self.cloud, self.account, self.global_region, self.client, shared_queue, self.core_feedback
)
global_builder.core_feedback.progress_done(self.global_region.name, 0, 1)
global_builder.core_feedback.progress_done(self.global_region.safe_name, 0, 1)
global_builder.add_node(self.global_region)

log.info(f"[Aws:{self.account.id}] Collect global resources.")
Expand All @@ -130,7 +130,7 @@ def collect(self) -> None:
if self.config.should_collect(resource.kind):
resource.collect_resources(global_builder)
shared_queue.wait_for_submitted_work()
global_builder.core_feedback.progress_done(self.global_region.name, 1, 1)
global_builder.core_feedback.progress_done(self.global_region.safe_name, 1, 1)

log.info(f"[Aws:{self.account.id}] Collect regional resources.")

Expand Down Expand Up @@ -169,7 +169,7 @@ def collect_region(self, region: AwsRegion, regional_builder: GraphBuilder) -> N
def collect_resource(resource: Type[AwsResource], rb: GraphBuilder) -> None:
try:
resource.collect_resources(rb)
log.info(f"[Aws:{self.account.id}:{region.name}] finished collecting: {resource.kind}")
log.info(f"[Aws:{self.account.id}:{region.safe_name}] finished collecting: {resource.kind}")
except ClientError as e:
code = e.response["Error"]["Code"]
if code == "UnauthorizedOperation":
Expand All @@ -186,15 +186,15 @@ def collect_resource(resource: Type[AwsResource], rb: GraphBuilder) -> None:
with ThreadPoolExecutor(
thread_name_prefix=regional_thread_name, max_workers=self.config.region_resources_pool_size
) as executor:
regional_builder.core_feedback.progress_done(region.name, 0, 1)
regional_builder.core_feedback.progress_done(region.safe_name, 0, 1)
# In case an exception is thrown for any resource, we should give up as quick as possible.
queue = ExecutorQueue(executor, region.name, fail_on_first_exception=True)
queue = ExecutorQueue(executor, region.safe_name, fail_on_first_exception=True)
regional_builder.add_node(region)
for res in regional_resources:
if self.config.should_collect(res.kind):
queue.submit_work(collect_resource, res, regional_builder)
queue.wait_for_submitted_work()
regional_builder.core_feedback.progress_done(region.name, 1, 1)
regional_builder.core_feedback.progress_done(region.safe_name, 1, 1)
except Exception as e:
msg = f"Error collecting resources in account {self.account.id} region {region.id}: {e} - skipping region"
self.core_feedback.error(msg, log)
Expand Down
6 changes: 3 additions & 3 deletions plugins/aws/resoto_plugin_aws/resource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,12 +436,12 @@ def resources_of(self, resource_type: Type[AwsResourceType]) -> List[AwsResource
def instance_type(self, instance_type: str) -> Optional[Any]:
if (global_type := self.global_instance_types.get(instance_type)) is None:
return None # instance type not found
price = AwsPricingPrice.instance_type_price(self.client, instance_type, self.region.name)
price = AwsPricingPrice.instance_type_price(self.client, instance_type, self.region.safe_name)
return evolve(global_type, region=self.region, ondemand_cost=price.on_demand_price_usd if price else None)

@lru_cache(maxsize=None)
def volume_type(self, volume_type: str) -> Optional[Any]:
price = AwsPricingPrice.volume_type_price(self.client, volume_type, self.region.name)
price = AwsPricingPrice.volume_type_price(self.client, volume_type, self.region.safe_name)
vt = AwsEc2VolumeType(
id=volume_type,
name=volume_type,
Expand All @@ -457,7 +457,7 @@ def for_region(self, region: AwsRegion) -> GraphBuilder:
self.cloud,
self.account,
region,
self.client.for_region(region.name),
self.client.for_region(region.safe_name),
self.executor,
self.core_feedback,
self.global_instance_types,
Expand Down
2 changes: 1 addition & 1 deletion plugins/aws/resoto_plugin_aws/resource/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) ->
it = AwsEc2InstanceType.from_api(js)
# only store this information in the builder, not directly in the graph
# reason: pricing is region-specific - this is enriched in the builder on demand
builder.global_instance_types[it.name] = it
builder.global_instance_types[it.safe_name] = it


# endregion
Expand Down
6 changes: 3 additions & 3 deletions plugins/aws/resoto_plugin_aws/resource/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
def pre_delete_resource(self, client: AwsClient, graph: Graph) -> bool:
for predecessor in self.predecessors(graph=graph, edge_type=EdgeType.default):
if isinstance(predecessor, AwsEcsService):
predecessor.purge_capacity_provider(client=client, capacity_provider_name=self.name)
predecessor.purge_capacity_provider(client=client, capacity_provider_name=self.safe_name)
if isinstance(predecessor, AwsEcsCluster):
predecessor.disassociate_capacity_provider(client=client, capacity_provider_name=self.name)
predecessor.disassociate_capacity_provider(client=client, capacity_provider_name=self.safe_name)
return True

def delete_resource(self, client: AwsClient) -> bool:
client.call("ecs", "delete-capacity-provider", None, capacityProvider=self.name)
client.call("ecs", "delete-capacity-provider", None, capacityProvider=self.safe_name)
return True


Expand Down
2 changes: 1 addition & 1 deletion plugins/aws/resoto_plugin_aws/resource/route53.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _keys(self) -> tuple[str, str, str, str, str, str, str, str, str, Optional[s
self.zone().id,
self.dns_zone().id,
self.id,
self.name,
self.safe_name,
self.record_type,
self.record_set_identifier,
)
Expand Down
2 changes: 1 addition & 1 deletion plugins/aws/resoto_plugin_aws/resource/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def add_tags(bucket: AwsS3Bucket, client: AwsClient) -> None:

for js in json:
bucket = cls.from_api(js)
bucket.set_arn(builder=builder, region="", account="", resource=bucket.name)
bucket.set_arn(builder=builder, region="", account="", resource=bucket.safe_name)
builder.add_node(bucket, js)
builder.submit_work(add_tags, bucket, builder.client)

Expand Down
4 changes: 2 additions & 2 deletions plugins/aws/resoto_plugin_aws/resource/service_quotas.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def match(self, quota: AwsServiceQuota) -> bool:
if self.quota_name is None:
return False
elif isinstance(self.quota_name, Pattern):
return self.quota_name.match(quota.name) is not None
return self.quota_name.match(quota.safe_name) is not None
else:
return self.quota_name == quota.name
return self.quota_name == quota.safe_name


RegionalQuotas = {
Expand Down
6 changes: 3 additions & 3 deletions resotocore/resotocore/cli/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2922,8 +2922,8 @@ class WorkerCustomCommand:
def to_template(self) -> AliasTemplate:
allowed_kind = f" --allowed-on {self.allowed_on_kind}" if self.allowed_on_kind else ""
result_flag = "" if self.expect_node_result else " --no-node-result"
command = f'--command "{self.name}"'
args = '--arg "{{args}}"'
command = f"--command '{self.name}'"
args = "--arg '{{args}}'"
return AliasTemplate(
name=self.name,
info=self.info or "",
Expand Down Expand Up @@ -3008,7 +3008,7 @@ def update_single(item: Json) -> Tuple[str, Dict[str, str], Json]:

return update_single

formatter, variables = ctx.formatter_with_variables(double_quoted_or_simple_string_dp.parse(args))
formatter, variables = ctx.formatter_with_variables(args or "")
fn = call_function(lambda item: {"args": args_parts_unquoted_parser.parse(formatter(item)), "node": item})

def setup_stream(in_stream: Stream) -> Stream:
Expand Down
2 changes: 1 addition & 1 deletion resotocore/tests/resotocore/cli/command_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ async def test_execute_task(cli: CLI) -> None:
assert command.info == "info"
assert command.description == "description"
assert command.args_description == {"a": "b"}
assert command.template == 'execute-task --no-node-result --command "name" --arg "{{args}}"'
assert command.template == "execute-task --no-node-result --command 'name' --arg '{{args}}'"

# execute-task in source position
source_result = await cli.execute_cli_command(
Expand Down
20 changes: 4 additions & 16 deletions resotolib/resotolib/baseresources.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class BaseResource(ABC):

id: str
tags: Dict[str, Optional[str]] = Factory(dict)
name: str = None
name: Optional[str] = field(default=None)
_cloud: "Optional[BaseCloud]" = field(default=None, repr=False)
_account: "Optional[BaseAccount]" = field(default=None, repr=False)
_region: "Optional[BaseRegion]" = field(default=None, repr=False)
Expand Down Expand Up @@ -172,14 +172,6 @@ def __attrs_post_init__(self) -> None:
if not hasattr(self, "_mtime"):
self._mtime = None

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}('{self.id}', name='{self.name}',"
f" region='{self.region().name}', zone='{self.zone().name}',"
f" account='{self.account().dname}', kind='{self.kind}',"
f" ctime={self.ctime!r}, chksum={self.chksum})"
)

def _keys(self) -> tuple:
"""Return a tuple of all keys that make this resource unique
Expand All @@ -201,13 +193,9 @@ def _keys(self) -> tuple:
self.name,
)

# def __hash__(self):
# return hash(self._keys())

# def __eq__(self, other):
# if isinstance(other, type(self)):
# return self._keys() == other._keys()
# return NotImplemented
@property
def safe_name(self) -> str:
return self.name or self.id

@property
def dname(self) -> str:
Expand Down

0 comments on commit 78045c6

Please sign in to comment.