Skip to content

Commit

Permalink
Add enum validation to Result class.
Browse files Browse the repository at this point in the history
  • Loading branch information
brian-pond committed Feb 22, 2024
1 parent e27201f commit b01b8ef
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion temporal/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import NamedTuple

import frappe

from temporal import validate_datatype
from temporal.helpers import dict_to_dateless_dict

Expand Down Expand Up @@ -36,8 +37,30 @@ def __friendly__(self):
# Define a Message
# ========

# NOTE: Define enum classes as both (str, Enum), so they are JSON serializable.
def validate_enum_value(any_value, enum_type, raise_on_errors=True):
"""
Ensure that an enumerated Value is a member of the specific enumerated data Type.
"""

# Validate the arguments -themselves- are of the correct types.
if not isinstance(any_value, (str, Enum)):
raise TypeError(f"Expected argument 'any_value' with value = '{any_value}' to be either String or Enum type.")
if not issubclass(enum_type, Enum):
raise TypeError(f"Argument '{enum_type}' must be a subclass of Enum.")

# Next, if the first argument is a String, verify it will successfully coerce into an Enum variant.
if isinstance(any_value, str):
try:
enum_type[any_value.upper()]
except Exception as ex:
message = f"Argument value '{any_value}' is not a type of Enum '{enum_type.__name__}'"
print(message)
if raise_on_errors:
raise TypeError(message) from ex
return False
return True

# NOTE: Should always define Enum classes as both (str, Enum), so they are JSON serializable.
class MessageLevel(str, Enum):
INFO = 'Info'
WARNING = 'Warning'
Expand Down Expand Up @@ -138,6 +161,9 @@ def as_json(self):
"messages": dict_to_dateless_dict(self._messages)
})

def __str__(self):
return self.as_json()

def should_raise_exceptions(self) -> bool:
if self.outcome in (OutcomeType.ERROR, OutcomeType.INTERNAL_ERROR):
return True
Expand All @@ -150,6 +176,9 @@ def should_raise_exceptions(self) -> bool:
# Message Functions
def add_message(self, audience, message_level, message_string, tags=None):

validate_enum_value(audience, MessageAudience)
validate_enum_value(message_level, MessageLevel)

# Validate the tags
if tags:
if isinstance(tags, str):
Expand Down Expand Up @@ -187,6 +216,7 @@ def add_result_to_crs(self, crs_instance):
crs_instance.add_data(key, value) # important to send as JSON, to convert things like Date and DateTime to string.

for each_message in self.get_all_messages():

if each_message.audience in (MessageAudience.ALL, MessageAudience.INTERNAL):
crs_instance.add_internal_message(str(each_message))
if each_message.audience in (MessageAudience.ALL, MessageAudience.CUSTOMER):
Expand Down

0 comments on commit b01b8ef

Please sign in to comment.