Skip to content

Commit

Permalink
feat: duplicate_action option for groups and policies
Browse files Browse the repository at this point in the history
  • Loading branch information
0x6f677548 committed Oct 25, 2023
1 parent 2e95e78 commit 687a335
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 41 deletions.
4 changes: 4 additions & 0 deletions src/ca_pwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@
get_groups_in_policies,
cleanup_policies,
)

from ca_pwt.helpers.graph_api import (
DuplicateActionEnum
)
26 changes: 18 additions & 8 deletions src/ca_pwt/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
cleanup_groups,
import_groups,
)
from ca_pwt.helpers.graph_api import DuplicateActionEnum

from ca_pwt.policies_mappings import replace_keys_by_values_in_policies, replace_values_by_keys_in_policies

Expand Down Expand Up @@ -60,6 +61,13 @@
"Duplicates are checked by comparing the displayName of Policies/Groups",
)

_duplicate_action_option = click.option(
"--duplicate_action",
type=click.Choice([action.value for action in DuplicateActionEnum], case_sensitive=True),
help="The action to take when a duplicate is found (default is ignore). ",
default=DuplicateActionEnum.ignore.value,
)


def _exit_with_exception(exception: Exception, exit_code: int = 1, fg: str = "red"):
"""Exit the program with an exception and exit code"""
Expand Down Expand Up @@ -356,13 +364,12 @@ def cleanup_groups_cmd(ctx: click.Context, input_file: str, output_file: str):
@click.pass_context
@_access_token_option
@_input_file_option
@_allow_duplicates_option
@_duplicate_action_option
def import_policies_cmd(
ctx: click.Context,
input_file: str,
access_token: str | None = None,
*,
allow_duplicates: bool = False,
duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore,
):
"""Imports CA policies from a file"""
try:
Expand All @@ -378,7 +385,11 @@ def import_policies_cmd(
click.echo(f"Input file: {input_file}")

policies = load_policies(input_file)
created_policies = import_policies(access_token, policies, allow_duplicates=allow_duplicates)

created_policies = import_policies(
access_token=access_token, policies=policies, duplicate_action=duplicate_action
)

click.echo("Successfully created policies:")
for policy in created_policies:
click.echo(f"{policy[0]}: {policy[1]}")
Expand Down Expand Up @@ -427,13 +438,12 @@ def export_groups_cmd(
@click.pass_context
@_access_token_option
@_input_file_option
@_allow_duplicates_option
@_duplicate_action_option
def import_groups_cmd(
ctx: click.Context,
input_file: str,
access_token: str | None = None,
*,
allow_duplicates: bool = False,
duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore,
):
"""Imports groups from a file"""
try:
Expand All @@ -453,7 +463,7 @@ def import_groups_cmd(
input_file = _get_from_ctx_if_none(ctx, "output_file", input_file, lambda: click.prompt("The input file"))
click.echo(f"Input file: {input_file}")
groups = load_groups(input_file)
created_groups = import_groups(access_token, groups, allow_duplicates=allow_duplicates)
created_groups = import_groups(access_token=access_token, groups=groups, duplicate_action=duplicate_action)
click.echo("Successfully created groups:")
for group in created_groups:
click.echo(f"{group[0]}: {group[1]}")
Expand Down
21 changes: 8 additions & 13 deletions src/ca_pwt/groups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests
import logging
from ca_pwt.helpers.graph_api import APIResponse, EntityAPI, _REQUEST_TIMEOUT
from ca_pwt.helpers.graph_api import APIResponse, EntityAPI, _REQUEST_TIMEOUT, DuplicateActionEnum
from ca_pwt.helpers.utils import assert_condition, cleanup_odata_dict, remove_element_from_dict, ensure_list

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,26 +94,21 @@ def get_groups_by_ids(access_token: str, group_ids: list[str], *, ignore_not_fou
return result


def import_groups(access_token: str, groups: list[dict], *, allow_duplicates: bool = False) -> list[tuple[str, str]]:
def import_groups(
access_token: str, groups: list[dict], duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore
) -> list[tuple[str, str]]:
"""Imports groups from the specified dictionary.
Returns a list of tuples with the group id and name of the imported groups.
"""
_logger.info("Importing groups...")
groups_api = GroupsAPI(access_token=access_token)
groups = cleanup_groups(groups)
result: list[tuple[str, str]] = []
for group in groups:
group_name = group["displayName"]

if not allow_duplicates:
existing_group = groups_api.get_by_display_name(group_name)
if existing_group.success:
_logger.warning(f"Group with display name {group_name} already exists. Skipping...")
continue

group_response = groups_api.create(group)
group_response.assert_success()
group_detail = group_response.json()
group_id = group_detail["id"]
response = groups_api.create_checking_duplicates(group, f"displayName eq '{group_name}'", duplicate_action)
response.assert_success()
group_id = response.json()["id"]
result.append((group_id, group_name))
_logger.info(f"Imported group {group_name} with id {group_id}")
return result
56 changes: 51 additions & 5 deletions src/ca_pwt/helpers/graph_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import requests
import logging
from ca_pwt.helpers.utils import assert_condition
from abc import ABC, abstractmethod
import logging
from enum import StrEnum

_REQUEST_TIMEOUT = 500


class DuplicateActionEnum(StrEnum):
ignore = "ignore"
replace = "replace"
duplicate = "duplicate"
fail = "fail"


class APIResponse:
"""A class to represent an API response"""

Expand Down Expand Up @@ -36,7 +44,7 @@ def json(self):

def assert_success(self):
"""Asserts that the request was successful"""
assert_condition(self.success, f"Request failed with status code {self.status_code}; {self.response.json()}")
assert_condition(self.success, f"Request failed with status code {self.status_code}; {self.response}")


class EntityAPI(ABC):
Expand Down Expand Up @@ -119,10 +127,15 @@ def get_by_display_name(self, display_name: str) -> APIResponse:
"""Gets the top entity found with the given display name
Returns an API_Response object and the entity is in the json property of the API_Response object
"""
assert_condition(display_name, "display_name cannot be None")
return self.get_top_entity(f"displayName eq '{display_name}'")

response = self.get_all(odata_filter=f"displayName eq '{display_name}'", odata_top=1)
def get_top_entity(self, odata_filter: str) -> APIResponse:
"""Gets the top entity found with the given filter
Returns an API_Response object and the entity is in the json property of the API_Response object
"""

assert_condition(odata_filter, "odata_filter cannot be None")
response = self.get_all(odata_filter=odata_filter, odata_top=1)
# if the request was successful, transform the response to a dict
if response.success:
# move the value property to the response property
Expand All @@ -134,14 +147,47 @@ def get_by_display_name(self, display_name: str) -> APIResponse:
response.response = "No results found"
else:
response.response = value[0]

return response

def create(self, entity: dict) -> APIResponse:
"""Creates an entity"""
assert_condition(entity, "entity cannot be None")
return self._request_post(self.entity_url, entity)

def create_checking_duplicates(
self, entity: dict, odata_filter: str, duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore
) -> APIResponse:
"""Creates an entity checking for duplicates first and taking the specified action if a duplicate is found
A duplicate is determined by the odata_filter parameter, getting the top entity with the specified filter"""
assert_condition(entity, "entity cannot be None")
assert_condition(odata_filter, "odata_filter cannot be None")

# if duplicate_action is not duplicate, check if the entity already exists
if duplicate_action != DuplicateActionEnum.duplicate:
existing_entity = self.get_top_entity(odata_filter)
if existing_entity.success:
if duplicate_action == DuplicateActionEnum.ignore:
self._logger.warning(
f"Entity {self._get_entity_path()} with filter {odata_filter} already exists. Skipping..."
)
return existing_entity
elif duplicate_action == DuplicateActionEnum.replace:
existing_entity_id = existing_entity.json()["id"]
self._logger.warning(f"Replacing entity {self._get_entity_path()} with id {existing_entity_id}...")
response = self.update(existing_entity_id, entity)
response.assert_success()
# response should be a "204 No Content" or "200 OK" response
# we need to return the existing_entity_id in the response body
response.response = {"id": existing_entity_id}
return response
elif duplicate_action == DuplicateActionEnum.fail:
msg = f"Entity {self._get_entity_path()} with filter {odata_filter} already exists."
raise ValueError(msg)
else:
msg = f"Invalid duplicate_action: {duplicate_action}"
raise ValueError(msg)
return self.create(entity)

def delete(self, entity_id: str) -> APIResponse:
"""Deletes an entity by its ID"""
assert_condition(entity_id, "entity_id cannot be None")
Expand Down
19 changes: 4 additions & 15 deletions src/ca_pwt/policies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from ca_pwt.helpers.utils import remove_element_from_dict, cleanup_odata_dict, ensure_list
from ca_pwt.helpers.graph_api import EntityAPI
from ca_pwt.helpers.graph_api import EntityAPI, DuplicateActionEnum
from ca_pwt.policies_mappings import replace_values_by_keys_in_policies
from ca_pwt.groups import get_groups_by_ids

Expand Down Expand Up @@ -70,8 +70,7 @@ def export_policies(access_token: str, odata_filter: str | None = None) -> list[
def import_policies(
access_token: str,
policies: list[dict],
*,
allow_duplicates: bool = False,
duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore,
) -> list[tuple[str, str]]:
"""Imports the specified policies. If allow_duplicates is False,
it will skip policies that already exist (using the display name as
Expand All @@ -88,20 +87,10 @@ def import_policies(
for policy in policies:
display_name: str = str(policy.get("displayName"))

# check if the policy already exists
if not allow_duplicates:
existing_policy = policies_api.get_by_display_name(display_name)
if existing_policy.success:
_logger.warning(f"Policy with display name {display_name} already exists. Skipping...")
continue

_logger.info(f"Creating policy {display_name}...")
_logger.debug(f"Policy: {policy}")
response = policies_api.create(policy)
response = policies_api.create_checking_duplicates(policy, f"displayName eq '{display_name}'", duplicate_action)
response.assert_success()

policy_id = response.json()["id"]
created_policies.append((display_name, policy_id))
created_policies.append((policy_id, display_name))
_logger.info("Policy created successfully with id %s", policy_id)
return created_policies

Expand Down

0 comments on commit 687a335

Please sign in to comment.