diff --git a/changelog.d/12924.feature b/changelog.d/12924.feature new file mode 100644 index 000000000000..c3a6838a2b5f --- /dev/null +++ b/changelog.d/12924.feature @@ -0,0 +1 @@ +Support for the new actions field on GET /login from [MSC3824](https://github.com/matrix-org/matrix-spec-proposals/pull/3824). \ No newline at end of file diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index cf4196ac0a2b..bf2a186546c4 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -71,6 +71,8 @@ class LoginRestServlet(RestServlet): JWT_TYPE = "org.matrix.login.jwt" APPSERVICE_TYPE = "m.login.application_service" REFRESH_TOKEN_PARAM = "refresh_token" + ACTION_LOGIN = "login" + ACTION_REGISTER = "register" def __init__(self, hs: "HomeServer"): super().__init__() @@ -114,6 +116,8 @@ def __init__(self, hs: "HomeServer"): burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, ) + self._registration_enabled = hs.config.registration.enable_registration + # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. # The reason for this is to ensure that the auth_provider_ids are registered # with SsoHandler, which in turn ensures that the login/registration prometheus @@ -152,7 +156,22 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) - flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + # You can only login with app-service + flows.append( + { + "type": LoginRestServlet.APPSERVICE_TYPE, + "actions": [LoginRestServlet.ACTION_LOGIN], + } + ) + + actions: List[str] = [LoginRestServlet.ACTION_LOGIN] + if self._registration_enabled: + actions.append(LoginRestServlet.ACTION_REGISTER) + + # Set actions for all flows if not already specified + for flow in flows: + if "actions" not in flow: + flow["actions"] = actions return 200, {"flows": flows} diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 82b3bb3b735d..9ee88e65a859 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -32,7 +32,7 @@ # Login flows we expect to appear in the list after the normal ones. ADDITIONAL_LOGIN_FLOWS = [ - {"type": "m.login.application_service"}, + {"type": "m.login.application_service", "actions": ["login"]}, ] # a mock instance which the dummy auth providers delegate to, so we can see what's going @@ -183,7 +183,11 @@ def test_password_only_auth_progiver_login_legacy(self): def password_only_auth_provider_login_test_body(self): # login flows should only have m.login.password flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + flows, + [{"type": "m.login.password", "actions": ["login", "register"]}] + + ADDITIONAL_LOGIN_FLOWS, + ) # check_password must return an awaitable mock_password_provider.check_password.return_value = make_awaitable(True) @@ -400,7 +404,10 @@ def custom_auth_provider_login_test_body(self): flows = self._get_login_flows() self.assertEqual( flows, - [{"type": "m.login.password"}, {"type": "test.login_type"}] + [ + {"type": "m.login.password", "actions": ["login", "register"]}, + {"type": "test.login_type", "actions": ["login", "register"]}, + ] + ADDITIONAL_LOGIN_FLOWS, ) @@ -545,7 +552,11 @@ def custom_auth_password_disabled_test_body(self): self.register_user("localuser", "localpass") flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + flows, + [{"type": "test.login_type", "actions": ["login", "register"]}] + + ADDITIONAL_LOGIN_FLOWS, + ) # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") @@ -580,7 +591,11 @@ def custom_auth_password_disabled_localdb_enabled_test_body(self): self.register_user("localuser", "localpass") flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + flows, + [{"type": "test.login_type", "actions": ["login", "register"]}] + + ADDITIONAL_LOGIN_FLOWS, + ) # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") @@ -611,7 +626,11 @@ def password_custom_auth_password_disabled_login_test_body(self): self.register_user("localuser", "localpass") flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + flows, + [{"type": "test.login_type", "actions": ["login", "register"]}] + + ADDITIONAL_LOGIN_FLOWS, + ) # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") @@ -716,7 +735,11 @@ def custom_auth_no_local_user_fallback_test_body(self): self.register_user("localuser", "localpass") flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + flows, + [{"type": "test.login_type", "actions": ["login", "register"]}] + + ADDITIONAL_LOGIN_FLOWS, + ) # password login shouldn't work and should be rejected with a 400 # ("unknown login type") diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 4920468f7ab8..879a2ceae0a0 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -83,7 +83,7 @@ # Login flows we expect to appear in the list after the normal ones. ADDITIONAL_LOGIN_FLOWS = [ - {"type": "m.login.application_service"}, + {"type": "m.login.application_service", "actions": ["login"]}, ] @@ -473,7 +473,7 @@ def test_get_login_flows(self) -> None: "m.login.sso", "m.login.token", "m.login.password", - ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] + ] + [str(f["type"]) for f in ADDITIONAL_LOGIN_FLOWS] self.assertCountEqual( [f["type"] for f in channel.json_body["flows"]], expected_flow_types diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 9aebf1735a96..144ec7c6dfc1 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -35,6 +35,11 @@ from tests import unittest from tests.unittest import override_config +# Login flows we expect to appear in the list after the normal ones. +ADDITIONAL_LOGIN_FLOWS = [ + {"type": "m.login.application_service", "actions": ["login"]}, +] + class RegisterRestServletTestCase(unittest.HomeserverTestCase): @@ -50,6 +55,11 @@ def default_config(self) -> Dict[str, Any]: config["allow_guest_access"] = True return config + def _get_login_flows(self) -> JsonDict: + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["flows"] + def test_POST_appservice_registration_valid(self) -> None: user_id = "@as_user_kermit:test" as_token = "i_am_an_app_service" @@ -121,6 +131,14 @@ def test_POST_bad_username(self) -> None: self.assertEqual(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self) -> None: + # login flows should have actions "login" and "register" + flows = self._get_login_flows() + self.assertEqual( + flows, + [{"type": "m.login.password", "actions": ["login", "register"]}] + + ADDITIONAL_LOGIN_FLOWS, + ) + user_id = "@kermit:test" device_id = "frogfone" params = { @@ -142,6 +160,14 @@ def test_POST_user_valid(self) -> None: @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: + # login flows should only have action "login" + flows = self._get_login_flows() + self.assertEqual( + flows, + [{"type": "m.login.password", "actions": ["login"]}] + + ADDITIONAL_LOGIN_FLOWS, + ) + request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)