Skip to content

Commit

Permalink
Merge pull request #101 from rekby/devel
Browse files Browse the repository at this point in the history
Fix background renew
  • Loading branch information
rekby authored Dec 17, 2019
2 parents 22453ef + 2b36758 commit 264b098
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ addons:
- dos2unix

go:
- "1.12.4"
- "1.10"

services:
- docker
Expand Down
57 changes: 31 additions & 26 deletions internal/cert_manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (resultCert *tls.Ce
}

logger = logger.With(logDomain(needDomain))

defer handlePanic(logger)

logger.Info("Get certificate", zap.String("original_domain", hello.ServerName))
if isTLSALPN01Hello(hello) {
return m.handleTLSALPN(logger, ctx, needDomain)
Expand All @@ -128,7 +131,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (resultCert *tls.Ce

defer func() {
if isNeedRenew(resultCert, now) {
go m.renewCertInBackground(ctx, certName)
go m.renewCertInBackground(ctx, needDomain, certName)
}
}()

Expand All @@ -137,7 +140,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (resultCert *tls.Ce
if cert != nil {
logger.Debug("Got certificate from local state", log.Cert(cert))

cert, err = validCertDer([]DomainName{needDomain}, cert.Certificate, cert.PrivateKey, certState.GetUseAsIs(), now)
cert, err = validCertTLS(cert, []DomainName{needDomain}, certState.GetUseAsIs(), now)
logger.Debug("Validate certificate from local state", zap.Error(err))
if err == nil {
return cert, nil
Expand Down Expand Up @@ -197,6 +200,7 @@ func (m *Manager) issueNewCert(ctx context.Context, needDomain DomainName, certN
}
certIssueContext, cancelFunc := context.WithTimeout(ctx, m.CertificateIssueTimeout)
defer cancelFunc()

domains := domainNamesFromCertificateName(certName)
domains, err = filterDomains(ctx, m.DomainChecker, domains, needDomain)
log.DebugError(logger, err, "Filter domains", logDomains(domains))
Expand Down Expand Up @@ -287,17 +291,19 @@ func (m *Manager) createCertificateForDomains(ctx context.Context, certName cert
logger := zc.L(ctx).With(logDomains(domainNames))
certState := m.certStateGet(ctx, certName)

if certState.StartIssue(ctx) {
// outer func need for get argument values in defer time
defer func() {
certState.FinishIssue(ctx, res, err)
}()
} else {
if !certState.StartIssue(ctx) {
waitTimeout, waitTimeoutCancel := context.WithTimeout(ctx, m.CertificateIssueTimeout)
defer waitTimeoutCancel()

logger.Debug("Certificate issue in process already - wait result")
return certState.WaitFinishIssue(waitTimeout)
}
// outer func need for get argument values in defer time
defer func() {
certState.FinishIssue(ctx, res, err)
}()

logger.Debug("Start issue process")

order, err := m.createOrderForDomains(ctx, domainNames...)
log.DebugWarning(logger, err, "Domains authorized")
Expand All @@ -306,11 +312,7 @@ func (m *Manager) createCertificateForDomains(ctx context.Context, certName cert
}

res, err = m.issueCertificate(ctx, certName, order)
if err == nil {
logger.Debug("Certificate created.")
} else {
logger.Warn("Can't issue certificate", zap.Error(err))
}
log.DebugWarning(logger, err, "Issue certificate")
return res, err
}

Expand Down Expand Up @@ -447,10 +449,11 @@ func (m *Manager) issueCertificate(ctx context.Context, certName certNameType, o
logger := zc.L(ctx).With(logDomains(domains))

key, err := m.certKeyGetOrCreate(ctx, certName, keyRSA)
log.DebugError(logger, err, "Get cert key")
if err != nil {
logger.Error("Can't get domain key", zap.Error(err))
return nil, err
}

csr, err := createCertRequest(key, domains[0], domains...)
log.DebugDPanic(logger, err, "Create certificate request")
if err != nil {
Expand All @@ -468,6 +471,7 @@ func (m *Manager) issueCertificate(ctx context.Context, certName certNameType, o
if err != nil {
return nil, err
}

err = storeCertificate(ctx, m.Cache, certName, cert)
log.DebugDPanic(logger, err, "Certificate stored")
if err != nil {
Expand All @@ -482,24 +486,17 @@ func (m *Manager) issueCertificate(ctx context.Context, certName certNameType, o
return cert, nil
}

func (m *Manager) renewCertInBackground(ctx context.Context, certName certNameType) {
func (m *Manager) renewCertInBackground(ctx context.Context, needDomain DomainName, certName certNameType) {
// detach from request lifetime, but save log context
logger := zc.L(ctx).Named("background")
ctx, ctxCancel := context.WithTimeout(context.Background(), m.CertificateIssueTimeout)
defer ctxCancel()

ctx = zc.WithLogger(ctx, logger)
certState := m.certStateGet(ctx, certName)
logger.Debug("Start reissue certificate in background")

if !certState.StartIssue(ctx) {
// already has other cert issue process
return
}
domains := domainNamesFromCertificateName(certName)
logger.Info("StartAutoRenew background certificate issue")
cert, err := m.createCertificateForDomains(ctx, certName, domains)
certState.FinishIssue(ctx, cert, err)
log.InfoError(logger, err, "Renew certificate in background finished", log.Cert(cert))
ctx = zc.WithLogger(ctx, logger)
_, err := m.issueNewCert(ctx, needDomain, certName)
log.DebugError(logger, err, "Cert reissue in background finished")
}

func (m *Manager) deactivatePendingAuthz(ctx context.Context, uries []string) {
Expand Down Expand Up @@ -838,6 +835,14 @@ func isNeedRenew(cert *tls.Certificate, now time.Time) bool {
return cert.Leaf.NotAfter.Add(-time.Hour * 24 * 30).Before(now)
}

// must called as defer handlepanic(logger)
func handlePanic(logger *zap.Logger) {
err := recover()
if err != nil {
logger.DPanic("Panic handled", zap.Any("panic", err))
}
}

func isCertLocked(ctx context.Context, storage cache.Bytes, certName certNameType) (bool, error) {
lockName := certName.String() + ".lock"

Expand Down
42 changes: 42 additions & 0 deletions internal/cert_manager/manager_semi_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,48 @@ func TestManager_GetCertificateHttp01(t *testing.T) {
td.CmpDeeply(t, cert2, cert)
}
})

t.Run("RenewSoonExpiredCert", func(t *testing.T) {
const domain = "soon-expired.com"

// issue certificate
cert, err := manager.GetCertificate(&tls.ClientHelloInfo{ServerName: domain})
if err != nil {
t.Errorf("cant issue certificate: %v", err)
return
}
certNumber := cert.Leaf.SerialNumber
newExpire := time.Now().Add(time.Hour)
cert.Leaf.NotAfter = newExpire

// get expired soon certificate and trigger reissue new
cert, err = manager.GetCertificate(&tls.ClientHelloInfo{ServerName: domain})
if err != nil {
t.Errorf("cant issue certificate: %v", err)
return
}
if certNumber.Cmp(cert.Leaf.SerialNumber) != 0 {
t.Error("Got other sertificate, need same.")
}
if !cert.Leaf.NotAfter.Equal(newExpire) {
t.Errorf("Bad expire time: '%v' instead of '%v'", cert.Leaf.NotAfter, newExpire)
}

time.Sleep(time.Second * 10)

// get renewed cert
cert, err = manager.GetCertificate(&tls.ClientHelloInfo{ServerName: domain})
if err != nil {
t.Errorf("cant issue certificate: %v", err)
return
}
if certNumber.Cmp(cert.Leaf.SerialNumber) == 0 {
t.Error("Need new certificate")
}
if !cert.Leaf.NotAfter.After(newExpire) {
t.Errorf("Bad expire time: %v", cert.Leaf.NotAfter)
}
})
}

func TestManager_GetCertificateTls(t *testing.T) {
Expand Down

0 comments on commit 264b098

Please sign in to comment.