diff --git a/changelog/344.bugfix.md b/changelog/344.bugfix.md new file mode 100644 index 000000000..8fd221485 --- /dev/null +++ b/changelog/344.bugfix.md @@ -0,0 +1 @@ +Fix a bug in the `FormValidationAction` which immediately deactivated the form. diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 445709f4d..64e3d4ecf 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -681,12 +681,17 @@ def _get_entity_type_of_slot_to_fill( class FormValidationAction(Action, ABC): """A helper class for slot validations and extractions of custom slots.""" + def form_name(self) -> Text: + """Returns the form's name.""" + return self.name().replace("validate_", "", 1) + async def run( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: "DomainDict", ) -> List[EventType]: + """Runs the custom actions. Please the docstring of the parent class.""" extraction_events = await self.extract_custom_slots(dispatcher, tracker, domain) tracker.add_slots(extraction_events) @@ -812,7 +817,7 @@ def slots_mapped_in_domain(self, domain: "DomainDict") -> List[Text]: Returns: Slot names mapped in the domain. """ - return list(domain.get("forms", {}).get(self.name(), {}).keys()) + return list(domain.get("forms", {}).get(self.form_name(), {}).keys()) async def validate( self, @@ -832,7 +837,7 @@ async def validate( """ slots: Dict[Text, Any] = tracker.slots_to_validate() - for slot_name, slot_value in slots.items(): + for slot_name, slot_value in list(slots.items()): method_name = f"validate_{slot_name.replace('-','_')}" validate_method = getattr(self, method_name, None) diff --git a/tests/test_forms.py b/tests/test_forms.py index efb1574f4..b6878a601 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1503,10 +1503,10 @@ def test_form_deprecation(): class TestFormValidationAction(FormValidationAction): def __init__(self, form_name: Text = "some_form") -> None: - self.form_name = form_name + self.name_of_form = form_name def name(self) -> Text: - return self.form_name + return self.name_of_form def validate_slot1( self, @@ -1643,6 +1643,48 @@ async def test_form_validation_without_validate_function(): ] +async def test_form_validation_changing_slots_during_validation(): + form_name = "some_form" + + class TestForm(FormValidationAction): + def name(self) -> Text: + return form_name + + async def validate_my_slot( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> Dict[Text, Any]: + return {"my_slot": None, "other_slot": "value"} + + form = TestForm() + + # tracker with active form + tracker = Tracker( + "default", + {}, + {}, + [SlotSet("my_slot", "correct_value")], + False, + None, + {"name": form_name, "is_interrupted": False, "rejected": False}, + "action_listen", + ) + + dispatcher = CollectingDispatcher() + events = await form.run( + dispatcher=dispatcher, + tracker=tracker, + domain={"forms": {form_name: {"my_slot": []}}}, + ) + assert events == [ + SlotSet("my_slot", None), + SlotSet("other_slot", "value"), + ] + + async def test_form_validation_dash_slot(): form_name = "some_form"