Skip to content

Commit

Permalink
Cleanup ipv6 masking
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-forte-elastic committed Oct 25, 2023
1 parent 6a7772d commit 47627db
Showing 1 changed file with 47 additions and 39 deletions.
86 changes: 47 additions & 39 deletions eql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,21 +282,26 @@ def to_mask(cls, cidr_string):
"""Split an IP address plus cidr block to the mask."""
ip_string, size = cidr_string.split("/")
size = int(size)
if cls.ipv4_compiled.match(ip_string):
ip_bytes = socket.inet_aton(ip_string)
(subnet_int,) = struct.unpack(">L", ip_bytes)
mask = cls.masks4[size]
elif cls.check_ipv6(ip_string):
ip_string = cls.expand_ipv6_address(ip_string)
ip_bytes = socket.inet_pton(socket.AF_INET6, ip_string)
# TODO remove x if possible
subnet_int, x = struct.unpack(">QQ", ip_bytes)
mask = cls.masks6[size] & cls.masks6[size]
else:
raise ValueError("Invalid IP address")
ip_bytes = socket.inet_aton(ip_string)
subnet_int, = struct.unpack(">L", ip_bytes)

mask = cls.masks[size]

return subnet_int & mask, mask


@classmethod
def to_mask_ipv6(cls, cidr_string):
"""Split an IP address plus cidr block to the mask."""
ip_string, size = cidr_string.split("/")
size = int(size)
ip_string = cls.expand_ipv6_address(ip_string)
ip_bytes = socket.inet_pton(socket.AF_INET6, ip_string)
most, least = struct.unpack(">QQ", ip_bytes)
mask = cls.masks6[size] & cls.masks6[size]

return most, least, mask

@classmethod
def make_octet_re(cls, start, end):
"""Convert an octet-range into a regular expression."""
Expand All @@ -318,23 +323,29 @@ def make_octet_re(cls, start, end):

return "(?:{})".format("|".join(combos))

@classmethod
def make_hex_re(cls, start, end):
"""Create a regex pattern for a range of hexadecimal values."""
if start == end:
return "{:04x}".format(start)

if start == 0 and end == 65535:
return r"[0-9a-fA-F]{1,4}"

if start & 0xff == 0 and end & 0xff == 0xff:
return "{:02x}{:02x}".format(start >> 8, end >> 8)

return "{:02x}{:02x}-{:02x}{:02x}".format(start >> 8, start & 0xff, end >> 8, end & 0xff)

@classmethod
def make_cidr_regex(cls, cidr):
"""Convert a list of wildcards strings for matching a cidr."""
if cls.cidrv4_compiled.match(cidr):
min_octets, max_octets = cls.to_range(cidr)
return r"\.".join(cls.make_octet_re(*pair) for pair in zip(min_octets, max_octets))
elif cls.check_ipv6_cidr(cidr):
cidr = cls.expand_ipv6(cidr)
h16s, prefix_len = cidr.split("/")
h16s = h16s.split(":")
prefix_len = int(prefix_len)
if len(h16s) < 8:
h16s += [""] * (8 - len(h16s))
h16s = [int(h, 16) if h else 0 for h in h16s]
min_h16s = [h & (0xFFFF << (16 - prefix_len)) for h in h16s]
max_h16s = [h | (0xFFFF >> prefix_len) for h in min_h16s]
return ":".join(cls.make_octet_re(*pair) for pair in zip(min_h16s, max_h16s))
min_octets, max_octets = cls.to_range(cidr)
return ":".join(cls.make_hex_re(*pair) for pair in zip(min_octets, max_octets))
else:
raise ValueError("Invalid CIDR notation")

Expand All @@ -349,22 +360,19 @@ def to_range(cls, cidr):
max_octets = struct.unpack("BBBB", struct.pack(">L", max_ip_integer))
elif cls.check_ipv6_cidr(cidr):
cidr = cls.expand_ipv6(cidr)
ip_integer, mask = cls.to_mask(cidr)
max_ip_integer = ip_integer | (MAX_IPV6 ^ mask)
most, least, mask = cls.to_mask_ipv6(cidr)

subnet_ints = struct.unpack(">8H", struct.pack(">QQ", most & 0xFFFFFFFFFFFFFFFF, least >> 64))

# Convert the subnet integer to a tuple of 8 16-bit integers
subnet_ints = struct.unpack(
">8H", struct.pack(">QQ", (ip_integer >> 64), (ip_integer & 0xFFFFFFFFFFFFFFFF))
)
# Convert the mask to a tuple of 8 16-bit integers
mask_ints = struct.unpack(">8H", struct.pack(">QQ", (mask >> 64), (mask & 0xFFFFFFFFFFFFFFFF)))
# Apply the mask to the subnet integer to get the network address
network_ints = tuple([subnet_ints[i] & mask_ints[i] for i in range(8)])
# Calculate the maximum IP integer
max_ip_ints = tuple([network_ints[i] | (0xFFFF ^ mask_ints[i]) for i in range(8)])
# Convert the network and maximum IP integers to a tuple of 4 32-bit integers
min_octets = struct.unpack(">BBBB", struct.pack(">8H", *network_ints))
max_octets = struct.unpack(">BBBB", struct.pack(">8H", *max_ip_ints))
# Convert the network and maximum IP integers to a tuple of 8 16-bit integers
min_octets = struct.unpack(">8H", struct.pack(">8H", *network_ints))
max_octets = struct.unpack(">8H", struct.pack(">8H", *max_ip_ints))
else:
raise ValueError("Invalid CIDR notation")

Expand All @@ -379,23 +387,22 @@ def get_callback(cls, _, *cidr_matches):
if cls.cidrv4_compiled.match(cidr.value):
ipv4_masks.append(cls.to_mask(cidr.value))
elif cls.check_ipv6_cidr(cidr.value):
ipv6_masks.append(cls.to_mask(cidr.value))
ipv6_masks.append(cls.to_mask_ipv6(cidr.value))

def callback(source, *_):
if is_string(source) and (
cls.ipv4_compiled.match(source) or
cls.ipv6_compiled.match(source) or
cls.ipv6_shorthand_compiled.match(source)
cls.check_ipv6(source)
):
if cls.ipv4_compiled.match(source):
ip_integer, _ = cls.to_mask(source + "/32")
for subnet, mask in ipv4_masks:
if ip_integer & mask == subnet:
return True
elif cls.check_ipv6(source):
ip_integer, _ = cls.to_mask(source + "/128")
for subnet, mask in ipv6_masks:
if ip_integer & mask == subnet:
most, least, _ = cls.to_mask_ipv6(source + "/128")
for subnet_most, subnet_least, mask in ipv6_masks:
if most & mask == subnet_most and least & mask == subnet_least:
return True

return False
Expand Down Expand Up @@ -451,12 +458,13 @@ def validate(cls, arguments):

# Since it does match, we should also rewrite the string to align to the base of the subnet
ip_address, size = text.split("/")
subnet_integer, _ = cls.to_mask(text)
if cls.cidrv4_compiled.match(argument.node.value):
subnet_integer, _ = cls.to_mask(text)
subnet_bytes = struct.pack(">L", subnet_integer)
subnet_base = socket.inet_ntoa(subnet_bytes)
elif cls.check_ipv6_cidr(argument.node.value):
subnet_bytes = struct.pack(">QQ", subnet_integer >> 64, subnet_integer & 0xFFFFFFFFFFFFFFFF)
most, least, _ = cls.to_mask_ipv6(text)
subnet_bytes = struct.pack(">QQ", most & 0xFFFFFFFFFFFFFFFF, least >> 64)
subnet_base = socket.inet_ntop(socket.AF_INET6, subnet_bytes)

# overwrite the original argument so it becomes the subnet
Expand Down

0 comments on commit 47627db

Please sign in to comment.