diff --git a/README.rst b/README.rst index 73677d7..d444c26 100644 --- a/README.rst +++ b/README.rst @@ -18,7 +18,7 @@ Setup a WireGuard server:: server = Server('myvpnserver.com', '192.168.24.0/24', address='192.168.24.1') # Write out the server config to the default location: /etc/wireguard/wg0.conf - server.config().write() + server.config.write() Create a client within the previously created server:: @@ -26,10 +26,10 @@ Create a client within the previously created server:: peer = server.peer('my-client') # Output this peer's config for copying to the peer device - print(peer.config().local_config) + print(peer.config.local_config) # Rewrite the server config file including the newly created peer - server.config().write() + server.config.write() Create a standalone client:: @@ -39,7 +39,7 @@ Create a standalone client:: peer = Peer('my-client', '192.168.24.0/24', address='192.168.24.45') # Write out the peer config to the default location: /etc/wireguard/wg0.conf - peer.config().write() + peer.config.write() **Note**: Both the server and peer config files are named the same by default. This is because diff --git a/tests/test_config.py b/tests/test_config.py index a8b0189..e282ca7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,8 +6,6 @@ patch, ) -from subnet import ip_network, IPv4Network, IPv4Address - from wireguard import ( Config, ServerConfig, @@ -112,7 +110,7 @@ def test_write_server_config_no_params(): ) with patch('builtins.open', mock_open()) as mo: - server.config().write() + server.config.write() mo.assert_has_calls([ call('/etc/wireguard/wg0.conf', mode='w', encoding='utf-8'), @@ -139,7 +137,7 @@ def test_write_server_config(interface, path, full_path, peers_full_path): interface=interface ) - config = server.config() + config = server.config assert config.full_path(path) == full_path assert config.peers_full_path(path) == peers_full_path @@ -162,7 +160,7 @@ def test_write_peer_config_no_params(): ) with patch('builtins.open', mock_open()) as mo: - peer.config().write() + peer.config.write() mo.assert_has_calls([ call('/etc/wireguard/wg0.conf', mode='w', encoding='utf-8'), @@ -191,7 +189,7 @@ def test_write_peer_config(interface, path, full_path): assert config.full_path(path) == full_path with patch('builtins.open', mock_open()) as mo: - peer.config().write(path) + peer.config.write(path) mo.assert_has_calls([ call(full_path, mode='w', encoding='utf-8'), diff --git a/tests/test_config_attributes.py b/tests/test_config_attributes.py index aba08e1..be6931b 100644 --- a/tests/test_config_attributes.py +++ b/tests/test_config_attributes.py @@ -1,20 +1,10 @@ import pytest -from unittest.mock import ( - call, - mock_open, - patch, -) - -from subnet import ip_network, IPv4Network, IPv4Address from wireguard import ( Config, - ServerConfig, Peer, - Server, ) -from wireguard.utils import IPAddressSet def test_description(): @@ -285,7 +275,7 @@ def test_comments(): comments=comments, ) - config = peer.config() + config = peer.config for comment in comments: assert f'# {comment}' in config.local_config diff --git a/tests/test_customizations.py b/tests/test_customizations.py new file mode 100644 index 0000000..38acf4d --- /dev/null +++ b/tests/test_customizations.py @@ -0,0 +1,89 @@ + +import pytest + +from subnet import ( + ip_address, + IPv4Address, +) + +from wireguard import ( + Config, + Interface, + Peer, +) + + +class MyCustomInterface(Interface): + pass + + +class MyCustomConfig(Config): + pass + + +def test_peer_custom_config_cls(): + address = '192.168.0.2' + dns = '1.1.1.1' + + peer = Peer( + 'test-peer', + address=address, + dns=ip_address(dns), + config_cls=MyCustomConfig, + ) + + assert isinstance(peer.config, MyCustomConfig) + + +@pytest.mark.parametrize( + ('cls',), + [ + (MyCustomInterface,), + (IPv4Address,), + ], +) +def test_peer_invalid_custom_config_cls(cls): + address = '192.168.0.2' + dns = '1.1.1.1' + + with pytest.raises(ValueError): + peer = Peer( + 'test-peer', + address=address, + dns=ip_address(dns), + config_cls=cls, + ) + + +def test_peer_custom_service_cls(): + address = '192.168.0.2' + dns = '1.1.1.1' + + peer = Peer( + 'test-peer', + address=address, + dns=ip_address(dns), + service_cls=MyCustomInterface, + ) + + assert isinstance(peer.service, MyCustomInterface) + + +@pytest.mark.parametrize( + ('cls',), + [ + (MyCustomConfig,), + (IPv4Address,), + ], +) +def test_peer_invalid_custom_service_cls(cls): + address = '192.168.0.2' + dns = '1.1.1.1' + + with pytest.raises(ValueError): + peer = Peer( + 'test-peer', + address=address, + dns=ip_address(dns), + service_cls=cls, + ) diff --git a/tests/test_json.py b/tests/test_json.py index ca2b6b3..9f0effd 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,23 +1,11 @@ -import functools import json import pytest -from subnet import ( - ip_network, - IPv4Network, - IPv4Address, -) - from wireguard import ( - INTERFACE, - PORT, - Config, - ServerConfig, Peer, Server, ) -from wireguard.utils import generate_key, public_key def test_server_json_dump_ipv4(): diff --git a/tests/test_peers.py b/tests/test_peers.py index 55af1bb..554a554 100644 --- a/tests/test_peers.py +++ b/tests/test_peers.py @@ -11,9 +11,7 @@ INTERFACE, PORT, Config, - ServerConfig, Peer, - Server, ) from wireguard.utils import public_key @@ -69,7 +67,7 @@ def test_basic_peer(ipv4_address, ipv6_address): assert not peer.keepalive assert not peer.preshared_key - config = peer.config() + config = peer.config assert isinstance(config, Config) wg_config = config.local_config @@ -139,7 +137,7 @@ def test_peer_mtu(mtu): assert not peer.keepalive assert not peer.preshared_key - config = peer.config() + config = peer.config config_lines = config.local_config.split('\n') assert f'MTU = {mtu}' in config_lines @@ -210,7 +208,7 @@ def test_peer_dns(): assert not peer.keepalive assert not peer.preshared_key - config = peer.config() + config = peer.config config_lines = config.local_config.split('\n') assert f'DNS = {dns}' in config_lines diff --git a/tests/test_qrcode.py b/tests/test_qrcode.py index 76e2869..d64b100 100644 --- a/tests/test_qrcode.py +++ b/tests/test_qrcode.py @@ -1,20 +1,9 @@ import pytest -from unittest.mock import ( - call, - mock_open, - patch, -) - -from subnet import ip_network, IPv4Network, IPv4Address from wireguard import ( - Config, - ServerConfig, Peer, - Server, ) -from wireguard.utils import IPAddressSet def test_peer_qrcode(): @@ -29,7 +18,7 @@ def test_peer_qrcode(): address=address, ) - assert peer.config().qrcode + assert peer.config.qrcode def test_peer_qrcode_not_present(): @@ -49,6 +38,6 @@ def test_peer_qrcode_not_present(): # If qrcode is not present in the venv, test it fails appropriately. with pytest.raises(AttributeError) as exc: - peer.config().qrcode + peer.config.qrcode assert 'add the qrcode' in str(exc.value) diff --git a/tests/test_server.py b/tests/test_server.py index 277857c..16b5693 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,5 +1,4 @@ -import functools import pytest from subnet import ( @@ -54,7 +53,7 @@ def test_basic_server(): assert not server.preshared_key assert not server.keepalive - config = server.config() + config = server.config assert isinstance(config, ServerConfig) config_lines = config.local_config.split('\n') @@ -101,8 +100,8 @@ def test_server_with_a_peer(): assert peer not in peer.peers assert peer in server.peers - server_config = server.config() - peer_config = peer.config() + server_config = server.config + peer_config = peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) @@ -142,7 +141,7 @@ def test_server_nat_traversal(): for line in server.post_down: assert 'eth1' in line - config = server.config().local_config + config = server.config.local_config assert 'PostUp' in config assert 'PostDown' in config assert 'iptables' in config @@ -165,8 +164,8 @@ def test_dns_in_server_and_peer(): 'test-peer', ) - server_config = server.config() - peer_config = peer.config() + server_config = server.config + peer_config = peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) @@ -276,8 +275,8 @@ def test_server_preshared_key(psk): assert server.preshared_key == psk assert peer.preshared_key == psk - server_config = server.config() - peer_config = peer.config() + server_config = server.config + peer_config = peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) @@ -314,9 +313,9 @@ def test_server_preshared_key_single_peer(): assert peer.preshared_key == psk assert no_psk_peer.preshared_key is None - server_config = server.config() - peer_config = peer.config() - no_psk_peer_config = no_psk_peer.config() + server_config = server.config + peer_config = peer.config + no_psk_peer_config = no_psk_peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) assert isinstance(no_psk_peer_config, Config) @@ -350,12 +349,12 @@ def test_server_mismatched_preshared_key(): ) with pytest.raises(ValueError) as exc: - server_config = server.config().local_config + server_config = server.config.local_config assert 'keys do not match' in str(exc.value) with pytest.raises(ValueError) as exc: - peer_config = peer.config().local_config + peer_config = peer.config.local_config assert 'keys do not match' in str(exc.value) @@ -381,8 +380,8 @@ def test_server_keepalive(keepalive): assert server.keepalive == keepalive assert peer.keepalive == keepalive - server_config = server.config() - peer_config = peer.config() + server_config = server.config + peer_config = peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) @@ -414,8 +413,8 @@ def test_server_keepalive_single_peer(): assert server.keepalive is None assert peer.keepalive == keepalive - server_config = server.config() - peer_config = peer.config() + server_config = server.config + peer_config = peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) @@ -449,8 +448,8 @@ def test_server_mismatched_keepalive(): assert server.keepalive == server_keepalive assert peer.keepalive == peer_keepalive - server_config = server.config() - peer_config = peer.config() + server_config = server.config + peer_config = peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) @@ -482,8 +481,8 @@ def test_server_mtu(): assert server.mtu == mtu assert peer.mtu == mtu - server_config = server.config() - peer_config = peer.config() + server_config = server.config + peer_config = peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) @@ -543,8 +542,8 @@ def test_server_table(table): assert not peer.table - server_config = server.config() - peer_config = peer.config() + server_config = server.config + peer_config = peer.config assert isinstance(server_config, ServerConfig) assert isinstance(peer_config, Config) diff --git a/wireguard/__init__.py b/wireguard/__init__.py index b7e6592..04a5823 100644 --- a/wireguard/__init__.py +++ b/wireguard/__init__.py @@ -15,3 +15,6 @@ from .server import ( Server, ) +from .service import ( + Interface, +) diff --git a/wireguard/cli/config.py b/wireguard/cli/config.py index cac0305..2d76ce9 100644 --- a/wireguard/cli/config.py +++ b/wireguard/cli/config.py @@ -56,18 +56,18 @@ def server(endpoint, if nat_traversal_interface: obj.add_nat_traversal(nat_traversal_interface) - click.echo(obj.config()) + click.echo(obj.config) if write: # pylint: disable=no-member - if os.path.isfile(obj.config().full_path): - if not click.prompt(f'{obj.config().full_path} exists! Overwrite? [y/N]'): + if os.path.isfile(obj.config.full_path): + if not click.prompt(f'{obj.config.full_path} exists! Overwrite? [y/N]'): raise click.Abort() - if os.path.isfile(obj.config().peers_full_path): - if not click.prompt(f'{obj.config().peers_full_path} exists! Overwrite? [y/N]'): + if os.path.isfile(obj.config.peers_full_path): + if not click.prompt(f'{obj.config.peers_full_path} exists! Overwrite? [y/N]'): raise click.Abort() - obj.config().write() + obj.config.write() @cli.command() @@ -123,14 +123,14 @@ def peer(name, interface=interface, ) - click.echo(obj.config()) + click.echo(obj.config) if write: - if os.path.isfile(obj.config().full_path): - if not click.prompt(f'{obj.config().full_path} exists! Overwrite? [y/N]'): + if os.path.isfile(obj.config.full_path): + if not click.prompt(f'{obj.config.full_path} exists! Overwrite? [y/N]'): raise click.Abort() - obj.config().write() + obj.config.write() if __name__ == "__main__": diff --git a/wireguard/config.py b/wireguard/config.py index 897449d..ae82fc2 100644 --- a/wireguard/config.py +++ b/wireguard/config.py @@ -271,22 +271,21 @@ def peers(self): peers_data = '' for peer in self._peer.peers: - peer_config = peer.config() - peers_data += peer_config.remote_config + peers_data += peer.config.remote_config extras = [] # Need to take special measures when the preshared keys aren't identical # And there is no need for an `else` clause, as the value would already have # been included by the `remote_config` returned data for normal cases - if self.preshared_key != peer_config.preshared_key: + if self.preshared_key != peer.config.preshared_key: # When only the remote peer has a key set, we need to use it too if self.preshared_key is None: - extras.append(peer_config.preshared_key) + extras.append(peer.config.preshared_key) # When only this peer has a key set, the remote peer needs to use it too - elif peer_config.preshared_key is None: + elif peer.config.preshared_key is None: extras.append(self.preshared_key) # The keys have both been set, but are not a match. diff --git a/wireguard/peer.py b/wireguard/peer.py index ae030ca..c0b7528 100644 --- a/wireguard/peer.py +++ b/wireguard/peer.py @@ -8,6 +8,13 @@ IPv6Address, ) +from .config import Config +from .constants import ( + INTERFACE, + KEEPALIVE_MINIMUM, + PORT, +) +from .service import Interface from .utils import ( generate_key, find_ip_and_subnet, @@ -17,12 +24,6 @@ IPNetworkSet, JSONEncoder, ) -from .config import Config -from .constants import ( - INTERFACE, - KEEPALIVE_MINIMUM, - PORT, -) class PeerSet(ClassedSet): @@ -164,8 +165,12 @@ class Peer: # pylint: disable=too-many-instance-attributes _table = None _config = None + _service = None peers = None + _config_cls = None + _service_cls = None + # pylint: disable=too-many-locals,too-many-branches,too-many-statements,too-many-arguments def __init__(self, description, @@ -186,10 +191,11 @@ def __init__(self, pre_down=None, post_down=None, interface=None, - peers=None, - config_cls=None, mtu=None, table=None, + peers=None, + config_cls=None, + service_cls=None, ): self.allowed_ips = IPNetworkSet() @@ -300,7 +306,8 @@ def __init__(self, else: self.peers.add(peers) - self.config(config_cls) + self.config_cls = config_cls + self.service_cls = service_cls def __repr__(self): """ @@ -661,23 +668,72 @@ def table(self, value): self._table = value - def config(self, config_cls=None): + @property + def config_cls(self): """ - Return the wireguard config file for this peer + Returns the config_cls value + """ + + if not self._config_cls: + self._config_cls = Config + + return self._config_cls + + @config_cls.setter + def config_cls(self, value): """ + Sets the config_cls value + """ + + if value is not None and not issubclass(value, Config): + raise ValueError('Provided value must be a subclass of Config') + + self._config_cls = value + + @property + def service_cls(self): + """ + Returns the service_cls value + """ + + if not self._service_cls: + self._service_cls = Interface + + return self._service_cls - if config_cls in [None, False]: - config_cls = Config + @service_cls.setter + def service_cls(self, value): + """ + Sets the service_cls value + """ - if self._config is not None and isinstance(self._config, config_cls): - return self._config + if value is not None and not issubclass(value, Interface): + raise ValueError('Provided value must be a subclass of Interface') - if not callable(config_cls): - raise ValueError('Invalid value given for config_cls') + self._service_cls = value + + @property + def config(self): + """ + Return the wireguard config file for this peer + """ + + if not isinstance(self._config, self.config_cls.__class__): + self._config = self.config_cls(self) - self._config = config_cls(self) return self._config + @property + def service(self): + """ + Returns the service interface for this peer + """ + + if not isinstance(self._service, self.service_cls.__class__): + self._service = self.service_cls(self.interface) + + return self._service + def add_nat_traversal(self, outbound_interface): """ Adds appropriate PostUp/PostDown rules when this peer is acting as diff --git a/wireguard/server.py b/wireguard/server.py index fa74a23..4a1f7fc 100644 --- a/wireguard/server.py +++ b/wireguard/server.py @@ -11,7 +11,6 @@ ) from .config import ServerConfig from .peer import Peer -from .service import Interface from .utils import generate_key, public_key, find_ip_and_subnet @@ -126,13 +125,6 @@ def __iter__(self): yield from {'subnet': subnets}.items() yield from super().__iter__() - @property - def service(self): - """ - Returns the service interface for this server - """ - return Interface(self.interface) - def pubkey_exists(self, item): """ Checks a public key against the public keys already used by this server and it's peers