diff --git a/README.md b/README.md index 527acb9..a41c58b 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,15 @@ applications. To enable this, add the `--tls` flag when deploying an instance: kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls +### Custom TLS certificate + +When you obtained your TLS certificate manually, manage your own certificate authority, +or need to install Cloudflare origin certificate, you can manually specify path to +your certificate file and the corresponding private key: + + kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls --tls-certificate-path cert.pem --tls-private-key-path key.pem + + ## Specifying `run` options with environment variables In some environments, like when running a Docker container, it can be convenient diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index bb23f56..82ae5c4 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -31,6 +31,8 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)") deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning") + deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSCertificatePath, "tls-certificate-path", "", "Configure custom TLS certificate path (PEM format)") + deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSPrivateKeyPath, "tls-private-key-path", "", "Configure custom TLS private key path (PEM format)") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target") @@ -53,6 +55,7 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().BoolVar(&deployCommand.args.TargetOptions.ForwardHeaders, "forward-headers", false, "Forward X-Forwarded headers to target (default false if TLS enabled; otherwise true)") deployCommand.cmd.MarkFlagRequired("target") + deployCommand.cmd.MarkFlagsRequiredTogether("tls-certificate-path", "tls-private-key-path") return deployCommand } diff --git a/internal/server/cert.go b/internal/server/cert.go new file mode 100644 index 0000000..493a5ff --- /dev/null +++ b/internal/server/cert.go @@ -0,0 +1,61 @@ +package server + +import ( + "crypto/tls" + "log/slog" + "net/http" + "sync" +) + +type CertManager interface { + GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) + HTTPHandler(handler http.Handler) http.Handler +} + +// StaticCertManager is a certificate manager that loads certificates from disk. +type StaticCertManager struct { + tlsCertificateFilePath string + tlsPrivateKeyFilePath string + cert *tls.Certificate + lock sync.RWMutex +} + +func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) *StaticCertManager { + return &StaticCertManager{ + tlsCertificateFilePath: tlsCertificateFilePath, + tlsPrivateKeyFilePath: tlsPrivateKeyFilePath, + } +} + +func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + m.lock.RLock() + if m.cert != nil { + defer m.lock.RUnlock() + return m.cert, nil + } + m.lock.RUnlock() + + m.lock.Lock() + defer m.lock.Unlock() + if m.cert != nil { // Double-check locking + return m.cert, nil + } + + slog.Info( + "Loading custom TLS certificate", + "tls-certificate-path", m.tlsCertificateFilePath, + "tls-private-key-path", m.tlsPrivateKeyFilePath, + ) + + cert, err := tls.LoadX509KeyPair(m.tlsCertificateFilePath, m.tlsPrivateKeyFilePath) + if err != nil { + return nil, err + } + m.cert = &cert + + return m.cert, nil +} + +func (m *StaticCertManager) HTTPHandler(handler http.Handler) http.Handler { + return handler +} diff --git a/internal/server/cert_test.go b/internal/server/cert_test.go new file mode 100644 index 0000000..96ed045 --- /dev/null +++ b/internal/server/cert_test.go @@ -0,0 +1,105 @@ +package server + +import ( + "crypto/tls" + "os" + "path" + "testing" + + "github.com/stretchr/testify/require" +) + +const certPem = `-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----` + +const keyPem = `-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----` + +func TestCertificateLoading(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles(t) + require.NoError(t, err) + + manager := NewStaticCertManager(certPath, keyPath) + cert, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert) +} + +func TestCertificateLoadingRaceCondition(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles(t) + require.NoError(t, err) + + manager := NewStaticCertManager(certPath, keyPath) + go func() { + _, err2 := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err2) + }() + cert, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert) +} + +func TestCachesLoadedCertificate(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles(t) + require.NoError(t, err) + + manager := NewStaticCertManager(certPath, keyPath) + cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert1) + + require.Nil(t, os.Remove(certPath)) + require.Nil(t, os.Remove(keyPath)) + + cert2, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.Equal(t, cert1, cert2) +} + +func TestErrorWhenFileDoesNotExist(t *testing.T) { + manager := NewStaticCertManager("testdata/cert.pem", "testdata/key.pem") + cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.ErrorContains(t, err, "no such file or directory") + require.Nil(t, cert1) +} + +func TestErrorWhenKeyFormatIsInvalid(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles(t) + require.NoError(t, err) + + manager := NewStaticCertManager(keyPath, certPath) + cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.ErrorContains(t, err, "failed to find certificate PEM data in certificate input") + require.Nil(t, cert1) +} + +func prepareTestCertificateFiles(t *testing.T) (string, string, error) { + t.Helper() + + dir := t.TempDir() + certFile := path.Join(dir, "example-cert.pem") + keyFile := path.Join(dir, "example-key.pem") + + err := os.WriteFile(certFile, []byte(certPem), 0644) + if err != nil { + return "", "", err + } + + err = os.WriteFile(keyFile, []byte(keyPem), 0644) + if err != nil { + return "", "", err + } + + return certFile, keyFile, nil +} diff --git a/internal/server/service.go b/internal/server/service.go index 1c96de0..fe32427 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -60,10 +60,12 @@ type HealthCheckConfig struct { } type ServiceOptions struct { - TLSEnabled bool `json:"tls_enabled"` - ACMEDirectory string `json:"acme_directory"` - ACMECachePath string `json:"acme_cache_path"` - ErrorPagePath string `json:"error_page_path"` + TLSEnabled bool `json:"tls_enabled"` + TLSCertificatePath string `json:"tls_certificate_path"` + TLSPrivateKeyPath string `json:"tls_private_key_path"` + ACMEDirectory string `json:"acme_directory"` + ACMECachePath string `json:"acme_cache_path"` + ErrorPagePath string `json:"error_page_path"` } func (so ServiceOptions) ScopedCachePath() string { @@ -90,7 +92,7 @@ type Service struct { pauseController *PauseController rolloutController *RolloutController - certManager *autocert.Manager + certManager CertManager middleware http.Handler } @@ -284,11 +286,15 @@ func (s *Service) initialize() { s.middleware = s.createMiddleware() } -func (s *Service) createCertManager() *autocert.Manager { +func (s *Service) createCertManager() CertManager { if !s.options.TLSEnabled { return nil } + if s.options.TLSCertificatePath != "" && s.options.TLSPrivateKeyPath != "" { + return NewStaticCertManager(s.options.TLSCertificatePath, s.options.TLSPrivateKeyPath) + } + return &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(s.options.ScopedCachePath()), diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 8ae8b0c..ae82027 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -42,6 +42,21 @@ func TestService_RedirectToHTTPWhenTLSRequired(t *testing.T) { require.Equal(t, http.StatusOK, w.Result().StatusCode) } +func TestService_UseStaticTLSCertificateWhenConfigured(t *testing.T) { + service := testCreateService( + t, + []string{"example.com"}, + ServiceOptions{ + TLSEnabled: true, + TLSCertificatePath: "cert.pem", + TLSPrivateKeyPath: "key.pem", + }, + defaultTargetOptions, + ) + + require.IsType(t, &StaticCertManager{}, service.certManager) +} + func TestService_RejectTLSRequestsWhenNotConfigured(t *testing.T) { service := testCreateService(t, defaultEmptyHosts, defaultServiceOptions, defaultTargetOptions)