From 176cee511dfcbb3c8f47fb9dd699655765e21ee5 Mon Sep 17 00:00:00 2001 From: Arjaan Buijk Date: Tue, 17 Nov 2020 13:37:14 -0500 Subject: [PATCH 1/4] Strip `validate_` from the name to get the form name --- rasa_sdk/forms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 445709f4d..04603e95b 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -812,7 +812,9 @@ def slots_mapped_in_domain(self, domain: "DomainDict") -> List[Text]: Returns: Slot names mapped in the domain. """ - return list(domain.get("forms", {}).get(self.name(), {}).keys()) + return list( + domain.get("forms", {}).get(self.name().strip("validate_"), {}).keys() + ) async def validate( self, From 1071c845e4928039067a961cbd216f1f3e5aa58e Mon Sep 17 00:00:00 2001 From: Arjaan Buijk Date: Tue, 17 Nov 2020 13:37:52 -0500 Subject: [PATCH 2/4] The slots dictionary can change during the loop. --- rasa_sdk/forms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 04603e95b..ffdb2647c 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -834,7 +834,7 @@ async def validate( """ slots: Dict[Text, Any] = tracker.slots_to_validate() - for slot_name, slot_value in slots.items(): + for slot_name, slot_value in list(slots.items()): method_name = f"validate_{slot_name.replace('-','_')}" validate_method = getattr(self, method_name, None) From 3d3609cc3c863516aa5388be9ef8fc64498525a1 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Tue, 17 Nov 2020 20:40:40 +0100 Subject: [PATCH 3/4] implement review comments --- changelog/344.bugfix.md | 1 + rasa_sdk/forms.py | 7 ++++--- tests/test_forms.py | 46 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 changelog/344.bugfix.md diff --git a/changelog/344.bugfix.md b/changelog/344.bugfix.md new file mode 100644 index 000000000..8fd221485 --- /dev/null +++ b/changelog/344.bugfix.md @@ -0,0 +1 @@ +Fix a bug in the `FormValidationAction` which immediately deactivated the form. diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index ffdb2647c..764c9a625 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -681,6 +681,9 @@ def _get_entity_type_of_slot_to_fill( class FormValidationAction(Action, ABC): """A helper class for slot validations and extractions of custom slots.""" + def form_name(self) -> Text: + return self.name().replace("validate_", "", 1) + async def run( self, dispatcher: "CollectingDispatcher", @@ -812,9 +815,7 @@ def slots_mapped_in_domain(self, domain: "DomainDict") -> List[Text]: Returns: Slot names mapped in the domain. """ - return list( - domain.get("forms", {}).get(self.name().strip("validate_"), {}).keys() - ) + return list(domain.get("forms", {}).get(self.form_name(), {}).keys()) async def validate( self, diff --git a/tests/test_forms.py b/tests/test_forms.py index efb1574f4..b6878a601 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1503,10 +1503,10 @@ def test_form_deprecation(): class TestFormValidationAction(FormValidationAction): def __init__(self, form_name: Text = "some_form") -> None: - self.form_name = form_name + self.name_of_form = form_name def name(self) -> Text: - return self.form_name + return self.name_of_form def validate_slot1( self, @@ -1643,6 +1643,48 @@ async def test_form_validation_without_validate_function(): ] +async def test_form_validation_changing_slots_during_validation(): + form_name = "some_form" + + class TestForm(FormValidationAction): + def name(self) -> Text: + return form_name + + async def validate_my_slot( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> Dict[Text, Any]: + return {"my_slot": None, "other_slot": "value"} + + form = TestForm() + + # tracker with active form + tracker = Tracker( + "default", + {}, + {}, + [SlotSet("my_slot", "correct_value")], + False, + None, + {"name": form_name, "is_interrupted": False, "rejected": False}, + "action_listen", + ) + + dispatcher = CollectingDispatcher() + events = await form.run( + dispatcher=dispatcher, + tracker=tracker, + domain={"forms": {form_name: {"my_slot": []}}}, + ) + assert events == [ + SlotSet("my_slot", None), + SlotSet("other_slot", "value"), + ] + + async def test_form_validation_dash_slot(): form_name = "some_form" From 9190bef5ae99db9351d259e91ebe31cdeae8f6ac Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Tue, 17 Nov 2020 20:54:54 +0100 Subject: [PATCH 4/4] add docstrings --- rasa_sdk/forms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 764c9a625..64e3d4ecf 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -682,6 +682,7 @@ class FormValidationAction(Action, ABC): """A helper class for slot validations and extractions of custom slots.""" def form_name(self) -> Text: + """Returns the form's name.""" return self.name().replace("validate_", "", 1) async def run( @@ -690,6 +691,7 @@ async def run( tracker: "Tracker", domain: "DomainDict", ) -> List[EventType]: + """Runs the custom actions. Please the docstring of the parent class.""" extraction_events = await self.extract_custom_slots(dispatcher, tracker, domain) tracker.add_slots(extraction_events)