Skip to content

Commit

Permalink
chore: bump tls lib to v4.6 (#228)
Browse files Browse the repository at this point in the history
Signed-off-by: guillaume <[email protected]>
  • Loading branch information
gruyaume authored Sep 4, 2024
1 parent d7b1acf commit 3bd8193
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 58 deletions.
161 changes: 104 additions & 57 deletions lib/charms/tls_certificates_interface/v4/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _get_config_locality_name(self) -> Optional[str]:

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 5
LIBPATCH = 6

PYDEPS = ["cryptography", "pydantic"]

Expand Down Expand Up @@ -353,9 +353,10 @@ class Certificate:

raw: str
common_name: str
sans_dns: Optional[FrozenSet[str]] = None
sans_ip: Optional[FrozenSet[str]] = None
sans_oid: Optional[FrozenSet[str]] = None
is_ca: bool = False
sans_dns: Optional[FrozenSet[str]] = frozenset()
sans_ip: Optional[FrozenSet[str]] = frozenset()
sans_oid: Optional[FrozenSet[str]] = frozenset()
email_address: Optional[str] = None
organization: Optional[str] = None
organizational_unit: Optional[str] = None
Expand Down Expand Up @@ -412,9 +413,18 @@ def from_string(cls, certificate: str) -> "Certificate":
sans_oid = []
expiry_time = certificate_object.not_valid_after_utc
validity_start_time = certificate_object.not_valid_before_utc
is_ca = False
try:
is_ca = certificate_object.extensions.get_extension_for_oid(
ExtensionOID.BASIC_CONSTRAINTS
).value.ca # type: ignore[reportAttributeAccessIssue]
except x509.ExtensionNotFound:
pass

return cls(
raw=certificate.strip(),
common_name=str(common_name[0].value),
is_ca=is_ca,
country_name=str(country_name[0].value) if country_name else None,
state_or_province_name=str(state_or_province_name[0].value)
if state_or_province_name
Expand Down Expand Up @@ -446,7 +456,6 @@ class CertificateSigningRequest:
country_name: Optional[str] = None
state_or_province_name: Optional[str] = None
locality_name: Optional[str] = None
is_ca: bool = False

def __eq__(self, other: object) -> bool:
"""Check if two CertificateSigningRequest objects are equal."""
Expand All @@ -458,22 +467,6 @@ def __str__(self) -> str:
"""Return the CSR as a string."""
return self.raw

def to_certificate_request(self) -> "CertificateRequest":
"""Convert to a CertificateRequest object."""
return CertificateRequest(
common_name=self.common_name,
sans_dns=self.sans_dns,
sans_ip=self.sans_ip,
sans_oid=self.sans_oid,
email_address=self.email_address,
organization=self.organization,
organizational_unit=self.organizational_unit,
country_name=self.country_name,
state_or_province_name=self.state_or_province_name,
locality_name=self.locality_name,
is_ca=self.is_ca,
)

@classmethod
def from_string(cls, csr: str) -> "CertificateSigningRequest":
"""Create a CertificateSigningRequest object from a CSR."""
Expand Down Expand Up @@ -511,8 +504,8 @@ def from_string(cls, csr: str) -> "CertificateSigningRequest":
organization=str(organization_name[0].value) if organization_name else None,
email_address=str(email_address[0].value) if email_address else None,
sans_dns=sans_dns,
sans_ip=sans_ip if sans_ip else None,
sans_oid=sans_oid if sans_oid else None,
sans_ip=sans_ip,
sans_oid=sans_oid,
)

def matches_private_key(self, key: PrivateKey) -> bool:
Expand Down Expand Up @@ -577,9 +570,9 @@ class CertificateRequest:
"""

common_name: str
sans_dns: Optional[FrozenSet[str]] = None
sans_ip: Optional[FrozenSet[str]] = None
sans_oid: Optional[FrozenSet[str]] = None
sans_dns: Optional[FrozenSet[str]] = frozenset()
sans_ip: Optional[FrozenSet[str]] = frozenset()
sans_oid: Optional[FrozenSet[str]] = frozenset()
email_address: Optional[str] = None
organization: Optional[str] = None
organizational_unit: Optional[str] = None
Expand Down Expand Up @@ -620,6 +613,23 @@ def generate_csr(
locality_name=self.locality_name,
)

@classmethod
def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool):
"""Create a CertificateRequest object from a CSR."""
return cls(
common_name=csr.common_name,
sans_dns=csr.sans_dns,
sans_ip=csr.sans_ip,
sans_oid=csr.sans_oid,
email_address=csr.email_address,
organization=csr.organization,
organizational_unit=csr.organizational_unit,
country_name=csr.country_name,
state_or_province_name=csr.state_or_province_name,
locality_name=csr.locality_name,
is_ca=is_ca,
)


@dataclass(frozen=True)
class ProviderCertificate:
Expand Down Expand Up @@ -656,6 +666,7 @@ class RequirerCSR:

relation_id: int
certificate_signing_request: CertificateSigningRequest
is_ca: bool


class CertificateAvailableEvent(EventBase):
Expand Down Expand Up @@ -761,7 +772,7 @@ def generate_private_key(
public_exponent: Public exponent.
Returns:
str: Private Key
PrivateKey: Private Key
"""
private_key = rsa.generate_private_key(
public_exponent=public_exponent,
Expand All @@ -778,9 +789,9 @@ def generate_private_key(
def generate_csr( # noqa: C901
private_key: PrivateKey,
common_name: str,
sans_dns: Optional[FrozenSet[str]] = None,
sans_ip: Optional[FrozenSet[str]] = None,
sans_oid: Optional[FrozenSet[str]] = None,
sans_dns: Optional[FrozenSet[str]] = frozenset(),
sans_ip: Optional[FrozenSet[str]] = frozenset(),
sans_oid: Optional[FrozenSet[str]] = frozenset(),
organization: Optional[str] = None,
organizational_unit: Optional[str] = None,
email_address: Optional[str] = None,
Expand Down Expand Up @@ -853,9 +864,9 @@ def generate_ca(
private_key: PrivateKey,
validity: int,
common_name: str,
sans_dns: Optional[FrozenSet[str]] = None,
sans_ip: Optional[FrozenSet[str]] = None,
sans_oid: Optional[FrozenSet[str]] = None,
sans_dns: Optional[FrozenSet[str]] = frozenset(),
sans_ip: Optional[FrozenSet[str]] = frozenset(),
sans_oid: Optional[FrozenSet[str]] = frozenset(),
organization: Optional[str] = None,
organizational_unit: Optional[str] = None,
email_address: Optional[str] = None,
Expand Down Expand Up @@ -1269,9 +1280,14 @@ def _private_key_generated(self) -> bool:
return False
return True

def _csr_matches_certificate_request(self, csr: CertificateSigningRequest) -> bool:
def _csr_matches_certificate_request(
self, certificate_signing_request: CertificateSigningRequest, is_ca: bool
) -> bool:
for certificate_request in self.certificate_requests:
if csr.to_certificate_request() == certificate_request:
if certificate_request == CertificateRequest.from_csr(
certificate_signing_request,
is_ca,
):
return True
return False

Expand All @@ -1281,19 +1297,23 @@ def _certificate_requested(self, certificate_request: CertificateRequest) -> boo
csr = self._certificate_requested_for_attributes(certificate_request)
if not csr:
return False
if not csr.matches_private_key(key=self.private_key):
if not csr.certificate_signing_request.matches_private_key(key=self.private_key):
return False
return True

def _certificate_requested_for_attributes(
self, certificate_request: CertificateRequest
) -> Optional[CertificateSigningRequest]:
self,
certificate_request: CertificateRequest,
) -> Optional[RequirerCSR]:
for requirer_csr in self.get_csrs_from_requirer_relation_data():
if requirer_csr.to_certificate_request() == certificate_request:
if certificate_request == CertificateRequest.from_csr(
requirer_csr.certificate_signing_request,
requirer_csr.is_ca,
):
return requirer_csr
return None

def get_csrs_from_requirer_relation_data(self) -> List[CertificateSigningRequest]:
def get_csrs_from_requirer_relation_data(self) -> List[RequirerCSR]:
"""Return list of requirer's CSRs from relation data."""
if self.mode == Mode.APP and not self.model.unit.is_leader():
logger.debug("Not a leader unit - Skipping")
Expand All @@ -1308,10 +1328,18 @@ def get_csrs_from_requirer_relation_data(self) -> List[CertificateSigningRequest
except DataValidationError:
logger.warning("Invalid relation data")
return []
return [
CertificateSigningRequest.from_string(csr.certificate_signing_request)
for csr in requirer_relation_data.certificate_signing_requests
]
requirer_csrs = []
for csr in requirer_relation_data.certificate_signing_requests:
requirer_csrs.append(
RequirerCSR(
relation_id=relation.id,
certificate_signing_request=CertificateSigningRequest.from_string(
csr.certificate_signing_request
),
is_ca=csr.ca if csr.ca else False,
)
)
return requirer_csrs

def get_provider_certificates(self) -> List[ProviderCertificate]:
"""Return list of certificates from the provider's relation data."""
Expand Down Expand Up @@ -1383,7 +1411,10 @@ def get_assigned_certificate(
) -> Tuple[ProviderCertificate | None, PrivateKey | None]:
"""Get the certificate that was assigned to the given certificate request."""
for requirer_csr in self.get_csrs_from_requirer_relation_data():
if certificate_request == requirer_csr.to_certificate_request():
if certificate_request == CertificateRequest.from_csr(
requirer_csr.certificate_signing_request,
requirer_csr.is_ca,
):
return self._find_certificate_in_relation_data(requirer_csr), self.private_key
return None, None

Expand All @@ -1396,11 +1427,14 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK
return assigned_certificates, self.private_key

def _find_certificate_in_relation_data(
self, csr: CertificateSigningRequest
self, csr: RequirerCSR
) -> Optional[ProviderCertificate]:
"""Return the certificate that match the given CSR."""
for provider_certificate in self.get_provider_certificates():
if provider_certificate.certificate_signing_request == csr:
if (
provider_certificate.certificate_signing_request == csr.certificate_signing_request
and provider_certificate.certificate.is_ca == csr.is_ca
):
return provider_certificate
return None

Expand All @@ -1412,9 +1446,10 @@ def _find_available_certificates(self):
If a certificate is revoked, the secret will be removed and an event will be emitted.
"""
requirer_csrs = self.get_csrs_from_requirer_relation_data()
csrs = [csr.certificate_signing_request for csr in requirer_csrs]
provider_certificates = self.get_provider_certificates()
for provider_certificate in provider_certificates:
if provider_certificate.certificate_signing_request in requirer_csrs:
if provider_certificate.certificate_signing_request in csrs:
secret_label = self._get_csr_secret_label(
provider_certificate.certificate_signing_request
)
Expand All @@ -1428,7 +1463,8 @@ def _find_available_certificates(self):
secret.remove_all_revisions()
else:
if not self._csr_matches_certificate_request(
provider_certificate.certificate_signing_request
certificate_signing_request=provider_certificate.certificate_signing_request,
is_ca=provider_certificate.certificate.is_ca,
):
logger.debug("Certificate requested for different attributes - Skipping")
continue
Expand Down Expand Up @@ -1470,17 +1506,27 @@ def _cleanup_certificate_requests(self):
- The CSR public key does not match the private key.
"""
for requirer_csr in self.get_csrs_from_requirer_relation_data():
if not self._csr_matches_certificate_request(requirer_csr):
self._remove_requirer_csr_from_relation_data(requirer_csr)
if not self._csr_matches_certificate_request(
certificate_signing_request=requirer_csr.certificate_signing_request,
is_ca=requirer_csr.is_ca,
):
self._remove_requirer_csr_from_relation_data(
requirer_csr.certificate_signing_request
)
logger.info(
"Removed CSR from relation data because \
it did not match any certificate request"
"Removed CSR from relation data because it did not match any certificate request" # noqa: E501
)
elif (
self.private_key
and not requirer_csr.certificate_signing_request.matches_private_key(
self.private_key
)
):
self._remove_requirer_csr_from_relation_data(
requirer_csr.certificate_signing_request
)
elif self.private_key and not requirer_csr.matches_private_key(self.private_key):
self._remove_requirer_csr_from_relation_data(requirer_csr)
logger.info(
"Removed CSR from relation data because \
it did not match the private key"
"Removed CSR from relation data because it did not match the private key" # noqa: E501
)

def _get_next_secret_expiry_time(
Expand Down Expand Up @@ -1617,6 +1663,7 @@ def _load_requirer_databag(
certificate_signing_request=CertificateSigningRequest.from_string(
csr.certificate_signing_request
),
is_ca=csr.ca if csr.ca else False,
)
for csr in requirer_relation_data.certificate_signing_requests
]
Expand Down
2 changes: 1 addition & 1 deletion src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _process_outstanding_certificate_requests(self) -> None:
for request in self.tls_certificates.get_outstanding_certificate_requests():
self._generate_self_signed_certificate(
csr=request.certificate_signing_request,
is_ca=request.certificate_signing_request.is_ca,
is_ca=request.is_ca,
relation_id=request.relation_id,
)

Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_charm_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def test_given_outstanding_certificate_requests_when_config_changed_then_certifi
RequirerCSR(
relation_id=tls_relation.relation_id,
certificate_signing_request=requirer_csr,
is_ca=False,
),
]
patch_generate_certificate.return_value = certificate
Expand Down

0 comments on commit 3bd8193

Please sign in to comment.