Skip to content

Commit

Permalink
Merge pull request #16 from fictivekin/11-peer-removal
Browse files Browse the repository at this point in the history
Simplified peer removal
  • Loading branch information
jnhmcknight authored Nov 6, 2023
2 parents 2da0b4c + ce90f74 commit 6607ec4
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 0 deletions.
153 changes: 153 additions & 0 deletions tests/test_peer_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@


import pytest

from wireguard import (
Peer,
Server,
)
from wireguard.peer import (
PeerSet,
)


def test_peer_set_removes():

server = Server(
'server1',
subnet='192.168.0.1/24',
)

peer1 = server.peer('peer1')
peer2 = server.peer('peer2')
peer3 = server.peer('peer3')
peer4 = server.peer('peer4')
peer5 = server.peer('peer5')

with pytest.raises(KeyError):
server.peers.remove_by_description('peer6')

with pytest.raises(KeyError):
server.peers.remove_by_ip('10.10.10.10')

with pytest.raises(KeyError):
server.peers.remove_by_private_key('wBdB54t1rUBQ3mc0OvKdhzzaD9MvKGrshLQyHw5CN1A=')

with pytest.raises(KeyError):
server.peers.remove_by_public_key('m5Tp7TvZOQYUnfmxsRN9TsmfEi5jssWpyjs5X6OP9k8=')

server.peers.remove(peer3)
for peer in server.peers:
assert peer.description != 'peer3'
assert len(server.peers) == 4

server.peers.remove_by_description('peer2')
for peer in server.peers:
assert peer.description != 'peer2'
assert len(server.peers) == 3

server.peers.remove_by_ip(peer1.ipv4)
for peer in server.peers:
assert peer.description != 'peer1'
assert peer.ipv4 != peer1.ipv4
assert len(server.peers) == 2

server.peers.remove_by_private_key(peer4.private_key)
for peer in server.peers:
assert peer.description != 'peer4'
assert peer.private_key != peer4.private_key
assert len(server.peers) == 1

server.peers.remove_by_public_key(peer5.public_key)
for peer in server.peers:
assert peer.description != 'peer5'
assert peer.public_key != peer5.public_key
assert len(server.peers) == 0


def test_peer_set_discards():

server = Server(
'server2',
subnet='192.168.0.1/24',
)

peer1 = server.peer('peer1')
peer2 = server.peer('peer2')
peer3 = server.peer('peer3')
peer4 = server.peer('peer4')
peer5 = server.peer('peer5')

server.peers.discard_by_description('peer6')
server.peers.discard_by_ip('10.10.10.10')
server.peers.discard_by_private_key('wBdB54t1rUBQ3mc0OvKdhzzaD9MvKGrshLQyHw5CN1A=')
server.peers.discard_by_public_key('m5Tp7TvZOQYUnfmxsRN9TsmfEi5jssWpyjs5X6OP9k8=')

server.peers.discard(peer4)
for peer in server.peers:
assert peer.description != 'peer4'
assert len(server.peers) == 4

server.peers.discard_by_description('peer5')
for peer in server.peers:
assert peer.description != 'peer5'
assert len(server.peers) == 3

server.peers.discard_by_ip(peer2.ipv4)
for peer in server.peers:
assert peer.description != 'peer2'
assert peer.ipv4 != peer2.ipv4
assert len(server.peers) == 2

server.peers.discard_by_private_key(peer3.private_key)
for peer in server.peers:
assert peer.description != 'peer3'
assert peer.private_key != peer3.private_key
assert len(server.peers) == 1

server.peers.discard_by_public_key(peer1.public_key)
for peer in server.peers:
assert peer.description != 'peer1'
assert peer.public_key != peer1.public_key
assert len(server.peers) == 0


def test_peer_bidirectional_removal():

server = Server(
'server3',
subnet='192.168.0.1/24',
)

peer1 = server.peer('peer1')
peer2 = server.peer('peer2')
peer3 = server.peer('peer3')
peer4 = server.peer('peer4')
peer5 = server.peer('peer5')

assert len(server.peers) == 5

assert len(peer4.peers) == 1
server.remove_peer(peer4)
assert len(peer4.peers) == 0
assert len(server.peers) == 4

assert len(peer3.peers) == 1
server.remove_peer(peer3, bidirectional=False)
assert len(peer3.peers) == 1
assert len(server.peers) == 3

assert len(peer3.peers) == 1
peer3.remove_peer(server)
assert len(peer3.peers) == 0
assert len(server.peers) == 3 # was already removed in the previous block

assert len(peer5.peers) == 1
peer5.remove_peer(server)
assert len(peer5.peers) == 0
assert len(server.peers) == 2

assert len(peer1.peers) == 1
peer1.remove_peer(server, bidirectional=False)
assert len(peer1.peers) == 0
assert len(server.peers) == 2
104 changes: 104 additions & 0 deletions wireguard/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,95 @@ def _coerce_value(self, value):

raise ValueError('Provided value must be an instance of Peer')

def discard_by_description(self, description):
"""
Discard a peer by description
"""

try:
self.remove_by_description(description)
except KeyError:
pass

def remove_by_description(self, description):
"""
Remove a peer by description
"""

for peer in self:
if peer.description == description:
self.remove(peer)
return

raise KeyError(description)

def discard_by_ip(self, ip):
"""
Discard a peer by ip
"""

try:
self.remove_by_ip(ip)
except KeyError:
pass

def remove_by_ip(self, ip):
"""
Remove a peer by ip
"""

chk_ip = ip_address(ip)
for peer in self:
if chk_ip in [peer.ipv6, peer.ipv4]:
self.remove(peer)
return

raise KeyError(ip)

def discard_by_private_key(self, key):
"""
Discard a peer by private key
"""

try:
self.remove_by_private_key(key)
except KeyError:
pass

def remove_by_private_key(self, key):
"""
Remove a peer by private key
"""

for peer in self:
if peer.private_key and peer.private_key == key:
self.remove(peer)
return

raise KeyError(key)

def discard_by_public_key(self, key):
"""
Discard a peer by public key
"""

try:
self.remove_by_public_key(key)
except KeyError:
pass

def remove_by_public_key(self, key):
"""
Remove a peer by public key
"""

for peer in self:
if peer.public_key == key:
self.remove(peer)
return

raise KeyError(key)


class Peer: # pylint: disable=too-many-instance-attributes
"""
Expand Down Expand Up @@ -266,6 +355,21 @@ def json(self, **kwargs):

return json.dumps(self, **kwargs)

def remove_peer(self, peer, *, bidirectional=True):
"""
Removes the given peer from this peer
Default behaviour removes this peer from the given peer as well. Passing
`bidirectional=False` will only perform the removal on this peer, leaving
the given peer unchanged.
"""

# Since we don't care if the peer is already gone, we are using `.discard()`
# instead of `.remove()` here.
self.peers.discard(peer)
if bidirectional:
peer.peers.discard(self)

@property
def comments(self):
"""
Expand Down

0 comments on commit 6607ec4

Please sign in to comment.