Skip to content

Commit

Permalink
Merge pull request #345 from RasaHQ/slots_mapped_in_domain
Browse files Browse the repository at this point in the history
Slots mapped in domain
  • Loading branch information
wochinge authored Nov 17, 2020
2 parents 1807800 + 9190bef commit 08b902f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
1 change: 1 addition & 0 deletions changelog/344.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug in the `FormValidationAction` which immediately deactivated the form.
9 changes: 7 additions & 2 deletions rasa_sdk/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,17 @@ def _get_entity_type_of_slot_to_fill(
class FormValidationAction(Action, ABC):
"""A helper class for slot validations and extractions of custom slots."""

def form_name(self) -> Text:
"""Returns the form's name."""
return self.name().replace("validate_", "", 1)

async def run(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> List[EventType]:
"""Runs the custom actions. Please the docstring of the parent class."""
extraction_events = await self.extract_custom_slots(dispatcher, tracker, domain)
tracker.add_slots(extraction_events)

Expand Down Expand Up @@ -812,7 +817,7 @@ def slots_mapped_in_domain(self, domain: "DomainDict") -> List[Text]:
Returns:
Slot names mapped in the domain.
"""
return list(domain.get("forms", {}).get(self.name(), {}).keys())
return list(domain.get("forms", {}).get(self.form_name(), {}).keys())

async def validate(
self,
Expand All @@ -832,7 +837,7 @@ async def validate(
"""
slots: Dict[Text, Any] = tracker.slots_to_validate()

for slot_name, slot_value in slots.items():
for slot_name, slot_value in list(slots.items()):
method_name = f"validate_{slot_name.replace('-','_')}"
validate_method = getattr(self, method_name, None)

Expand Down
46 changes: 44 additions & 2 deletions tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,10 +1503,10 @@ def test_form_deprecation():

class TestFormValidationAction(FormValidationAction):
def __init__(self, form_name: Text = "some_form") -> None:
self.form_name = form_name
self.name_of_form = form_name

def name(self) -> Text:
return self.form_name
return self.name_of_form

def validate_slot1(
self,
Expand Down Expand Up @@ -1643,6 +1643,48 @@ async def test_form_validation_without_validate_function():
]


async def test_form_validation_changing_slots_during_validation():
form_name = "some_form"

class TestForm(FormValidationAction):
def name(self) -> Text:
return form_name

async def validate_my_slot(
self,
slot_value: Any,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> Dict[Text, Any]:
return {"my_slot": None, "other_slot": "value"}

form = TestForm()

# tracker with active form
tracker = Tracker(
"default",
{},
{},
[SlotSet("my_slot", "correct_value")],
False,
None,
{"name": form_name, "is_interrupted": False, "rejected": False},
"action_listen",
)

dispatcher = CollectingDispatcher()
events = await form.run(
dispatcher=dispatcher,
tracker=tracker,
domain={"forms": {form_name: {"my_slot": []}}},
)
assert events == [
SlotSet("my_slot", None),
SlotSet("other_slot", "value"),
]


async def test_form_validation_dash_slot():
form_name = "some_form"

Expand Down

0 comments on commit 08b902f

Please sign in to comment.