diff --git a/ircrobots/interface.py b/ircrobots/interface.py index f680f5f..db66353 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -6,6 +6,7 @@ from irctokens import Line, Hostmask from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy +from .security import TLS class ITCPReader(object): async def read(self, byte_count: int): @@ -24,11 +25,10 @@ async def close(self): class ITCPTransport(object): async def connect(self, - hostname: str, - port: int, - tls: bool, - tls_verify: bool=True, - bindhost: Optional[str]=None + hostname: str, + port: int, + tls: Optional[TLS], + bindhost: Optional[str]=None ) -> Tuple[ITCPReader, ITCPWriter]: pass diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index b26ab8f..359ca12 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -8,6 +8,7 @@ from .matching import Response, ANY from .interface import ICapability from .params import ConnectionParams, STSPolicy, ResumePolicy +from .security import TLS_VERIFYCHAIN class Capability(ICapability): def __init__(self, @@ -101,12 +102,12 @@ def _cap_dict(s: str) -> Dict[str, str]: return d async def sts_transmute(params: ConnectionParams): - if not params.sts is None and not params.tls: + if not params.sts is None and params.tls is None: now = time() since = (now-params.sts.created) if since <= params.sts.duration: params.port = params.sts.port - params.tls = True + params.tls = TLS_VERIFYCHAIN async def resume_transmute(params: ConnectionParams): if params.resume is not None: params.host = params.resume.address @@ -182,7 +183,7 @@ async def _sts(self, tokens: Dict[str, str]): if not params.tls: if "port" in sts_dict: params.port = int(sts_dict["port"]) - params.tls = True + params.tls = TLS_VERIFYCHAIN await self.server.bot.disconnect(self.server) await self.server.bot.add_server(self.server.name, params) diff --git a/ircrobots/params.py b/ircrobots/params.py index 9c8ea97..e52d6d6 100644 --- a/ircrobots/params.py +++ b/ircrobots/params.py @@ -1,6 +1,9 @@ +from re import compile as re_compile from typing import List, Optional from dataclasses import dataclass, field +from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN + class SASLParams(object): mechanism: str @@ -28,19 +31,24 @@ class ResumePolicy(object): address: str token: str +RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]") + +_TLS_TYPES = { + "+": TLS_VERIFYCHAIN, + "~": TLS_NOVERIFY +} @dataclass class ConnectionParams(object): nickname: str host: str port: int - tls: bool + tls: Optional[TLS] = TLS_VERIFYCHAIN username: Optional[str] = None realname: Optional[str] = None bindhost: Optional[str] = None password: Optional[str] = None - tls_verify: bool = True sasl: Optional[SASLParams] = None sts: Optional[STSPolicy] = None @@ -57,15 +65,19 @@ def from_hoststring( hoststring: str ) -> "ConnectionParams": - host, _, port_s = hoststring.strip().partition(":") + ipv6host = RE_IPV6HOST.search(hoststring) + if ipv6host is not None and ipv6host.start() == 0: + host = ipv6host.group(1) + port_s = hoststring[ipv6host.end()+1:] + else: + host, _, port_s = hoststring.strip().partition(":") - if port_s.startswith("+"): - tls = True - port_s = port_s.lstrip("+") or "6697" - elif not port_s: - tls = False + tls_type: Optional[TLS] = None + if not port_s: port_s = "6667" else: - tls = False + tls_type = _TLS_TYPES.get(port_s[0], None) + if tls_type is not None: + port_s = port_s[1:] or "6697" - return ConnectionParams(nickname, host, int(port_s), tls) + return ConnectionParams(nickname, host, int(port_s), tls_type) diff --git a/ircrobots/security.py b/ircrobots/security.py index 7b65236..f10b700 100644 --- a/ircrobots/security.py +++ b/ircrobots/security.py @@ -1,4 +1,28 @@ import ssl +class TLS: + pass + +# tls without verification +class TLSNoVerify(TLS): + pass +TLS_NOVERIFY = TLSNoVerify() + +# verify via CAs +class TLSVerifyChain(TLS): + pass +TLS_VERIFYCHAIN = TLSVerifyChain() + +# verify by a pinned hash +class TLSVerifyHash(TLSNoVerify): + def __init__(self, sum: str): + self.sum = sum.lower() +class TLSVerifySHA512(TLSVerifyHash): + pass + def tls_context(verify: bool=True) -> ssl.SSLContext: - return ssl.create_default_context() + ctx = ssl.create_default_context() + if not verify: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx diff --git a/ircrobots/server.py b/ircrobots/server.py index 39d2e04..d29946d 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -124,9 +124,8 @@ async def connect(self, reader, writer = await transport.connect( params.host, params.port, - tls =params.tls, - tls_verify=params.tls_verify, - bindhost =params.bindhost) + tls =params.tls, + bindhost =params.bindhost) self._reader = reader self._writer = writer diff --git a/ircrobots/transport.py b/ircrobots/transport.py index 291409c..a7cb330 100644 --- a/ircrobots/transport.py +++ b/ircrobots/transport.py @@ -1,10 +1,12 @@ +from hashlib import sha512 from ssl import SSLContext from typing import Optional, Tuple from asyncio import StreamReader, StreamWriter from async_stagger import open_connection from .interface import ITCPTransport, ITCPReader, ITCPWriter -from .security import tls_context +from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash, + TLSVerifySHA512) class TCPReader(ITCPReader): def __init__(self, reader: StreamReader): @@ -32,16 +34,15 @@ async def close(self): class TCPTransport(ITCPTransport): async def connect(self, - hostname: str, - port: int, - tls: bool, - tls_verify: bool=True, - bindhost: Optional[str]=None + hostname: str, + port: int, + tls: Optional[TLS], + bindhost: Optional[str]=None ) -> Tuple[ITCPReader, ITCPWriter]: cur_ssl: Optional[SSLContext] = None - if tls: - cur_ssl = tls_context(tls_verify) + if tls is not None: + cur_ssl = tls_context(not isinstance(tls, TLSNoVerify)) local_addr: Optional[Tuple[str, int]] = None if not bindhost is None: @@ -55,5 +56,20 @@ async def connect(self, server_hostname=server_hostname, ssl =cur_ssl, local_addr =local_addr) + + if isinstance(tls, TLSVerifyHash): + cert: bytes = writer.transport.get_extra_info( + "ssl_object" + ).getpeercert(True) + if isinstance(tls, TLSVerifySHA512): + sum = sha512(cert).hexdigest() + else: + raise ValueError(f"unknown hash pinning {type(tls)}") + + if not sum == tls.sum: + raise ValueError( + f"pinned hash for {hostname} does not match ({sum})" + ) + return (TCPReader(reader), TCPWriter(writer))