Skip to content

Commit

Permalink
add csp headers to kurl_proxy (#3725)
Browse files Browse the repository at this point in the history
* add csp headers to kurl_proxy

* add unit tests for csp headers in kurl_proxy
  • Loading branch information
Craig O'Donnell authored Mar 9, 2023
1 parent ef0e8f1 commit 1938146
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 22 deletions.
47 changes: 25 additions & 22 deletions kurl_proxy/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,14 @@ func main() {

m := cmux.New(listener)

httpsServer = getHttpsServer(upstream, dexUpstream, tlsSecretName, secrets, cert.acceptAnonymousUploads)
assetsDir := "/assets"

httpsServer = getHttpsServer(upstream, dexUpstream, tlsSecretName, secrets, cert.acceptAnonymousUploads, assetsDir)
tlsConfig := tlsconfig.ServerDefault()
tlsConfig.Certificates = []tls.Certificate{cert.tlsCert}
go httpsServer.Serve(tls.NewListener(m.Match(cmux.TLS()), tlsConfig))

httpServer = getHttpServer(cert.fingerprint, cert.acceptAnonymousUploads)
httpServer = getHttpServer(cert.fingerprint, cert.acceptAnonymousUploads, assetsDir)
go httpServer.Serve(m.Match(cmux.Any()))

log.Printf("Kurl Proxy listening on :%s\n", nodePort)
Expand Down Expand Up @@ -228,11 +230,13 @@ func getFingerprint(certData []byte) (string, error) {
return strings.ToUpper(strings.Replace(fmt.Sprintf("% x", sha1.Sum(x509Cert.Raw)), " ", ":", -1)), nil
}

func getHttpServer(fingerprint string, acceptAnonymousUploads bool) *http.Server {
func getHttpServer(fingerprint string, acceptAnonymousUploads bool, assetsDir string) *http.Server {
r := gin.Default()

r.StaticFS("/assets", http.Dir("/assets"))
r.LoadHTMLGlob("/assets/*.html")
r.Use(CSPMiddleware)

r.StaticFS("/assets", http.Dir(assetsDir))
r.LoadHTMLGlob(fmt.Sprintf("%s/*.html", assetsDir))

r.GET("/", func(c *gin.Context) {
if !acceptAnonymousUploads {
Expand Down Expand Up @@ -275,13 +279,13 @@ func getHttpServer(fingerprint string, acceptAnonymousUploads bool) *http.Server
}
}

func getHttpsServer(upstream, dexUpstream *url.URL, tlsSecretName string, secrets corev1.SecretInterface, acceptAnonymousUploads bool) *http.Server {
mux := http.NewServeMux()

func getHttpsServer(upstream, dexUpstream *url.URL, tlsSecretName string, secrets corev1.SecretInterface, acceptAnonymousUploads bool, assetsDir string) *http.Server {
r := gin.Default()

mux.Handle("/tls/assets/", http.StripPrefix("/tls/assets/", http.FileServer(http.Dir("/assets"))))
r.LoadHTMLGlob("/assets/*.html")
r.Use(CSPMiddleware)

r.StaticFS("/tls/assets", http.Dir(assetsDir))
r.LoadHTMLGlob(fmt.Sprintf("%s/*.html", assetsDir))

r.GET("/tls", func(c *gin.Context) {
if !acceptAnonymousUploads {
Expand Down Expand Up @@ -430,26 +434,25 @@ func getHttpsServer(upstream, dexUpstream *url.URL, tlsSecretName string, secret
}
}()
})
mux.Handle("/tls", r)
mux.Handle("/tls/", r)

// mux.Handle("/api/v1/kots/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// log.Println("Kots REST API not proxied.")
// http.Error(w, "Not found", http.StatusNotFound)
// }))

if dexUpstream != nil {
dexReverseProxy := httputil.NewSingleHostReverseProxy(dexUpstream)
mux.Handle("/dex", dexReverseProxy)
mux.Handle("/dex/", dexReverseProxy)
r.Any("/dex/*path", gin.WrapH(httputil.NewSingleHostReverseProxy(dexUpstream)))
}
mux.Handle("/", httputil.NewSingleHostReverseProxy(upstream))

r.NoRoute(gin.WrapH(httputil.NewSingleHostReverseProxy(upstream)))

return &http.Server{
Handler: mux,
Handler: r,
}
}

// CSPMiddleware adds Content-Security-Policy and X-Frame-Options headers to the response.
func CSPMiddleware(c *gin.Context) {
c.Writer.Header().Set("Content-Security-Policy", "frame-ancestors 'none';")
c.Writer.Header().Set("X-Frame-Options", "DENY")
c.Next()
}

func getUploadedCerts(c *gin.Context) ([]byte, []byte, error) {
certHeader, err := c.FormFile("cert")
if err != nil {
Expand Down
73 changes: 73 additions & 0 deletions kurl_proxy/cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ package main

import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
)

Expand Down Expand Up @@ -143,3 +148,71 @@ func Test_getFingerprint(t *testing.T) {
})
}
}

func Test_httpServerCSPHeaders(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "test-assets")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

indexFile := filepath.Join(tmpDir, "index.html")
if err := os.WriteFile(indexFile, []byte("hello world"), 0644); err != nil {
t.Fatalf("failed to write index.html: %v", err)
}

tests := []struct {
name string
httpServer *http.Server
isHttps bool
path string
wantHeaders map[string]string
}{
{
name: "returns the correct headers from the http server",
httpServer: getHttpServer("some-fingerprint", true, tmpDir),
path: "/assets/index.html",
wantHeaders: map[string]string{
"Content-Security-Policy": "frame-ancestors 'none';",
"X-Frame-Options": "DENY",
},
},
{
name: "returns the correct headers from the https server",
httpServer: getHttpsServer(&url.URL{}, &url.URL{}, "some-tls-secret", nil, true, tmpDir),
isHttps: true,
path: "/tls/assets/index.html",
wantHeaders: map[string]string{
"Content-Security-Policy": "frame-ancestors 'none';",
"X-Frame-Options": "DENY",
},
},
}

for _, tt := range tests {
var ts *httptest.Server
if tt.isHttps {
ts = httptest.NewTLSServer(tt.httpServer.Handler)
} else {
ts = httptest.NewServer(tt.httpServer.Handler)
}
defer ts.Close()

client := ts.Client()
resp, err := client.Get(ts.URL + tt.path)
if err != nil {
t.Fatalf("failed to get index.html: %v", err)
}

if resp.StatusCode != http.StatusOK {
t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}

for header, want := range tt.wantHeaders {
if got := resp.Header.Get(header); got != want {
t.Errorf("expected header %q to be %q, got %q", header, want, got)
}
}
}

}

0 comments on commit 1938146

Please sign in to comment.