From 401beedff2ecaf128e99fe1c5ee5114f301a882a Mon Sep 17 00:00:00 2001 From: Simon Bihel Date: Mon, 29 Apr 2024 17:35:22 +0100 Subject: [PATCH] Split out parsing method from initializer (#61) --------- Co-authored-by: LeaveMyYard Co-authored-by: Pavel Zhukov <33721692+LeaveMyYard@users.noreply.github.com> --- README.md | 16 ++++++++-------- siwe/siwe.py | 23 ++++++++--------------- tests/test_siwe.py | 18 +++++++++--------- 3 files changed, 25 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index e80de0f..43348ef 100644 --- a/README.md +++ b/README.md @@ -18,22 +18,22 @@ SIWE provides a `SiweMessage` class which implements EIP-4361. Parsing is done by initializing a `SiweMessage` object with an EIP-4361 formatted string: -``` python +```python from siwe import SiweMessage -message: SiweMessage = SiweMessage(message=eip_4361_string) +message = SiweMessage.from_message(message=eip_4361_string) ``` -Alternatively, initialization of a `SiweMessage` object can be done with a dictionary containing expected attributes: +Or to initialize a `SiweMessage` as a `pydantic.BaseModel` right away: -``` python -message: SiweMessage = SiweMessage(message={"domain": "login.xyz", "address": "0x1234...", ...}) +```python +message = SiweMessage(domain="login.xyz", address="0x1234...", ...) ``` ### Verifying and Authenticating a SIWE Message Verification and authentication is performed via EIP-191, using the `address` field of the `SiweMessage` as the expected signer. The validate method checks message structural integrity, signature address validity, and time-based validity attributes. -``` python +```python try: message.verify(signature="0x...") # You can also specify other checks (e.g. the nonce or domain expected). @@ -45,7 +45,7 @@ except siwe.ValidationError: `SiweMessage` instances can also be serialized as their EIP-4361 string representations via the `prepare_message` method: -``` python +```python print(message.prepare_message()) ``` @@ -53,7 +53,7 @@ print(message.prepare_message()) Parsing and verifying a `SiweMessage` is easy: -``` python +```python try: message: SiweMessage = SiweMessage(message=eip_4361_string) message.verify(signature, nonce="abcdef", domain="example.com"): diff --git a/siwe/siwe.py b/siwe/siwe.py index 2442016..c4afab5 100644 --- a/siwe/siwe.py +++ b/siwe/siwe.py @@ -4,7 +4,7 @@ import string from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Iterable, List, Optional import eth_utils from eth_account.messages import SignableMessage, _hash_eip191_message, encode_defunct @@ -216,23 +216,16 @@ def address_is_checksum_address(cls, v: str) -> str: raise ValueError("Message `address` must be in EIP-55 format") return v - def __init__(self, message: Union[str, Dict[str, Any]], abnf: bool = True): - """Construct or parse a message.""" - if isinstance(message, str): - if abnf: - parsed_message = ABNFParsedMessage(message=message) - else: - parsed_message = RegExpParsedMessage(message=message) - message_dict = parsed_message.__dict__ - - elif isinstance(message, dict): - message_dict = message - + @classmethod + def from_message(cls, message: str, abnf: bool = True) -> "SiweMessage": + """Parse a message in its EIP-4361 format.""" + if abnf: + parsed_message = ABNFParsedMessage(message=message) else: - raise TypeError(f"Unhandable message type: '{type(message)}'.") + parsed_message = RegExpParsedMessage(message=message) # TODO There is some redundancy in the checks when deserialising a message. - super().__init__(**message_dict) + return cls(**parsed_message.__dict__) def prepare_message(self) -> str: """Serialize to the EIP-4361 format for signing. diff --git a/tests/test_siwe.py b/tests/test_siwe.py index edb9d08..82efe54 100644 --- a/tests/test_siwe.py +++ b/tests/test_siwe.py @@ -30,7 +30,7 @@ class TestMessageParsing: [(test_name, test) for test_name, test in parsing_positive.items()], ) def test_valid_message(self, abnf, test_name, test): - siwe_message = SiweMessage(message=test["message"], abnf=abnf) + siwe_message = SiweMessage.from_message(message=test["message"], abnf=abnf) for key, value in test["fields"].items(): v = getattr(siwe_message, key) if not (isinstance(v, int) or isinstance(v, list) or v is None): @@ -44,7 +44,7 @@ def test_valid_message(self, abnf, test_name, test): ) def test_invalid_message(self, abnf, test_name, test): with pytest.raises(ValueError): - SiweMessage(message=test, abnf=abnf) + SiweMessage.from_message(message=test, abnf=abnf) @pytest.mark.parametrize( "test_name,test", @@ -52,7 +52,7 @@ def test_invalid_message(self, abnf, test_name, test): ) def test_invalid_object_message(self, test_name, test): with pytest.raises(ValidationError): - SiweMessage(message=test) + SiweMessage(**test) class TestMessageGeneration: @@ -61,7 +61,7 @@ class TestMessageGeneration: [(test_name, test) for test_name, test in parsing_positive.items()], ) def test_valid_message(self, test_name, test): - siwe_message = SiweMessage(message=test["fields"]) + siwe_message = SiweMessage(**test["fields"]) assert siwe_message.prepare_message() == test["message"] @@ -71,7 +71,7 @@ class TestMessageVerification: [(test_name, test) for test_name, test in verification_positive.items()], ) def test_valid_message(self, test_name, test): - siwe_message = SiweMessage(message=test) + siwe_message = SiweMessage(**test) timestamp = datetime_from_iso8601_string(test["time"]) if "time" in test else None siwe_message.verify(test["signature"], timestamp=timestamp) @@ -81,7 +81,7 @@ def test_valid_message(self, test_name, test): ) def test_eip1271_message(self, test_name, test): provider = HTTPProvider(endpoint_uri="https://cloudflare-eth.com") - siwe_message = SiweMessage(message=test["message"]) + siwe_message = SiweMessage.from_message(message=test["message"]) siwe_message.verify(test["signature"], provider=provider) @pytest.mark.parametrize( @@ -98,9 +98,9 @@ def test_invalid_message(self, provider, test_name, test): "invalidissued_at", ]: with pytest.raises(ValidationError): - siwe_message = SiweMessage(message=test) + siwe_message = SiweMessage(**test) return - siwe_message = SiweMessage(message=test) + siwe_message = SiweMessage(**test) domain_binding = test.get("domain_binding") match_nonce = test.get("match_nonce") timestamp = datetime_from_iso8601_string(test["time"]) if "time" in test else None @@ -122,7 +122,7 @@ class TestMessageRoundTrip: [(test_name, test) for test_name, test in parsing_positive.items()], ) def test_message_round_trip(self, test_name, test): - message = SiweMessage(test["fields"]) + message = SiweMessage(**test["fields"]) message.address = self.account.address signature = self.account.sign_message( messages.encode_defunct(text=message.prepare_message())