diff --git a/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/lib/charms/tls_certificates_interface/v4/tls_certificates.py index 62037a3..967ee03 100644 --- a/lib/charms/tls_certificates_interface/v4/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -185,7 +185,7 @@ def _get_config_locality_name(self) -> Optional[str]: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 5 +LIBPATCH = 6 PYDEPS = ["cryptography", "pydantic"] @@ -353,9 +353,10 @@ class Certificate: raw: str common_name: str - sans_dns: Optional[FrozenSet[str]] = None - sans_ip: Optional[FrozenSet[str]] = None - sans_oid: Optional[FrozenSet[str]] = None + is_ca: bool = False + sans_dns: Optional[FrozenSet[str]] = frozenset() + sans_ip: Optional[FrozenSet[str]] = frozenset() + sans_oid: Optional[FrozenSet[str]] = frozenset() email_address: Optional[str] = None organization: Optional[str] = None organizational_unit: Optional[str] = None @@ -412,9 +413,18 @@ def from_string(cls, certificate: str) -> "Certificate": sans_oid = [] expiry_time = certificate_object.not_valid_after_utc validity_start_time = certificate_object.not_valid_before_utc + is_ca = False + try: + is_ca = certificate_object.extensions.get_extension_for_oid( + ExtensionOID.BASIC_CONSTRAINTS + ).value.ca # type: ignore[reportAttributeAccessIssue] + except x509.ExtensionNotFound: + pass + return cls( raw=certificate.strip(), common_name=str(common_name[0].value), + is_ca=is_ca, country_name=str(country_name[0].value) if country_name else None, state_or_province_name=str(state_or_province_name[0].value) if state_or_province_name @@ -446,7 +456,6 @@ class CertificateSigningRequest: country_name: Optional[str] = None state_or_province_name: Optional[str] = None locality_name: Optional[str] = None - is_ca: bool = False def __eq__(self, other: object) -> bool: """Check if two CertificateSigningRequest objects are equal.""" @@ -458,22 +467,6 @@ def __str__(self) -> str: """Return the CSR as a string.""" return self.raw - def to_certificate_request(self) -> "CertificateRequest": - """Convert to a CertificateRequest object.""" - return CertificateRequest( - common_name=self.common_name, - sans_dns=self.sans_dns, - sans_ip=self.sans_ip, - sans_oid=self.sans_oid, - email_address=self.email_address, - organization=self.organization, - organizational_unit=self.organizational_unit, - country_name=self.country_name, - state_or_province_name=self.state_or_province_name, - locality_name=self.locality_name, - is_ca=self.is_ca, - ) - @classmethod def from_string(cls, csr: str) -> "CertificateSigningRequest": """Create a CertificateSigningRequest object from a CSR.""" @@ -511,8 +504,8 @@ def from_string(cls, csr: str) -> "CertificateSigningRequest": organization=str(organization_name[0].value) if organization_name else None, email_address=str(email_address[0].value) if email_address else None, sans_dns=sans_dns, - sans_ip=sans_ip if sans_ip else None, - sans_oid=sans_oid if sans_oid else None, + sans_ip=sans_ip, + sans_oid=sans_oid, ) def matches_private_key(self, key: PrivateKey) -> bool: @@ -577,9 +570,9 @@ class CertificateRequest: """ common_name: str - sans_dns: Optional[FrozenSet[str]] = None - sans_ip: Optional[FrozenSet[str]] = None - sans_oid: Optional[FrozenSet[str]] = None + sans_dns: Optional[FrozenSet[str]] = frozenset() + sans_ip: Optional[FrozenSet[str]] = frozenset() + sans_oid: Optional[FrozenSet[str]] = frozenset() email_address: Optional[str] = None organization: Optional[str] = None organizational_unit: Optional[str] = None @@ -620,6 +613,23 @@ def generate_csr( locality_name=self.locality_name, ) + @classmethod + def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool): + """Create a CertificateRequest object from a CSR.""" + return cls( + common_name=csr.common_name, + sans_dns=csr.sans_dns, + sans_ip=csr.sans_ip, + sans_oid=csr.sans_oid, + email_address=csr.email_address, + organization=csr.organization, + organizational_unit=csr.organizational_unit, + country_name=csr.country_name, + state_or_province_name=csr.state_or_province_name, + locality_name=csr.locality_name, + is_ca=is_ca, + ) + @dataclass(frozen=True) class ProviderCertificate: @@ -656,6 +666,7 @@ class RequirerCSR: relation_id: int certificate_signing_request: CertificateSigningRequest + is_ca: bool class CertificateAvailableEvent(EventBase): @@ -761,7 +772,7 @@ def generate_private_key( public_exponent: Public exponent. Returns: - str: Private Key + PrivateKey: Private Key """ private_key = rsa.generate_private_key( public_exponent=public_exponent, @@ -778,9 +789,9 @@ def generate_private_key( def generate_csr( # noqa: C901 private_key: PrivateKey, common_name: str, - sans_dns: Optional[FrozenSet[str]] = None, - sans_ip: Optional[FrozenSet[str]] = None, - sans_oid: Optional[FrozenSet[str]] = None, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), organization: Optional[str] = None, organizational_unit: Optional[str] = None, email_address: Optional[str] = None, @@ -853,9 +864,9 @@ def generate_ca( private_key: PrivateKey, validity: int, common_name: str, - sans_dns: Optional[FrozenSet[str]] = None, - sans_ip: Optional[FrozenSet[str]] = None, - sans_oid: Optional[FrozenSet[str]] = None, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), organization: Optional[str] = None, organizational_unit: Optional[str] = None, email_address: Optional[str] = None, @@ -1269,9 +1280,14 @@ def _private_key_generated(self) -> bool: return False return True - def _csr_matches_certificate_request(self, csr: CertificateSigningRequest) -> bool: + def _csr_matches_certificate_request( + self, certificate_signing_request: CertificateSigningRequest, is_ca: bool + ) -> bool: for certificate_request in self.certificate_requests: - if csr.to_certificate_request() == certificate_request: + if certificate_request == CertificateRequest.from_csr( + certificate_signing_request, + is_ca, + ): return True return False @@ -1281,19 +1297,23 @@ def _certificate_requested(self, certificate_request: CertificateRequest) -> boo csr = self._certificate_requested_for_attributes(certificate_request) if not csr: return False - if not csr.matches_private_key(key=self.private_key): + if not csr.certificate_signing_request.matches_private_key(key=self.private_key): return False return True def _certificate_requested_for_attributes( - self, certificate_request: CertificateRequest - ) -> Optional[CertificateSigningRequest]: + self, + certificate_request: CertificateRequest, + ) -> Optional[RequirerCSR]: for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if requirer_csr.to_certificate_request() == certificate_request: + if certificate_request == CertificateRequest.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): return requirer_csr return None - def get_csrs_from_requirer_relation_data(self) -> List[CertificateSigningRequest]: + def get_csrs_from_requirer_relation_data(self) -> List[RequirerCSR]: """Return list of requirer's CSRs from relation data.""" if self.mode == Mode.APP and not self.model.unit.is_leader(): logger.debug("Not a leader unit - Skipping") @@ -1308,10 +1328,18 @@ def get_csrs_from_requirer_relation_data(self) -> List[CertificateSigningRequest except DataValidationError: logger.warning("Invalid relation data") return [] - return [ - CertificateSigningRequest.from_string(csr.certificate_signing_request) - for csr in requirer_relation_data.certificate_signing_requests - ] + requirer_csrs = [] + for csr in requirer_relation_data.certificate_signing_requests: + requirer_csrs.append( + RequirerCSR( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + is_ca=csr.ca if csr.ca else False, + ) + ) + return requirer_csrs def get_provider_certificates(self) -> List[ProviderCertificate]: """Return list of certificates from the provider's relation data.""" @@ -1383,7 +1411,10 @@ def get_assigned_certificate( ) -> Tuple[ProviderCertificate | None, PrivateKey | None]: """Get the certificate that was assigned to the given certificate request.""" for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if certificate_request == requirer_csr.to_certificate_request(): + if certificate_request == CertificateRequest.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): return self._find_certificate_in_relation_data(requirer_csr), self.private_key return None, None @@ -1396,11 +1427,14 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK return assigned_certificates, self.private_key def _find_certificate_in_relation_data( - self, csr: CertificateSigningRequest + self, csr: RequirerCSR ) -> Optional[ProviderCertificate]: """Return the certificate that match the given CSR.""" for provider_certificate in self.get_provider_certificates(): - if provider_certificate.certificate_signing_request == csr: + if ( + provider_certificate.certificate_signing_request == csr.certificate_signing_request + and provider_certificate.certificate.is_ca == csr.is_ca + ): return provider_certificate return None @@ -1412,9 +1446,10 @@ def _find_available_certificates(self): If a certificate is revoked, the secret will be removed and an event will be emitted. """ requirer_csrs = self.get_csrs_from_requirer_relation_data() + csrs = [csr.certificate_signing_request for csr in requirer_csrs] provider_certificates = self.get_provider_certificates() for provider_certificate in provider_certificates: - if provider_certificate.certificate_signing_request in requirer_csrs: + if provider_certificate.certificate_signing_request in csrs: secret_label = self._get_csr_secret_label( provider_certificate.certificate_signing_request ) @@ -1428,7 +1463,8 @@ def _find_available_certificates(self): secret.remove_all_revisions() else: if not self._csr_matches_certificate_request( - provider_certificate.certificate_signing_request + certificate_signing_request=provider_certificate.certificate_signing_request, + is_ca=provider_certificate.certificate.is_ca, ): logger.debug("Certificate requested for different attributes - Skipping") continue @@ -1470,17 +1506,27 @@ def _cleanup_certificate_requests(self): - The CSR public key does not match the private key. """ for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if not self._csr_matches_certificate_request(requirer_csr): - self._remove_requirer_csr_from_relation_data(requirer_csr) + if not self._csr_matches_certificate_request( + certificate_signing_request=requirer_csr.certificate_signing_request, + is_ca=requirer_csr.is_ca, + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request + ) logger.info( - "Removed CSR from relation data because \ - it did not match any certificate request" + "Removed CSR from relation data because it did not match any certificate request" # noqa: E501 + ) + elif ( + self.private_key + and not requirer_csr.certificate_signing_request.matches_private_key( + self.private_key + ) + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request ) - elif self.private_key and not requirer_csr.matches_private_key(self.private_key): - self._remove_requirer_csr_from_relation_data(requirer_csr) logger.info( - "Removed CSR from relation data because \ - it did not match the private key" + "Removed CSR from relation data because it did not match the private key" # noqa: E501 ) def _get_next_secret_expiry_time( @@ -1617,6 +1663,7 @@ def _load_requirer_databag( certificate_signing_request=CertificateSigningRequest.from_string( csr.certificate_signing_request ), + is_ca=csr.ca if csr.ca else False, ) for csr in requirer_relation_data.certificate_signing_requests ] diff --git a/src/charm.py b/src/charm.py index 2117c93..9a7fe5c 100755 --- a/src/charm.py +++ b/src/charm.py @@ -231,7 +231,7 @@ def _process_outstanding_certificate_requests(self) -> None: for request in self.tls_certificates.get_outstanding_certificate_requests(): self._generate_self_signed_certificate( csr=request.certificate_signing_request, - is_ca=request.certificate_signing_request.is_ca, + is_ca=request.is_ca, relation_id=request.relation_id, ) diff --git a/tests/unit/test_charm_configure.py b/tests/unit/test_charm_configure.py index 616f307..0470a31 100644 --- a/tests/unit/test_charm_configure.py +++ b/tests/unit/test_charm_configure.py @@ -220,6 +220,7 @@ def test_given_outstanding_certificate_requests_when_config_changed_then_certifi RequirerCSR( relation_id=tls_relation.relation_id, certificate_signing_request=requirer_csr, + is_ca=False, ), ] patch_generate_certificate.return_value = certificate