Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support for MSC3824 auth type actions #12924

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12924.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for the new actions field on GET /login from [MSC3824](https://github.com/matrix-org/matrix-spec-proposals/pull/3824).
21 changes: 20 additions & 1 deletion synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class LoginRestServlet(RestServlet):
JWT_TYPE = "org.matrix.login.jwt"
APPSERVICE_TYPE = "m.login.application_service"
REFRESH_TOKEN_PARAM = "refresh_token"
ACTION_LOGIN = "login"
ACTION_REGISTER = "register"

def __init__(self, hs: "HomeServer"):
super().__init__()
Expand Down Expand Up @@ -114,6 +116,8 @@ def __init__(self, hs: "HomeServer"):
burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
)

self._registration_enabled = hs.config.registration.enable_registration

# ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
# The reason for this is to ensure that the auth_provider_ids are registered
# with SsoHandler, which in turn ensures that the login/registration prometheus
Expand Down Expand Up @@ -152,7 +156,22 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())

flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
# You can only login with app-service
flows.append(
{
"type": LoginRestServlet.APPSERVICE_TYPE,
"actions": [LoginRestServlet.ACTION_LOGIN],
}
)

actions: List[str] = [LoginRestServlet.ACTION_LOGIN]
if self._registration_enabled:
actions.append(LoginRestServlet.ACTION_REGISTER)

# Set actions for all flows if not already specified
for flow in flows:
if "actions" not in flow:
flow["actions"] = actions

return 200, {"flows": flows}

Expand Down
37 changes: 30 additions & 7 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# Login flows we expect to appear in the list after the normal ones.
ADDITIONAL_LOGIN_FLOWS = [
{"type": "m.login.application_service"},
{"type": "m.login.application_service", "actions": ["login"]},
]

# a mock instance which the dummy auth providers delegate to, so we can see what's going
Expand Down Expand Up @@ -183,7 +183,11 @@ def test_password_only_auth_progiver_login_legacy(self):
def password_only_auth_provider_login_test_body(self):
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
flows,
[{"type": "m.login.password", "actions": ["login", "register"]}]
+ ADDITIONAL_LOGIN_FLOWS,
)

# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
Expand Down Expand Up @@ -400,7 +404,10 @@ def custom_auth_provider_login_test_body(self):
flows = self._get_login_flows()
self.assertEqual(
flows,
[{"type": "m.login.password"}, {"type": "test.login_type"}]
[
{"type": "m.login.password", "actions": ["login", "register"]},
{"type": "test.login_type", "actions": ["login", "register"]},
]
+ ADDITIONAL_LOGIN_FLOWS,
)

Expand Down Expand Up @@ -545,7 +552,11 @@ def custom_auth_password_disabled_test_body(self):
self.register_user("localuser", "localpass")

flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
flows,
[{"type": "test.login_type", "actions": ["login", "register"]}]
+ ADDITIONAL_LOGIN_FLOWS,
)

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
Expand Down Expand Up @@ -580,7 +591,11 @@ def custom_auth_password_disabled_localdb_enabled_test_body(self):
self.register_user("localuser", "localpass")

flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
flows,
[{"type": "test.login_type", "actions": ["login", "register"]}]
+ ADDITIONAL_LOGIN_FLOWS,
)

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
Expand Down Expand Up @@ -611,7 +626,11 @@ def password_custom_auth_password_disabled_login_test_body(self):
self.register_user("localuser", "localpass")

flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
flows,
[{"type": "test.login_type", "actions": ["login", "register"]}]
+ ADDITIONAL_LOGIN_FLOWS,
)

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
Expand Down Expand Up @@ -716,7 +735,11 @@ def custom_auth_no_local_user_fallback_test_body(self):
self.register_user("localuser", "localpass")

flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
flows,
[{"type": "test.login_type", "actions": ["login", "register"]}]
+ ADDITIONAL_LOGIN_FLOWS,
)

# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

# Login flows we expect to appear in the list after the normal ones.
ADDITIONAL_LOGIN_FLOWS = [
{"type": "m.login.application_service"},
{"type": "m.login.application_service", "actions": ["login"]},
]


Expand Down Expand Up @@ -473,7 +473,7 @@ def test_get_login_flows(self) -> None:
"m.login.sso",
"m.login.token",
"m.login.password",
] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
] + [str(f["type"]) for f in ADDITIONAL_LOGIN_FLOWS]

self.assertCountEqual(
[f["type"] for f in channel.json_body["flows"]], expected_flow_types
Expand Down
26 changes: 26 additions & 0 deletions tests/rest/client/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
from tests import unittest
from tests.unittest import override_config

# Login flows we expect to appear in the list after the normal ones.
ADDITIONAL_LOGIN_FLOWS = [
{"type": "m.login.application_service", "actions": ["login"]},
]


class RegisterRestServletTestCase(unittest.HomeserverTestCase):

Expand All @@ -50,6 +55,11 @@ def default_config(self) -> Dict[str, Any]:
config["allow_guest_access"] = True
return config

def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["flows"]

def test_POST_appservice_registration_valid(self) -> None:
user_id = "@as_user_kermit:test"
as_token = "i_am_an_app_service"
Expand Down Expand Up @@ -121,6 +131,14 @@ def test_POST_bad_username(self) -> None:
self.assertEqual(channel.json_body["error"], "Invalid username")

def test_POST_user_valid(self) -> None:
# login flows should have actions "login" and "register"
flows = self._get_login_flows()
self.assertEqual(
flows,
[{"type": "m.login.password", "actions": ["login", "register"]}]
+ ADDITIONAL_LOGIN_FLOWS,
)

user_id = "@kermit:test"
device_id = "frogfone"
params = {
Expand All @@ -142,6 +160,14 @@ def test_POST_user_valid(self) -> None:

@override_config({"enable_registration": False})
def test_POST_disabled_registration(self) -> None:
# login flows should only have action "login"
flows = self._get_login_flows()
self.assertEqual(
flows,
[{"type": "m.login.password", "actions": ["login"]}]
+ ADDITIONAL_LOGIN_FLOWS,
)

request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)

Expand Down