From 47627db7b2f786cd816dae8f7e818e00f9b3d75f Mon Sep 17 00:00:00 2001 From: eric-forte-elastic Date: Wed, 25 Oct 2023 10:24:46 -0400 Subject: [PATCH] Cleanup ipv6 masking --- eql/functions.py | 86 ++++++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/eql/functions.py b/eql/functions.py index 11667af..3b256e8 100644 --- a/eql/functions.py +++ b/eql/functions.py @@ -282,21 +282,26 @@ def to_mask(cls, cidr_string): """Split an IP address plus cidr block to the mask.""" ip_string, size = cidr_string.split("/") size = int(size) - if cls.ipv4_compiled.match(ip_string): - ip_bytes = socket.inet_aton(ip_string) - (subnet_int,) = struct.unpack(">L", ip_bytes) - mask = cls.masks4[size] - elif cls.check_ipv6(ip_string): - ip_string = cls.expand_ipv6_address(ip_string) - ip_bytes = socket.inet_pton(socket.AF_INET6, ip_string) - # TODO remove x if possible - subnet_int, x = struct.unpack(">QQ", ip_bytes) - mask = cls.masks6[size] & cls.masks6[size] - else: - raise ValueError("Invalid IP address") + ip_bytes = socket.inet_aton(ip_string) + subnet_int, = struct.unpack(">L", ip_bytes) + + mask = cls.masks[size] return subnet_int & mask, mask + + @classmethod + def to_mask_ipv6(cls, cidr_string): + """Split an IP address plus cidr block to the mask.""" + ip_string, size = cidr_string.split("/") + size = int(size) + ip_string = cls.expand_ipv6_address(ip_string) + ip_bytes = socket.inet_pton(socket.AF_INET6, ip_string) + most, least = struct.unpack(">QQ", ip_bytes) + mask = cls.masks6[size] & cls.masks6[size] + + return most, least, mask + @classmethod def make_octet_re(cls, start, end): """Convert an octet-range into a regular expression.""" @@ -318,6 +323,20 @@ def make_octet_re(cls, start, end): return "(?:{})".format("|".join(combos)) + @classmethod + def make_hex_re(cls, start, end): + """Create a regex pattern for a range of hexadecimal values.""" + if start == end: + return "{:04x}".format(start) + + if start == 0 and end == 65535: + return r"[0-9a-fA-F]{1,4}" + + if start & 0xff == 0 and end & 0xff == 0xff: + return "{:02x}{:02x}".format(start >> 8, end >> 8) + + return "{:02x}{:02x}-{:02x}{:02x}".format(start >> 8, start & 0xff, end >> 8, end & 0xff) + @classmethod def make_cidr_regex(cls, cidr): """Convert a list of wildcards strings for matching a cidr.""" @@ -325,16 +344,8 @@ def make_cidr_regex(cls, cidr): min_octets, max_octets = cls.to_range(cidr) return r"\.".join(cls.make_octet_re(*pair) for pair in zip(min_octets, max_octets)) elif cls.check_ipv6_cidr(cidr): - cidr = cls.expand_ipv6(cidr) - h16s, prefix_len = cidr.split("/") - h16s = h16s.split(":") - prefix_len = int(prefix_len) - if len(h16s) < 8: - h16s += [""] * (8 - len(h16s)) - h16s = [int(h, 16) if h else 0 for h in h16s] - min_h16s = [h & (0xFFFF << (16 - prefix_len)) for h in h16s] - max_h16s = [h | (0xFFFF >> prefix_len) for h in min_h16s] - return ":".join(cls.make_octet_re(*pair) for pair in zip(min_h16s, max_h16s)) + min_octets, max_octets = cls.to_range(cidr) + return ":".join(cls.make_hex_re(*pair) for pair in zip(min_octets, max_octets)) else: raise ValueError("Invalid CIDR notation") @@ -349,22 +360,19 @@ def to_range(cls, cidr): max_octets = struct.unpack("BBBB", struct.pack(">L", max_ip_integer)) elif cls.check_ipv6_cidr(cidr): cidr = cls.expand_ipv6(cidr) - ip_integer, mask = cls.to_mask(cidr) - max_ip_integer = ip_integer | (MAX_IPV6 ^ mask) + most, least, mask = cls.to_mask_ipv6(cidr) + + subnet_ints = struct.unpack(">8H", struct.pack(">QQ", most & 0xFFFFFFFFFFFFFFFF, least >> 64)) - # Convert the subnet integer to a tuple of 8 16-bit integers - subnet_ints = struct.unpack( - ">8H", struct.pack(">QQ", (ip_integer >> 64), (ip_integer & 0xFFFFFFFFFFFFFFFF)) - ) # Convert the mask to a tuple of 8 16-bit integers mask_ints = struct.unpack(">8H", struct.pack(">QQ", (mask >> 64), (mask & 0xFFFFFFFFFFFFFFFF))) # Apply the mask to the subnet integer to get the network address network_ints = tuple([subnet_ints[i] & mask_ints[i] for i in range(8)]) # Calculate the maximum IP integer max_ip_ints = tuple([network_ints[i] | (0xFFFF ^ mask_ints[i]) for i in range(8)]) - # Convert the network and maximum IP integers to a tuple of 4 32-bit integers - min_octets = struct.unpack(">BBBB", struct.pack(">8H", *network_ints)) - max_octets = struct.unpack(">BBBB", struct.pack(">8H", *max_ip_ints)) + # Convert the network and maximum IP integers to a tuple of 8 16-bit integers + min_octets = struct.unpack(">8H", struct.pack(">8H", *network_ints)) + max_octets = struct.unpack(">8H", struct.pack(">8H", *max_ip_ints)) else: raise ValueError("Invalid CIDR notation") @@ -379,13 +387,12 @@ def get_callback(cls, _, *cidr_matches): if cls.cidrv4_compiled.match(cidr.value): ipv4_masks.append(cls.to_mask(cidr.value)) elif cls.check_ipv6_cidr(cidr.value): - ipv6_masks.append(cls.to_mask(cidr.value)) + ipv6_masks.append(cls.to_mask_ipv6(cidr.value)) def callback(source, *_): if is_string(source) and ( cls.ipv4_compiled.match(source) or - cls.ipv6_compiled.match(source) or - cls.ipv6_shorthand_compiled.match(source) + cls.check_ipv6(source) ): if cls.ipv4_compiled.match(source): ip_integer, _ = cls.to_mask(source + "/32") @@ -393,9 +400,9 @@ def callback(source, *_): if ip_integer & mask == subnet: return True elif cls.check_ipv6(source): - ip_integer, _ = cls.to_mask(source + "/128") - for subnet, mask in ipv6_masks: - if ip_integer & mask == subnet: + most, least, _ = cls.to_mask_ipv6(source + "/128") + for subnet_most, subnet_least, mask in ipv6_masks: + if most & mask == subnet_most and least & mask == subnet_least: return True return False @@ -451,12 +458,13 @@ def validate(cls, arguments): # Since it does match, we should also rewrite the string to align to the base of the subnet ip_address, size = text.split("/") - subnet_integer, _ = cls.to_mask(text) if cls.cidrv4_compiled.match(argument.node.value): + subnet_integer, _ = cls.to_mask(text) subnet_bytes = struct.pack(">L", subnet_integer) subnet_base = socket.inet_ntoa(subnet_bytes) elif cls.check_ipv6_cidr(argument.node.value): - subnet_bytes = struct.pack(">QQ", subnet_integer >> 64, subnet_integer & 0xFFFFFFFFFFFFFFFF) + most, least, _ = cls.to_mask_ipv6(text) + subnet_bytes = struct.pack(">QQ", most & 0xFFFFFFFFFFFFFFFF, least >> 64) subnet_base = socket.inet_ntop(socket.AF_INET6, subnet_bytes) # overwrite the original argument so it becomes the subnet