diff --git a/go.mod b/go.mod index 0220af2..2ee8f39 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.0 require ( github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/lestrrat-go/jwx v1.2.29 github.com/stretchr/testify v1.9.0 ) diff --git a/go.sum b/go.sum index 5648a45..5991fca 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A= github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y= github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= diff --git a/main.go b/main.go index 8d3d8c7..626e195 100644 --- a/main.go +++ b/main.go @@ -9,12 +9,17 @@ import ( "net/http" "os" "os/exec" - "strings" "github.com/golang-jwt/jwt/v5" + "github.com/google/shlex" "github.com/lestrrat-go/jwx/jwk" ) +func init() { + // call getArgs early to fail on a bad config + getArgs() +} + func main() { if os.Getenv("JWKS_URI") == "" { log.Fatal("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys") @@ -84,11 +89,7 @@ func Rollout(w http.ResponseWriter, r *http.Request) { if name == "" { name = "/bin/bash" } - args := os.Getenv("ROLLOUT_ARGS") - if args == "" { - args = "/rollout.sh" - } - cmd := exec.Command(name, strings.Split(args, " ")...) + cmd := exec.Command(name, getArgs()...) var stdOut, stdErr bytes.Buffer cmd.Stdout = &stdOut @@ -162,3 +163,16 @@ func strInSlice(e string, s []string) bool { } return false } + +func getArgs() []string { + args := os.Getenv("ROLLOUT_ARGS") + if args == "" { + args = "/rollout.sh" + } + rolloutArgs, err := shlex.Split(args) + if err != nil { + log.Fatalf("Error parsing ROLLOUT_ARGS %s: %v", args, err) + } + + return rolloutArgs +} diff --git a/main_test.go b/main_test.go index ad5075a..f3ee9b0 100644 --- a/main_test.go +++ b/main_test.go @@ -18,6 +18,11 @@ import ( "github.com/stretchr/testify/assert" ) +var ( + kid, claim, aud string + privateKey *rsa.PrivateKey +) + // createJWKS creates a JWKS JSON representation with a single RSA key. func mockJWKS(pub *rsa.PublicKey, kid string) (string, error) { jwks := struct { @@ -80,6 +85,23 @@ func setupMockJwksServer(pub *rsa.PublicKey, kid string) *httptest.Server { return httptest.NewServer(handler) } +func createMockJwksServer() *httptest.Server { + var publicKey *rsa.PublicKey + var err error + + os.Setenv("JWT_AUD", "test-success") + kid = "no-kidding" + aud = os.Getenv("JWT_AUD") + claim = "bar" + privateKey, publicKey, err = GenerateRSAKeys() + if err != nil { + log.Fatalf("Unable to generate RSA keys: %v", err) + } + testServer := setupMockJwksServer(publicKey, kid) + os.Setenv("JWKS_URI", fmt.Sprintf("%s/oauth/discovery/keys", testServer.URL)) + return testServer +} + func CreateSignedJWT(kid, aud, claim string, exp int64, privateKey *rsa.PrivateKey) (string, error) { // Define the claims of the token. You can add more claims based on your needs. claims := jwt.MapClaims{ @@ -114,6 +136,8 @@ func createRequest(authHeader string) *http.Request { // TestRollout tests the Rollout function with various scenarios func TestRollout(t *testing.T) { testFile := "/tmp/rollout-test.txt" + + // have our test rollout cmd just touch a file os.Setenv("ROLLOUT_CMD", "touch") os.Setenv("ROLLOUT_ARGS", testFile) @@ -123,19 +147,8 @@ func TestRollout(t *testing.T) { log.Fatalf("Unable to cleanup test file: %v", err) } - // mock the JWKS server response - os.Setenv("JWT_AUD", "test-success") - kid := "no-kidding" - aud := os.Getenv("JWT_AUD") - claim := "bar" - privateKey, publicKey, err := GenerateRSAKeys() - if err != nil { - log.Fatalf("Unable to generate RSA keys: %v", err) - } - server := setupMockJwksServer(publicKey, kid) - defer server.Close() - jwkURL := fmt.Sprintf("%s/oauth/discovery/keys", server.URL) - os.Setenv("JWKS_URI", jwkURL) + s := createMockJwksServer() + defer s.Close() // get a valid token exp := time.Now().Add(time.Hour * 1).Unix() @@ -179,13 +192,13 @@ func TestRollout(t *testing.T) { t.Fatalf("Unable to create a JWT with our test key: %v", err) } - // Define test cases tests := []struct { name string authHeader string expectedStatus int expectedBody string claim map[string]string + cmdArgs string }{ { name: "No Authorization Header", @@ -235,6 +248,13 @@ func TestRollout(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: "Rollout complete\n", }, + { + name: "Rollout cmd with quotes parsed correctly", + authHeader: "Bearer " + jwtToken, + expectedStatus: http.StatusOK, + cmdArgs: `/tmp/rollout-shlex-test /tmp/"rollout test filename wrapped in quotes"`, + expectedBody: "Rollout complete\n", + }, { name: "Valid Token and Successful Command", authHeader: "Bearer " + jwtToken, @@ -250,8 +270,11 @@ func TestRollout(t *testing.T) { os.Setenv("CUSTOM_CLAIMS", "") } else { os.Setenv("CUSTOM_CLAIMS", `{"foo": "bar"}`) - } + if tt.cmdArgs != "" { + os.Setenv("ROLLOUT_ARGS", tt.cmdArgs) + } + Rollout(recorder, request) assert.Equal(t, tt.expectedStatus, recorder.Code) @@ -259,16 +282,24 @@ func TestRollout(t *testing.T) { }) } - // make sure the rollout command actually ran the command - _, err = os.Stat(testFile) - if err != nil && os.IsNotExist(err) { - t.Errorf("The successful test did not create the expected file") + testFiles := []string{ + testFile, + "/tmp/rollout-shlex-test", + `/tmp/rollout test filename wrapped in quotes`, } + for _, f := range testFiles { + // make sure the rollout command actually ran the command + // which creates the file + _, err = os.Stat(f) + if err != nil && os.IsNotExist(err) { + t.Errorf("The successful test did not create the expected file %s", f) + } - // cleanup - err = RemoveFileIfExists(testFile) - if err != nil { - log.Fatalf("Unable to cleanup test file: %v", err) + // cleanup + err = RemoveFileIfExists(f) + if err != nil { + log.Fatalf("Unable to cleanup test file: %v", err) + } } }