diff --git a/kurl_proxy/cmd/main.go b/kurl_proxy/cmd/main.go index 48888e3af9..4298490875 100644 --- a/kurl_proxy/cmd/main.go +++ b/kurl_proxy/cmd/main.go @@ -118,12 +118,14 @@ func main() { m := cmux.New(listener) - httpsServer = getHttpsServer(upstream, dexUpstream, tlsSecretName, secrets, cert.acceptAnonymousUploads) + assetsDir := "/assets" + + httpsServer = getHttpsServer(upstream, dexUpstream, tlsSecretName, secrets, cert.acceptAnonymousUploads, assetsDir) tlsConfig := tlsconfig.ServerDefault() tlsConfig.Certificates = []tls.Certificate{cert.tlsCert} go httpsServer.Serve(tls.NewListener(m.Match(cmux.TLS()), tlsConfig)) - httpServer = getHttpServer(cert.fingerprint, cert.acceptAnonymousUploads) + httpServer = getHttpServer(cert.fingerprint, cert.acceptAnonymousUploads, assetsDir) go httpServer.Serve(m.Match(cmux.Any())) log.Printf("Kurl Proxy listening on :%s\n", nodePort) @@ -228,11 +230,13 @@ func getFingerprint(certData []byte) (string, error) { return strings.ToUpper(strings.Replace(fmt.Sprintf("% x", sha1.Sum(x509Cert.Raw)), " ", ":", -1)), nil } -func getHttpServer(fingerprint string, acceptAnonymousUploads bool) *http.Server { +func getHttpServer(fingerprint string, acceptAnonymousUploads bool, assetsDir string) *http.Server { r := gin.Default() - r.StaticFS("/assets", http.Dir("/assets")) - r.LoadHTMLGlob("/assets/*.html") + r.Use(CSPMiddleware) + + r.StaticFS("/assets", http.Dir(assetsDir)) + r.LoadHTMLGlob(fmt.Sprintf("%s/*.html", assetsDir)) r.GET("/", func(c *gin.Context) { if !acceptAnonymousUploads { @@ -275,13 +279,13 @@ func getHttpServer(fingerprint string, acceptAnonymousUploads bool) *http.Server } } -func getHttpsServer(upstream, dexUpstream *url.URL, tlsSecretName string, secrets corev1.SecretInterface, acceptAnonymousUploads bool) *http.Server { - mux := http.NewServeMux() - +func getHttpsServer(upstream, dexUpstream *url.URL, tlsSecretName string, secrets corev1.SecretInterface, acceptAnonymousUploads bool, assetsDir string) *http.Server { r := gin.Default() - mux.Handle("/tls/assets/", http.StripPrefix("/tls/assets/", http.FileServer(http.Dir("/assets")))) - r.LoadHTMLGlob("/assets/*.html") + r.Use(CSPMiddleware) + + r.StaticFS("/tls/assets", http.Dir(assetsDir)) + r.LoadHTMLGlob(fmt.Sprintf("%s/*.html", assetsDir)) r.GET("/tls", func(c *gin.Context) { if !acceptAnonymousUploads { @@ -430,26 +434,25 @@ func getHttpsServer(upstream, dexUpstream *url.URL, tlsSecretName string, secret } }() }) - mux.Handle("/tls", r) - mux.Handle("/tls/", r) - - // mux.Handle("/api/v1/kots/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // log.Println("Kots REST API not proxied.") - // http.Error(w, "Not found", http.StatusNotFound) - // })) if dexUpstream != nil { - dexReverseProxy := httputil.NewSingleHostReverseProxy(dexUpstream) - mux.Handle("/dex", dexReverseProxy) - mux.Handle("/dex/", dexReverseProxy) + r.Any("/dex/*path", gin.WrapH(httputil.NewSingleHostReverseProxy(dexUpstream))) } - mux.Handle("/", httputil.NewSingleHostReverseProxy(upstream)) + + r.NoRoute(gin.WrapH(httputil.NewSingleHostReverseProxy(upstream))) return &http.Server{ - Handler: mux, + Handler: r, } } +// CSPMiddleware adds Content-Security-Policy and X-Frame-Options headers to the response. +func CSPMiddleware(c *gin.Context) { + c.Writer.Header().Set("Content-Security-Policy", "frame-ancestors 'none';") + c.Writer.Header().Set("X-Frame-Options", "DENY") + c.Next() +} + func getUploadedCerts(c *gin.Context) ([]byte, []byte, error) { certHeader, err := c.FormFile("cert") if err != nil { diff --git a/kurl_proxy/cmd/main_test.go b/kurl_proxy/cmd/main_test.go index e86b6f6248..254f7bde51 100644 --- a/kurl_proxy/cmd/main_test.go +++ b/kurl_proxy/cmd/main_test.go @@ -2,6 +2,11 @@ package main import ( "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" "testing" ) @@ -143,3 +148,71 @@ func Test_getFingerprint(t *testing.T) { }) } } + +func Test_httpServerCSPHeaders(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "test-assets") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + indexFile := filepath.Join(tmpDir, "index.html") + if err := os.WriteFile(indexFile, []byte("hello world"), 0644); err != nil { + t.Fatalf("failed to write index.html: %v", err) + } + + tests := []struct { + name string + httpServer *http.Server + isHttps bool + path string + wantHeaders map[string]string + }{ + { + name: "returns the correct headers from the http server", + httpServer: getHttpServer("some-fingerprint", true, tmpDir), + path: "/assets/index.html", + wantHeaders: map[string]string{ + "Content-Security-Policy": "frame-ancestors 'none';", + "X-Frame-Options": "DENY", + }, + }, + { + name: "returns the correct headers from the https server", + httpServer: getHttpsServer(&url.URL{}, &url.URL{}, "some-tls-secret", nil, true, tmpDir), + isHttps: true, + path: "/tls/assets/index.html", + wantHeaders: map[string]string{ + "Content-Security-Policy": "frame-ancestors 'none';", + "X-Frame-Options": "DENY", + }, + }, + } + + for _, tt := range tests { + var ts *httptest.Server + if tt.isHttps { + ts = httptest.NewTLSServer(tt.httpServer.Handler) + } else { + ts = httptest.NewServer(tt.httpServer.Handler) + } + defer ts.Close() + + client := ts.Client() + resp, err := client.Get(ts.URL + tt.path) + if err != nil { + t.Fatalf("failed to get index.html: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + for header, want := range tt.wantHeaders { + if got := resp.Header.Get(header); got != want { + t.Errorf("expected header %q to be %q, got %q", header, want, got) + } + } + } + +}