From 8477db4d9ee1a6d17e4f1ff9f0b2cf3807b7efa0 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 2 Mar 2024 08:51:06 -0600 Subject: [PATCH 01/45] Refactor the HTTP test module When the `file_cache` module will be added, the `launcher` module will depend on the `file_cache` module which depends on the client. That introduces a circular dependency since the tests are in the client module itself. This splits the pieces that need the launcher into a separate `test` module that breaks the dependency; it works as the launcher-integrated pieces only depend on the public interfaces of the client. --- client/fed_test.go | 855 +++++++++++++++++++++++++++++++++++++ client/handle_http_test.go | 819 ----------------------------------- 2 files changed, 855 insertions(+), 819 deletions(-) create mode 100644 client/fed_test.go diff --git a/client/fed_test.go b/client/fed_test.go new file mode 100644 index 000000000..cdb152c06 --- /dev/null +++ b/client/fed_test.go @@ -0,0 +1,855 @@ +//go:build !windows + +/*************************************************************** + * + * Copyright (C) 2023, University of Nebraska-Lincoln + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package client_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "testing" + "time" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/pelicanplatform/pelican/client" + "github.com/pelicanplatform/pelican/config" + "github.com/pelicanplatform/pelican/launchers" + "github.com/pelicanplatform/pelican/param" + "github.com/pelicanplatform/pelican/server_utils" + "github.com/pelicanplatform/pelican/test_utils" + "github.com/pelicanplatform/pelican/token_scopes" + "github.com/pelicanplatform/pelican/utils" +) + +func generateFileTestScitoken() (string, error) { + // Issuer is whichever server that initiates the test, so it's the server itself + issuerUrl, err := config.GetServerIssuerURL() + if err != nil { + return "", err + } + if issuerUrl == "" { // if empty, then error + return "", errors.New("Failed to create token: Invalid iss, Server_ExternalWebUrl is empty") + } + + scopes := []token_scopes.TokenScope{} + readScope, err := token_scopes.Storage_Read.Path("/") + if err != nil { + return "", errors.Wrap(err, "failed to create 'read' scope for file test token:") + } + scopes = append(scopes, readScope) + modScope, err := token_scopes.Storage_Modify.Path("/") + if err != nil { + return "", errors.Wrap(err, "failed to create 'modify' scope for file test token:") + } + scopes = append(scopes, modScope) + + fTestTokenCfg := utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Minute, + Issuer: issuerUrl, + Audience: []string{config.GetServerAudience()}, + Version: "1.0", + Subject: "origin", + } + fTestTokenCfg.AddScopes(scopes) + + // CreateToken also handles validation for us + tok, err := fTestTokenCfg.CreateToken() + if err != nil { + return "", errors.Wrap(err, "failed to create file test token:") + } + + return tok, nil +} + +func TestFullUpload(t *testing.T) { + // Setup our test federation + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer func() { require.NoError(t, egrp.Wait()) }() + defer cancel() + + viper.Reset() + + modules := config.ServerType(0) + modules.Set(config.OriginType) + modules.Set(config.DirectorType) + modules.Set(config.RegistryType) + + // Create our own temp directory (for some reason t.TempDir() does not play well with xrootd) + tmpPathPattern := "XRootD-Test_Origin*" + tmpPath, err := os.MkdirTemp("", tmpPathPattern) + require.NoError(t, err) + + // Need to set permissions or the xrootd process we spawn won't be able to write PID/UID files + permissions := os.FileMode(0755) + err = os.Chmod(tmpPath, permissions) + require.NoError(t, err) + + viper.Set("ConfigDir", tmpPath) + + // Increase the log level; otherwise, its difficult to debug failures + viper.Set("Logging.Level", "Debug") + config.InitConfig() + + originDir, err := os.MkdirTemp("", "Origin") + assert.NoError(t, err) + + // Change the permissions of the temporary directory + permissions = os.FileMode(0777) + err = os.Chmod(originDir, permissions) + require.NoError(t, err) + + viper.Set("Origin.ExportVolume", originDir+":/test") + viper.Set("Origin.Mode", "posix") + // Disable functionality we're not using (and is difficult to make work on Mac) + viper.Set("Origin.EnableCmsd", false) + viper.Set("Origin.EnableMacaroons", false) + viper.Set("Origin.EnableVoms", false) + viper.Set("Origin.EnableWrite", true) + viper.Set("TLSSkipVerify", true) + viper.Set("Server.EnableUI", false) + viper.Set("Registry.DbLocation", filepath.Join(t.TempDir(), "ns-registry.sqlite")) + viper.Set("Origin.RunLocation", tmpPath) + viper.Set("Registry.RequireOriginApproval", false) + viper.Set("Registry.RequireCacheApproval", false) + viper.Set("Logging.Origin.Scitokens", "debug") + viper.Set("Origin.Port", 0) + + err = config.InitServer(ctx, modules) + require.NoError(t, err) + + fedCancel, err := launchers.LaunchModules(ctx, modules) + defer fedCancel() + if err != nil { + log.Errorln("Failure in fedServeInternal:", err) + require.NoError(t, err) + } + + desiredURL := param.Server_ExternalWebUrl.GetString() + "/.well-known/openid-configuration" + err = server_utils.WaitUntilWorking(ctx, "GET", desiredURL, "director", 200) + require.NoError(t, err) + + httpc := http.Client{ + Transport: config.GetTransport(), + } + resp, err := httpc.Get(desiredURL) + require.NoError(t, err) + + assert.Equal(t, resp.StatusCode, http.StatusOK) + + responseBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + expectedResponse := struct { + JwksUri string `json:"jwks_uri"` + }{} + err = json.Unmarshal(responseBody, &expectedResponse) + require.NoError(t, err) + + assert.NotEmpty(t, expectedResponse.JwksUri) + + t.Run("testFullUpload", func(t *testing.T) { + testFileContent := "test file content" + + // Create the temporary file to upload + tempFile, err := os.CreateTemp(t.TempDir(), "test") + assert.NoError(t, err, "Error creating temp file") + defer os.Remove(tempFile.Name()) + _, err = tempFile.WriteString(testFileContent) + assert.NoError(t, err, "Error writing to temp file") + tempFile.Close() + + // Create a token file + token, err := generateFileTestScitoken() + assert.NoError(t, err) + tempToken, err := os.CreateTemp(t.TempDir(), "token") + assert.NoError(t, err, "Error creating temp token file") + defer os.Remove(tempToken.Name()) + _, err = tempToken.WriteString(token) + assert.NoError(t, err, "Error writing to temp token file") + tempToken.Close() + + // Upload the file + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + uploadURL := "stash:///test/" + fileName + + transferResults, err := client.DoCopy(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err, "Error uploading file") + assert.Equal(t, int64(len(testFileContent)), transferResults[0].TransferredBytes, "Uploaded file size does not match") + + // Upload an osdf file + uploadURL = "pelican:///test/stuff/blah.txt" + assert.NoError(t, err, "Error parsing upload URL") + transferResults, err = client.DoCopy(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err, "Error uploading file") + assert.Equal(t, int64(len(testFileContent)), transferResults[0].TransferredBytes, "Uploaded file size does not match") + }) + t.Cleanup(func() { + os.RemoveAll(tmpPath) + os.RemoveAll(originDir) + }) + + cancel() + fedCancel() + assert.NoError(t, egrp.Wait()) + viper.Reset() +} + +type FedTest struct { + T *testing.T + TmpPath string + OriginDir string + Output *os.File + Cancel context.CancelFunc + FedCancel context.CancelFunc + ErrGroup *errgroup.Group +} + +func (f *FedTest) Spinup() { + //////////////////////////////Setup our test federation////////////////////////////////////////// + ctx, cancel, egrp := test_utils.TestContext(context.Background(), f.T) + + modules := config.ServerType(0) + modules.Set(config.OriginType) + modules.Set(config.DirectorType) + modules.Set(config.RegistryType) + + // Create our own temp directory (for some reason t.TempDir() does not play well with xrootd) + tmpPathPattern := "XRootD-Test_Origin*" + tmpPath, err := os.MkdirTemp("", tmpPathPattern) + require.NoError(f.T, err) + f.TmpPath = tmpPath + + // Need to set permissions or the xrootd process we spawn won't be able to write PID/UID files + permissions := os.FileMode(0755) + err = os.Chmod(tmpPath, permissions) + require.NoError(f.T, err) + + viper.Set("ConfigDir", tmpPath) + + config.InitConfig() + // Create a file to capture output from commands + output, err := os.CreateTemp(f.T.TempDir(), "output") + assert.NoError(f.T, err) + f.Output = output + viper.Set("Logging.LogLocation", output.Name()) + + originDir, err := os.MkdirTemp("", "Origin") + assert.NoError(f.T, err) + f.OriginDir = originDir + + // Change the permissions of the temporary origin directory + permissions = os.FileMode(0777) + err = os.Chmod(originDir, permissions) + require.NoError(f.T, err) + + viper.Set("Origin.ExportVolume", originDir+":/test") + viper.Set("Origin.Mode", "posix") + viper.Set("Origin.EnableFallbackRead", true) + // Disable functionality we're not using (and is difficult to make work on Mac) + viper.Set("Origin.EnableCmsd", false) + viper.Set("Origin.EnableMacaroons", false) + viper.Set("Origin.EnableVoms", false) + viper.Set("Origin.EnableWrite", true) + viper.Set("TLSSkipVerify", true) + viper.Set("Server.EnableUI", false) + viper.Set("Registry.DbLocation", filepath.Join(f.T.TempDir(), "ns-registry.sqlite")) + viper.Set("Origin.Port", 0) + viper.Set("Server.WebPort", 0) + viper.Set("Origin.RunLocation", tmpPath) + + err = config.InitServer(ctx, modules) + require.NoError(f.T, err) + + viper.Set("Registry.RequireOriginApproval", false) + viper.Set("Registry.RequireCacheApproval", false) + + f.FedCancel, err = launchers.LaunchModules(ctx, modules) + if err != nil { + f.T.Fatalf("Failure in fedServeInternal: %v", err) + } + + desiredURL := param.Server_ExternalWebUrl.GetString() + "/.well-known/openid-configuration" + err = server_utils.WaitUntilWorking(ctx, "GET", desiredURL, "director", 200) + require.NoError(f.T, err) + + httpc := http.Client{ + Transport: config.GetTransport(), + } + resp, err := httpc.Get(desiredURL) + require.NoError(f.T, err) + + assert.Equal(f.T, resp.StatusCode, http.StatusOK) + + responseBody, err := io.ReadAll(resp.Body) + require.NoError(f.T, err) + expectedResponse := struct { + JwksUri string `json:"jwks_uri"` + }{} + err = json.Unmarshal(responseBody, &expectedResponse) + require.NoError(f.T, err) + + f.Cancel = cancel + f.ErrGroup = egrp +} + +func (f *FedTest) Teardown() { + os.RemoveAll(f.TmpPath) + os.RemoveAll(f.OriginDir) + f.Cancel() + f.FedCancel() + assert.NoError(f.T, f.ErrGroup.Wait()) + viper.Reset() +} + +// A test that spins up a federation, and tests object get and put +func TestGetAndPutAuth(t *testing.T) { + // Create instance of test federation + ctx, _, _ := test_utils.TestContext(context.Background(), t) + + viper.Reset() + fed := FedTest{T: t} + fed.Spinup() + defer fed.Teardown() + + // Other set-up items: + testFileContent := "test file content" + // Create the temporary file to upload + tempFile, err := os.CreateTemp(t.TempDir(), "test") + assert.NoError(t, err, "Error creating temp file") + defer os.Remove(tempFile.Name()) + _, err = tempFile.WriteString(testFileContent) + assert.NoError(t, err, "Error writing to temp file") + tempFile.Close() + + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + audience := config.GetServerAudience() + + // Create a token file + tokenConfig := utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Minute, + Issuer: issuer, + Audience: []string{audience}, + Subject: "origin", + } + + scopes := []token_scopes.TokenScope{} + readScope, err := token_scopes.Storage_Read.Path("/") + assert.NoError(t, err) + scopes = append(scopes, readScope) + modScope, err := token_scopes.Storage_Modify.Path("/") + assert.NoError(t, err) + scopes = append(scopes, modScope) + tokenConfig.AddScopes(scopes) + token, err := tokenConfig.CreateToken() + assert.NoError(t, err) + tempToken, err := os.CreateTemp(t.TempDir(), "token") + assert.NoError(t, err, "Error creating temp token file") + defer os.Remove(tempToken.Name()) + _, err = tempToken.WriteString(token) + assert.NoError(t, err, "Error writing to temp token file") + tempToken.Close() + // Disable progress bars to not reuse the same mpb instance + viper.Set("Logging.DisableProgressBars", true) + + // This tests object get/put with a pelican:// url + t.Run("testPelicanObjectPutAndGetWithPelicanUrl", func(t *testing.T) { + config.SetPreferredPrefix("pelican") + // Set path for object to upload/download + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + uploadURL := "pelican:///test/" + fileName + + // Upload the file with PUT + transferResultsUpload, err := client.DoPut(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) + } + + // Download that same file with GET + transferResultsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) + } + }) + + // This tests pelican object get/put with an osdf url + t.Run("testPelicanObjectPutAndGetWithOSDFUrl", func(t *testing.T) { + config.SetPreferredPrefix("pelican") + // Set path for object to upload/download + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + // Minimal fix of test as it is soon to be replaced + uploadURL := "pelican:///test/" + fileName + + // Upload the file with PUT + transferResultsUpload, err := client.DoPut(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) + } + + // Download that same file with GET + transferResultsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) + } + }) + + // This tests object get/put with a pelican:// url + t.Run("testOsdfObjectPutAndGetWithPelicanUrl", func(t *testing.T) { + config.SetPreferredPrefix("osdf") + // Set path for object to upload/download + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + uploadURL := "pelican:///test/" + fileName + + // Upload the file with PUT + transferResultsUpload, err := client.DoPut(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) + } + + // Download that same file with GET + transferResultsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) + } + }) + + // This tests pelican object get/put with an osdf url + t.Run("testOsdfObjectPutAndGetWithOSDFUrl", func(t *testing.T) { + config.SetPreferredPrefix("osdf") + // Set path for object to upload/download + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + // Minimal fix of test as it is soon to be replaced + uploadURL := "pelican:///test/" + fileName + + // Upload the file with PUT + transferResultsUpload, err := client.DoPut(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) + } + + // Download that same file with GET + transferResultsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) + } + }) +} + +// A test that spins up the federation, where the origin is in EnablePublicReads mode. Then GET a file from the origin without a token +func TestGetPublicRead(t *testing.T) { + ctx, _, _ := test_utils.TestContext(context.Background(), t) + viper.Reset() + viper.Set("Origin.EnablePublicReads", true) + fed := FedTest{T: t} + fed.Spinup() + defer fed.Teardown() + t.Run("testPubObjGet", func(t *testing.T) { + testFileContent := "test file content" + // Drop the testFileContent into the origin directory + tempFile, err := os.Create(filepath.Join(fed.OriginDir, "test.txt")) + assert.NoError(t, err, "Error creating temp file") + defer os.Remove(tempFile.Name()) + _, err = tempFile.WriteString(testFileContent) + assert.NoError(t, err, "Error writing to temp file") + tempFile.Close() + + viper.Set("Logging.DisableProgressBars", true) + + // Set path for object to upload/download + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + uploadURL := "pelican:///test/" + fileName + + // Download the file with GET. Shouldn't need a token to succeed + transferResults, err := client.DoGet(ctx, uploadURL, t.TempDir(), false) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResults[0].TransferredBytes, int64(17)) + } + }) +} + +func TestRecursiveUploadsAndDownloads(t *testing.T) { + // Create instance of test federation + ctx, _, _ := test_utils.TestContext(context.Background(), t) + + viper.Reset() + fed := FedTest{T: t} + fed.Spinup() + defer fed.Teardown() + + //////////////////////////SETUP/////////////////////////// + // Create a token file + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + audience := config.GetServerAudience() + + tokenConfig := utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Minute, + Issuer: issuer, + Audience: []string{audience}, + Subject: "origin", + } + scopes := []token_scopes.TokenScope{} + readScope, err := token_scopes.Storage_Read.Path("/") + assert.NoError(t, err) + scopes = append(scopes, readScope) + modScope, err := token_scopes.Storage_Modify.Path("/") + assert.NoError(t, err) + scopes = append(scopes, modScope) + tokenConfig.AddScopes(scopes) + token, err := tokenConfig.CreateToken() + assert.NoError(t, err) + tempToken, err := os.CreateTemp(t.TempDir(), "token") + assert.NoError(t, err, "Error creating temp token file") + defer os.Remove(tempToken.Name()) + _, err = tempToken.WriteString(token) + assert.NoError(t, err, "Error writing to temp token file") + tempToken.Close() + + // Disable progress bars to not reuse the same mpb instance + viper.Set("Logging.DisableProgressBars", true) + + // Make our test directories and files + tempDir, err := os.MkdirTemp("", "UploadDir") + assert.NoError(t, err) + defer os.RemoveAll(tempDir) + permissions := os.FileMode(0777) + err = os.Chmod(tempDir, permissions) + require.NoError(t, err) + + testFileContent1 := "test file content" + testFileContent2 := "more test file content!" + tempFile1, err := os.CreateTemp(tempDir, "test1") + assert.NoError(t, err, "Error creating temp1 file") + tempFile2, err := os.CreateTemp(tempDir, "test1") + assert.NoError(t, err, "Error creating temp2 file") + defer os.Remove(tempFile1.Name()) + defer os.Remove(tempFile2.Name()) + _, err = tempFile1.WriteString(testFileContent1) + assert.NoError(t, err, "Error writing to temp1 file") + tempFile1.Close() + _, err = tempFile2.WriteString(testFileContent2) + assert.NoError(t, err, "Error writing to temp2 file") + tempFile2.Close() + + t.Run("testPelicanRecursiveGetAndPutOsdfURL", func(t *testing.T) { + config.SetPreferredPrefix("pelican") + // Set path for object to upload/download + tempPath := tempDir + dirName := filepath.Base(tempPath) + // Note: minimally fixing this test as it is soon to be replaced + uploadURL := "pelican://" + param.Server_Hostname.GetString() + ":" + strconv.Itoa(param.Server_WebPort.GetInt()) + "/test/" + dirName + + ////////////////////////////////////////////////////////// + + // Upload the file with PUT + transferDetailsUpload, err := client.DoPut(ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) + require.NoError(t, err) + if err == nil && len(transferDetailsUpload) == 2 { + countBytes17 := 0 + countBytes23 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) + for _, transfer := range transferDetailsUpload { + transferredBytes := transfer.TransferredBytes + switch transferredBytes { + case int64(17): + countBytes17++ + continue + case int64(23): + countBytes23++ + continue + default: + // We got a byte amount we are not expecting + t.Fatal("did not upload proper amount of bytes") + } + } + if countBytes17 != 1 || countBytes23 != 1 { + // We would hit this case if 1 counter got hit twice for some reason + t.Fatal("One of the files was not uploaded correctly") + } + } else if len(transferDetailsUpload) != 2 { + t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) + } + + // Download the files we just uploaded + transferDetailsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil && len(transferDetailsUpload) == 2 { + countBytesUploadIdx0 := 0 + countBytesUploadIdx1 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) + // In this case, we want to match them to the sizes of the uploaded files + for _, transfer := range transferDetailsUpload { + transferredBytes := transfer.TransferredBytes + switch transferredBytes { + case transferDetailsUpload[0].TransferredBytes: + countBytesUploadIdx0++ + continue + case transferDetailsUpload[1].TransferredBytes: + countBytesUploadIdx1++ + continue + default: + // We got a byte amount we are not expecting + t.Fatal("did not download proper amount of bytes") + } + } + if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { + // We would hit this case if 1 counter got hit twice for some reason + t.Fatal("One of the files was not downloaded correctly") + } else if len(transferDetailsDownload) != 2 { + t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) + } + } + }) + + t.Run("testPelicanRecursiveGetAndPutPelicanURL", func(t *testing.T) { + config.SetPreferredPrefix("pelican") + // Set path for object to upload/download + tempPath := tempDir + dirName := filepath.Base(tempPath) + uploadURL := "pelican:///test/" + dirName + + ////////////////////////////////////////////////////////// + + // Upload the file with PUT + transferDetailsUpload, err := client.DoPut(ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil && len(transferDetailsUpload) == 2 { + countBytes17 := 0 + countBytes23 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) + for _, transfer := range transferDetailsUpload { + transferredBytes := transfer.TransferredBytes + switch transferredBytes { + case int64(17): + countBytes17++ + continue + case int64(23): + countBytes23++ + continue + default: + // We got a byte amount we are not expecting + t.Fatal("did not upload proper amount of bytes") + } + } + if countBytes17 != 1 || countBytes23 != 1 { + // We would hit this case if 1 counter got hit twice for some reason + t.Fatal("One of the files was not uploaded correctly") + } + } else if len(transferDetailsUpload) != 2 { + t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) + } + + // Download the files we just uploaded + transferDetailsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil && len(transferDetailsUpload) == 2 { + countBytesUploadIdx0 := 0 + countBytesUploadIdx1 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) + // In this case, we want to match them to the sizes of the uploaded files + for _, transfer := range transferDetailsUpload { + transferredBytes := transfer.TransferredBytes + switch transferredBytes { + case transferDetailsUpload[0].TransferredBytes: + countBytesUploadIdx0++ + continue + case transferDetailsUpload[1].TransferredBytes: + countBytesUploadIdx1++ + continue + default: + // We got a byte amount we are not expecting + t.Fatal("did not download proper amount of bytes") + } + } + if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { + // We would hit this case if 1 counter got hit twice for some reason + t.Fatal("One of the files was not downloaded correctly") + } else if len(transferDetailsDownload) != 2 { + t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) + } + } + }) + + t.Run("testOsdfRecursiveGetAndPutOsdfURL", func(t *testing.T) { + config.SetPreferredPrefix("osdf") + // Set path for object to upload/download + tempPath := tempDir + dirName := filepath.Base(tempPath) + // Note: minimally fixing this test as it is soon to be replaced + uploadURL := "pelican://" + param.Server_Hostname.GetString() + ":" + strconv.Itoa(param.Server_WebPort.GetInt()) + "/test/" + dirName + + ////////////////////////////////////////////////////////// + + // Upload the file with PUT + transferDetailsUpload, err := client.DoPut(ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil && len(transferDetailsUpload) == 2 { + countBytes17 := 0 + countBytes23 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) + for _, transfer := range transferDetailsUpload { + transferredBytes := transfer.TransferredBytes + switch transferredBytes { + case int64(17): + countBytes17++ + continue + case int64(23): + countBytes23++ + continue + default: + // We got a byte amount we are not expecting + t.Fatal("did not upload proper amount of bytes") + } + } + if countBytes17 != 1 || countBytes23 != 1 { + // We would hit this case if 1 counter got hit twice for some reason + t.Fatal("One of the files was not uploaded correctly") + } + } else if len(transferDetailsUpload) != 2 { + t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) + } + + // Download the files we just uploaded + transferDetailsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil && len(transferDetailsUpload) == 2 { + countBytesUploadIdx0 := 0 + countBytesUploadIdx1 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) + // In this case, we want to match them to the sizes of the uploaded files + for _, transfer := range transferDetailsUpload { + transferredBytes := transfer.TransferredBytes + switch transferredBytes { + case transferDetailsUpload[0].TransferredBytes: + countBytesUploadIdx0++ + continue + case transferDetailsUpload[1].TransferredBytes: + countBytesUploadIdx1++ + continue + default: + // We got a byte amount we are not expecting + t.Fatal("did not download proper amount of bytes") + } + } + if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { + // We would hit this case if 1 counter got hit twice for some reason + t.Fatal("One of the files was not downloaded correctly") + } else if len(transferDetailsDownload) != 2 { + t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) + } + } + }) + + t.Run("testOsdfRecursiveGetAndPutPelicanURL", func(t *testing.T) { + config.SetPreferredPrefix("osdf") + // Set path for object to upload/download + tempPath := tempDir + dirName := filepath.Base(tempPath) + uploadURL := "pelican:///test/" + dirName + + ////////////////////////////////////////////////////////// + + // Upload the file with PUT + transferDetailsUpload, err := client.DoPut(ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil && len(transferDetailsUpload) == 2 { + countBytes17 := 0 + countBytes23 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) + for _, transfer := range transferDetailsUpload { + transferredBytes := transfer.TransferredBytes + switch transferredBytes { + case int64(17): + countBytes17++ + continue + case int64(23): + countBytes23++ + continue + default: + // We got a byte amount we are not expecting + t.Fatal("did not upload proper amount of bytes") + } + } + if countBytes17 != 1 || countBytes23 != 1 { + // We would hit this case if 1 counter got hit twice for some reason + t.Fatal("One of the files was not uploaded correctly") + } + } else if len(transferDetailsUpload) != 2 { + t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) + } + + // Download the files we just uploaded + transferDetailsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil && len(transferDetailsUpload) == 2 { + countBytesUploadIdx0 := 0 + countBytesUploadIdx1 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) + // In this case, we want to match them to the sizes of the uploaded files + for _, transfer := range transferDetailsUpload { + transferredBytes := transfer.TransferredBytes + switch transferredBytes { + case transferDetailsUpload[0].TransferredBytes: + countBytesUploadIdx0++ + continue + case transferDetailsUpload[1].TransferredBytes: + countBytesUploadIdx1++ + continue + default: + // We got a byte amount we are not expecting + t.Fatal("did not download proper amount of bytes") + } + } + if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { + // We would hit this case if 1 counter got hit twice for some reason + t.Fatal("One of the files was not downloaded correctly") + } else if len(transferDetailsDownload) != 2 { + t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) + } + } + }) +} diff --git a/client/handle_http_test.go b/client/handle_http_test.go index 6b32ae9bc..32992df1a 100644 --- a/client/handle_http_test.go +++ b/client/handle_http_test.go @@ -23,8 +23,6 @@ package client import ( "bytes" "context" - "encoding/json" - "io" "net" "net/http" "net/http/httptest" @@ -32,26 +30,16 @@ import ( "net/url" "os" "path/filepath" - "strconv" "strings" "testing" "time" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" "github.com/pelicanplatform/pelican/config" - "github.com/pelicanplatform/pelican/launchers" "github.com/pelicanplatform/pelican/namespaces" - "github.com/pelicanplatform/pelican/param" - "github.com/pelicanplatform/pelican/server_utils" "github.com/pelicanplatform/pelican/test_utils" - "github.com/pelicanplatform/pelican/token_scopes" - "github.com/pelicanplatform/pelican/utils" ) func TestMain(m *testing.M) { @@ -400,810 +388,3 @@ func TestFailedUpload(t *testing.T) { assert.Fail(t, "Timeout while waiting for response") } } - -func generateFileTestScitoken() (string, error) { - // Issuer is whichever server that initiates the test, so it's the server itself - issuerUrl, err := config.GetServerIssuerURL() - if err != nil { - return "", err - } - if issuerUrl == "" { // if empty, then error - return "", errors.New("Failed to create token: Invalid iss, Server_ExternalWebUrl is empty") - } - - scopes := []token_scopes.TokenScope{} - readScope, err := token_scopes.Storage_Read.Path("/") - if err != nil { - return "", errors.Wrap(err, "failed to create 'read' scope for file test token:") - } - scopes = append(scopes, readScope) - modScope, err := token_scopes.Storage_Modify.Path("/") - if err != nil { - return "", errors.Wrap(err, "failed to create 'modify' scope for file test token:") - } - scopes = append(scopes, modScope) - - fTestTokenCfg := utils.TokenConfig{ - TokenProfile: utils.WLCG, - Lifetime: time.Minute, - Issuer: issuerUrl, - Audience: []string{config.GetServerAudience()}, - Version: "1.0", - Subject: "origin", - } - fTestTokenCfg.AddScopes(scopes) - - // CreateToken also handles validation for us - tok, err := fTestTokenCfg.CreateToken() - if err != nil { - return "", errors.Wrap(err, "failed to create file test token:") - } - - return tok, nil -} - -func TestFullUpload(t *testing.T) { - // Setup our test federation - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer func() { require.NoError(t, egrp.Wait()) }() - defer cancel() - - viper.Reset() - - modules := config.ServerType(0) - modules.Set(config.OriginType) - modules.Set(config.DirectorType) - modules.Set(config.RegistryType) - - // Create our own temp directory (for some reason t.TempDir() does not play well with xrootd) - tmpPathPattern := "XRootD-Test_Origin*" - tmpPath, err := os.MkdirTemp("", tmpPathPattern) - require.NoError(t, err) - - // Need to set permissions or the xrootd process we spawn won't be able to write PID/UID files - permissions := os.FileMode(0755) - err = os.Chmod(tmpPath, permissions) - require.NoError(t, err) - - viper.Set("ConfigDir", tmpPath) - - // Increase the log level; otherwise, its difficult to debug failures - viper.Set("Logging.Level", "Debug") - config.InitConfig() - - originDir, err := os.MkdirTemp("", "Origin") - assert.NoError(t, err) - - // Change the permissions of the temporary directory - permissions = os.FileMode(0777) - err = os.Chmod(originDir, permissions) - require.NoError(t, err) - - viper.Set("Origin.ExportVolume", originDir+":/test") - viper.Set("Origin.Mode", "posix") - // Disable functionality we're not using (and is difficult to make work on Mac) - viper.Set("Origin.EnableCmsd", false) - viper.Set("Origin.EnableMacaroons", false) - viper.Set("Origin.EnableVoms", false) - viper.Set("Origin.EnableWrite", true) - viper.Set("TLSSkipVerify", true) - viper.Set("Server.EnableUI", false) - viper.Set("Registry.DbLocation", filepath.Join(t.TempDir(), "ns-registry.sqlite")) - viper.Set("Origin.RunLocation", tmpPath) - viper.Set("Registry.RequireOriginApproval", false) - viper.Set("Registry.RequireCacheApproval", false) - viper.Set("Logging.Origin.Scitokens", "debug") - viper.Set("Origin.Port", 0) - - err = config.InitServer(ctx, modules) - require.NoError(t, err) - - fedCancel, err := launchers.LaunchModules(ctx, modules) - defer fedCancel() - if err != nil { - log.Errorln("Failure in fedServeInternal:", err) - require.NoError(t, err) - } - - desiredURL := param.Server_ExternalWebUrl.GetString() + "/.well-known/openid-configuration" - err = server_utils.WaitUntilWorking(ctx, "GET", desiredURL, "director", 200) - require.NoError(t, err) - - httpc := http.Client{ - Transport: config.GetTransport(), - } - resp, err := httpc.Get(desiredURL) - require.NoError(t, err) - - assert.Equal(t, resp.StatusCode, http.StatusOK) - - responseBody, err := io.ReadAll(resp.Body) - require.NoError(t, err) - expectedResponse := struct { - JwksUri string `json:"jwks_uri"` - }{} - err = json.Unmarshal(responseBody, &expectedResponse) - require.NoError(t, err) - - assert.NotEmpty(t, expectedResponse.JwksUri) - - t.Run("testFullUpload", func(t *testing.T) { - testFileContent := "test file content" - - // Create the temporary file to upload - tempFile, err := os.CreateTemp(t.TempDir(), "test") - assert.NoError(t, err, "Error creating temp file") - defer os.Remove(tempFile.Name()) - _, err = tempFile.WriteString(testFileContent) - assert.NoError(t, err, "Error writing to temp file") - tempFile.Close() - - // Create a token file - token, err := generateFileTestScitoken() - assert.NoError(t, err) - tempToken, err := os.CreateTemp(t.TempDir(), "token") - assert.NoError(t, err, "Error creating temp token file") - defer os.Remove(tempToken.Name()) - _, err = tempToken.WriteString(token) - assert.NoError(t, err, "Error writing to temp token file") - tempToken.Close() - - // Upload the file - tempPath := tempFile.Name() - fileName := filepath.Base(tempPath) - uploadURL := "stash:///test/" + fileName - - transferResults, err := DoCopy(ctx, tempFile.Name(), uploadURL, false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err, "Error uploading file") - assert.Equal(t, int64(len(testFileContent)), transferResults[0].TransferredBytes, "Uploaded file size does not match") - - // Upload an osdf file - uploadURL = "pelican:///test/stuff/blah.txt" - assert.NoError(t, err, "Error parsing upload URL") - transferResults, err = DoCopy(ctx, tempFile.Name(), uploadURL, false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err, "Error uploading file") - assert.Equal(t, int64(len(testFileContent)), transferResults[0].TransferredBytes, "Uploaded file size does not match") - }) - t.Cleanup(func() { - os.RemoveAll(tmpPath) - os.RemoveAll(originDir) - }) - - cancel() - fedCancel() - assert.NoError(t, egrp.Wait()) - viper.Reset() -} - -type FedTest struct { - T *testing.T - TmpPath string - OriginDir string - Output *os.File - Cancel context.CancelFunc - FedCancel context.CancelFunc - ErrGroup *errgroup.Group -} - -func (f *FedTest) Spinup() { - //////////////////////////////Setup our test federation////////////////////////////////////////// - ctx, cancel, egrp := test_utils.TestContext(context.Background(), f.T) - - modules := config.ServerType(0) - modules.Set(config.OriginType) - modules.Set(config.DirectorType) - modules.Set(config.RegistryType) - - // Create our own temp directory (for some reason t.TempDir() does not play well with xrootd) - tmpPathPattern := "XRootD-Test_Origin*" - tmpPath, err := os.MkdirTemp("", tmpPathPattern) - require.NoError(f.T, err) - f.TmpPath = tmpPath - - // Need to set permissions or the xrootd process we spawn won't be able to write PID/UID files - permissions := os.FileMode(0755) - err = os.Chmod(tmpPath, permissions) - require.NoError(f.T, err) - - viper.Set("ConfigDir", tmpPath) - - config.InitConfig() - // Create a file to capture output from commands - output, err := os.CreateTemp(f.T.TempDir(), "output") - assert.NoError(f.T, err) - f.Output = output - viper.Set("Logging.LogLocation", output.Name()) - - originDir, err := os.MkdirTemp("", "Origin") - assert.NoError(f.T, err) - f.OriginDir = originDir - - // Change the permissions of the temporary origin directory - permissions = os.FileMode(0777) - err = os.Chmod(originDir, permissions) - require.NoError(f.T, err) - - viper.Set("Origin.ExportVolume", originDir+":/test") - viper.Set("Origin.Mode", "posix") - viper.Set("Origin.EnableFallbackRead", true) - // Disable functionality we're not using (and is difficult to make work on Mac) - viper.Set("Origin.EnableCmsd", false) - viper.Set("Origin.EnableMacaroons", false) - viper.Set("Origin.EnableVoms", false) - viper.Set("Origin.EnableWrite", true) - viper.Set("TLSSkipVerify", true) - viper.Set("Server.EnableUI", false) - viper.Set("Registry.DbLocation", filepath.Join(f.T.TempDir(), "ns-registry.sqlite")) - viper.Set("Origin.Port", 0) - viper.Set("Server.WebPort", 0) - viper.Set("Origin.RunLocation", tmpPath) - - err = config.InitServer(ctx, modules) - require.NoError(f.T, err) - - viper.Set("Registry.RequireOriginApproval", false) - viper.Set("Registry.RequireCacheApproval", false) - - f.FedCancel, err = launchers.LaunchModules(ctx, modules) - if err != nil { - f.T.Fatalf("Failure in fedServeInternal: %v", err) - } - - desiredURL := param.Server_ExternalWebUrl.GetString() + "/.well-known/openid-configuration" - err = server_utils.WaitUntilWorking(ctx, "GET", desiredURL, "director", 200) - require.NoError(f.T, err) - - httpc := http.Client{ - Transport: config.GetTransport(), - } - resp, err := httpc.Get(desiredURL) - require.NoError(f.T, err) - - assert.Equal(f.T, resp.StatusCode, http.StatusOK) - - responseBody, err := io.ReadAll(resp.Body) - require.NoError(f.T, err) - expectedResponse := struct { - JwksUri string `json:"jwks_uri"` - }{} - err = json.Unmarshal(responseBody, &expectedResponse) - require.NoError(f.T, err) - - f.Cancel = cancel - f.ErrGroup = egrp -} - -func (f *FedTest) Teardown() { - os.RemoveAll(f.TmpPath) - os.RemoveAll(f.OriginDir) - f.Cancel() - f.FedCancel() - assert.NoError(f.T, f.ErrGroup.Wait()) - viper.Reset() -} - -// A test that spins up a federation, and tests object get and put -func TestGetAndPutAuth(t *testing.T) { - // Create instance of test federation - ctx, _, _ := test_utils.TestContext(context.Background(), t) - - viper.Reset() - fed := FedTest{T: t} - fed.Spinup() - defer fed.Teardown() - - // Other set-up items: - testFileContent := "test file content" - // Create the temporary file to upload - tempFile, err := os.CreateTemp(t.TempDir(), "test") - assert.NoError(t, err, "Error creating temp file") - defer os.Remove(tempFile.Name()) - _, err = tempFile.WriteString(testFileContent) - assert.NoError(t, err, "Error writing to temp file") - tempFile.Close() - - issuer, err := config.GetServerIssuerURL() - require.NoError(t, err) - audience := config.GetServerAudience() - - // Create a token file - tokenConfig := utils.TokenConfig{ - TokenProfile: utils.WLCG, - Lifetime: time.Minute, - Issuer: issuer, - Audience: []string{audience}, - Subject: "origin", - } - - scopes := []token_scopes.TokenScope{} - readScope, err := token_scopes.Storage_Read.Path("/") - assert.NoError(t, err) - scopes = append(scopes, readScope) - modScope, err := token_scopes.Storage_Modify.Path("/") - assert.NoError(t, err) - scopes = append(scopes, modScope) - tokenConfig.AddScopes(scopes) - token, err := tokenConfig.CreateToken() - assert.NoError(t, err) - tempToken, err := os.CreateTemp(t.TempDir(), "token") - assert.NoError(t, err, "Error creating temp token file") - defer os.Remove(tempToken.Name()) - _, err = tempToken.WriteString(token) - assert.NoError(t, err, "Error writing to temp token file") - tempToken.Close() - // Disable progress bars to not reuse the same mpb instance - viper.Set("Logging.DisableProgressBars", true) - - // This tests object get/put with a pelican:// url - t.Run("testPelicanObjectPutAndGetWithPelicanUrl", func(t *testing.T) { - config.SetPreferredPrefix("pelican") - // Set path for object to upload/download - tempPath := tempFile.Name() - fileName := filepath.Base(tempPath) - uploadURL := "pelican:///test/" + fileName - - // Upload the file with PUT - transferResultsUpload, err := DoPut(ctx, tempFile.Name(), uploadURL, false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) - } - - // Download that same file with GET - transferResultsDownload, err := DoGet(ctx, uploadURL, t.TempDir(), false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) - } - }) - - // This tests pelican object get/put with an osdf url - t.Run("testPelicanObjectPutAndGetWithOSDFUrl", func(t *testing.T) { - config.SetPreferredPrefix("pelican") - // Set path for object to upload/download - tempPath := tempFile.Name() - fileName := filepath.Base(tempPath) - // Minimal fix of test as it is soon to be replaced - uploadURL := "pelican:///test/" + fileName - - // Upload the file with PUT - transferResultsUpload, err := DoPut(ctx, tempFile.Name(), uploadURL, false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) - } - - // Download that same file with GET - transferResultsDownload, err := DoGet(ctx, uploadURL, t.TempDir(), false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) - } - }) - - // This tests object get/put with a pelican:// url - t.Run("testOsdfObjectPutAndGetWithPelicanUrl", func(t *testing.T) { - config.SetPreferredPrefix("osdf") - // Set path for object to upload/download - tempPath := tempFile.Name() - fileName := filepath.Base(tempPath) - uploadURL := "pelican:///test/" + fileName - - // Upload the file with PUT - transferResultsUpload, err := DoPut(ctx, tempFile.Name(), uploadURL, false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) - } - - // Download that same file with GET - transferResultsDownload, err := DoGet(ctx, uploadURL, t.TempDir(), false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) - } - }) - - // This tests pelican object get/put with an osdf url - t.Run("testOsdfObjectPutAndGetWithOSDFUrl", func(t *testing.T) { - config.SetPreferredPrefix("osdf") - // Set path for object to upload/download - tempPath := tempFile.Name() - fileName := filepath.Base(tempPath) - // Minimal fix of test as it is soon to be replaced - uploadURL := "pelican:///test/" + fileName - - // Upload the file with PUT - transferResultsUpload, err := DoPut(ctx, tempFile.Name(), uploadURL, false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) - } - - // Download that same file with GET - transferResultsDownload, err := DoGet(ctx, uploadURL, t.TempDir(), false, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) - } - }) -} - -// A test that spins up the federation, where the origin is in EnablePublicReads mode. Then GET a file from the origin without a token -func TestGetPublicRead(t *testing.T) { - ctx, _, _ := test_utils.TestContext(context.Background(), t) - viper.Reset() - viper.Set("Origin.EnablePublicReads", true) - fed := FedTest{T: t} - fed.Spinup() - defer fed.Teardown() - t.Run("testPubObjGet", func(t *testing.T) { - testFileContent := "test file content" - // Drop the testFileContent into the origin directory - tempFile, err := os.Create(filepath.Join(fed.OriginDir, "test.txt")) - assert.NoError(t, err, "Error creating temp file") - defer os.Remove(tempFile.Name()) - _, err = tempFile.WriteString(testFileContent) - assert.NoError(t, err, "Error writing to temp file") - tempFile.Close() - - viper.Set("Logging.DisableProgressBars", true) - - // Set path for object to upload/download - tempPath := tempFile.Name() - fileName := filepath.Base(tempPath) - uploadURL := "pelican:///test/" + fileName - - // Download the file with GET. Shouldn't need a token to succeed - transferResults, err := DoGet(ctx, uploadURL, t.TempDir(), false) - assert.NoError(t, err) - if err == nil { - assert.Equal(t, transferResults[0].TransferredBytes, int64(17)) - } - }) -} - -func TestRecursiveUploadsAndDownloads(t *testing.T) { - // Create instance of test federation - ctx, _, _ := test_utils.TestContext(context.Background(), t) - - viper.Reset() - fed := FedTest{T: t} - fed.Spinup() - defer fed.Teardown() - - //////////////////////////SETUP/////////////////////////// - // Create a token file - issuer, err := config.GetServerIssuerURL() - require.NoError(t, err) - audience := config.GetServerAudience() - - tokenConfig := utils.TokenConfig{ - TokenProfile: utils.WLCG, - Lifetime: time.Minute, - Issuer: issuer, - Audience: []string{audience}, - Subject: "origin", - } - scopes := []token_scopes.TokenScope{} - readScope, err := token_scopes.Storage_Read.Path("/") - assert.NoError(t, err) - scopes = append(scopes, readScope) - modScope, err := token_scopes.Storage_Modify.Path("/") - assert.NoError(t, err) - scopes = append(scopes, modScope) - tokenConfig.AddScopes(scopes) - token, err := tokenConfig.CreateToken() - assert.NoError(t, err) - tempToken, err := os.CreateTemp(t.TempDir(), "token") - assert.NoError(t, err, "Error creating temp token file") - defer os.Remove(tempToken.Name()) - _, err = tempToken.WriteString(token) - assert.NoError(t, err, "Error writing to temp token file") - tempToken.Close() - - // Disable progress bars to not reuse the same mpb instance - viper.Set("Logging.DisableProgressBars", true) - - // Make our test directories and files - tempDir, err := os.MkdirTemp("", "UploadDir") - assert.NoError(t, err) - defer os.RemoveAll(tempDir) - permissions := os.FileMode(0777) - err = os.Chmod(tempDir, permissions) - require.NoError(t, err) - - testFileContent1 := "test file content" - testFileContent2 := "more test file content!" - tempFile1, err := os.CreateTemp(tempDir, "test1") - assert.NoError(t, err, "Error creating temp1 file") - tempFile2, err := os.CreateTemp(tempDir, "test1") - assert.NoError(t, err, "Error creating temp2 file") - defer os.Remove(tempFile1.Name()) - defer os.Remove(tempFile2.Name()) - _, err = tempFile1.WriteString(testFileContent1) - assert.NoError(t, err, "Error writing to temp1 file") - tempFile1.Close() - _, err = tempFile2.WriteString(testFileContent2) - assert.NoError(t, err, "Error writing to temp2 file") - tempFile2.Close() - - t.Run("testPelicanRecursiveGetAndPutOsdfURL", func(t *testing.T) { - config.SetPreferredPrefix("pelican") - // Set path for object to upload/download - tempPath := tempDir - dirName := filepath.Base(tempPath) - // Note: minimally fixing this test as it is soon to be replaced - uploadURL := "pelican://" + param.Server_Hostname.GetString() + ":" + strconv.Itoa(param.Server_WebPort.GetInt()) + "/test/" + dirName - - ////////////////////////////////////////////////////////// - - // Upload the file with PUT - transferDetailsUpload, err := DoPut(ctx, tempDir, uploadURL, true, WithTokenLocation(tempToken.Name())) - require.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytes17 := 0 - countBytes23 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case int64(17): - countBytes17++ - continue - case int64(23): - countBytes23++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not upload proper amount of bytes") - } - } - if countBytes17 != 1 || countBytes23 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not uploaded correctly") - } - } else if len(transferDetailsUpload) != 2 { - t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) - } - - // Download the files we just uploaded - transferDetailsDownload, err := DoGet(ctx, uploadURL, t.TempDir(), true, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytesUploadIdx0 := 0 - countBytesUploadIdx1 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - // In this case, we want to match them to the sizes of the uploaded files - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case transferDetailsUpload[0].TransferredBytes: - countBytesUploadIdx0++ - continue - case transferDetailsUpload[1].TransferredBytes: - countBytesUploadIdx1++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not download proper amount of bytes") - } - } - if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not downloaded correctly") - } else if len(transferDetailsDownload) != 2 { - t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) - } - } - }) - - t.Run("testPelicanRecursiveGetAndPutPelicanURL", func(t *testing.T) { - config.SetPreferredPrefix("pelican") - // Set path for object to upload/download - tempPath := tempDir - dirName := filepath.Base(tempPath) - uploadURL := "pelican:///test/" + dirName - - ////////////////////////////////////////////////////////// - - // Upload the file with PUT - transferDetailsUpload, err := DoPut(ctx, tempDir, uploadURL, true, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytes17 := 0 - countBytes23 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case int64(17): - countBytes17++ - continue - case int64(23): - countBytes23++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not upload proper amount of bytes") - } - } - if countBytes17 != 1 || countBytes23 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not uploaded correctly") - } - } else if len(transferDetailsUpload) != 2 { - t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) - } - - // Download the files we just uploaded - transferDetailsDownload, err := DoGet(ctx, uploadURL, t.TempDir(), true, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytesUploadIdx0 := 0 - countBytesUploadIdx1 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - // In this case, we want to match them to the sizes of the uploaded files - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case transferDetailsUpload[0].TransferredBytes: - countBytesUploadIdx0++ - continue - case transferDetailsUpload[1].TransferredBytes: - countBytesUploadIdx1++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not download proper amount of bytes") - } - } - if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not downloaded correctly") - } else if len(transferDetailsDownload) != 2 { - t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) - } - } - }) - - t.Run("testOsdfRecursiveGetAndPutOsdfURL", func(t *testing.T) { - config.SetPreferredPrefix("osdf") - // Set path for object to upload/download - tempPath := tempDir - dirName := filepath.Base(tempPath) - // Note: minimally fixing this test as it is soon to be replaced - uploadURL := "pelican://" + param.Server_Hostname.GetString() + ":" + strconv.Itoa(param.Server_WebPort.GetInt()) + "/test/" + dirName - - ////////////////////////////////////////////////////////// - - // Upload the file with PUT - transferDetailsUpload, err := DoPut(ctx, tempDir, uploadURL, true, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytes17 := 0 - countBytes23 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case int64(17): - countBytes17++ - continue - case int64(23): - countBytes23++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not upload proper amount of bytes") - } - } - if countBytes17 != 1 || countBytes23 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not uploaded correctly") - } - } else if len(transferDetailsUpload) != 2 { - t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) - } - - // Download the files we just uploaded - transferDetailsDownload, err := DoGet(ctx, uploadURL, t.TempDir(), true, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytesUploadIdx0 := 0 - countBytesUploadIdx1 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - // In this case, we want to match them to the sizes of the uploaded files - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case transferDetailsUpload[0].TransferredBytes: - countBytesUploadIdx0++ - continue - case transferDetailsUpload[1].TransferredBytes: - countBytesUploadIdx1++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not download proper amount of bytes") - } - } - if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not downloaded correctly") - } else if len(transferDetailsDownload) != 2 { - t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) - } - } - }) - - t.Run("testOsdfRecursiveGetAndPutPelicanURL", func(t *testing.T) { - config.SetPreferredPrefix("osdf") - // Set path for object to upload/download - tempPath := tempDir - dirName := filepath.Base(tempPath) - uploadURL := "pelican:///test/" + dirName - - ////////////////////////////////////////////////////////// - - // Upload the file with PUT - transferDetailsUpload, err := DoPut(ctx, tempDir, uploadURL, true, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytes17 := 0 - countBytes23 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case int64(17): - countBytes17++ - continue - case int64(23): - countBytes23++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not upload proper amount of bytes") - } - } - if countBytes17 != 1 || countBytes23 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not uploaded correctly") - } - } else if len(transferDetailsUpload) != 2 { - t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) - } - - // Download the files we just uploaded - transferDetailsDownload, err := DoGet(ctx, uploadURL, t.TempDir(), true, WithTokenLocation(tempToken.Name())) - assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytesUploadIdx0 := 0 - countBytesUploadIdx1 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - // In this case, we want to match them to the sizes of the uploaded files - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case transferDetailsUpload[0].TransferredBytes: - countBytesUploadIdx0++ - continue - case transferDetailsUpload[1].TransferredBytes: - countBytesUploadIdx1++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not download proper amount of bytes") - } - } - if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not downloaded correctly") - } else if len(transferDetailsDownload) != 2 { - t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) - } - } - }) - -} From 334fb4ee5e5def9fe9ac3e00fd2084f2f47355c1 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Wed, 28 Feb 2024 20:29:36 -0600 Subject: [PATCH 02/45] Create a new `file_cache` module (WIP) The file cache provides a simple caching interface. Only supporting full files - and minimal functionality - it's meant to support caching at worker ndoes. This commit is a work-in-progress, lacking tests and still needing integration with token-based authorization. --- client/handle_http.go | 21 +- config/resources/defaults.yaml | 3 + docs/parameters.yaml | 45 +++ file_cache/cache_api.go | 88 +++++ file_cache/simple_cache.go | 617 +++++++++++++++++++++++++++++++++ go.mod | 2 +- param/parameters.go | 6 + param/parameters_struct.go | 20 ++ 8 files changed, 798 insertions(+), 4 deletions(-) create mode 100644 file_cache/cache_api.go create mode 100644 file_cache/simple_cache.go diff --git a/client/handle_http.go b/client/handle_http.go index fefd9579b..c450bc12d 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -205,8 +205,9 @@ type ( ctx context.Context cancel context.CancelFunc callback TransferCallbackFunc - skipAcquire bool // Enable/disable the token acquisition logic. Defaults to acquiring a token - tokenLocation string + skipAcquire bool // Enable/disable the token acquisition logic. Defaults to acquiring a token + tokenLocation string // Location of a token file to use for transfers + token string // Token that should be used for transfers work chan *TransferJob closed bool caches []*url.URL @@ -220,6 +221,7 @@ type ( identTransferOptionCallback struct{} identTransferOptionTokenLocation struct{} identTransferOptionAcquireToken struct{} + identTransferOptionToken struct{} transferDetailsOptions struct { NeedsToken bool @@ -385,6 +387,13 @@ func WithTokenLocation(location string) TransferOption { return option.New(identTransferOptionTokenLocation{}, location) } +// Create an option to provide a specific token to the transfer +// +// The contents of the token will be used as part of the HTTP request +func WithToken(token string) TransferOption { + return option.New(identTransferOptionToken{}, token) +} + // Create an option to specify the token acquisition logic // // Token acquisition (e.g., using OAuth2 to get a token when one @@ -419,6 +428,8 @@ func (te *TransferEngine) NewClient(options ...TransferOption) (client *Transfer client.tokenLocation = option.Value().(string) case identTransferOptionAcquireToken{}: client.skipAcquire = !option.Value().(bool) + case identTransferOptionToken{}: + client.token = option.Value().(string) } } func() { @@ -738,6 +749,8 @@ func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, u tc.tokenLocation = option.Value().(string) case identTransferOptionAcquireToken{}: tc.skipAcquire = !option.Value().(bool) + case identTransferOptionToken{}: + tj.token = option.Value().(string) } } @@ -778,7 +791,7 @@ func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, u } tj.namespace = ns - if upload || ns.UseTokenOnRead { + if upload || ns.UseTokenOnRead && tj.token == "" { tj.token, err = getToken(remoteUrl, ns, true, "", tc.tokenLocation, !tj.skipAcquire) if err != nil { return nil, fmt.Errorf("failed to get token for transfer: %v", err) @@ -963,6 +976,8 @@ func (te *TransferEngine) createTransferFiles(job *clientTransferJob) (err error } }() + log.Debugln("Processing transfer job for URL", job.job.remoteURL.String()) + packOption := job.job.remoteURL.Query().Get("pack") if packOption != "" { log.Debugln("Will use unpack option value", packOption) diff --git a/config/resources/defaults.yaml b/config/resources/defaults.yaml index b2c7b3e0b..ec40165c3 100644 --- a/config/resources/defaults.yaml +++ b/config/resources/defaults.yaml @@ -46,6 +46,9 @@ Director: EnableBroker: true Cache: Port: 8442 +FileCache: + HighWaterMarkPercentage: 95 + LowWaterMarkPercentage: 85 Origin: NamespacePrefix: "" Multiuser: false diff --git a/docs/parameters.yaml b/docs/parameters.yaml index 6057eacc8..ddac09f74 100644 --- a/docs/parameters.yaml +++ b/docs/parameters.yaml @@ -609,6 +609,51 @@ default: path components: ["origin"] --- ############################ +# File-cache configs # +############################ +name: FileCache.RunLocation +description: >- + The directory for the runtime files of the file cache +type: string +root_default: /run/pelican/filecache +default: $XDG_RUNTIME_DIR/pelican/filecache +--- +name: FileCache.DataLocation +description: >- + The directory for the location of the cache data files - this is where the actual data in the cache is stored + for the file cache. +type: string +default: $PELICAN_FILECACHE_RUNLOCATION/cache +--- +name: FileCache.Socket +description: >- + The location of the socket used for client communication for the file cache +type: string +default: $PELICAN_FILECACHE_RUNLOCATION/cache.sock +--- +name: FileCache.Size +description: >- + The maximum size of the file cache. If not set, it is assumed the entire device can be used. +type: string +default: 0 +--- +name: FileCache.HighWaterMarkPercentage +description: >- + A percentage value where the cache cleanup routines will triggered. Once the cache usage + of completed files hits the high water mark, files will be deleted until the usage hits the + low water mark. +type: int +default: 95 +--- +name: FileCache.LowWaterMarkPercentage +description: >- + A percentage value where the cache cleanup routines will complete. Once the cache usage + of completed files hits the high water mark, files will be deleted until the usage hits the + low water mark. +type: int +default: 85 +--- +############################ # Cache-level configs # ############################ name: Cache.DataLocation diff --git a/file_cache/cache_api.go b/file_cache/cache_api.go new file mode 100644 index 000000000..eac573f12 --- /dev/null +++ b/file_cache/cache_api.go @@ -0,0 +1,88 @@ +/*************************************************************** + * + * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package simple_cache + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "strings" + + "github.com/pelicanplatform/pelican/param" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" +) + +// Launch the unix socket listener as a separate goroutine +func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { + socketName := param.FileCache_DataLocation.GetString() + listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: socketName, Net: "unix"}) + if err != nil { + return err + } + sc, err := NewSimpleCache(ctx, egrp) + if err != nil { + return err + } + + handler := func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + transferStatusStr := r.Header.Get("X-Transfer-Status") + sendTrailer := false + if transferStatusStr == "true" { + for _, encoding := range r.Header.Values("TE") { + if encoding == "trailers" { + sendTrailer = true + break + } + } + } + + bearerToken := r.Header.Get("Authorization") + bearerToken = strings.TrimPrefix(bearerToken, "Bearer ") + reader, err := sc.Get(r.URL.Path, bearerToken) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Errorln("Failed to get file from cache") + return + } + w.WriteHeader(http.StatusOK) + if _, err = io.Copy(w, reader); err != nil && sendTrailer { + // TODO: Enumerate more error values + w.Header().Set("X-Transfer-Status", fmt.Sprintf("%d: %s", 500, err)) + } else if sendTrailer { + w.Header().Set("X-Transfer-Status", "200: OK") + } + } + srv := http.Server{ + Handler: http.HandlerFunc(handler), + } + egrp.Go(func() error { + return srv.Serve(listener) + }) + egrp.Go(func() error { + return srv.Shutdown(ctx) + }) + return nil +} diff --git a/file_cache/simple_cache.go b/file_cache/simple_cache.go new file mode 100644 index 000000000..047388701 --- /dev/null +++ b/file_cache/simple_cache.go @@ -0,0 +1,617 @@ +/*************************************************************** + * + * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package simple_cache + +import ( + "container/heap" + "context" + "io" + "net/url" + "os" + "path" + "path/filepath" + "reflect" + "slices" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/alecthomas/units" + "github.com/google/uuid" + "github.com/pelicanplatform/pelican/client" + "github.com/pelicanplatform/pelican/param" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" +) + +type ( + SimpleCache struct { + ctx context.Context + egrp *errgroup.Group + te *client.TransferEngine + tc *client.TransferClient + cancelReq chan cancelReq + basePath string + sizeReq chan availSizeReq + mutex sync.RWMutex + downloads map[string]*activeDownload + directorURL *url.URL + + // Cache static configuration + highWater uint64 + lowWater uint64 + + // LRU implementation + hitChan chan lruEntry // Notifies the central handler the cache has been used + lru lru // Manages a LRU of cache entries + lruLookup map[string]*lruEntry + cacheSize uint64 // Total cache size + } + + lruEntry struct { + lastUse time.Time + path string + size int64 + } + + lru []*lruEntry + + waiterInfo struct { + id uuid.UUID + size int64 + notify chan *downloadStatus + } + + // The waiters type fulfills the heap interface, allowing + // them to be used as a sorted priority queue + waiters []waiterInfo + + activeDownload struct { + tj *client.TransferJob + status *downloadStatus + waiterList waiters + } + + downloadStatus struct { + curSize atomic.Int64 + size atomic.Int64 + err atomic.Pointer[error] + done atomic.Bool + } + + cacheReader struct { + sc *SimpleCache + offset int64 + path string + token string + size int64 + avail int64 + fdOnce sync.Once + fd *os.File + openErr error + status chan *downloadStatus + } + + req struct { + id uuid.UUID + path string + token string + } + + cancelReq struct { + req req + done chan bool + } + + availSizeReq struct { + request req + size int64 + results chan *downloadStatus + } +) + +const ( + reqSize = 2 * 1024 * 1024 +) + +func newRequest(path, token string) (req req, err error) { + req.id, err = uuid.NewV7() + if err != nil { + return + } + req.path = path + req.token = token + return +} + +func (waiters waiters) Len() int { + return len(waiters) +} + +func (waiters waiters) Less(i, j int) bool { + return waiters[i].size < waiters[j].size +} + +func (waiters waiters) Swap(i, j int) { + waiters[i], waiters[j] = waiters[j], waiters[i] +} + +func (waiters *waiters) Push(x any) { + *waiters = append(*waiters, x.(waiterInfo)) +} + +func (waiters *waiters) Pop() any { + old := *waiters + n := len(old) + x := old[n-1] + *waiters = old[0 : n-1] + return x +} + +func (lru lru) Len() int { + return len(lru) +} + +func (lru lru) Less(i, j int) bool { + return lru[i].lastUse.Before(lru[j].lastUse) +} + +func (lru lru) Swap(i, j int) { + lru[i], lru[j] = lru[j], lru[i] +} + +func (lru *lru) Push(x any) { + *lru = append(*lru, x.(*lruEntry)) +} + +func (lru *lru) Pop() any { + old := *lru + n := len(old) + x := old[n-1] + *lru = old[0 : n-1] + return x +} + +// Create a simple cache object +// +// Launches background goroutines associated with the cache +func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, err error) { + + // Setup cache on disk + cacheDir := param.FileCache_DataLocation.GetString() + if cacheDir == "" { + err = errors.New("FileCache.DataLocation is not set; cannot determine where to place file cache's data") + return + } + if err = os.MkdirAll(cacheDir, os.FileMode(0700)); err != nil { + return + } + if err = os.RemoveAll(cacheDir); err != nil { + return + } + + sizeStr := param.FileCache_Size.GetString() + var cacheSize uint64 + if sizeStr == "" || sizeStr == "0" { + var stat syscall.Statfs_t + if err = syscall.Statfs(cacheDir, &stat); err != nil { + err = errors.Wrap(err, "Unable to determine free space for cache directory") + return + } + cacheSize = stat.Bavail * uint64(stat.Bsize) + } else { + var signedCacheSize int64 + signedCacheSize, err = units.ParseStrictBytes(param.FileCache_Size.GetString()) + if err != nil { + return + } + cacheSize = uint64(signedCacheSize) + } + highWaterPercentage := param.FileCache_HighWaterMarkPercentage.GetInt() + lowWaterPercentage := param.FileCache_LowWaterMarkPercentage.GetInt() + + sc = &SimpleCache{ + ctx: ctx, + egrp: egrp, + te: client.NewTransferEngine(ctx), + downloads: make(map[string]*activeDownload), + hitChan: make(chan lruEntry, 64), + highWater: (cacheSize / 100) * uint64(highWaterPercentage), + lowWater: (cacheSize / 100) * uint64(lowWaterPercentage), + cacheSize: cacheSize, + basePath: cacheDir, + } + + sc.tc, err = sc.te.NewClient(client.WithAcquireToken(false), client.WithCallback(sc.callback)) + if err != nil { + shutdownErr := sc.te.Shutdown() + if shutdownErr != nil { + log.Errorln("Failed to shutdown transfer engine") + } + return + } + + egrp.Go(sc.runMux) + + return +} + +// Callback for in-progress transfers +// +// The TransferClient will invoke the callback as it progresses; +// the callback info will be used to help the waiters progress. +func (sc *SimpleCache) callback(path string, downloaded int64, size int64, completed bool) { + ds := func() (ds *downloadStatus) { + sc.mutex.RLock() + defer sc.mutex.Unlock() + dl := sc.downloads[path] + if dl != nil { + ds = dl.status + } + return + }() + if ds != nil { + ds.curSize.Store(downloaded) + ds.size.Store(size) + ds.done.Store(completed) + } +} + +// The main goroutine for managing the cache and its requests +func (sc *SimpleCache) runMux() error { + results := sc.tc.Results() + + type result struct { + path string + ds *downloadStatus + channel chan *downloadStatus + } + tmpResults := make([]result, 0) + cancelRequest := make([]chan bool, 0) + activeJobs := make(map[string]*activeDownload) + ticker := time.NewTicker(100 * time.Millisecond) + clientClosed := false + for { + lenResults := len(tmpResults) + lenCancel := len(cancelRequest) + lenChan := lenResults + lenCancel + cases := make([]reflect.SelectCase, lenResults+6) + jobPath := make(map[uuid.UUID]string) + for idx, info := range tmpResults { + cases[idx].Dir = reflect.SelectSend + cases[idx].Chan = reflect.ValueOf(tmpResults[idx]) + cases[idx].Send = reflect.ValueOf(&activeJobs[info.path].status) + } + for idx, channel := range cancelRequest { + cases[lenResults+idx].Dir = reflect.SelectSend + cases[lenResults+idx].Chan = reflect.ValueOf(channel) + cases[lenResults+idx].Send = reflect.ValueOf(true) + } + cases[lenChan].Dir = reflect.SelectRecv + cases[lenChan].Chan = reflect.ValueOf(sc.ctx.Done()) + cases[lenChan+1].Dir = reflect.SelectRecv + cases[lenChan+1].Chan = reflect.ValueOf(results) + if clientClosed { + cases[lenChan+1].Chan = reflect.ValueOf(nil) + } + cases[lenChan+2].Dir = reflect.SelectRecv + cases[lenChan+2].Chan = reflect.ValueOf(ticker.C) + cases[lenChan+3].Dir = reflect.SelectRecv + cases[lenChan+3].Chan = reflect.ValueOf(sc.sizeReq) + cases[lenChan+4].Dir = reflect.SelectRecv + cases[lenChan+4].Chan = reflect.ValueOf(sc.cancelReq) + cases[lenChan+5].Dir = reflect.SelectRecv + cases[lenChan+5].Chan = reflect.ValueOf(sc.hitChan) + chosen, recv, ok := reflect.Select(cases) + + if chosen < lenResults { + // Sent a result to the waiter + slices.Delete(tmpResults, chosen, chosen+1) + } else if chosen < lenChan { + // Acknowledged a cancellation + slices.Delete(cancelRequest, chosen-lenResults, chosen-lenResults+1) + } else if chosen == lenChan { + // Cancellation; shut down + return nil + } else if chosen == lenChan+1 { + // New transfer results + if !ok { + // Client has closed, last notification for everyone + for path, ad := range activeJobs { + ad.status.done.Store(true) + for _, waiter := range ad.waiterList { + tmpResults = append(tmpResults, result{path: path, channel: waiter.notify}) + } + } + clientClosed = true + continue + } + results := recv.Interface().(*client.TransferResults) + path := jobPath[results.JobId] + delete(jobPath, results.JobId) + ad := activeJobs[path] + delete(activeJobs, path) + ad.status.err.Store(&results.Error) + ad.status.curSize.Store(results.TransferredBytes) + ad.status.size.Store(results.TransferredBytes) + ad.status.done.Store(true) + for _, waiter := range ad.waiterList { + tmpResults = append(tmpResults, result{path: path, channel: waiter.notify}) + } + if results.Error == nil { + entry := sc.lruLookup[path] + if entry == nil { + entry = &lruEntry{} + sc.lruLookup[path] = entry + entry.size = results.TransferredBytes + sc.cacheSize += uint64(entry.size) + sc.lru = append(sc.lru, entry) + } else { + entry.lastUse = time.Now() + } + } + } else if chosen == lenChan+2 { + // Ticker has fired - update progress + for path, dl := range activeJobs { + curSize := dl.status.curSize.Load() + for { + if dl.waiterList.Len() > 0 && dl.waiterList[0].size <= curSize { + waiter := heap.Pop(&dl.waiterList).(waiterInfo) + tmpResults = append(tmpResults, result{path: path, channel: waiter.notify, ds: dl.status}) + } + } + } + } else if chosen == lenChan+3 { + // New request + req := recv.Interface().(availSizeReq) + + // See if we can add the request to the waiter list + if ds := activeJobs[req.request.path]; ds != nil { + heap.Push(&ds.waiterList, waiterInfo{ + size: req.size, + notify: req.results, + }) + continue + } + // Start a new download + localPath := filepath.Join(sc.basePath, path.Clean(req.request.path)) + + // Ensure there's no .DONE file placed since the request was made. + if fpDone, err := os.Open(localPath + ".DONE"); err == nil { + fpDone.Close() + ds := &downloadStatus{} + ds.done.Store(true) + if fi, err := os.Stat(localPath); err == nil { + ds.curSize.Store(fi.Size()) + ds.size.Store(fi.Size()) + tmpResults = append(tmpResults, result{ + path: req.request.path, + channel: req.results, + ds: ds, + }) + } + } + + sourceURL := *sc.directorURL + sourceURL.Path = path.Join(sourceURL.Path, path.Clean(req.request.path)) + tj, err := sc.tc.NewTransferJob(&sourceURL, localPath, false, false, client.WithToken(req.request.token)) + if err != nil { + ds := &downloadStatus{} + ds.err.Store(&err) + tmpResults = append(tmpResults, result{ + path: req.request.path, + channel: req.results, + ds: ds, + }) + continue + } + ad := &activeDownload{ + tj: tj, + status: &downloadStatus{}, + waiterList: make(waiters, 0), + } + ad.waiterList = append(ad.waiterList, waiterInfo{ + size: req.size, + notify: req.results, + }) + activeJobs[req.request.path] = ad + } else if chosen == lenChan+4 { + // Cancel a given request. + req := recv.Interface().(cancelReq) + ds := activeJobs[req.req.path] + if ds != nil { + var idx int + found := false + var waiter waiterInfo + for idx, waiter = range ds.waiterList { + if waiter.id == req.req.id { + break + } + } + if found { + heap.Remove(&ds.waiterList, idx) + } + } + cancelRequest = append(cancelRequest, req.done) + } else if chosen == lenChan+5 { + // Notification there was a cache hit. + hit := recv.Interface().(lruEntry) + entry := sc.lruLookup[hit.path] + if entry == nil { + entry = &lruEntry{} + sc.lruLookup[hit.path] = entry + entry.size = hit.size + sc.lru = append(sc.lru, entry) + sc.cacheSize += uint64(hit.size) + if sc.cacheSize > sc.highWater { + sc.purge() + } + } + entry.lastUse = hit.lastUse + } + } +} + +func (sc *SimpleCache) purge() { + heap.Init(&sc.lru) + start := time.Now() + for sc.cacheSize > sc.lowWater { + entry := heap.Pop(&sc.lru).(*lruEntry) + localPath := path.Join(sc.basePath, path.Clean(entry.path)) + if err := os.Remove(localPath + ".DONE"); err != nil { + log.Warningln("Failed to purge DONE file:", err) + } + if err := os.Remove(localPath); err != nil { + log.Warningln("Failed to purge file:", err) + } + sc.cacheSize -= uint64(entry.size) + // Since purge is called from the mux thread, blocking can cause + // other failures; do a time-based break even if we've not hit the low-water + if time.Since(start) > 3*time.Second { + break + } + } +} + +// Given a URL, return a reader from the disk cache +// +// If there is no sentinal $NAME.DONE file, then returns nil +func (sc *SimpleCache) getFromDisk(localPath string) io.ReadCloser { + localPath = filepath.Join(sc.basePath, path.Clean(localPath)) + fp, err := os.Open(localPath + ".DONE") + if err != nil { + return nil + } + defer fp.Close() + if fpReal, err := os.Open(localPath); err == nil { + return fpReal + } + return nil +} + +func (sc *SimpleCache) newCacheReader(path, token string) (reader *cacheReader, err error) { + reader = &cacheReader{ + path: path, + token: token, + sc: sc, + size: -1, + status: nil, + } + return +} + +// Get path from the cache +func (sc *SimpleCache) Get(path, token string) (io.ReadCloser, error) { + if fp := sc.getFromDisk(path); fp != nil { + return fp, nil + } + + return sc.newCacheReader(path, token) + +} + +// Read bytes from a file in the cache +// +// Does not request more data if bytes are not found +func (cr *cacheReader) readFromFile(p []byte, off int64) (n int, err error) { + cr.fdOnce.Do(func() { + cr.fd, cr.openErr = os.Open(filepath.Join(cr.sc.basePath, path.Clean(cr.path))) + }) + if cr.openErr != nil { + err = cr.openErr + return + } + return cr.fd.ReadAt(p, off) +} + +func (cr *cacheReader) Read(p []byte) (n int, err error) { + neededSize := cr.offset + int64(len(p)) + if cr.size >= 0 && neededSize > cr.size { + neededSize = cr.size + } + if neededSize > cr.avail { + // Insufficient available data; request more from the cache + if cr.status == nil { + // Send a request to the cache + var req req + req, err = newRequest(cr.path, cr.token) + if err != nil { + return + } + + // Bump up the size we're waiting on; only get notifications every 2MB + if len(p) < reqSize { + if cr.size >= 0 && cr.offset+reqSize > cr.size { + neededSize = cr.size + } else { + neededSize = cr.offset + reqSize + } + } + cr.status = make(chan *downloadStatus) + sizeReq := availSizeReq{ + request: req, + size: neededSize, + } + cr.sc.sizeReq <- sizeReq + } + select { + case <-cr.sc.ctx.Done(): + return 0, cr.sc.ctx.Err() + case availSize, ok := <-cr.status: + cr.status = nil + if !ok { + err = errors.New("unable to get response from cache engine") + return + } + dlErr := availSize.err.Load() + if dlErr != nil && *dlErr != nil { + err = *dlErr + return + } + done := availSize.done.Load() + dlSize := availSize.curSize.Load() + cr.size = availSize.size.Load() + cr.avail = dlSize + if dlSize < neededSize && !done { + err = errors.New("download thread returned too-short read") + return + } else { + n, err = cr.readFromFile(p, cr.offset) + if err != nil && err != io.EOF { + return + } + cr.offset += int64(n) + return + } + } + } else { + n, err = cr.readFromFile(p, cr.offset) + if err != nil && err != io.EOF { + return + } + cr.offset += int64(n) + return + } +} + +func (cr *cacheReader) Close() error { + return nil +} diff --git a/go.mod b/go.mod index 450cf368d..702f5b2f9 100644 --- a/go.mod +++ b/go.mod @@ -75,6 +75,7 @@ require ( go.opentelemetry.io/collector/pdata v1.0.0-rcv0016 // indirect go.opentelemetry.io/collector/semconv v0.87.0 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/sys v0.15.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f // indirect google.golang.org/grpc v1.59.0 // indirect modernc.org/sqlite v1.28.0 // indirect @@ -181,7 +182,6 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/sync v0.6.0 - golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 golang.org/x/time v0.5.0 google.golang.org/appengine v1.6.7 // indirect diff --git a/param/parameters.go b/param/parameters.go index da3181d78..a69741bc8 100644 --- a/param/parameters.go +++ b/param/parameters.go @@ -90,6 +90,10 @@ var ( Federation_RegistryUrl = StringParam{"Federation.RegistryUrl"} Federation_TopologyNamespaceUrl = StringParam{"Federation.TopologyNamespaceUrl"} Federation_TopologyUrl = StringParam{"Federation.TopologyUrl"} + FileCache_DataLocation = StringParam{"FileCache.DataLocation"} + FileCache_RunLocation = StringParam{"FileCache.RunLocation"} + FileCache_Size = StringParam{"FileCache.Size"} + FileCache_Socket = StringParam{"FileCache.Socket"} IssuerKey = StringParam{"IssuerKey"} Issuer_AuthenticationSource = StringParam{"Issuer.AuthenticationSource"} Issuer_GroupFile = StringParam{"Issuer.GroupFile"} @@ -201,6 +205,8 @@ var ( Director_MaxStatResponse = IntParam{"Director.MaxStatResponse"} Director_MinStatResponse = IntParam{"Director.MinStatResponse"} Director_StatConcurrencyLimit = IntParam{"Director.StatConcurrencyLimit"} + FileCache_HighWaterMarkPercentage = IntParam{"FileCache.HighWaterMarkPercentage"} + FileCache_LowWaterMarkPercentage = IntParam{"FileCcache.LowWaterMarkPercentage"} MinimumDownloadSpeed = IntParam{"MinimumDownloadSpeed"} Monitoring_PortHigher = IntParam{"Monitoring.PortHigher"} Monitoring_PortLower = IntParam{"Monitoring.PortLower"} diff --git a/param/parameters_struct.go b/param/parameters_struct.go index 2ddea8b13..348838ea2 100644 --- a/param/parameters_struct.go +++ b/param/parameters_struct.go @@ -72,6 +72,16 @@ type Config struct { TopologyReloadInterval time.Duration TopologyUrl string } + FileCache struct { + DataLocation string + HighWaterMarkPercentage int + RunLocation string + Size string + Socket string + } + FileCcache struct { + LowWaterMarkPercentage int + } GeoIPOverrides interface{} Issuer struct { AuthenticationSource string @@ -296,6 +306,16 @@ type configWithType struct { TopologyReloadInterval struct { Type string; Value time.Duration } TopologyUrl struct { Type string; Value string } } + FileCache struct { + DataLocation struct { Type string; Value string } + HighWaterMarkPercentage struct { Type string; Value int } + RunLocation struct { Type string; Value string } + Size struct { Type string; Value string } + Socket struct { Type string; Value string } + } + FileCcache struct { + LowWaterMarkPercentage struct { Type string; Value int } + } GeoIPOverrides struct { Type string; Value interface{} } Issuer struct { AuthenticationSource struct { Type string; Value string } From 05388e85e211b2584a9962018158f6eb2054017f Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Wed, 28 Feb 2024 20:36:53 -0600 Subject: [PATCH 03/45] Start integration between the file cache and client/launchers This adds the `file_cache` as a top-level module in Pelican that can be invoked as part of launchers. It also adds support for the client to accept caches of the form `unix://`. --- client/director.go | 44 +++++++++++++++++++++++++++++-------------- client/handle_http.go | 8 ++++++++ config/config.go | 22 ++++++++++++++++++---- launchers/launcher.go | 10 ++++++++++ 4 files changed, 66 insertions(+), 18 deletions(-) diff --git a/client/director.go b/client/director.go index 39f19bdc4..9069e4067 100644 --- a/client/director.go +++ b/client/director.go @@ -23,6 +23,7 @@ import ( "io" "net/http" "net/url" + "path" "sort" "strconv" "strings" @@ -244,7 +245,9 @@ func NewTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transf log.Errorln("Failed to parse cache:", cache, "error:", err) return nil } - if cacheURL.Host == "" { + if cacheURL.Scheme == "unix" && cacheURL.Host != "" { + cacheURL.Path = path.Clean("/" + path.Join(cacheURL.Host, cacheURL.Path)) + } else if cacheURL.Host == "" { // Assume the cache is just a hostname cacheURL.Host = cacheEndpoint cacheURL.Path = "" @@ -253,18 +256,21 @@ func NewTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transf } log.Debugf("Parsed Cache: %s", cacheURL.String()) if opts.NeedsToken { - cacheURL.Scheme = "https" - if !hasPort(cacheURL.Host) { - // Add port 8444 and 8443 - urlCopy := *cacheURL - urlCopy.Host += ":8444" - details = append(details, transferAttemptDetails{ - Url: &urlCopy, - Proxy: false, - PackOption: opts.PackOption, - }) - // Strip the port off and add 8443 - cacheURL.Host = cacheURL.Host + ":8443" + // Unless we're using the local Unix domain socket cache, force HTTPS + if cacheURL.Scheme != "unix" { + cacheURL.Scheme = "https" + if !hasPort(cacheURL.Host) { + // Add port 8444 and 8443 + urlCopy := *cacheURL + urlCopy.Host += ":8444" + details = append(details, transferAttemptDetails{ + Url: &urlCopy, + Proxy: false, + PackOption: opts.PackOption, + }) + // Strip the port off and add 8443 + cacheURL.Host = cacheURL.Host + ":8443" + } } // Whether port is specified or not, add a transfer without proxy details = append(details, transferAttemptDetails{ @@ -272,7 +278,9 @@ func NewTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transf Proxy: false, PackOption: opts.PackOption, }) - } else { + } else if cacheURL.Scheme == "" || cacheURL.Scheme == "http" { + // Assume a transfer not needing a token and not specifying a scheme is HTTP + // WARNING: This is legacy code; we should always specify a scheme cacheURL.Scheme = "http" if !hasPort(cacheURL.Host) { cacheURL.Host += ":8000" @@ -290,6 +298,14 @@ func NewTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transf PackOption: opts.PackOption, }) } + } else { + // A non-HTTP scheme is specified and a token is not needed; this wasn't possible + // in the legacy cases. Simply use the provided config + details = append(details, transferAttemptDetails{ + Url: cacheURL, + Proxy: false, + PackOption: opts.PackOption, + }) } return details diff --git a/client/handle_http.go b/client/handle_http.go index c450bc12d..b4fe54030 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1246,6 +1246,14 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall if !transfer.Proxy { transport.Proxy = nil } + if transfer.Url.Scheme == "unix" { + transport.Proxy = nil // Proxies make no sense when reading via a Unix socket + transport = transport.Clone() + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "unix", transfer.Url.Path) + } + } httpClient, ok := client.HTTPClient.(*http.Client) if !ok { return 0, 0, "", errors.New("Internal error: implementation is not a http.Client type") diff --git a/config/config.go b/config/config.go index 1d1df49bc..a2130a0d6 100644 --- a/config/config.go +++ b/config/config.go @@ -30,7 +30,6 @@ import ( "net/http" "net/url" "os" - "path" "path/filepath" "sort" "strconv" @@ -101,6 +100,7 @@ const ( DirectorType RegistryType BrokerType + LocalCacheType EgrpKey ContextKey = "egrp" ) @@ -246,6 +246,9 @@ func GetEnabledServerString(lowerCase bool) []string { if enabledServers.IsEnabled(CacheType) { servers = append(servers, CacheType.String()) } + if enabledServers.IsEnabled(LocalCacheType) { + servers = append(servers, LocalCacheType.String()) + } if enabledServers.IsEnabled(OriginType) { servers = append(servers, OriginType.String()) } @@ -278,6 +281,8 @@ func (sType ServerType) String() string { switch sType { case CacheType: return "Cache" + case LocalCacheType: + return "LocalCache" case OriginType: return "Origin" case DirectorType: @@ -295,6 +300,8 @@ func (sType *ServerType) SetString(name string) bool { case "cache": *sType |= CacheType return true + case "localcache": + *sType |= LocalCacheType case "origin": *sType |= OriginType return true @@ -815,6 +822,8 @@ func InitServer(ctx context.Context, currentServers ServerType) error { viper.SetDefault("Cache.RunLocation", filepath.Join("/run", "pelican", "xrootd", "cache")) } viper.SetDefault("Cache.DataLocation", "/run/pelican/xcache") + viper.SetDefault("FileCache.RunLocation", filepath.Join("/run", "pelican", "filecache")) + viper.SetDefault("Origin.Multiuser", true) viper.SetDefault("Director.GeoIPLocation", "/var/cache/pelican/maxmind/GeoLite2-City.mmdb") viper.SetDefault("Registry.DbLocation", "/var/lib/pelican/registry.sqlite") @@ -828,8 +837,9 @@ func InitServer(ctx context.Context, currentServers ServerType) error { viper.SetDefault("Shoveler.QueueDirectory", filepath.Join(configDir, "shoveler/queue")) viper.SetDefault("Shoveler.AMQPTokenLocation", filepath.Join(configDir, "shoveler-token")) + var runtimeDir string if userRuntimeDir := os.Getenv("XDG_RUNTIME_DIR"); userRuntimeDir != "" { - runtimeDir := filepath.Join(userRuntimeDir, "pelican") + runtimeDir = filepath.Join(userRuntimeDir, "pelican") err := os.MkdirAll(runtimeDir, 0750) if err != nil { return err @@ -839,7 +849,6 @@ func InitServer(ctx context.Context, currentServers ServerType) error { if err != nil { return err } - viper.SetDefault("Cache.DataLocation", path.Join(runtimeDir, "xcache")) } else { dir, err := os.MkdirTemp("", "pelican-xrootd-*") if err != nil { @@ -849,11 +858,16 @@ func InitServer(ctx context.Context, currentServers ServerType) error { if err != nil { return err } - viper.SetDefault("Cache.DataLocation", path.Join(dir, "xcache")) cleanupDirOnShutdown(ctx, dir) } + viper.SetDefault("Cache.DataLocation", filepath.Join(runtimeDir, "xcache")) + viper.SetDefault("FileCache.RunLocation", filepath.Join(runtimeDir, "cache")) viper.SetDefault("Origin.Multiuser", false) } + fcRunLocation := viper.GetString("FileCache.RunLocation") + viper.SetDefault("FileCache.Socket", filepath.Join(fcRunLocation, "cache.sock")) + viper.SetDefault("FileCache.DataLocation", filepath.Join(fcRunLocation, "cache")) + // Any platform-specific paths should go here err := InitServerOSDefaults() if err != nil { diff --git a/launchers/launcher.go b/launchers/launcher.go index 1873168a4..0cb371385 100644 --- a/launchers/launcher.go +++ b/launchers/launcher.go @@ -34,6 +34,7 @@ import ( "github.com/pelicanplatform/pelican/broker" "github.com/pelicanplatform/pelican/config" + "github.com/pelicanplatform/pelican/file_cache" "github.com/pelicanplatform/pelican/origin_ui" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/server_ui" @@ -181,6 +182,15 @@ func LaunchModules(ctx context.Context, modules config.ServerType) (context.Canc } } } + + if modules.IsEnabled(config.LocalCacheType) { + log.Debugln("Starting local cache listener") + if err := simple_cache.LaunchListener(ctx, egrp); err != nil { + log.Errorln("Failure when starting the local cache listener:", err) + return shutdownCancel, err + } + } + log.Info("Starting web engine...") lnReference = nil egrp.Go(func() error { From 171f6c502d53b0feddd647e57a52f8ef152a6a35 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 2 Mar 2024 16:47:28 -0600 Subject: [PATCH 04/45] Revise `MakeRequest` to take a context If we want to keep most of Pelican cancellable, then any network calls need to take a context. --- cache_ui/advertise.go | 5 +++-- registry/client_commands.go | 15 ++++++++------- utils/web_utils.go | 5 +++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/cache_ui/advertise.go b/cache_ui/advertise.go index ddbb57242..21d0aec43 100644 --- a/cache_ui/advertise.go +++ b/cache_ui/advertise.go @@ -19,6 +19,7 @@ package cache_ui import ( + "context" "encoding/json" "net/url" "strings" @@ -76,14 +77,14 @@ func (server *CacheServer) GetNamespaceAdsFromDirector() error { // Attempt to get data from the 2.0 endpoint, if that returns a 404 error, then attempt to get data // from the 1.0 endpoint and convert from V1 to V2 - respData, err := utils.MakeRequest(directorNSListEndpointURL, "GET", nil, nil) + respData, err := utils.MakeRequest(context.Background(), directorNSListEndpointURL, "GET", nil, nil) if err != nil { if strings.Contains(err.Error(), "404") { directorNSListEndpointURL, err = url.JoinPath(directorEndpoint, "api", "v1.0", "director", "listNamespaces") if err != nil { return err } - respData, err = utils.MakeRequest(directorNSListEndpointURL, "GET", nil, nil) + respData, err = utils.MakeRequest(context.Background(), directorNSListEndpointURL, "GET", nil, nil) var respNSV1 []common.NamespaceAdV1 if err != nil { return errors.Wrap(err, "Failed to make request") diff --git a/registry/client_commands.go b/registry/client_commands.go index fd0910dc6..006cd0714 100644 --- a/registry/client_commands.go +++ b/registry/client_commands.go @@ -20,6 +20,7 @@ package registry import ( "bufio" + "context" "crypto/ecdsa" "encoding/hex" "encoding/json" @@ -56,7 +57,7 @@ func NamespaceRegisterWithIdentity(privateKey jwk.Key, namespaceRegistryEndpoint // it's also registered already } - resp, err := utils.MakeRequest(namespaceRegistryEndpoint, "POST", identifiedPayload, nil) + resp, err := utils.MakeRequest(context.Background(), namespaceRegistryEndpoint, "POST", identifiedPayload, nil) var respData clientResponseData // Handle case where there was an error encoded in the body @@ -79,7 +80,7 @@ func NamespaceRegisterWithIdentity(privateKey jwk.Key, namespaceRegistryEndpoint "identity_required": "true", "device_code": respData.DeviceCode, } - resp, err = utils.MakeRequest(namespaceRegistryEndpoint, "POST", identifiedPayload, nil) + resp, err = utils.MakeRequest(context.Background(), namespaceRegistryEndpoint, "POST", identifiedPayload, nil) if err != nil { return errors.Wrap(err, "Failed to make request") } @@ -135,7 +136,7 @@ func NamespaceRegister(privateKey jwk.Key, namespaceRegistryEndpoint string, acc "pubkey": keySet, } - resp, err := utils.MakeRequest(namespaceRegistryEndpoint, "POST", data, nil) + resp, err := utils.MakeRequest(context.Background(), namespaceRegistryEndpoint, "POST", data, nil) var respData clientResponseData // Handle case where there was an error encoded in the body @@ -179,7 +180,7 @@ func NamespaceRegister(privateKey jwk.Key, namespaceRegistryEndpoint string, acc } // Send the second POST request - resp, err = utils.MakeRequest(namespaceRegistryEndpoint, "POST", unidentifiedPayload, nil) + resp, err = utils.MakeRequest(context.Background(), namespaceRegistryEndpoint, "POST", unidentifiedPayload, nil) // Handle case where there was an error encoded in the body if unmarshalErr := json.Unmarshal(resp, &respData); unmarshalErr == nil { @@ -198,7 +199,7 @@ func NamespaceRegister(privateKey jwk.Key, namespaceRegistryEndpoint string, acc } func NamespaceList(endpoint string) error { - respData, err := utils.MakeRequest(endpoint, "GET", nil, nil) + respData, err := utils.MakeRequest(context.Background(), endpoint, "GET", nil, nil) var respErr clientResponseData if err != nil { if jsonErr := json.Unmarshal(respData, &respErr); jsonErr == nil { // Error creating json @@ -211,7 +212,7 @@ func NamespaceList(endpoint string) error { } func NamespaceGet(endpoint string) error { - respData, err := utils.MakeRequest(endpoint, "GET", nil, nil) + respData, err := utils.MakeRequest(context.Background(), endpoint, "GET", nil, nil) var respErr clientResponseData if err != nil { if jsonErr := json.Unmarshal(respData, &respErr); jsonErr == nil { // Error creating json @@ -265,7 +266,7 @@ func NamespaceDelete(endpoint string, prefix string) error { "Authorization": "Bearer " + tok, } - respData, err := utils.MakeRequest(endpoint, "DELETE", nil, authHeader) + respData, err := utils.MakeRequest(context.Background(), endpoint, "DELETE", nil, authHeader) var respErr clientResponseData if err != nil { if unmarshalErr := json.Unmarshal(respData, &respErr); unmarshalErr == nil { // Error creating json diff --git a/utils/web_utils.go b/utils/web_utils.go index eeb7e87cc..7e679f41f 100644 --- a/utils/web_utils.go +++ b/utils/web_utils.go @@ -20,6 +20,7 @@ package utils import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -74,9 +75,9 @@ type ( // MakeRequest makes an http request with our custom http client. It acts similarly to the http.NewRequest but // it only takes json as the request data. -func MakeRequest(url string, method string, data map[string]interface{}, headers map[string]string) ([]byte, error) { +func MakeRequest(ctx context.Context, url string, method string, data map[string]interface{}, headers map[string]string) ([]byte, error) { payload, _ := json.Marshal(data) - req, err := http.NewRequest(method, url, bytes.NewBuffer(payload)) + req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(payload)) if err != nil { return nil, err } From 9224cd8b48b66fb95ebf1e633c71811b8be6c3ba Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 2 Mar 2024 16:51:47 -0600 Subject: [PATCH 05/45] Add support for fetching JWKS URLs by issuer name A necessary prereq to validating tokens given an issuer. --- utils/token_utils.go | 62 ++++++++++++++++++++++++++++++++++ utils/token_utils_test.go | 71 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) diff --git a/utils/token_utils.go b/utils/token_utils.go index 4307f65dc..cb42dad1d 100644 --- a/utils/token_utils.go +++ b/utils/token_utils.go @@ -19,10 +19,15 @@ package utils import ( + "context" "crypto/rand" "encoding/base64" + "encoding/json" "fmt" + "io" + "net/http" "net/url" + "path" "regexp" "time" @@ -48,6 +53,11 @@ type ( Claims map[string]string // Additional claims scope string // scope is a string with space-delimited list of scopes. To enforce type check, use AddRawScope or AddScopes to add scopes to your token } + + openIdConfiguration struct { + Issuer string `json:"issuer"` + JwksUri string `json:"jwks_uri"` + } ) const ( @@ -263,3 +273,55 @@ func (tokenConfig *TokenConfig) CreateTokenWithKey(key jwk.Key) (string, error) return string(signed), nil } + +// Given an issuer URL, lookup the corresponding JWKS URL using OAuth2 metadata discovery +func LookupIssuerJwksUrl(ctx context.Context, issuerUrlStr string) (jwksUrl *url.URL, err error) { + issuerUrl, err := url.Parse(issuerUrlStr) + if err != nil { + err = errors.Wrap(err, "failed to parse issuer as URL") + return + } + wellKnownUrl := *issuerUrl + wellKnownUrl.Path = path.Join(wellKnownUrl.Path, ".well-known/openid-configuration") + + client := &http.Client{Transport: config.GetTransport()} + + req, err := http.NewRequestWithContext(ctx, "GET", wellKnownUrl.String(), nil) + if err != nil { + err = errors.Wrap(err, "failed to generate new request to the remote issuer") + return + } + resp, err := client.Do(req) + if err != nil { + err = errors.Wrapf(err, "failed to get metadata from %s", issuerUrlStr) + return + } + defer resp.Body.Close() + + if resp.StatusCode > 299 { + err = errors.Errorf("issuer %s returned error %s (HTTP %d) for its OpenID auto-discovery configuration", issuerUrlStr, resp.Status, resp.StatusCode) + return + } + + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + err = errors.Wrapf(err, "failed to read HTTP response when looking up OpenID auto-discovery configuration for issuer %s", issuerUrlStr) + return + } + + var conf openIdConfiguration + if err = json.Unmarshal(respBytes, &conf); err != nil { + err = errors.Wrapf(err, "failed to parse the OpenID auto-discovery configuration for issuer %s", issuerUrl) + return + } + if conf.JwksUri == "" { + err = errors.Errorf("issuer %s provided no JWKS URL in its OpenID auto-discovery configuration", issuerUrl) + return + } + jwksUrl, err = url.Parse(conf.JwksUri) + if err != nil { + err = errors.Wrapf(err, "issuer %s provided an invalid JWKS URL in its OpenID auto-discovery configuration", issuerUrl) + return + } + return +} diff --git a/utils/token_utils_test.go b/utils/token_utils_test.go index 6226ef457..0df6d5af1 100644 --- a/utils/token_utils_test.go +++ b/utils/token_utils_test.go @@ -21,6 +21,9 @@ package utils import ( "context" "fmt" + "net/http" + "net/http/httptest" + "net/url" "path/filepath" "testing" "time" @@ -261,3 +264,71 @@ func TestCreateToken(t *testing.T) { assert.EqualError(t, err, "No issuer was found in the configuration file, "+ "and none was provided as a claim") } + +func TestLookupIssuerJwksUrl(t *testing.T) { + var resp *string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/issuer/.well-known/openid-configuration" { + if resp == nil || *resp == "" { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(*resp)) + } + })) + + issuerURL, err := url.Parse(srv.URL) + require.NoError(t, err) + issuerURL.Path = "/issuer" + + tests := []struct { + resp string + result string + errStr string + }{ + { + resp: `{"jwks_uri": "https://osdf.org"}`, + result: "https://osdf.org", + errStr: "", + }, + { + resp: "", + result: "", + errStr: fmt.Sprintf("issuer %s returned error 500 Internal Server Error (HTTP 500) for its OpenID auto-discovery configuration", issuerURL), + }, + { + resp: `{}`, + result: "", + errStr: fmt.Sprintf("issuer %s provided no JWKS URL in its OpenID auto-discovery configuration", issuerURL), + }, + { + resp: `{{`, + result: "", + errStr: fmt.Sprintf("failed to parse the OpenID auto-discovery configuration for issuer %s: invalid character '{' looking for beginning of object key string", issuerURL), + }, + { + resp: `{"jwks_uri": "http_blah://foo"}`, + result: "", + errStr: fmt.Sprintf("issuer %s provided an invalid JWKS URL in its OpenID auto-discovery configuration: parse \"http_blah://foo\": first path segment in URL cannot contain colon", issuerURL), + }, + } + + ctx := context.Background() + for _, tt := range tests { + resp = &tt.resp + result, err := LookupIssuerJwksUrl(ctx, issuerURL.String()) + if tt.errStr == "" { + assert.NoError(t, err) + } else { + assert.Error(t, err) + if err != nil { + assert.Equal(t, tt.errStr, err.Error()) + } + } + if tt.result != "" { + assert.NoError(t, err) + assert.Equal(t, tt.result, result.String()) + } + } +} From 19724d6dca024752e4eccbe91164ac3379e04c23 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 2 Mar 2024 17:29:09 -0600 Subject: [PATCH 06/45] Add support for "resource scopes" A resource-enabled scope is of the form: ``` authz:/resource ``` (example: `storage.read:/foo`). These resources are partially ordered and hierarchical, meaning some scopes "contain" the authorization for others. For example, `storage.read:/foo` implies access to `storage.read:/foo/bar` If a resource is not given, it is assumed to be `/`. --- docs/scopes.yaml | 8 +-- generate/scope_generator.go | 2 +- token_scopes/token_scope_utils.go | 69 +++++++++++++++++++++++++- token_scopes/token_scope_utils_test.go | 45 +++++++++++++++++ token_scopes/token_scopes.go | 10 ++-- 5 files changed, 123 insertions(+), 11 deletions(-) diff --git a/docs/scopes.yaml b/docs/scopes.yaml index 907763fd4..212b9b83c 100644 --- a/docs/scopes.yaml +++ b/docs/scopes.yaml @@ -96,25 +96,25 @@ acceptedBy: [cache"] ############################ # Storage Scopes # ############################ -name: "storage.read:" +name: "storage.read" description: >- For granting object read permissions to the bearer of the token. This scope must also posses a path to be valid, eg `storage.read:/foo/bar` issuedBy: ["origin"] acceptedBy: ["origin", "cache"] --- -name: "storage.create:" +name: "storage.create" description: >- For granting object creation permissions to the bearer of token. This scope must also posses a path to be valid, eg `storage.create:/foo/bar` issuedBy: ["origin"] acceptedBy: ["origin", "cache"] --- -name: "storage.modify:" +name: "storage.modify" description: >- For granting object modification permissions to the bearer of the token. This scope must also posses a path to be valid, eg `storage.modify:/foo/bar` issuedBy: ["origin"] acceptedBy: ["origin", "cache"] --- -name: "storage.stage:" +name: "storage.stage" description: >- For granting object staging permissions to the bearer of the token. This scope must also posses a path to be valid, eg `storage.stage:/foo/bar` issuedBy: ["origin"] diff --git a/generate/scope_generator.go b/generate/scope_generator.go index 8634d6d19..1e20d7f89 100644 --- a/generate/scope_generator.go +++ b/generate/scope_generator.go @@ -184,6 +184,6 @@ func (s TokenScope) Path(path string) (TokenScope, error) { return "", errors.New("cannot assign path to non-storage token scope") } - return TokenScope(s.String() + path), nil + return TokenScope(s.String() + ":" + path), nil } `)) diff --git a/token_scopes/token_scope_utils.go b/token_scopes/token_scope_utils.go index 8b5b0abb8..32452ed1a 100644 --- a/token_scopes/token_scope_utils.go +++ b/token_scopes/token_scope_utils.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "path" "slices" "sort" "strings" @@ -29,6 +30,47 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" ) +type ( + // A resourced scope is a scope whose privileges + // are narrowed to a specific resource. If there's + // the authorization for foo, then the ResourceScope of + // foo:/bar also contains foo:/bar/baz. + ResourceScope struct { + Authorization TokenScope + Resource string + } +) + +func NewResourceScope(authz TokenScope, resource string) ResourceScope { + return ResourceScope{ + Authorization: authz, + Resource: path.Clean("/" + resource), + } +} + +func (rc ResourceScope) String() string { + if rc.Resource == "/" { + return string(rc.Authorization) + } + return rc.Authorization.String() + ":" + rc.Resource +} + +func (rc ResourceScope) Contains(other ResourceScope) bool { + if rc.Authorization != other.Authorization { + return false + } + if strings.HasPrefix(other.Resource, rc.Resource) { + if len(rc.Resource) == 1 { + return true + } + if len(other.Resource) > len(rc.Resource) { + return other.Resource[len(rc.Resource)] == '/' + } + return true + } + return false +} + // Get a string representation of a list of scopes, which can then be passed // to the Claim builder of JWT constructor func GetScopeString(scopes []TokenScope) (scopeString string) { @@ -47,6 +89,31 @@ func GetScopeString(scopes []TokenScope) (scopeString string) { return } +// Get a list of resource-style scopes from the token +func ParseResourceScopeString(tok jwt.Token) (scopes []ResourceScope) { + scopes = make([]ResourceScope, 0) + scopeAny, ok := tok.Get("scope") + if !ok { + return + } + scopeString, ok := scopeAny.(string) + if !ok { + return + } + for _, scope := range strings.Split(scopeString, " ") { + if scope == "" { + continue + } + info := strings.SplitN(scope, ":", 2) + if len(info) == 1 { + scopes = append(scopes, NewResourceScope(TokenScope(info[0]), "/")) + } else { + scopes = append(scopes, NewResourceScope(TokenScope(info[0]), info[1])) + } + } + return +} + // Return if expectedScopes contains the tokenScope and it's case-insensitive. // If all=false, it checks if the tokenScopes have any one scope in expectedScopes; // If all=true, it checks if tokenScopes is the same set as expectedScopes @@ -86,7 +153,7 @@ func CreateScopeValidator(expectedScopes []TokenScope, all bool) jwt.ValidatorFu } scope_any, present := tok.Get("scope") if !present { - return jwt.NewValidationError(errors.New("No scope is present; required for authorization")) + return jwt.NewValidationError(errors.New("no scope is present; required for authorization")) } scope, ok := scope_any.(string) if !ok { diff --git a/token_scopes/token_scope_utils_test.go b/token_scopes/token_scope_utils_test.go index 728475058..684c61602 100644 --- a/token_scopes/token_scope_utils_test.go +++ b/token_scopes/token_scope_utils_test.go @@ -21,6 +21,10 @@ package token_scopes import ( "strconv" "testing" + + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetScopeString(t *testing.T) { @@ -89,3 +93,44 @@ func TestScopeContains(t *testing.T) { }) } } + +func TestResourceScopes(t *testing.T) { + tests := []struct { + name string + myScope ResourceScope + otherScope ResourceScope + expected bool + }{ + {"diffScope", NewResourceScope(Storage_Create, "/"), NewResourceScope(Storage_Modify, "/"), false}, + {"same", NewResourceScope(Storage_Create, ""), NewResourceScope(Storage_Create, ""), true}, + {"default", NewResourceScope(Storage_Create, ""), NewResourceScope(Storage_Create, "/"), true}, + {"sub", NewResourceScope(Storage_Create, ""), NewResourceScope(Storage_Create, "/foo"), true}, + {"subDeep", NewResourceScope(Storage_Create, "/foo"), NewResourceScope(Storage_Create, "/foo/bar"), true}, + {"subClean", NewResourceScope(Storage_Create, "/foo"), NewResourceScope(Storage_Create, "/foo/"), true}, + {"noPath", NewResourceScope(Storage_Create, "/foo"), NewResourceScope(Storage_Create, "/foobar"), false}, + {"sameDeep", NewResourceScope(Storage_Create, "/foo/bar"), NewResourceScope(Storage_Create, "/foo/bar"), true}, + {"sameDeep", NewResourceScope(Storage_Create, "/foo/bar"), NewResourceScope(Storage_Create, "/foo/barbaz"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.myScope.Contains(tt.otherScope)) + }) + } +} + +func TestParseResources(t *testing.T) { + tok := jwt.New() + + require.NoError(t, tok.Set("scope", "blah")) + assert.Equal(t, []ResourceScope{{Authorization: TokenScope("blah"), Resource: "/"}}, ParseResourceScopeString(tok)) + + require.NoError(t, tok.Set("scope", 5)) + assert.Equal(t, []ResourceScope{}, ParseResourceScopeString(tok)) + + require.NoError(t, tok.Set("scope", "foo bar")) + assert.Equal(t, []ResourceScope{{Authorization: TokenScope("foo"), Resource: "/"}, {Authorization: TokenScope("bar"), Resource: "/"}}, ParseResourceScopeString(tok)) + + require.NoError(t, tok.Set("scope", "storage.create:/foo")) + assert.Equal(t, []ResourceScope{{Authorization: Storage_Create, Resource: "/foo"}}, ParseResourceScopeString(tok)) +} diff --git a/token_scopes/token_scopes.go b/token_scopes/token_scopes.go index c4fd658e0..c39bc7ed5 100644 --- a/token_scopes/token_scopes.go +++ b/token_scopes/token_scopes.go @@ -38,10 +38,10 @@ const ( Broker_Callback TokenScope = "broker.callback" // Storage Scopes - Storage_Read TokenScope = "storage.read:" - Storage_Create TokenScope = "storage.create:" - Storage_Modify TokenScope = "storage.modify:" - Storage_Stage TokenScope = "storage.stage:" + Storage_Read TokenScope = "storage.read" + Storage_Create TokenScope = "storage.create" + Storage_Modify TokenScope = "storage.modify" + Storage_Stage TokenScope = "storage.stage" ) func (s TokenScope) String() string { @@ -55,5 +55,5 @@ func (s TokenScope) Path(path string) (TokenScope, error) { return "", errors.New("cannot assign path to non-storage token scope") } - return TokenScope(s.String() + path), nil + return TokenScope(s.String() + ":" + path), nil } From 3c8ef0c1fe99c7d72a56a6ee0d7e9a76d6f3325d Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 2 Mar 2024 17:34:53 -0600 Subject: [PATCH 07/45] Add support for authorization in the local cache This adds support for authorization to the local cache and periodically synchronizes the local cache's authorization policy with the director. --- file_cache/cache_api.go | 21 +++- file_cache/cache_authz.go | 246 +++++++++++++++++++++++++++++++++++++ file_cache/cache_test.go | 121 ++++++++++++++++++ file_cache/simple_cache.go | 71 ++++++++++- param/parameters.go | 2 +- param/parameters_struct.go | 8 +- utils/token_utils_test.go | 3 +- 7 files changed, 458 insertions(+), 14 deletions(-) create mode 100644 file_cache/cache_authz.go create mode 100644 file_cache/cache_test.go diff --git a/file_cache/cache_api.go b/file_cache/cache_api.go index eac573f12..47155f990 100644 --- a/file_cache/cache_api.go +++ b/file_cache/cache_api.go @@ -20,6 +20,7 @@ package simple_cache import ( "context" + "errors" "fmt" "io" "net" @@ -59,12 +60,24 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { } } - bearerToken := r.Header.Get("Authorization") - bearerToken = strings.TrimPrefix(bearerToken, "Bearer ") + authzHeader := r.Header.Get("Authorization") + bearerToken := "" + if strings.HasPrefix(bearerToken, "Bearer ") { + bearerToken = authzHeader[7:] // len("Bearer ") == 7 + } reader, err := sc.Get(r.URL.Path, bearerToken) - if err != nil { + if errors.Is(err, authorizationDenied) { + w.WriteHeader(http.StatusForbidden) + if _, err = w.Write([]byte("Authorization Denied")); err != nil { + log.Errorln("Failed to write authorization denied to client") + } + return + } else if err != nil { w.WriteHeader(http.StatusInternalServerError) - log.Errorln("Failed to get file from cache") + if _, err = w.Write([]byte("Unexpected internal error")); err != nil { + log.Errorln("Failed to write internal error message to client") + } + log.Errorln("Failed to get file from cache:", err) return } w.WriteHeader(http.StatusOK) diff --git a/file_cache/cache_authz.go b/file_cache/cache_authz.go new file mode 100644 index 000000000..b7ac117b0 --- /dev/null +++ b/file_cache/cache_authz.go @@ -0,0 +1,246 @@ +/*************************************************************** + * + * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package simple_cache + +import ( + "context" + "net/http" + "net/url" + "path" + "slices" + "sync/atomic" + "time" + + "github.com/jellydator/ttlcache/v3" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/pelicanplatform/pelican/common" + "github.com/pelicanplatform/pelican/config" + "github.com/pelicanplatform/pelican/token_scopes" + "github.com/pelicanplatform/pelican/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" +) + +type ( + authConfig struct { + ns atomic.Pointer[[]common.NamespaceAdV2] + issuers atomic.Pointer[map[string]bool] + issuerKeys *ttlcache.Cache[string, authConfigItem] + tokenAuthz *ttlcache.Cache[string, acls] + } + + authConfigItem struct { + set jwk.Set + err error + } + + prefixConfig struct { + public bool + issuers []issuerConfig + } + + issuerConfig struct { + url *url.URL + basePath string + } + + acls []token_scopes.ResourceScope + + aclsItem struct { + acls acls + err error + } +) + +func newAuthConfig(ctx context.Context, egrp *errgroup.Group) (ac *authConfig) { + ac = &authConfig{} + + loader := ttlcache.LoaderFunc[string, authConfigItem]( + func(cache *ttlcache.Cache[string, authConfigItem], issuerUrl string) *ttlcache.Item[string, authConfigItem] { + var ar *jwk.Cache + jwksUrl, err := utils.LookupIssuerJwksUrl(ctx, issuerUrl) + if err != nil { + log.Errorln("Failed to lookup JWKS URL:", err) + } else { + ar := jwk.NewCache(ctx) + client := &http.Client{Transport: config.GetTransport()} + if err = ar.Register(jwksUrl.String(), jwk.WithMinRefreshInterval(15*time.Minute), jwk.WithHTTPClient(client)); err != nil { + log.Errorln("Failed to register JWKS URL with cache: ", err) + } else { + log.Debugln("Setting public key cache for issuer", issuerUrl) + } + } + + ttl := ttlcache.DefaultTTL + if err != nil { + ttl = time.Duration(5 * time.Minute) + } + item := cache.Set(issuerUrl, authConfigItem{set: jwk.NewCachedSet(ar, jwksUrl.String()), err: nil}, ttl) + return item + }, + ) + ac.issuerKeys = ttlcache.New[string, authConfigItem]( + ttlcache.WithTTL[string, authConfigItem](15*time.Minute), + ttlcache.WithLoader[string, authConfigItem](ttlcache.NewSuppressedLoader[string, authConfigItem](loader, nil)), + ) + + ac.tokenAuthz = ttlcache.New[string, acls]( + ttlcache.WithTTL[string, acls](5*time.Minute), + ttlcache.WithLoader[string, acls](ttlcache.LoaderFunc[string, acls](ac.loader)), + ) + + go ac.issuerKeys.Start() + egrp.Go(func() error { + <-ctx.Done() + ac.issuerKeys.Stop() + ac.issuerKeys.DeleteAll() + return nil + }) + + return +} + +func (ac *authConfig) updateConfig(nsAds []common.NamespaceAdV2) error { + issuers := make(map[string]bool) + for _, nsAd := range nsAds { + for _, issuer := range nsAd.Issuer { + issuers[issuer.IssuerUrl.String()] = true + } + } + ac.issuers.Store(&issuers) + ac.ns.Store(&nsAds) + return nil +} + +func (ac *authConfig) getResourceScopes(token string) (scopes []token_scopes.ResourceScope, issuer string, err error) { + tok, err := jwt.Parse([]byte(token)) + if err != nil { + err = errors.Wrap(err, "failed to parse incoming JWT when authorizing request") + return + } + issuer = tok.Issuer() + + issuers := ac.issuers.Load() + if !(*issuers)[issuer] { + err = errors.Errorf("token issuer %s is not one of the trusted issuers", issuer) + } + + issuerConfItem := ac.issuerKeys.Get(issuer) + if issuerConfItem == nil { + err = errors.Errorf("unable to determine keys for issuer %s", issuer) + return + } + + issuerConf := issuerConfItem.Value() + if issuerConf.err != nil { + err = issuerConf.err + return + } + + tok, err = jwt.Parse([]byte(token), jwt.WithKeySet(issuerConfItem.Value().set)) + if err != nil { + return + } + err = jwt.Validate(tok) + if err != nil { + return + } + + scopes = token_scopes.ParseResourceScopeString(tok) + + return +} + +func calcResourceScopes(rs token_scopes.ResourceScope, basePaths []string, restrictedPaths []string) (results []token_scopes.ResourceScope) { + results = make([]token_scopes.ResourceScope, 0) + for _, basePath := range basePaths { + if len(restrictedPaths) == 0 { + results = append(results, token_scopes.NewResourceScope(rs.Authorization, path.Join(basePath, rs.Resource))) + } else { + for _, restrictedPath := range restrictedPaths { + tmpResource := token_scopes.NewResourceScope(rs.Authorization, restrictedPath) + if tmpResource.Contains(rs) { + // E.g., restricted_path=/foo, token scope is storage.read:/foo/bar; generate ACL of storage.read:/foo/bar + results = append(results, token_scopes.NewResourceScope(rs.Authorization, path.Join(basePath, rs.Resource))) + } else if rs.Contains(tmpResource) { + // E.g., restricted_path=/foo, token scope is storage.read:/; generate ACL of storage.read:/foo + results = append(results, token_scopes.NewResourceScope(rs.Authorization, path.Join(basePath, tmpResource.Resource))) + } + } + } + } + return +} + +func (ac *authConfig) getAcls(token string) (newAcls acls, err error) { + namespaces := ac.ns.Load() + if namespaces == nil { + return + } + resources, issuer, err := ac.getResourceScopes(token) + if err != nil { + return + } + + newAcls = make(acls, 0) + for _, conf := range *namespaces { + if conf.Caps.PublicRead { + newAcls = append(newAcls, token_scopes.ResourceScope{Authorization: token_scopes.Storage_Read, Resource: conf.Path}) + } else if conf.Issuer != nil { + for _, resource := range resources { + if (resource.Authorization == token_scopes.Storage_Create || resource.Authorization == token_scopes.Storage_Modify) && !conf.Caps.Write { + continue + } + if resource.Authorization == token_scopes.Storage_Read && !conf.Caps.Read { + continue + } + for _, issuerConfig := range conf.Issuer { + if issuerConfig.IssuerUrl.String() != issuer { + continue + } + newAcls = append(newAcls, calcResourceScopes(resource, issuerConfig.BasePaths, issuerConfig.RestrictedPaths)...) + } + } + } + } + return +} + +func (ac *authConfig) loader(cache *ttlcache.Cache[string, acls], token string) *ttlcache.Item[string, acls] { + acls, err := ac.getAcls(token) + if err != nil { + // If the token is not a valid one signed by a known issuer, do not keep it in memory (avoids a DoS) + log.Warningln("Rejecting invalid token:", err) + return nil + } + + item := cache.Set(token, acls, ttlcache.DefaultTTL) + return item +} + +func (ac *authConfig) authorize(action token_scopes.TokenScope, resource, token string) bool { + aclsItem := ac.tokenAuthz.Get(token) + if aclsItem == nil { + return false + } + rsScope := token_scopes.NewResourceScope(action, resource) + return slices.ContainsFunc(aclsItem.Value(), func(rs token_scopes.ResourceScope) bool { return rs.Contains(rsScope) }) +} diff --git a/file_cache/cache_test.go b/file_cache/cache_test.go new file mode 100644 index 000000000..84cb696f0 --- /dev/null +++ b/file_cache/cache_test.go @@ -0,0 +1,121 @@ +/*************************************************************** + * + * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package simple_cache_test + +import ( + "context" + "io" + "os" + "path/filepath" + "testing" + + "github.com/pelicanplatform/pelican/config" + simple_cache "github.com/pelicanplatform/pelican/file_cache" + "github.com/pelicanplatform/pelican/launchers" + "github.com/pelicanplatform/pelican/test_utils" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) string { + + modules := config.ServerType(0) + modules.Set(config.OriginType) + modules.Set(config.DirectorType) + modules.Set(config.RegistryType) + modules.Set(config.CacheType) + modules.Set(config.LocalCacheType) + + tmpPathPattern := "XRootD-Test_Origin*" + tmpPath, err := os.MkdirTemp("", tmpPathPattern) + require.NoError(t, err) + + permissions := os.FileMode(0755) + err = os.Chmod(tmpPath, permissions) + require.NoError(t, err) + t.Cleanup(func() { + err := os.RemoveAll(tmpPath) + require.NoError(t, err) + }) + + viper.Set("ConfigDir", tmpPath) + + config.InitConfig() + + originDir, err := os.MkdirTemp("", "Origin") + assert.NoError(t, err) + t.Cleanup(func() { + err := os.RemoveAll(originDir) + require.NoError(t, err) + }) + + // Change the permissions of the temporary origin directory + permissions = os.FileMode(0777) + err = os.Chmod(originDir, permissions) + require.NoError(t, err) + + viper.Set("Origin.ExportVolume", originDir+":/test") + viper.Set("Origin.Mode", "posix") + // Disable functionality we're not using (and is difficult to make work on Mac) + viper.Set("Origin.EnableCmsd", false) + viper.Set("Origin.EnableMacaroons", false) + viper.Set("Origin.EnableVoms", false) + viper.Set("Server.EnableUI", false) + viper.Set("Registry.DbLocation", filepath.Join(t.TempDir(), "ns-registry.sqlite")) + viper.Set("Origin.Port", 0) + viper.Set("Server.WebPort", 0) + viper.Set("Origin.RunLocation", tmpPath) + viper.Set("Cache.RunLocation", tmpPath) + viper.Set("Registry.RequireOriginApproval", false) + viper.Set("Registry.RequireCacheApproval", false) + + err = config.InitServer(ctx, modules) + require.NoError(t, err) + + cancel, err := launchers.LaunchModules(ctx, modules) + require.NoError(t, err) + t.Cleanup(func() { + cancel() + egrp.Wait() + }) + + return originDir +} + +func TestFileCacheSimpleGet(t *testing.T) { + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer cancel() + + originDir := spinup(t, ctx, egrp) + + err := os.WriteFile(filepath.Join(originDir, "hello_world.txt"), []byte("Hello, World!"), os.FileMode(0644)) + require.NoError(t, err) + + sc, err := simple_cache.NewSimpleCache(ctx, egrp) + require.NoError(t, err) + + reader, err := sc.Get("/test/hello_world.txt", "") + require.NoError(t, err) + + byteBuff, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, "Hello, World!", string(byteBuff)) +} diff --git a/file_cache/simple_cache.go b/file_cache/simple_cache.go index 047388701..300d72012 100644 --- a/file_cache/simple_cache.go +++ b/file_cache/simple_cache.go @@ -21,6 +21,7 @@ package simple_cache import ( "container/heap" "context" + "encoding/json" "io" "net/url" "os" @@ -36,7 +37,10 @@ import ( "github.com/alecthomas/units" "github.com/google/uuid" "github.com/pelicanplatform/pelican/client" + "github.com/pelicanplatform/pelican/common" "github.com/pelicanplatform/pelican/param" + "github.com/pelicanplatform/pelican/token_scopes" + "github.com/pelicanplatform/pelican/utils" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -54,6 +58,7 @@ type ( mutex sync.RWMutex downloads map[string]*activeDownload directorURL *url.URL + ac *authConfig // Cache static configuration highWater uint64 @@ -128,6 +133,10 @@ type ( } ) +var ( + authorizationDenied error = errors.New("authorization denied") +) + const ( reqSize = 2 * 1024 * 1024 ) @@ -238,6 +247,7 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, lowWater: (cacheSize / 100) * uint64(lowWaterPercentage), cacheSize: cacheSize, basePath: cacheDir, + ac: newAuthConfig(ctx, egrp), } sc.tc, err = sc.te.NewClient(client.WithAcquireToken(false), client.WithCallback(sc.callback)) @@ -248,8 +258,12 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, } return } + if err = sc.updateConfig(); err != nil { + log.Warningln("First attempt to update cache's authorization failed:", err) + } egrp.Go(sc.runMux) + egrp.Go(sc.periodicUpdateConfig) return } @@ -324,10 +338,10 @@ func (sc *SimpleCache) runMux() error { if chosen < lenResults { // Sent a result to the waiter - slices.Delete(tmpResults, chosen, chosen+1) + tmpResults = slices.Delete(tmpResults, chosen, chosen+1) } else if chosen < lenChan { // Acknowledged a cancellation - slices.Delete(cancelRequest, chosen-lenResults, chosen-lenResults+1) + cancelRequest = slices.Delete(cancelRequest, chosen-lenResults, chosen-lenResults+1) } else if chosen == lenChan { // Cancellation; shut down return nil @@ -520,6 +534,10 @@ func (sc *SimpleCache) newCacheReader(path, token string) (reader *cacheReader, // Get path from the cache func (sc *SimpleCache) Get(path, token string) (io.ReadCloser, error) { + if !sc.ac.authorize(token_scopes.Storage_Read, path, token) { + return nil, authorizationDenied + } + if fp := sc.getFromDisk(path); fp != nil { return fp, nil } @@ -528,6 +546,55 @@ func (sc *SimpleCache) Get(path, token string) (io.ReadCloser, error) { } +func (sc *SimpleCache) updateConfig() error { + // Get the endpoint of the director + var respNS []common.NamespaceAdV2 + + directorEndpoint := param.Federation_DirectorUrl.GetString() + if directorEndpoint == "" { + return errors.New("No director specified; give the federation name (-f)") + } + + directorEndpointURL, err := url.Parse(directorEndpoint) + if err != nil { + return errors.Wrap(err, "Unable to parse director url") + } + + // Create the listNamespaces url + directorNSListEndpointURL, err := url.JoinPath(directorEndpointURL.String(), "api", "v2.0", "director", "listNamespaces") + if err != nil { + return errors.Wrap(err, "Unable to generate the director's listNamespaces endpoint") + } + + respData, err := utils.MakeRequest(sc.ctx, directorNSListEndpointURL, "GET", nil, nil) + if err != nil { + return err + } else { + err = json.Unmarshal(respData, &respNS) + if err != nil { + return errors.Wrapf(err, "Failed to marshal response in to JSON: %v", err) + } + } + + return sc.ac.updateConfig(respNS) +} + +// Periodically update the cache configuration from the registry +func (sc *SimpleCache) periodicUpdateConfig() error { + ticker := time.NewTicker(time.Minute) + for { + select { + case <-sc.ctx.Done(): + return nil + case <-ticker.C: + err := sc.updateConfig() + if err != nil { + log.Warningln("Failed to update the file cache config:", err) + } + } + } +} + // Read bytes from a file in the cache // // Does not request more data if bytes are not found diff --git a/param/parameters.go b/param/parameters.go index a69741bc8..69e4f2f72 100644 --- a/param/parameters.go +++ b/param/parameters.go @@ -206,7 +206,7 @@ var ( Director_MinStatResponse = IntParam{"Director.MinStatResponse"} Director_StatConcurrencyLimit = IntParam{"Director.StatConcurrencyLimit"} FileCache_HighWaterMarkPercentage = IntParam{"FileCache.HighWaterMarkPercentage"} - FileCache_LowWaterMarkPercentage = IntParam{"FileCcache.LowWaterMarkPercentage"} + FileCache_LowWaterMarkPercentage = IntParam{"FileCache.LowWaterMarkPercentage"} MinimumDownloadSpeed = IntParam{"MinimumDownloadSpeed"} Monitoring_PortHigher = IntParam{"Monitoring.PortHigher"} Monitoring_PortLower = IntParam{"Monitoring.PortLower"} diff --git a/param/parameters_struct.go b/param/parameters_struct.go index 348838ea2..cfbceda72 100644 --- a/param/parameters_struct.go +++ b/param/parameters_struct.go @@ -75,13 +75,11 @@ type Config struct { FileCache struct { DataLocation string HighWaterMarkPercentage int + LowWaterMarkPercentage int RunLocation string Size string Socket string } - FileCcache struct { - LowWaterMarkPercentage int - } GeoIPOverrides interface{} Issuer struct { AuthenticationSource string @@ -309,13 +307,11 @@ type configWithType struct { FileCache struct { DataLocation struct { Type string; Value string } HighWaterMarkPercentage struct { Type string; Value int } + LowWaterMarkPercentage struct { Type string; Value int } RunLocation struct { Type string; Value string } Size struct { Type string; Value string } Socket struct { Type string; Value string } } - FileCcache struct { - LowWaterMarkPercentage struct { Type string; Value int } - } GeoIPOverrides struct { Type string; Value interface{} } Issuer struct { AuthenticationSource struct { Type string; Value string } diff --git a/utils/token_utils_test.go b/utils/token_utils_test.go index 0df6d5af1..09a697ffe 100644 --- a/utils/token_utils_test.go +++ b/utils/token_utils_test.go @@ -274,7 +274,8 @@ func TestLookupIssuerJwksUrl(t *testing.T) { return } w.WriteHeader(http.StatusOK) - w.Write([]byte(*resp)) + _, err := w.Write([]byte(*resp)) + require.NoError(t, err) } })) From fda3ea0f50fb72d765bf629204555695f032c16f Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 2 Mar 2024 22:05:18 -0600 Subject: [PATCH 08/45] Remove unreachable code --- cache_ui/advertise.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cache_ui/advertise.go b/cache_ui/advertise.go index 21d0aec43..d8c065ade 100644 --- a/cache_ui/advertise.go +++ b/cache_ui/advertise.go @@ -64,10 +64,6 @@ func (server *CacheServer) GetNamespaceAdsFromDirector() error { return errors.Wrap(err, "Unable to parse director url") } - if err != nil { - return errors.Wrapf(err, "Failed to get DirectorURL from config: %v", err) - } - // Create the listNamespaces url directorNSListEndpointURL, err := url.JoinPath(directorEndpointURL.String(), "api", "v2.0", "director", "listNamespaces") if err != nil { From 5f9419c3cdf8c8ab8a81b6e939db94eca5fbf2bf Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 2 Mar 2024 22:06:55 -0600 Subject: [PATCH 09/45] Do not export JobID from TransferResults Instead, provide a ID() method which returns a string. Mirrors what is done for the TransferJob object. --- client/handle_http.go | 27 ++++++++++++++++++++++----- cmd/plugin.go | 2 +- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/client/handle_http.go b/client/handle_http.go index b4fe54030..8d68aa15a 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -83,7 +83,7 @@ type ( // Represents the results of a single object transfer, // potentially across multiple attempts / retries. TransferResults struct { - JobId uuid.UUID // The job ID this result corresponds to + jobId uuid.UUID // The job ID this result corresponds to job *TransferJob Error error TransferredBytes int64 @@ -319,11 +319,15 @@ func hasPort(host string) bool { func newTransferResults(job *TransferJob) TransferResults { return TransferResults{ job: job, - JobId: job.uuid, + jobId: job.uuid, Attempts: make([]TransferResult, 0), } } +func (tr TransferResults) ID() string { + return tr.jobId.String() +} + // Returns a new transfer engine object whose lifetime is tied // to the provided context. Will launcher worker goroutines to // handle the underlying transfers @@ -797,6 +801,19 @@ func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, u return nil, fmt.Errorf("failed to get token for transfer: %v", err) } } + + log.Debugf("Created new transfer job, ID %s client %s, for URL %s", tj.uuid.String(), tc.id.String(), remoteUrl.String()) + return +} + +// Returns the status of the transfer job-to-file(s) lookup +// +// ok is true if the lookup has completed. +func (tj *TransferJob) GetLookupStatus() (ok bool, err error) { + ok = tj.lookupDone.Load() + if ok { + err = tj.lookupErr + } return } @@ -1084,7 +1101,7 @@ func runTransferWorker(ctx context.Context, workChan <-chan *clientTransferFile, results <- &clientTransferResults{ id: file.uuid, results: TransferResults{ - JobId: file.jobId, + jobId: file.jobId, Error: file.file.ctx.Err(), }, } @@ -1094,7 +1111,7 @@ func runTransferWorker(ctx context.Context, workChan <-chan *clientTransferFile, results <- &clientTransferResults{ id: file.uuid, results: TransferResults{ - JobId: file.jobId, + jobId: file.jobId, Error: file.file.err, }, } @@ -1107,7 +1124,7 @@ func runTransferWorker(ctx context.Context, workChan <-chan *clientTransferFile, } else { transferResults, err = downloadObject(file.file) } - transferResults.JobId = file.jobId + transferResults.jobId = file.jobId if err != nil { log.Errorf("Error when attempting to transfer object %s for client %s", file.file.remoteURL, file.uuid.String()) transferResults = newTransferResults(file.file.job) diff --git a/cmd/plugin.go b/cmd/plugin.go index 0ffb4abea..9d0c64783 100644 --- a/cmd/plugin.go +++ b/cmd/plugin.go @@ -456,7 +456,7 @@ func runPluginWorker(ctx context.Context, upload bool, workChan <-chan PluginTra hostname, _ := os.Hostname() resultAd.Set("TransferLocalMachineName", hostname) resultAd.Set("TransferProtocol", "stash") - transfer := jobMap[result.JobId.String()] + transfer := jobMap[result.ID()] resultAd.Set("TransferUrl", transfer.url.String()) resultAd.Set("TransferFileName", transfer.localFile) if upload { From 60f7510d241edc1e648df60dff29a3a2fcd01a77 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 2 Mar 2024 22:11:52 -0600 Subject: [PATCH 10/45] Add initial tests for the local cache functionality Includes fixes required to make the code work with the tests. --- client/handle_http.go | 14 ++++- config/config.go | 7 ++- file_cache/cache_api.go | 12 +++- file_cache/cache_authz.go | 23 ++------ file_cache/cache_test.go | 18 +++++- file_cache/simple_cache.go | 116 +++++++++++++++++++++++++++++-------- launchers/launcher.go | 18 +++--- server_ui/advertise.go | 23 +++++--- 8 files changed, 165 insertions(+), 66 deletions(-) diff --git a/client/handle_http.go b/client/handle_http.go index 8d68aa15a..1458687fa 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -442,8 +442,13 @@ func (te *TransferEngine) NewClient(options ...TransferOption) (client *Transfer te.resultsMap[id] = client.results te.workMap[id] = client.work }() - log.Debugln("Inserted work map for client", id.String()) - te.notifyChan <- true + log.Debugln("Created new client", id.String()) + select { + case <-te.ctx.Done(): + log.Debugln("New client unable to start; transfer engine has been canceled") + err = te.ctx.Err() + case te.notifyChan <- true: + } return } @@ -465,7 +470,10 @@ func (te *TransferEngine) Shutdown() error { // Closes the TransferEngine. No new work may // be submitted. Any ongoing work will continue func (te *TransferEngine) Close() { - te.closeChan <- true + select { + case <-te.ctx.Done(): + case te.closeChan <- true: + } } // Launches a helper goroutine that ensures completed diff --git a/config/config.go b/config/config.go index a2130a0d6..91a363ac0 100644 --- a/config/config.go +++ b/config/config.go @@ -850,15 +850,16 @@ func InitServer(ctx context.Context, currentServers ServerType) error { return err } } else { - dir, err := os.MkdirTemp("", "pelican-xrootd-*") + var err error + runtimeDir, err = os.MkdirTemp("", "pelican-xrootd-*") if err != nil { return err } - err = setXrootdRunLocations(currentServers, dir) + err = setXrootdRunLocations(currentServers, runtimeDir) if err != nil { return err } - cleanupDirOnShutdown(ctx, dir) + cleanupDirOnShutdown(ctx, runtimeDir) } viper.SetDefault("Cache.DataLocation", filepath.Join(runtimeDir, "xcache")) viper.SetDefault("FileCache.RunLocation", filepath.Join(runtimeDir, "cache")) diff --git a/file_cache/cache_api.go b/file_cache/cache_api.go index 47155f990..f8a3b6628 100644 --- a/file_cache/cache_api.go +++ b/file_cache/cache_api.go @@ -20,21 +20,28 @@ package simple_cache import ( "context" - "errors" "fmt" "io" + "io/fs" "net" "net/http" + "os" + "path/filepath" "strings" "github.com/pelicanplatform/pelican/param" + "github.com/pkg/errors" log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" ) // Launch the unix socket listener as a separate goroutine func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { - socketName := param.FileCache_DataLocation.GetString() + socketName := param.FileCache_Socket.GetString() + if err := os.MkdirAll(filepath.Dir(socketName), fs.FileMode(0755)); err != nil { + return errors.Wrap(err, "failed to create socket directory") + } + listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: socketName, Net: "unix"}) if err != nil { return err @@ -95,6 +102,7 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { return srv.Serve(listener) }) egrp.Go(func() error { + <-ctx.Done() return srv.Shutdown(ctx) }) return nil diff --git a/file_cache/cache_authz.go b/file_cache/cache_authz.go index b7ac117b0..fd49d957a 100644 --- a/file_cache/cache_authz.go +++ b/file_cache/cache_authz.go @@ -21,7 +21,6 @@ package simple_cache import ( "context" "net/http" - "net/url" "path" "slices" "sync/atomic" @@ -52,22 +51,7 @@ type ( err error } - prefixConfig struct { - public bool - issuers []issuerConfig - } - - issuerConfig struct { - url *url.URL - basePath string - } - acls []token_scopes.ResourceScope - - aclsItem struct { - acls acls - err error - } ) func newAuthConfig(ctx context.Context, egrp *errgroup.Group) (ac *authConfig) { @@ -131,7 +115,11 @@ func (ac *authConfig) updateConfig(nsAds []common.NamespaceAdV2) error { } func (ac *authConfig) getResourceScopes(token string) (scopes []token_scopes.ResourceScope, issuer string, err error) { - tok, err := jwt.Parse([]byte(token)) + if token == "" { + return + } + + tok, err := jwt.Parse([]byte(token), jwt.WithVerify(false)) if err != nil { err = errors.Wrap(err, "failed to parse incoming JWT when authorizing request") return @@ -141,6 +129,7 @@ func (ac *authConfig) getResourceScopes(token string) (scopes []token_scopes.Res issuers := ac.issuers.Load() if !(*issuers)[issuer] { err = errors.Errorf("token issuer %s is not one of the trusted issuers", issuer) + return } issuerConfItem := ac.issuerKeys.Get(issuer) diff --git a/file_cache/cache_test.go b/file_cache/cache_test.go index 84cb696f0..6ccec8621 100644 --- a/file_cache/cache_test.go +++ b/file_cache/cache_test.go @@ -20,9 +20,11 @@ package simple_cache_test import ( "context" + "fmt" "io" "os" "path/filepath" + "runtime" "testing" "github.com/pelicanplatform/pelican/config" @@ -41,7 +43,11 @@ func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) string { modules.Set(config.OriginType) modules.Set(config.DirectorType) modules.Set(config.RegistryType) - modules.Set(config.CacheType) + if runtime.GOOS == "darwin" { + viper.Set("Origin.EnableFallbackRead", true) + } else { + modules.Set(config.CacheType) + } modules.Set(config.LocalCacheType) tmpPathPattern := "XRootD-Test_Origin*" @@ -104,6 +110,7 @@ func TestFileCacheSimpleGet(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() + viper.Set("Origin.EnablePublicReads", true) originDir := spinup(t, ctx, egrp) err := os.WriteFile(filepath.Join(originDir, "hello_world.txt"), []byte("Hello, World!"), os.FileMode(0644)) @@ -118,4 +125,13 @@ func TestFileCacheSimpleGet(t *testing.T) { byteBuff, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, "Hello, World!", string(byteBuff)) + + // Query again -- cache hit case + reader, err = sc.Get("/test/hello_world.txt", "") + require.NoError(t, err) + + assert.Equal(t, "*os.File", fmt.Sprintf("%T", reader)) + byteBuff, err = io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, "Hello, World!", string(byteBuff)) } diff --git a/file_cache/simple_cache.go b/file_cache/simple_cache.go index 300d72012..1e1731628 100644 --- a/file_cache/simple_cache.go +++ b/file_cache/simple_cache.go @@ -22,6 +22,7 @@ import ( "container/heap" "context" "encoding/json" + "fmt" "io" "net/url" "os" @@ -199,6 +200,15 @@ func (lru *lru) Pop() any { return x } +func (ds *downloadStatus) String() string { + errP := ds.err.Load() + if errP == nil { + return fmt.Sprintf("{size=%d,total=%d,done=%v}", ds.curSize.Load(), ds.size.Load(), ds.done.Load()) + } else { + return fmt.Sprintf("{size=%d,total=%d,err=%s,done=%v}", ds.curSize.Load(), ds.size.Load(), *errP, ds.done.Load()) + } +} + // Create a simple cache object // // Launches background goroutines associated with the cache @@ -210,10 +220,10 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, err = errors.New("FileCache.DataLocation is not set; cannot determine where to place file cache's data") return } - if err = os.MkdirAll(cacheDir, os.FileMode(0700)); err != nil { + if err = os.RemoveAll(cacheDir); err != nil { return } - if err = os.RemoveAll(cacheDir); err != nil { + if err = os.MkdirAll(cacheDir, os.FileMode(0700)); err != nil { return } @@ -222,7 +232,7 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, if sizeStr == "" || sizeStr == "0" { var stat syscall.Statfs_t if err = syscall.Statfs(cacheDir, &stat); err != nil { - err = errors.Wrap(err, "Unable to determine free space for cache directory") + err = errors.Wrapf(err, "unable to determine free space for cache directory %s", cacheDir) return } cacheSize = stat.Bavail * uint64(stat.Bsize) @@ -237,17 +247,25 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, highWaterPercentage := param.FileCache_HighWaterMarkPercentage.GetInt() lowWaterPercentage := param.FileCache_LowWaterMarkPercentage.GetInt() + directorUrl, err := url.Parse(param.Federation_DirectorUrl.GetString()) + if err != nil { + return + } + sc = &SimpleCache{ - ctx: ctx, - egrp: egrp, - te: client.NewTransferEngine(ctx), - downloads: make(map[string]*activeDownload), - hitChan: make(chan lruEntry, 64), - highWater: (cacheSize / 100) * uint64(highWaterPercentage), - lowWater: (cacheSize / 100) * uint64(lowWaterPercentage), - cacheSize: cacheSize, - basePath: cacheDir, - ac: newAuthConfig(ctx, egrp), + ctx: ctx, + egrp: egrp, + te: client.NewTransferEngine(ctx), + downloads: make(map[string]*activeDownload), + hitChan: make(chan lruEntry, 64), + highWater: (cacheSize / 100) * uint64(highWaterPercentage), + lowWater: (cacheSize / 100) * uint64(lowWaterPercentage), + cacheSize: cacheSize, + basePath: cacheDir, + ac: newAuthConfig(ctx, egrp), + sizeReq: make(chan availSizeReq), + directorURL: directorUrl, + lruLookup: make(map[string]*lruEntry), } sc.tc, err = sc.te.NewClient(client.WithAcquireToken(false), client.WithCallback(sc.callback)) @@ -265,6 +283,7 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, egrp.Go(sc.runMux) egrp.Go(sc.periodicUpdateConfig) + log.Debugln("Successfully created a new local cache object") return } @@ -275,7 +294,7 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, func (sc *SimpleCache) callback(path string, downloaded int64, size int64, completed bool) { ds := func() (ds *downloadStatus) { sc.mutex.RLock() - defer sc.mutex.Unlock() + defer sc.mutex.RUnlock() dl := sc.downloads[path] if dl != nil { ds = dl.status @@ -301,6 +320,7 @@ func (sc *SimpleCache) runMux() error { tmpResults := make([]result, 0) cancelRequest := make([]chan bool, 0) activeJobs := make(map[string]*activeDownload) + jobPath := make(map[string]string) ticker := time.NewTicker(100 * time.Millisecond) clientClosed := false for { @@ -308,11 +328,10 @@ func (sc *SimpleCache) runMux() error { lenCancel := len(cancelRequest) lenChan := lenResults + lenCancel cases := make([]reflect.SelectCase, lenResults+6) - jobPath := make(map[uuid.UUID]string) for idx, info := range tmpResults { cases[idx].Dir = reflect.SelectSend - cases[idx].Chan = reflect.ValueOf(tmpResults[idx]) - cases[idx].Send = reflect.ValueOf(&activeJobs[info.path].status) + cases[idx].Chan = reflect.ValueOf(info.channel) + cases[idx].Send = reflect.ValueOf(info.ds) } for idx, channel := range cancelRequest { cases[lenResults+idx].Dir = reflect.SelectSend @@ -358,19 +377,34 @@ func (sc *SimpleCache) runMux() error { clientClosed = true continue } - results := recv.Interface().(*client.TransferResults) - path := jobPath[results.JobId] - delete(jobPath, results.JobId) + results := recv.Interface().(client.TransferResults) + path := jobPath[results.ID()] + if path == "" { + log.Errorf("Transfer results from job %s but no corresponding path known", results.ID()) + continue + } + delete(jobPath, results.ID()) ad := activeJobs[path] + if ad == nil { + log.Errorf("Transfer results from job %s returned for path %s but no active job known", results.ID(), path) + continue + } delete(activeJobs, path) - ad.status.err.Store(&results.Error) + if results.Error != nil { + ad.status.err.Store(&results.Error) + } ad.status.curSize.Store(results.TransferredBytes) ad.status.size.Store(results.TransferredBytes) ad.status.done.Store(true) for _, waiter := range ad.waiterList { - tmpResults = append(tmpResults, result{path: path, channel: waiter.notify}) + tmpResults = append(tmpResults, result{ds: ad.status, path: path, channel: waiter.notify}) } if results.Error == nil { + if fp, err := os.OpenFile(filepath.Join(sc.basePath, path)+".DONE", os.O_CREATE|os.O_WRONLY, os.FileMode(0600)); err != nil { + log.Debugln("Unable to save a DONE file for cache path", path) + } else { + fp.Close() + } entry := sc.lruLookup[path] if entry == nil { entry = &lruEntry{} @@ -384,15 +418,31 @@ func (sc *SimpleCache) runMux() error { } } else if chosen == lenChan+2 { // Ticker has fired - update progress + jobsToDelete := make([]string, 0) for path, dl := range activeJobs { + if _, err := dl.tj.GetLookupStatus(); err != nil { + dl.status.err.Store(&err) + for _, waiter := range dl.waiterList { + tmpResults = append(tmpResults, result{path: path, channel: waiter.notify, ds: dl.status}) + } + jobsToDelete = append(jobsToDelete, path) + delete(jobPath, dl.tj.ID()) + continue + } + curSize := dl.status.curSize.Load() for { if dl.waiterList.Len() > 0 && dl.waiterList[0].size <= curSize { waiter := heap.Pop(&dl.waiterList).(waiterInfo) tmpResults = append(tmpResults, result{path: path, channel: waiter.notify, ds: dl.status}) + } else { + break } } } + for _, path := range jobsToDelete { + delete(activeJobs, path) + } } else if chosen == lenChan+3 { // New request req := recv.Interface().(availSizeReq) @@ -446,7 +496,17 @@ func (sc *SimpleCache) runMux() error { size: req.size, notify: req.results, }) + if err := sc.tc.Submit(tj); err != nil { + ds := &downloadStatus{} + ds.err.Store(&err) + tmpResults = append(tmpResults, result{ + path: req.request.path, + channel: req.results, + ds: ds, + }) + } activeJobs[req.request.path] = ad + jobPath[tj.ID()] = req.request.path } else if chosen == lenChan+4 { // Cancel a given request. req := recv.Interface().(cancelReq) @@ -636,8 +696,14 @@ func (cr *cacheReader) Read(p []byte) (n int, err error) { sizeReq := availSizeReq{ request: req, size: neededSize, + results: cr.status, + } + select { + case <-cr.sc.ctx.Done(): + err = cr.sc.ctx.Err() + return + case cr.sc.sizeReq <- sizeReq: } - cr.sc.sizeReq <- sizeReq } select { case <-cr.sc.ctx.Done(): @@ -648,6 +714,10 @@ func (cr *cacheReader) Read(p []byte) (n int, err error) { err = errors.New("unable to get response from cache engine") return } + if availSize == nil { + err = errors.New("internal error - cache sent a nil result") + return + } dlErr := availSize.err.Load() if dlErr != nil && *dlErr != nil { err = *dlErr diff --git a/launchers/launcher.go b/launchers/launcher.go index 0cb371385..6e5f37b05 100644 --- a/launchers/launcher.go +++ b/launchers/launcher.go @@ -34,7 +34,7 @@ import ( "github.com/pelicanplatform/pelican/broker" "github.com/pelicanplatform/pelican/config" - "github.com/pelicanplatform/pelican/file_cache" + simple_cache "github.com/pelicanplatform/pelican/file_cache" "github.com/pelicanplatform/pelican/origin_ui" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/server_ui" @@ -183,14 +183,6 @@ func LaunchModules(ctx context.Context, modules config.ServerType) (context.Canc } } - if modules.IsEnabled(config.LocalCacheType) { - log.Debugln("Starting local cache listener") - if err := simple_cache.LaunchListener(ctx, egrp); err != nil { - log.Errorln("Failure when starting the local cache listener:", err) - return shutdownCancel, err - } - } - log.Info("Starting web engine...") lnReference = nil egrp.Go(func() error { @@ -257,6 +249,14 @@ func LaunchModules(ctx context.Context, modules config.ServerType) (context.Canc } } + if modules.IsEnabled(config.LocalCacheType) { + log.Debugln("Starting local cache listener") + if err := simple_cache.LaunchListener(ctx, egrp); err != nil { + log.Errorln("Failure when starting the local cache listener:", err) + return shutdownCancel, err + } + } + if param.Server_EnableUI.GetBool() { if err = web_ui.ConfigureEmbeddedPrometheus(ctx, engine); err != nil { return shutdownCancel, errors.Wrap(err, "Failed to configure embedded prometheus instance") diff --git a/server_ui/advertise.go b/server_ui/advertise.go index 74289f542..568efd1b6 100644 --- a/server_ui/advertise.go +++ b/server_ui/advertise.go @@ -47,17 +47,22 @@ type directorResponse struct { ApprovalError bool `json:"approval_error"` } +func doAdvertise(ctx context.Context, servers []server_utils.XRootDServer) { + log.Debugf("About to advertise %d XRootD servers", len(servers)) + err := Advertise(ctx, servers) + if err != nil { + log.Warningln("XRootD server advertise failed:", err) + metrics.SetComponentHealthStatus(metrics.OriginCache_Federation, metrics.StatusCritical, fmt.Sprintf("XRootD server advertise failed: %v", err)) + } else { + metrics.SetComponentHealthStatus(metrics.OriginCache_Federation, metrics.StatusOK, "") + } +} + func LaunchPeriodicAdvertise(ctx context.Context, egrp *errgroup.Group, servers []server_utils.XRootDServer) error { + doAdvertise(ctx, servers) + ticker := time.NewTicker(1 * time.Minute) egrp.Go(func() error { - log.Debugf("About to advertise %d XRootD servers", len(servers)) - err := Advertise(ctx, servers) - if err != nil { - log.Warningln("XRootD server advertise failed:", err) - metrics.SetComponentHealthStatus(metrics.OriginCache_Federation, metrics.StatusCritical, fmt.Sprintf("XRootD server advertise failed: %v", err)) - } else { - metrics.SetComponentHealthStatus(metrics.OriginCache_Federation, metrics.StatusOK, "") - } for { select { @@ -73,6 +78,8 @@ func LaunchPeriodicAdvertise(ctx context.Context, egrp *errgroup.Group, servers log.Infoln("Periodic advertisement loop has been terminated") return nil } + + doAdvertise(ctx, servers) } }) From 377f8798b3887dc9d5cc50ad986dc16312c3582d Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 3 Mar 2024 11:34:23 -0600 Subject: [PATCH 11/45] Simplify use of `AddScopes` AddScopes now takes a variadic argument instead of a slice; further, added some generics so the same functions can be used for normal or resource-enabled scopes. --- broker/token_utils.go | 4 ++-- client/fed_test.go | 27 +++++-------------------- director/origin_api_test.go | 2 +- director/origin_monitor.go | 2 +- registry/registry_ui_test.go | 2 +- server_ui/advertise.go | 2 +- token_scopes/token_scope_utils.go | 19 +++++++++--------- utils/token_utils.go | 33 ++++++++++++++++--------------- utils/token_utils_test.go | 10 ++++------ web_ui/authentication.go | 3 +-- web_ui/prometheus.go | 4 ++-- 11 files changed, 45 insertions(+), 63 deletions(-) diff --git a/broker/token_utils.go b/broker/token_utils.go index 52e20bc12..73c4813e0 100644 --- a/broker/token_utils.go +++ b/broker/token_utils.go @@ -131,7 +131,7 @@ func createToken(namespace, subject, audience string, desiredScope token_scopes. Version: "1.0", Subject: subject, } - tokenCfg.AddScopes([]token_scopes.TokenScope{desiredScope}) + tokenCfg.AddScopes(desiredScope) token, err = tokenCfg.CreateToken() return @@ -170,7 +170,7 @@ func verifyToken(ctx context.Context, token, namespace, audience string, require scopeValidator := token_scopes.CreateScopeValidator([]token_scopes.TokenScope{requiredScope}, false) err = jwt.Validate(tok, - jwt.WithAudience(param.Server_ExternalWebUrl.GetString()), + jwt.WithAudience(audience), jwt.WithValidator(scopeValidator), jwt.WithClaimValue("iss", issuerUrl), ) diff --git a/client/fed_test.go b/client/fed_test.go index cdb152c06..ac5a65e98 100644 --- a/client/fed_test.go +++ b/client/fed_test.go @@ -58,18 +58,6 @@ func generateFileTestScitoken() (string, error) { return "", errors.New("Failed to create token: Invalid iss, Server_ExternalWebUrl is empty") } - scopes := []token_scopes.TokenScope{} - readScope, err := token_scopes.Storage_Read.Path("/") - if err != nil { - return "", errors.Wrap(err, "failed to create 'read' scope for file test token:") - } - scopes = append(scopes, readScope) - modScope, err := token_scopes.Storage_Modify.Path("/") - if err != nil { - return "", errors.Wrap(err, "failed to create 'modify' scope for file test token:") - } - scopes = append(scopes, modScope) - fTestTokenCfg := utils.TokenConfig{ TokenProfile: utils.WLCG, Lifetime: time.Minute, @@ -78,7 +66,8 @@ func generateFileTestScitoken() (string, error) { Version: "1.0", Subject: "origin", } - fTestTokenCfg.AddScopes(scopes) + fTestTokenCfg.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/"), + token_scopes.NewResourceScope(token_scopes.Storage_Modify, "/")) // CreateToken also handles validation for us tok, err := fTestTokenCfg.CreateToken() @@ -369,7 +358,7 @@ func TestGetAndPutAuth(t *testing.T) { modScope, err := token_scopes.Storage_Modify.Path("/") assert.NoError(t, err) scopes = append(scopes, modScope) - tokenConfig.AddScopes(scopes) + tokenConfig.AddScopes(scopes...) token, err := tokenConfig.CreateToken() assert.NoError(t, err) tempToken, err := os.CreateTemp(t.TempDir(), "token") @@ -532,14 +521,8 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { Audience: []string{audience}, Subject: "origin", } - scopes := []token_scopes.TokenScope{} - readScope, err := token_scopes.Storage_Read.Path("/") - assert.NoError(t, err) - scopes = append(scopes, readScope) - modScope, err := token_scopes.Storage_Modify.Path("/") - assert.NoError(t, err) - scopes = append(scopes, modScope) - tokenConfig.AddScopes(scopes) + tokenConfig.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/"), + token_scopes.NewResourceScope(token_scopes.Storage_Modify, "/")) token, err := tokenConfig.CreateToken() assert.NoError(t, err) tempToken, err := os.CreateTemp(t.TempDir(), "token") diff --git a/director/origin_api_test.go b/director/origin_api_test.go index e35b663c7..9f4fb457d 100644 --- a/director/origin_api_test.go +++ b/director/origin_api_test.go @@ -137,7 +137,7 @@ func TestVerifyAdvertiseToken(t *testing.T) { Audience: []string{"https://director-url.org"}, Subject: "origin", } - advTokenCfg.AddScopes([]token_scopes.TokenScope{token_scopes.Pelican_Advertise}) + advTokenCfg.AddScopes(token_scopes.Pelican_Advertise) // CreateToken also handles validation for us tok, err := advTokenCfg.CreateToken() diff --git a/director/origin_monitor.go b/director/origin_monitor.go index 33a94f547..c0c8cafa2 100644 --- a/director/origin_monitor.go +++ b/director/origin_monitor.go @@ -63,7 +63,7 @@ func reportStatusToOrigin(ctx context.Context, originWebUrl string, status strin Audience: []string{originWebUrl}, Subject: "director", } - testTokenCfg.AddScopes([]token_scopes.TokenScope{token_scopes.Pelican_DirectorTestReport}) + testTokenCfg.AddScopes(token_scopes.Pelican_DirectorTestReport) tok, err := testTokenCfg.CreateToken() if err != nil { diff --git a/registry/registry_ui_test.go b/registry/registry_ui_test.go index 501e315e4..f97944416 100644 --- a/registry/registry_ui_test.go +++ b/registry/registry_ui_test.go @@ -240,7 +240,7 @@ func TestListNamespaces(t *testing.T) { req, _ := http.NewRequest("GET", requestURL, nil) if tc.authUser { tokenCfg := utils.TokenConfig{Issuer: "https://mock-server.com", Lifetime: time.Minute, Subject: "admin", TokenProfile: utils.None} - tokenCfg.AddScopes([]token_scopes.TokenScope{token_scopes.WebUi_Access}) + tokenCfg.AddScopes(token_scopes.WebUi_Access) token, err := tokenCfg.CreateToken() require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "login", Value: token, Path: "/"}) diff --git a/server_ui/advertise.go b/server_ui/advertise.go index 568efd1b6..3ca7d9812 100644 --- a/server_ui/advertise.go +++ b/server_ui/advertise.go @@ -150,7 +150,7 @@ func advertiseInternal(ctx context.Context, server server_utils.XRootDServer) er Audience: []string{param.Federation_DirectorUrl.GetString()}, Subject: "origin", } - advTokenCfg.AddScopes([]token_scopes.TokenScope{token_scopes.Pelican_Advertise}) + advTokenCfg.AddScopes(token_scopes.Pelican_Advertise) // CreateToken also handles validation for us tok, err := advTokenCfg.CreateToken() diff --git a/token_scopes/token_scope_utils.go b/token_scopes/token_scope_utils.go index 32452ed1a..35bd35126 100644 --- a/token_scopes/token_scope_utils.go +++ b/token_scopes/token_scope_utils.go @@ -39,6 +39,12 @@ type ( Authorization TokenScope Resource string } + + Scope interface { + TokenScope | ResourceScope + + String() string + } ) func NewResourceScope(authz TokenScope, resource string) ResourceScope { @@ -73,19 +79,14 @@ func (rc ResourceScope) Contains(other ResourceScope) bool { // Get a string representation of a list of scopes, which can then be passed // to the Claim builder of JWT constructor -func GetScopeString(scopes []TokenScope) (scopeString string) { - scopeString = "" +func GetScopeString[Scopes ~[]Sc, Sc Scope](scopes Scopes) (scopeString string) { if len(scopes) == 0 { return } - if len(scopes) == 1 { - scopeString = string(scopes[0]) - return - } - for _, scope := range scopes { - scopeString += scope.String() + " " + scopeString = scopes[0].String() + for _, scope := range scopes[1:] { + scopeString += " " + scope.String() } - scopeString = strings.TrimRight(scopeString, " ") return } diff --git a/utils/token_utils.go b/utils/token_utils.go index cb42dad1d..6051aab47 100644 --- a/utils/token_utils.go +++ b/utils/token_utils.go @@ -35,7 +35,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" - "github.com/spf13/viper" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/token_scopes" @@ -64,6 +63,9 @@ const ( WLCG TokenProfile = "wlcg" Scitokens2 TokenProfile = "scitokens2" None TokenProfile = "none" + + WLCGAny string = "https://wlcg.cern.ch/jwt/v1/any" + ScitokensAny string = "ANY" ) func (p TokenProfile) String() string { @@ -150,16 +152,14 @@ func (config *TokenConfig) verifyCreateWLCG() error { return nil } -// AddScopes appends a list of token_scopes.TokenScope to the Scope field. -func (config *TokenConfig) AddScopes(scopes []token_scopes.TokenScope) { - if config.scope == "" { - config.scope = token_scopes.GetScopeString(scopes) - } else { - scopeStr := token_scopes.GetScopeString(scopes) - if scopeStr != "" { - config.scope += " " + scopeStr - } - } +// AddScopes appends multiple token_scopes.TokenScope to the Scope field. +func (config *TokenConfig) AddScopes(scopes ...token_scopes.TokenScope) { + config.AddRawScope(token_scopes.GetScopeString(scopes)) +} + +// AddResourceScopes appends multiple token_scopes.TokenScope to the Scope field. +func (config *TokenConfig) AddResourceScopes(scopes ...token_scopes.ResourceScope) { + config.AddRawScope(token_scopes.GetScopeString(scopes)) } // AddRawScope appends a space-delimited, case-sensitive scope string to the Scope field. @@ -170,10 +170,8 @@ func (config *TokenConfig) AddScopes(scopes []token_scopes.TokenScope) { func (config *TokenConfig) AddRawScope(scope string) { if config.scope == "" { config.scope = scope - } else { - if scope != "" { - config.scope += " " + scope - } + } else if scope != "" { + config.scope += " " + scope } } @@ -217,7 +215,10 @@ func (tokenConfig *TokenConfig) CreateTokenWithKey(key jwk.Key) (string, error) } issuerUrl = url.String() } else { - issuerUrlStr := viper.GetString("IssuerUrl") + issuerUrlStr, err := config.GetServerIssuerURL() + if err != nil { + return "", errors.Wrap(err, "unable to generate token issuer URL") + } url, err := url.Parse(issuerUrlStr) if err != nil { return "", errors.Wrap(err, "Failed to parse the configured IssuerUrl") diff --git a/utils/token_utils_test.go b/utils/token_utils_test.go index 09a697ffe..6bffdd9e8 100644 --- a/utils/token_utils_test.go +++ b/utils/token_utils_test.go @@ -136,7 +136,7 @@ func TestAddScopes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config := &TokenConfig{scope: tt.initialScope} - config.AddScopes(tt.additionalScopes) + config.AddScopes(tt.additionalScopes...) assert.Equal(t, tt.expectedScope, config.GetScope(), fmt.Sprintf("AddScopes() = %v, want %v", config.scope, tt.expectedScope)) }) } @@ -258,11 +258,9 @@ func TestCreateToken(t *testing.T) { _, err = tokenConfig.CreateToken() assert.NoError(t, err) - // Test without configured issuer - tokenConfig = TokenConfig{TokenProfile: WLCG, Audience: []string{"foo"}, Subject: "bar", Lifetime: time.Minute * 10} - _, err = tokenConfig.CreateToken() - assert.EqualError(t, err, "No issuer was found in the configuration file, "+ - "and none was provided as a claim") + // Note: we used to test what occurred when no issuer was set (assuming it should fail). However, we switched to a new + // helper function in the `config` module which falls back to an auto-constructed IssuerUrl, meaning the + // test condition was no longer valid; the test was deleted. } func TestLookupIssuerJwksUrl(t *testing.T) { diff --git a/web_ui/authentication.go b/web_ui/authentication.go index 6fed0a2d9..bc2cff8aa 100644 --- a/web_ui/authentication.go +++ b/web_ui/authentication.go @@ -153,7 +153,6 @@ func GetUser(ctx *gin.Context) (string, error) { // Create a JWT and set the "login" cookie to store that JWT func setLoginCookie(ctx *gin.Context, user string) { - scopes := []token_scopes.TokenScope{token_scopes.WebUi_Access, token_scopes.Monitoring_Query, token_scopes.Monitoring_Scrape} loginCookieTokenCfg := utils.TokenConfig{ TokenProfile: utils.WLCG, Lifetime: 30 * time.Minute, @@ -162,7 +161,7 @@ func setLoginCookie(ctx *gin.Context, user string) { Version: "1.0", Subject: user, } - loginCookieTokenCfg.AddScopes(scopes) + loginCookieTokenCfg.AddScopes(token_scopes.WebUi_Access, token_scopes.Monitoring_Query, token_scopes.Monitoring_Scrape) // CreateToken also handles validation for us tok, err := loginCookieTokenCfg.CreateToken() diff --git a/web_ui/prometheus.go b/web_ui/prometheus.go index 60f64f031..08c1a1de0 100644 --- a/web_ui/prometheus.go +++ b/web_ui/prometheus.go @@ -159,7 +159,7 @@ func configDirectorPromScraper(ctx context.Context) (*config.ScrapeConfig, error Version: "1.0", Subject: "director", } - promTokenCfg.AddScopes([]token_scopes.TokenScope{token_scopes.Pelican_DirectorServiceDiscovery}) + promTokenCfg.AddScopes(token_scopes.Pelican_DirectorServiceDiscovery) // CreateToken also handles validation for us sdToken, err := promTokenCfg.CreateToken() @@ -175,7 +175,7 @@ func configDirectorPromScraper(ctx context.Context) (*config.ScrapeConfig, error Audience: []string{"prometheus"}, Subject: "director", } - scrapeTokenCfg.AddScopes([]token_scopes.TokenScope{token_scopes.Monitoring_Scrape}) + scrapeTokenCfg.AddScopes(token_scopes.Monitoring_Scrape) scraperToken, err := scrapeTokenCfg.CreateToken() if err != nil { From 33a5b8079e86776f5be3c5b90b5be57f14b97928 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 3 Mar 2024 11:56:37 -0600 Subject: [PATCH 12/45] Update jwx library --- go.mod | 8 ++++---- go.sum | 32 ++++++++------------------------ 2 files changed, 12 insertions(+), 28 deletions(-) diff --git a/go.mod b/go.mod index 702f5b2f9..953dda5b4 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/jellydator/ttlcache/v3 v3.1.0 github.com/jsipprell/keyctl v1.0.4-0.20211208153515-36ca02672b6c github.com/lestrrat-go/httprc v1.0.4 - github.com/lestrrat-go/jwx/v2 v2.0.16 + github.com/lestrrat-go/jwx/v2 v2.0.20 github.com/minio/minio-go/v7 v7.0.65 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f github.com/oklog/run v1.1.0 @@ -44,10 +44,10 @@ require ( github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a github.com/zsais/go-gin-prometheus v0.1.0 go.uber.org/atomic v1.11.0 - golang.org/x/crypto v0.17.0 + golang.org/x/crypto v0.19.0 golang.org/x/net v0.19.0 golang.org/x/oauth2 v0.15.0 - golang.org/x/term v0.15.0 + golang.org/x/term v0.17.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/gorm v1.25.7 kernel.org/pub/linux/libs/security/libcap/cap v1.2.69 @@ -75,7 +75,7 @@ require ( go.opentelemetry.io/collector/pdata v1.0.0-rcv0016 // indirect go.opentelemetry.io/collector/semconv v0.87.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/sys v0.17.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f // indirect google.golang.org/grpc v1.59.0 // indirect modernc.org/sqlite v1.28.0 // indirect diff --git a/go.sum b/go.sum index edb20675d..126c3100c 100644 --- a/go.sum +++ b/go.sum @@ -122,7 +122,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/dennwc/varint v1.0.0 h1:kGNFFSSw8ToIy3obO/kKr8U9GZYUAxQEVuix4zfDWzE= @@ -506,9 +505,8 @@ github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJG github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= -github.com/lestrrat-go/jwx/v2 v2.0.16 h1:TuH3dBkYTy2giQg/9D8f20znS3JtMRuQJ372boS3lWk= -github.com/lestrrat-go/jwx/v2 v2.0.16/go.mod h1:jBHyESp4e7QxfERM0UKkQ80/94paqNIEcdEfiUYz5zE= -github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/lestrrat-go/jwx/v2 v2.0.20 h1:sAgXuWS/t8ykxS9Bi2Qtn5Qhpakw1wrcjxChudjolCc= +github.com/lestrrat-go/jwx/v2 v2.0.20/go.mod h1:UlCSmKqw+agm5BsOBfEAbTvKsEApaGNqHAEUTv5PJC4= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/libsql/sqlite-antlr4-parser v0.0.0-20230802215326-5cb5bb604475 h1:6PfEMwfInASh9hkN83aR0j4W/eKaAZt/AURtXAXlas0= @@ -798,9 +796,8 @@ golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -834,7 +831,6 @@ golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -873,8 +869,6 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -898,7 +892,6 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -943,25 +936,19 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -973,8 +960,6 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1029,7 +1014,6 @@ golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.15.0 h1:zdAyfUGbYmuVokhzVmghFl2ZJh5QhcfebBgmVPFYA+8= golang.org/x/tools v0.15.0/go.mod h1:hpksKq4dtpQWS1uQ61JkdqWM3LscIS6Slf+VVkm+wQk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From 01a4ddad373d7af4518d0180c5e31e0246c3c2ed Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 3 Mar 2024 11:59:36 -0600 Subject: [PATCH 13/45] Add tests for authenticated LocalCache and access via HTTP Includes corresponding bugfixes. --- file_cache/cache_api.go | 6 ++- file_cache/cache_authz.go | 22 ++++++-- file_cache/cache_test.go | 108 +++++++++++++++++++++++++++++++++++--- 3 files changed, 123 insertions(+), 13 deletions(-) diff --git a/file_cache/cache_api.go b/file_cache/cache_api.go index f8a3b6628..4b8cb73dc 100644 --- a/file_cache/cache_api.go +++ b/file_cache/cache_api.go @@ -26,6 +26,7 @@ import ( "net" "net/http" "os" + "path" "path/filepath" "strings" @@ -69,10 +70,11 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { authzHeader := r.Header.Get("Authorization") bearerToken := "" - if strings.HasPrefix(bearerToken, "Bearer ") { + if strings.HasPrefix(authzHeader, "Bearer ") { bearerToken = authzHeader[7:] // len("Bearer ") == 7 } - reader, err := sc.Get(r.URL.Path, bearerToken) + path := path.Clean(r.URL.Path) + reader, err := sc.Get(path, bearerToken) if errors.Is(err, authorizationDenied) { w.WriteHeader(http.StatusForbidden) if _, err = w.Write([]byte("Authorization Denied")); err != nil { diff --git a/file_cache/cache_authz.go b/file_cache/cache_authz.go index fd49d957a..cbdaa932a 100644 --- a/file_cache/cache_authz.go +++ b/file_cache/cache_authz.go @@ -64,7 +64,7 @@ func newAuthConfig(ctx context.Context, egrp *errgroup.Group) (ac *authConfig) { if err != nil { log.Errorln("Failed to lookup JWKS URL:", err) } else { - ar := jwk.NewCache(ctx) + ar = jwk.NewCache(ctx) client := &http.Client{Transport: config.GetTransport()} if err = ar.Register(jwksUrl.String(), jwk.WithMinRefreshInterval(15*time.Minute), jwk.WithHTTPClient(client)); err != nil { log.Errorln("Failed to register JWKS URL with cache: ", err) @@ -74,10 +74,13 @@ func newAuthConfig(ctx context.Context, egrp *errgroup.Group) (ac *authConfig) { } ttl := ttlcache.DefaultTTL - if err != nil { + var item *ttlcache.Item[string, authConfigItem] + if ar != nil { + item = cache.Set(issuerUrl, authConfigItem{set: jwk.NewCachedSet(ar, jwksUrl.String()), err: nil}, ttl) + } else { ttl = time.Duration(5 * time.Minute) + item = cache.Set(issuerUrl, authConfigItem{set: nil, err: err}, ttl) } - item := cache.Set(issuerUrl, authConfigItem{set: jwk.NewCachedSet(ar, jwksUrl.String()), err: nil}, ttl) return item }, ) @@ -144,12 +147,23 @@ func (ac *authConfig) getResourceScopes(token string) (scopes []token_scopes.Res return } - tok, err = jwt.Parse([]byte(token), jwt.WithKeySet(issuerConfItem.Value().set)) + item := issuerConfItem.Value() + if item.set == nil { + if item.err == nil { + err = item.err + } else { + err = errors.Errorf("failed to fetch public key set") + } + return + } + tok, err = jwt.Parse([]byte(token), jwt.WithKeySet(item.set)) if err != nil { return } + err = jwt.Validate(tok) if err != nil { + err = errors.Wrap(err, "unable to get resource scopes because validation failed") return } diff --git a/file_cache/cache_test.go b/file_cache/cache_test.go index 6ccec8621..91c87a95e 100644 --- a/file_cache/cache_test.go +++ b/file_cache/cache_test.go @@ -22,22 +22,28 @@ import ( "context" "fmt" "io" + "net" + "net/http" "os" "path/filepath" "runtime" "testing" + "time" "github.com/pelicanplatform/pelican/config" simple_cache "github.com/pelicanplatform/pelican/file_cache" "github.com/pelicanplatform/pelican/launchers" + "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/test_utils" + "github.com/pelicanplatform/pelican/token_scopes" + "github.com/pelicanplatform/pelican/utils" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) -func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) string { +func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) { modules := config.ServerType(0) modules.Set(config.OriginType) @@ -103,18 +109,20 @@ func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) string { egrp.Wait() }) - return originDir + err = os.WriteFile(filepath.Join(originDir, "hello_world.txt"), []byte("Hello, World!"), os.FileMode(0644)) + require.NoError(t, err) } -func TestFileCacheSimpleGet(t *testing.T) { +// Setup a federation, invoke "get" through the local cache module +// +// The download is done twice -- once to verify functionality and once +// as a cache hit. +func TestFedPublicGet(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() viper.Set("Origin.EnablePublicReads", true) - originDir := spinup(t, ctx, egrp) - - err := os.WriteFile(filepath.Join(originDir, "hello_world.txt"), []byte("Hello, World!"), os.FileMode(0644)) - require.NoError(t, err) + spinup(t, ctx, egrp) sc, err := simple_cache.NewSimpleCache(ctx, egrp) require.NoError(t, err) @@ -135,3 +143,89 @@ func TestFileCacheSimpleGet(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Hello, World!", string(byteBuff)) } + +func TestFedAuthGet(t *testing.T) { + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer cancel() + + viper.Set("Origin.EnablePublicReads", false) + spinup(t, ctx, egrp) + + lc, err := simple_cache.NewSimpleCache(ctx, egrp) + require.NoError(t, err) + + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + tokConf := utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Duration(time.Minute), + Issuer: issuer, + Subject: "test", + Audience: []string{utils.WLCGAny}, + } + tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) + + tok, err := tokConf.CreateToken() + require.NoError(t, err) + + reader, err := lc.Get("/test/hello_world.txt", tok) + require.NoError(t, err) + + byteBuff, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, "Hello, World!", string(byteBuff)) + + tokConf = utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Duration(time.Minute), + Issuer: issuer, + Subject: "test", + Audience: []string{utils.WLCGAny}, + } + tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/not_correct")) + + tok, err = tokConf.CreateToken() + require.NoError(t, err) + + _, err = lc.Get("/test/hello_world.txt", tok) + assert.Error(t, err) + assert.Equal(t, "authorization denied", err.Error()) +} + +func TestHttpReq(t *testing.T) { + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer cancel() + + viper.Set("Origin.EnablePublicReads", false) + spinup(t, ctx, egrp) + + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + tokConf := utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Duration(time.Minute), + Issuer: issuer, + Subject: "test", + Audience: []string{utils.WLCGAny}, + } + tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) + tok, err := tokConf.CreateToken() + require.NoError(t, err) + + transport := config.GetTransport().Clone() + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", param.FileCache_Socket.GetString()) + } + + client := &http.Client{Transport: transport} + req, err := http.NewRequest("GET", "http://localhost/test/hello_world.txt", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+tok) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "Hello, World!", string(body)) +} From 83ffc21f408c95ef5c4ab626211d7df218f15623 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 3 Mar 2024 12:02:10 -0600 Subject: [PATCH 14/45] Fix linter issue in LocalCache test --- file_cache/cache_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/file_cache/cache_test.go b/file_cache/cache_test.go index 91c87a95e..da39f6241 100644 --- a/file_cache/cache_test.go +++ b/file_cache/cache_test.go @@ -106,7 +106,9 @@ func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) { require.NoError(t, err) t.Cleanup(func() { cancel() - egrp.Wait() + if err = egrp.Wait(); err != nil && err != context.Canceled { + require.NoError(t, err) + } }) err = os.WriteFile(filepath.Join(originDir, "hello_world.txt"), []byte("Hello, World!"), os.FileMode(0644)) From 60c707133d7ab1d4268fc4b7315d766fb8de0ccd Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 3 Mar 2024 15:12:34 -0600 Subject: [PATCH 15/45] Add unit test for invoking the Pelican client to download Includes corresponding bugfixes. The client can now work with the cache! --- client/director.go | 22 +++++++--- client/director_test.go | 8 ++-- client/handle_http.go | 74 +++++++++++++++++++++----------- client/main.go | 6 +-- director/redirect.go | 5 +-- file_cache/cache_test.go | 92 ++++++++++++++++++++++++---------------- 6 files changed, 129 insertions(+), 78 deletions(-) diff --git a/client/director.go b/client/director.go index 9069e4067..9068d10be 100644 --- a/client/director.go +++ b/client/director.go @@ -235,7 +235,7 @@ func GetCachesFromDirectorResponse(resp *http.Response, needsToken bool) (caches } // NewTransferDetails creates the TransferDetails struct with the given cache -func NewTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transferDetailsOptions) []transferAttemptDetails { +func newTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transferDetailsOptions) []transferAttemptDetails { details := make([]transferAttemptDetails, 0) cacheEndpoint := cache.EndpointUrl @@ -247,7 +247,7 @@ func NewTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transf } if cacheURL.Scheme == "unix" && cacheURL.Host != "" { cacheURL.Path = path.Clean("/" + path.Join(cacheURL.Host, cacheURL.Path)) - } else if cacheURL.Host == "" { + } else if cacheURL.Scheme != "unix" && cacheURL.Host == "" { // Assume the cache is just a hostname cacheURL.Host = cacheEndpoint cacheURL.Path = "" @@ -272,12 +272,16 @@ func NewTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transf cacheURL.Host = cacheURL.Host + ":8443" } } - // Whether port is specified or not, add a transfer without proxy - details = append(details, transferAttemptDetails{ + det := transferAttemptDetails{ Url: cacheURL, Proxy: false, PackOption: opts.PackOption, - }) + } + if cacheURL.Scheme == "unix" { + det.UnixSocket = cacheURL.Path + } + // Whether port is specified or not, add a transfer without proxy + details = append(details, det) } else if cacheURL.Scheme == "" || cacheURL.Scheme == "http" { // Assume a transfer not needing a token and not specifying a scheme is HTTP // WARNING: This is legacy code; we should always specify a scheme @@ -301,11 +305,15 @@ func NewTransferDetailsUsingDirector(cache namespaces.DirectorCache, opts transf } else { // A non-HTTP scheme is specified and a token is not needed; this wasn't possible // in the legacy cases. Simply use the provided config - details = append(details, transferAttemptDetails{ + det := transferAttemptDetails{ Url: cacheURL, Proxy: false, PackOption: opts.PackOption, - }) + } + if cacheURL.Scheme == "unix" { + det.UnixSocket = cacheURL.Path + } + details = append(details, det) } return details diff --git a/client/director_test.go b/client/director_test.go index db0e3406a..b8f8d08c9 100644 --- a/client/director_test.go +++ b/client/director_test.go @@ -155,7 +155,7 @@ func TestNewTransferDetailsUsingDirector(t *testing.T) { // Case 1: cache with http - transfers := NewTransferDetailsUsingDirector(nonAuthCache, transferDetailsOptions{nonAuthCache.AuthedReq, ""}) + transfers := newTransferDetailsUsingDirector(nonAuthCache, transferDetailsOptions{nonAuthCache.AuthedReq, ""}) assert.Equal(t, 2, len(transfers)) assert.Equal(t, "my-cache-url:8000", transfers[0].Url.Host) assert.Equal(t, "http", transfers[0].Url.Scheme) @@ -166,7 +166,7 @@ func TestNewTransferDetailsUsingDirector(t *testing.T) { assert.Equal(t, false, transfers[1].Proxy) // Case 2: cache with https - transfers = NewTransferDetailsUsingDirector(authCache, transferDetailsOptions{authCache.AuthedReq, ""}) + transfers = newTransferDetailsUsingDirector(authCache, transferDetailsOptions{authCache.AuthedReq, ""}) assert.Equal(t, 1, len(transfers)) assert.Equal(t, "my-cache-url:8443", transfers[0].Url.Host) assert.Equal(t, "https", transfers[0].Url.Scheme) @@ -174,7 +174,7 @@ func TestNewTransferDetailsUsingDirector(t *testing.T) { // Case 3: cache without port with http nonAuthCache.EndpointUrl = "my-cache-url" - transfers = NewTransferDetailsUsingDirector(nonAuthCache, transferDetailsOptions{nonAuthCache.AuthedReq, ""}) + transfers = newTransferDetailsUsingDirector(nonAuthCache, transferDetailsOptions{nonAuthCache.AuthedReq, ""}) assert.Equal(t, 2, len(transfers)) assert.Equal(t, "my-cache-url:8000", transfers[0].Url.Host) assert.Equal(t, "http", transfers[0].Url.Scheme) @@ -185,7 +185,7 @@ func TestNewTransferDetailsUsingDirector(t *testing.T) { // Case 4. cache without port with https authCache.EndpointUrl = "my-cache-url" - transfers = NewTransferDetailsUsingDirector(authCache, transferDetailsOptions{authCache.AuthedReq, ""}) + transfers = newTransferDetailsUsingDirector(authCache, transferDetailsOptions{authCache.AuthedReq, ""}) assert.Equal(t, 2, len(transfers)) assert.Equal(t, "my-cache-url:8444", transfers[0].Url.Host) assert.Equal(t, "https", transfers[0].Url.Scheme) diff --git a/client/handle_http.go b/client/handle_http.go index 1458687fa..daeeac8f0 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -113,6 +113,9 @@ type ( // Proxy specifies if a proxy should be used Proxy bool + // If the Url scheme is unix, this is the path to connect to + UnixSocket string + // Specifies the pack option in the transfer URL PackOption string } @@ -748,6 +751,7 @@ func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, u tokenLocation: tc.tokenLocation, upload: upload, uuid: id, + token: tc.token, } tj.ctx, tj.cancel = context.WithCancel(tc.ctx) @@ -927,7 +931,10 @@ func newTransferDetails(cache namespaces.Cache, opts transferDetailsOptions) []t log.Errorln("Failed to parse cache:", cache, "error:", err) return nil } - if cacheURL.Host == "" { + if cacheURL.Scheme == "unix" && cacheURL.Host != "" { + cacheURL.Path = path.Clean("/" + cacheURL.Host + "/" + cacheURL.Path) + cacheURL.Host = "" + } else if cacheURL.Scheme != "unix" && cacheURL.Host == "" { // Assume the cache is just a hostname cacheURL.Host = cacheEndpoint cacheURL.Path = "" @@ -936,26 +943,32 @@ func newTransferDetails(cache namespaces.Cache, opts transferDetailsOptions) []t } log.Debugf("Parsed Cache: %s", cacheURL.String()) if opts.NeedsToken { - cacheURL.Scheme = "https" - if !hasPort(cacheURL.Host) { - // Add port 8444 and 8443 - urlCopy := *cacheURL - urlCopy.Host += ":8444" - details = append(details, transferAttemptDetails{ - Url: &urlCopy, - Proxy: false, - PackOption: opts.PackOption, - }) - // Strip the port off and add 8443 - cacheURL.Host = cacheURL.Host + ":8443" + if cacheURL.Scheme != "unix" { + cacheURL.Scheme = "https" + if !hasPort(cacheURL.Host) { + // Add port 8444 and 8443 + urlCopy := *cacheURL + urlCopy.Host += ":8444" + details = append(details, transferAttemptDetails{ + Url: &urlCopy, + Proxy: false, + PackOption: opts.PackOption, + }) + // Strip the port off and add 8443 + cacheURL.Host = cacheURL.Host + ":8443" + } } - // Whether port is specified or not, add a transfer without proxy - details = append(details, transferAttemptDetails{ + det := transferAttemptDetails{ Url: cacheURL, Proxy: false, PackOption: opts.PackOption, - }) - } else { + } + if cacheURL.Scheme == "unix" { + det.UnixSocket = cacheURL.Path + } + // Whether port is specified or not, add a transfer without proxy + details = append(details, det) + } else if cacheURL.Scheme == "" || cacheURL.Scheme == "http" { cacheURL.Scheme = "http" if !hasPort(cacheURL.Host) { cacheURL.Host += ":8000" @@ -973,6 +986,16 @@ func newTransferDetails(cache namespaces.Cache, opts transferDetailsOptions) []t PackOption: opts.PackOption, }) } + } else { + det := transferAttemptDetails{ + Url: cacheURL, + Proxy: false, + PackOption: opts.PackOption, + } + if cacheURL.Scheme == "unix" { + det.UnixSocket = cacheURL.Path + } + details = append(details, det) } return details @@ -980,9 +1003,9 @@ func newTransferDetails(cache namespaces.Cache, opts transferDetailsOptions) []t type CacheInterface interface{} -func GenerateTransferDetailsUsingCache(cache CacheInterface, opts transferDetailsOptions) []transferAttemptDetails { +func generateTransferDetailsUsingCache(cache CacheInterface, opts transferDetailsOptions) []transferAttemptDetails { if directorCache, ok := cache.(namespaces.DirectorCache); ok { - return NewTransferDetailsUsingDirector(directorCache, opts) + return newTransferDetailsUsingDirector(directorCache, opts) } else if cache, ok := cache.(namespaces.Cache); ok { return newTransferDetails(cache, opts) } @@ -1042,7 +1065,7 @@ func (te *TransferEngine) createTransferFiles(job *clientTransferJob) (err error NeedsToken: job.job.namespace.ReadHTTPS || job.job.namespace.UseTokenOnRead, PackOption: packOption, } - transfers = append(transfers, GenerateTransferDetailsUsingCache(cache, td)...) + transfers = append(transfers, generateTransferDetailsUsingCache(cache, td)...) } if len(transfers) > 0 { @@ -1271,13 +1294,16 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall if !transfer.Proxy { transport.Proxy = nil } + transferUrl := *transfer.Url if transfer.Url.Scheme == "unix" { transport.Proxy = nil // Proxies make no sense when reading via a Unix socket transport = transport.Clone() transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { dialer := net.Dialer{} - return dialer.DialContext(ctx, "unix", transfer.Url.Path) + return dialer.DialContext(ctx, "unix", transfer.UnixSocket) } + transferUrl.Scheme = "http" + transferUrl.Host = "localhost" } httpClient, ok := client.HTTPClient.(*http.Client) if !ok { @@ -1287,7 +1313,7 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall ctx, cancel := context.WithCancel(ctx) defer cancel() - log.Debugln("Transfer URL String:", transfer.Url.String()) + log.Debugln("Transfer URL String:", transferUrl.String()) var req *grab.Request var unpacker *autoUnpacker if transfer.PackOption != "" { @@ -1302,10 +1328,10 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall } } unpacker = newAutoUnpacker(dest, behavior) - if req, err = grab.NewRequestToWriter(unpacker, transfer.Url.String()); err != nil { + if req, err = grab.NewRequestToWriter(unpacker, transferUrl.String()); err != nil { return 0, 0, "", errors.Wrap(err, "Failed to create new download request") } - } else if req, err = grab.NewRequest(dest, transfer.Url.String()); err != nil { + } else if req, err = grab.NewRequest(dest, transferUrl.String()); err != nil { return 0, 0, "", errors.Wrap(err, "Failed to create new download request") } diff --git a/client/main.go b/client/main.go index bf55f4773..cc7f89269 100644 --- a/client/main.go +++ b/client/main.go @@ -291,10 +291,8 @@ func getCachesFromNamespace(namespace namespaces.Namespace, useDirector bool, pr return } log.Debugf("Using the cache (%s) from the config override\n", preferredCaches[0]) - cache := namespaces.Cache{ - Endpoint: preferredCaches[0].String(), - AuthEndpoint: preferredCaches[0].String(), - Resource: preferredCaches[0].String(), + cache := namespaces.DirectorCache{ + EndpointUrl: preferredCaches[0].String(), } caches = []CacheInterface{cache} return diff --git a/director/redirect.go b/director/redirect.go index a1ad672f8..87a5105b9 100644 --- a/director/redirect.go +++ b/director/redirect.go @@ -261,9 +261,8 @@ func redirectToCache(ginCtx *gin.Context) { if len(namespaceAd.Generation) != 0 { tokenGen := "" first := true - hdrVals := []string{namespaceAd.Generation[0].CredentialIssuer.String(), fmt.Sprint(namespaceAd.Generation[0].MaxScopeDepth), string(namespaceAd.Generation[0].Strategy), - string(namespaceAd.Generation[0].Strategy)} - for idx, hdrKey := range []string{"issuer", "max-scope-depth", "strategy", "vault-server"} { + hdrVals := []string{namespaceAd.Generation[0].CredentialIssuer.String(), fmt.Sprint(namespaceAd.Generation[0].MaxScopeDepth), string(namespaceAd.Generation[0].Strategy)} + for idx, hdrKey := range []string{"issuer", "max-scope-depth", "strategy"} { hdrVal := hdrVals[idx] if hdrVal == "" { continue diff --git a/file_cache/cache_test.go b/file_cache/cache_test.go index da39f6241..c6f016513 100644 --- a/file_cache/cache_test.go +++ b/file_cache/cache_test.go @@ -24,12 +24,14 @@ import ( "io" "net" "net/http" + "net/url" "os" "path/filepath" "runtime" "testing" "time" + "github.com/pelicanplatform/pelican/client" "github.com/pelicanplatform/pelican/config" simple_cache "github.com/pelicanplatform/pelican/file_cache" "github.com/pelicanplatform/pelican/launchers" @@ -43,7 +45,7 @@ import ( "golang.org/x/sync/errgroup" ) -func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) { +func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) (token string) { modules := config.ServerType(0) modules.Set(config.OriginType) @@ -106,13 +108,29 @@ func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) { require.NoError(t, err) t.Cleanup(func() { cancel() - if err = egrp.Wait(); err != nil && err != context.Canceled { + if err = egrp.Wait(); err != nil && err != context.Canceled && err != http.ErrServerClosed { require.NoError(t, err) } }) err = os.WriteFile(filepath.Join(originDir, "hello_world.txt"), []byte("Hello, World!"), os.FileMode(0644)) require.NoError(t, err) + + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + tokConf := utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Duration(time.Minute), + Issuer: issuer, + Subject: "test", + Audience: []string{utils.WLCGAny}, + } + tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) + + token, err = tokConf.CreateToken() + require.NoError(t, err) + + return } // Setup a federation, invoke "get" through the local cache module @@ -151,33 +169,21 @@ func TestFedAuthGet(t *testing.T) { defer cancel() viper.Set("Origin.EnablePublicReads", false) - spinup(t, ctx, egrp) + token := spinup(t, ctx, egrp) lc, err := simple_cache.NewSimpleCache(ctx, egrp) require.NoError(t, err) - issuer, err := config.GetServerIssuerURL() - require.NoError(t, err) - tokConf := utils.TokenConfig{ - TokenProfile: utils.WLCG, - Lifetime: time.Duration(time.Minute), - Issuer: issuer, - Subject: "test", - Audience: []string{utils.WLCGAny}, - } - tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) - - tok, err := tokConf.CreateToken() - require.NoError(t, err) - - reader, err := lc.Get("/test/hello_world.txt", tok) + reader, err := lc.Get("/test/hello_world.txt", token) require.NoError(t, err) byteBuff, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, "Hello, World!", string(byteBuff)) - tokConf = utils.TokenConfig{ + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + tokConf := utils.TokenConfig{ TokenProfile: utils.WLCG, Lifetime: time.Duration(time.Minute), Issuer: issuer, @@ -186,10 +192,10 @@ func TestFedAuthGet(t *testing.T) { } tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/not_correct")) - tok, err = tokConf.CreateToken() + token, err = tokConf.CreateToken() require.NoError(t, err) - _, err = lc.Get("/test/hello_world.txt", tok) + _, err = lc.Get("/test/hello_world.txt", token) assert.Error(t, err) assert.Equal(t, "authorization denied", err.Error()) } @@ -199,20 +205,7 @@ func TestHttpReq(t *testing.T) { defer cancel() viper.Set("Origin.EnablePublicReads", false) - spinup(t, ctx, egrp) - - issuer, err := config.GetServerIssuerURL() - require.NoError(t, err) - tokConf := utils.TokenConfig{ - TokenProfile: utils.WLCG, - Lifetime: time.Duration(time.Minute), - Issuer: issuer, - Subject: "test", - Audience: []string{utils.WLCGAny}, - } - tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) - tok, err := tokConf.CreateToken() - require.NoError(t, err) + token := spinup(t, ctx, egrp) transport := config.GetTransport().Clone() transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { @@ -222,7 +215,7 @@ func TestHttpReq(t *testing.T) { client := &http.Client{Transport: transport} req, err := http.NewRequest("GET", "http://localhost/test/hello_world.txt", nil) require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Authorization", "Bearer "+token) resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -231,3 +224,30 @@ func TestHttpReq(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Hello, World!", string(body)) } + +func TestClient(t *testing.T) { + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer cancel() + tmpDir := t.TempDir() + + viper.Set("Origin.EnablePublicReads", false) + token := spinup(t, ctx, egrp) + + cacheUrl := &url.URL{ + Scheme: "unix", + Path: param.FileCache_Socket.GetString(), + } + + discoveryHost := param.Federation_DiscoveryUrl.GetString() + discoveryUrl, err := url.Parse(discoveryHost) + require.NoError(t, err) + tr, err := client.DoGet(ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, client.WithToken(token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + assert.NoError(t, err) + require.Equal(t, 1, len(tr)) + assert.Equal(t, int64(13), tr[0].TransferredBytes) + assert.NoError(t, tr[0].Error) + + byteBuff, err := os.ReadFile(filepath.Join(tmpDir, "hello_world.txt")) + assert.NoError(t, err) + assert.Equal(t, "Hello, World!", string(byteBuff)) +} From e36ae1a274418ba0d86f1aa21bf01b9683d074e4 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 3 Mar 2024 22:07:52 -0600 Subject: [PATCH 16/45] Add support for `Stat` in the local cache Also fixes the client's `DoStat` function to work without topology. --- client/handle_http.go | 14 ++++++-------- client/main.go | 34 ++++++++++++++++++++++------------ file_cache/simple_cache.go | 22 +++++++++++++++++++++- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/client/handle_http.go b/client/handle_http.go index daeeac8f0..c1fea117e 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1913,15 +1913,13 @@ func (te *TransferEngine) walkDirUpload(job *clientTransferJob, transfers []tran return err } -func statHttp(ctx context.Context, dest *url.URL, namespace namespaces.Namespace, tokenLocation string, acquire bool) (size uint64, err error) { - - token, err := getToken(dest, namespace, false, "", tokenLocation, acquire) - if err != nil { - return - } - +func statHttp(ctx context.Context, dest *url.URL, namespace namespaces.Namespace, token string) (size uint64, err error) { // Parse the writeback host as a URL - writebackhostUrl, err := url.Parse(namespace.WriteBackHost) + statHost := namespace.WriteBackHost + if statHost == "" { + statHost = namespace.DirListHost + } + writebackhostUrl, err := url.Parse(statHost) if err != nil { return } diff --git a/client/main.go b/client/main.go index cc7f89269..9e362cb5b 100644 --- a/client/main.go +++ b/client/main.go @@ -189,7 +189,7 @@ func DoStat(ctx context.Context, destination string, options ...TransferOption) } }() - dest_uri, err := url.Parse(destination) + destUri, err := url.Parse(destination) if err != nil { log.Errorln("Failed to parse destination URL") return 0, err @@ -197,21 +197,21 @@ func DoStat(ctx context.Context, destination string, options ...TransferOption) understoodSchemes := []string{"osdf", "pelican", ""} - _, foundSource := Find(understoodSchemes, dest_uri.Scheme) + _, foundSource := Find(understoodSchemes, destUri.Scheme) if !foundSource { - log.Errorln("Unknown schema provided:", dest_uri.Scheme) + log.Errorln("Unknown schema provided:", destUri.Scheme) return 0, errors.New("Unsupported scheme requested") } - origScheme := dest_uri.Scheme + origScheme := destUri.Scheme if config.GetPreferredPrefix() != "PELICAN" && origScheme == "" { - dest_uri.Scheme = "osdf" + destUri.Scheme = "osdf" } - if (dest_uri.Scheme == "osdf" || dest_uri.Scheme == "stash") && dest_uri.Host != "" { - dest_uri.Path = path.Clean("/" + dest_uri.Host + "/" + dest_uri.Path) - dest_uri.Host = "" - } else if dest_uri.Scheme == "pelican" { - federationUrl, _ := url.Parse(dest_uri.String()) + if (destUri.Scheme == "osdf" || destUri.Scheme == "stash") && destUri.Host != "" { + destUri.Path = path.Clean("/" + destUri.Host + "/" + destUri.Path) + destUri.Host = "" + } else if destUri.Scheme == "pelican" { + federationUrl, _ := url.Parse(destUri.String()) federationUrl.Scheme = "https" federationUrl.Path = "" viper.Set("Federation.DiscoveryUrl", federationUrl.String()) @@ -221,23 +221,33 @@ func DoStat(ctx context.Context, destination string, options ...TransferOption) } } - ns, err := namespaces.MatchNamespace(dest_uri.Path) + ns, err := getNamespaceInfo(destUri.Path, param.Federation_DirectorUrl.GetString(), false) if err != nil { return 0, err } tokenLocation := "" acquire := true + token := "" for _, option := range options { switch option.Ident() { case identTransferOptionTokenLocation{}: tokenLocation = option.Value().(string) case identTransferOptionAcquireToken{}: acquire = option.Value().(bool) + case identTransferOptionToken{}: + token = option.Value().(string) } } - if remoteSize, err = statHttp(ctx, dest_uri, ns, tokenLocation, acquire); err == nil { + if ns.UseTokenOnRead && token == "" { + token, err = getToken(destUri, ns, true, "", tokenLocation, acquire) + if err != nil { + return 0, fmt.Errorf("failed to get token for transfer: %v", err) + } + } + + if remoteSize, err = statHttp(ctx, destUri, ns, token); err == nil { return remoteSize, nil } return 0, err diff --git a/file_cache/simple_cache.go b/file_cache/simple_cache.go index 1e1731628..8cc6cd838 100644 --- a/file_cache/simple_cache.go +++ b/file_cache/simple_cache.go @@ -568,7 +568,7 @@ func (sc *SimpleCache) purge() { // Given a URL, return a reader from the disk cache // // If there is no sentinal $NAME.DONE file, then returns nil -func (sc *SimpleCache) getFromDisk(localPath string) io.ReadCloser { +func (sc *SimpleCache) getFromDisk(localPath string) *os.File { localPath = filepath.Join(sc.basePath, path.Clean(localPath)) fp, err := os.Open(localPath + ".DONE") if err != nil { @@ -606,6 +606,26 @@ func (sc *SimpleCache) Get(path, token string) (io.ReadCloser, error) { } +func (lc *SimpleCache) Stat(path, token string) (uint64, error) { + if !lc.ac.authorize(token_scopes.Storage_Read, path, token) { + return 0, authorizationDenied + } + + if fp := lc.getFromDisk(path); fp != nil { + finfo, err := fp.Stat() + if err != nil { + return 0, errors.New("Failed to determine cached file size for object") + } + return uint64(finfo.Size()), nil + } + + dUrl := *lc.directorURL + dUrl.Path = path + dUrl.Scheme = "pelican" + log.Debugln("LocalCache doing Stat:", dUrl.String()) + return client.DoStat(context.Background(), dUrl.String(), client.WithToken(token)) +} + func (sc *SimpleCache) updateConfig() error { // Get the endpoint of the director var respNS []common.NamespaceAdV2 From ecefa09e05e06ae75574b3b5b042ddabec9584f0 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 3 Mar 2024 22:09:55 -0600 Subject: [PATCH 17/45] Add support for MaximumDownloadSpeed This implements an upper limit on the transfer rates; meant mainly to support unit tests. --- client/handle_http.go | 6 ++++++ docs/parameters.yaml | 9 +++++++++ param/parameters.go | 1 + param/parameters_struct.go | 2 ++ 4 files changed, 18 insertions(+) diff --git a/client/handle_http.go b/client/handle_http.go index c1fea117e..83c194c9a 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -49,6 +49,7 @@ import ( "github.com/studio-b12/gowebdav" "github.com/vbauerster/mpb/v8" "golang.org/x/sync/errgroup" + "golang.org/x/time/rate" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/namespaces" @@ -1335,6 +1336,11 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall return 0, 0, "", errors.Wrap(err, "Failed to create new download request") } + rateLimit := param.Client_MaximumDownloadSpeed.GetInt() + if rateLimit > 0 { + req.RateLimiter = rate.NewLimiter(rate.Limit(rateLimit), 64*1024) + } + if token != "" { req.HTTPRequest.Header.Set("Authorization", "Bearer "+token) } diff --git a/docs/parameters.yaml b/docs/parameters.yaml index ddac09f74..c9d08175b 100644 --- a/docs/parameters.yaml +++ b/docs/parameters.yaml @@ -372,6 +372,15 @@ type: int default: 102400 components: ["client"] --- +name: Client.MaximumDownloadSpeed +description: >- + The maximum speed allowed for a client to download a given file (enforced via rate limits). + This is not intended for use by production clients but rather for unit tests; 0 disables the rate limit +type: int +default: 0 +components: ["client"] +deprecated: true +--- ############################ # Origin-level Configs # ############################ diff --git a/param/parameters.go b/param/parameters.go index 69e4f2f72..5b18ec302 100644 --- a/param/parameters.go +++ b/param/parameters.go @@ -197,6 +197,7 @@ var ( var ( Cache_Concurrency = IntParam{"Cache.Concurrency"} Cache_Port = IntParam{"Cache.Port"} + Client_MaximumDownloadSpeed = IntParam{"Client.MaximumDownloadSpeed"} Client_MinimumDownloadSpeed = IntParam{"Client.MinimumDownloadSpeed"} Client_SlowTransferRampupTime = IntParam{"Client.SlowTransferRampupTime"} Client_SlowTransferWindow = IntParam{"Client.SlowTransferWindow"} diff --git a/param/parameters_struct.go b/param/parameters_struct.go index cfbceda72..a5d48cdb8 100644 --- a/param/parameters_struct.go +++ b/param/parameters_struct.go @@ -37,6 +37,7 @@ type Config struct { Client struct { DisableHttpProxy bool DisableProxyFallback bool + MaximumDownloadSpeed int MinimumDownloadSpeed int SlowTransferRampupTime int SlowTransferWindow int @@ -269,6 +270,7 @@ type configWithType struct { Client struct { DisableHttpProxy struct { Type string; Value bool } DisableProxyFallback struct { Type string; Value bool } + MaximumDownloadSpeed struct { Type string; Value int } MinimumDownloadSpeed struct { Type string; Value int } SlowTransferRampupTime struct { Type string; Value int } SlowTransferWindow struct { Type string; Value int } From 9727bc41089fc0b16a0feda44e4fddfc80d30444 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 3 Mar 2024 22:13:23 -0600 Subject: [PATCH 18/45] Add unit tests for large file support and stat The large file unit test ensures that partial progress is made; that is, data flows to the client while it is still being downloaded. --- client/handle_http.go | 4 +- config/config.go | 3 +- file_cache/cache_test.go | 115 +++++++++++++++++++++++++++++++++---- file_cache/simple_cache.go | 47 +++++++++++---- 4 files changed, 145 insertions(+), 24 deletions(-) diff --git a/client/handle_http.go b/client/handle_http.go index 83c194c9a..b6222c6ab 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1382,8 +1382,8 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall totalSize = resp.Size() // Do a head request for content length if resp.Size is unknown if totalSize <= 0 { - headClient := &http.Client{Transport: config.GetTransport()} - headRequest, _ := http.NewRequest("HEAD", transfer.Url.String(), nil) + headClient := &http.Client{Transport: transport} + headRequest, _ := http.NewRequest("HEAD", transferUrl.String(), nil) var headResponse *http.Response headResponse, err = headClient.Do(headRequest) if err != nil { diff --git a/config/config.go b/config/config.go index 91a363ac0..298837da3 100644 --- a/config/config.go +++ b/config/config.go @@ -148,7 +148,8 @@ var ( // This function creates a new MetadataError by wrapping the previous error func NewMetadataError(err error, msg string) *MetadataErr { return &MetadataErr{ - msg: msg, + msg: msg, + innerErr: err, } } diff --git a/file_cache/cache_test.go b/file_cache/cache_test.go index c6f016513..55a518465 100644 --- a/file_cache/cache_test.go +++ b/file_cache/cache_test.go @@ -45,7 +45,14 @@ import ( "golang.org/x/sync/errgroup" ) -func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) (token string) { +type ( + fedTest struct { + originDir string + token string + } +) + +func (ft *fedTest) spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) { modules := config.ServerType(0) modules.Set(config.OriginType) @@ -127,10 +134,11 @@ func spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) (token stri } tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) - token, err = tokConf.CreateToken() + token, err := tokConf.CreateToken() require.NoError(t, err) - return + ft.originDir = originDir + ft.token = token } // Setup a federation, invoke "get" through the local cache module @@ -141,8 +149,10 @@ func TestFedPublicGet(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() + viper.Reset() viper.Set("Origin.EnablePublicReads", true) - spinup(t, ctx, egrp) + ft := fedTest{} + ft.spinup(t, ctx, egrp) sc, err := simple_cache.NewSimpleCache(ctx, egrp) require.NoError(t, err) @@ -168,13 +178,15 @@ func TestFedAuthGet(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() + viper.Reset() viper.Set("Origin.EnablePublicReads", false) - token := spinup(t, ctx, egrp) + ft := fedTest{} + ft.spinup(t, ctx, egrp) lc, err := simple_cache.NewSimpleCache(ctx, egrp) require.NoError(t, err) - reader, err := lc.Get("/test/hello_world.txt", token) + reader, err := lc.Get("/test/hello_world.txt", ft.token) require.NoError(t, err) byteBuff, err := io.ReadAll(reader) @@ -192,7 +204,7 @@ func TestFedAuthGet(t *testing.T) { } tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/not_correct")) - token, err = tokConf.CreateToken() + token, err := tokConf.CreateToken() require.NoError(t, err) _, err = lc.Get("/test/hello_world.txt", token) @@ -204,8 +216,10 @@ func TestHttpReq(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() + viper.Reset() viper.Set("Origin.EnablePublicReads", false) - token := spinup(t, ctx, egrp) + ft := fedTest{} + ft.spinup(t, ctx, egrp) transport := config.GetTransport().Clone() transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { @@ -215,7 +229,7 @@ func TestHttpReq(t *testing.T) { client := &http.Client{Transport: transport} req, err := http.NewRequest("GET", "http://localhost/test/hello_world.txt", nil) require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Authorization", "Bearer "+ft.token) resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -230,8 +244,10 @@ func TestClient(t *testing.T) { defer cancel() tmpDir := t.TempDir() + viper.Reset() viper.Set("Origin.EnablePublicReads", false) - token := spinup(t, ctx, egrp) + ft := fedTest{} + ft.spinup(t, ctx, egrp) cacheUrl := &url.URL{ Scheme: "unix", @@ -241,7 +257,8 @@ func TestClient(t *testing.T) { discoveryHost := param.Federation_DiscoveryUrl.GetString() discoveryUrl, err := url.Parse(discoveryHost) require.NoError(t, err) - tr, err := client.DoGet(ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, client.WithToken(token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + tr, err := client.DoGet(ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, + client.WithToken(ft.token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) assert.Equal(t, int64(13), tr[0].TransferredBytes) @@ -251,3 +268,79 @@ func TestClient(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Hello, World!", string(byteBuff)) } + +func TestStat(t *testing.T) { + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer cancel() + + viper.Reset() + viper.Set("Origin.EnablePublicReads", true) + ft := fedTest{} + ft.spinup(t, ctx, egrp) + + lc, err := simple_cache.NewSimpleCache(ctx, egrp) + require.NoError(t, err) + + size, err := lc.Stat("/test/hello_world.txt", "") + require.NoError(t, err) + assert.Equal(t, uint64(13), size) + + reader, err := lc.Get("/test/hello_world.txt", "") + require.NoError(t, err) + byteBuff, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, 13, len(byteBuff)) + + size, err = lc.Stat("/test/hello_world.txt", "") + assert.NoError(t, err) + assert.Equal(t, uint64(13), size) +} + +func TestLargeFile(t *testing.T) { + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer cancel() + tmpDir := t.TempDir() + + viper.Reset() + viper.Set("Origin.EnablePublicReads", true) + viper.Set("Client.MaximumDownloadSpeed", 40*1024*1024) + ft := fedTest{} + ft.spinup(t, ctx, egrp) + + cacheUrl := &url.URL{ + Scheme: "unix", + Path: param.FileCache_Socket.GetString(), + } + + fp, err := os.OpenFile(filepath.Join(ft.originDir, "hello_world.txt"), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + require.NoError(t, err) + + byteBuff := []byte("Hello, World!") + for { + byteBuff = append(byteBuff, []byte("Hello, World!")...) + if len(byteBuff) > 4096 { + break + } + } + size := 0 + for { + n, err := fp.Write(byteBuff) + require.NoError(t, err) + size += n + if size > 100*1024*1024 { + break + } + } + fp.Close() + + discoveryHost := param.Federation_DiscoveryUrl.GetString() + discoveryUrl, err := url.Parse(discoveryHost) + require.NoError(t, err) + tr, err := client.DoGet(ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, + client.WithCaches(cacheUrl)) + assert.NoError(t, err) + require.Equal(t, 1, len(tr)) + assert.Equal(t, int64(size), tr[0].TransferredBytes) + assert.NoError(t, tr[0].Error) + +} diff --git a/file_cache/simple_cache.go b/file_cache/simple_cache.go index 8cc6cd838..b3b45bfc2 100644 --- a/file_cache/simple_cache.go +++ b/file_cache/simple_cache.go @@ -378,18 +378,24 @@ func (sc *SimpleCache) runMux() error { continue } results := recv.Interface().(client.TransferResults) - path := jobPath[results.ID()] - if path == "" { + reqPath := jobPath[results.ID()] + if reqPath == "" { log.Errorf("Transfer results from job %s but no corresponding path known", results.ID()) continue } delete(jobPath, results.ID()) - ad := activeJobs[path] + ad := activeJobs[reqPath] if ad == nil { - log.Errorf("Transfer results from job %s returned for path %s but no active job known", results.ID(), path) + log.Errorf("Transfer results from job %s returned for path %s but no active job known", results.ID(), reqPath) continue } - delete(activeJobs, path) + delete(activeJobs, reqPath) + func() { + localPath := filepath.Join(sc.basePath, path.Clean(reqPath)) + sc.mutex.Lock() + defer sc.mutex.Unlock() + delete(sc.downloads, localPath) + }() if results.Error != nil { ad.status.err.Store(&results.Error) } @@ -397,18 +403,18 @@ func (sc *SimpleCache) runMux() error { ad.status.size.Store(results.TransferredBytes) ad.status.done.Store(true) for _, waiter := range ad.waiterList { - tmpResults = append(tmpResults, result{ds: ad.status, path: path, channel: waiter.notify}) + tmpResults = append(tmpResults, result{ds: ad.status, path: reqPath, channel: waiter.notify}) } if results.Error == nil { - if fp, err := os.OpenFile(filepath.Join(sc.basePath, path)+".DONE", os.O_CREATE|os.O_WRONLY, os.FileMode(0600)); err != nil { - log.Debugln("Unable to save a DONE file for cache path", path) + if fp, err := os.OpenFile(filepath.Join(sc.basePath, reqPath)+".DONE", os.O_CREATE|os.O_WRONLY, os.FileMode(0600)); err != nil { + log.Debugln("Unable to save a DONE file for cache path", reqPath) } else { fp.Close() } - entry := sc.lruLookup[path] + entry := sc.lruLookup[reqPath] if entry == nil { entry = &lruEntry{} - sc.lruLookup[path] = entry + sc.lruLookup[reqPath] = entry entry.size = results.TransferredBytes sc.cacheSize += uint64(entry.size) sc.lru = append(sc.lru, entry) @@ -443,6 +449,14 @@ func (sc *SimpleCache) runMux() error { for _, path := range jobsToDelete { delete(activeJobs, path) } + func() { + sc.mutex.Lock() + defer sc.mutex.Unlock() + for _, lpath := range jobsToDelete { + localPath := filepath.Join(sc.basePath, path.Clean(lpath)) + delete(sc.downloads, localPath) + } + }() } else if chosen == lenChan+3 { // New request req := recv.Interface().(availSizeReq) @@ -507,6 +521,11 @@ func (sc *SimpleCache) runMux() error { } activeJobs[req.request.path] = ad jobPath[tj.ID()] = req.request.path + func() { + sc.mutex.Lock() + defer sc.mutex.Unlock() + sc.downloads[localPath] = ad + }() } else if chosen == lenChan+4 { // Cancel a given request. req := recv.Interface().(cancelReq) @@ -694,6 +713,14 @@ func (cr *cacheReader) Read(p []byte) (n int, err error) { if cr.size >= 0 && neededSize > cr.size { neededSize = cr.size } + if neededSize > cr.avail && cr.fd != nil { + finfo, err := cr.fd.Stat() + if err == nil { + cr.avail = finfo.Size() + } else { + log.Warningln("Unable to stat open file handle:", err) + } + } if neededSize > cr.avail { // Insufficient available data; request more from the cache if cr.status == nil { From eb8d2e0fafdd8093d653060acb43f2c62e297b02 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Mon, 4 Mar 2024 09:32:42 -0600 Subject: [PATCH 19/45] Add integration test to ensure the LocalCache works with OSDF. --- client/handle_http.go | 11 +++++++---- cmd/plugin.go | 18 ++++++++++++++++-- config/config.go | 1 + file_cache/cache_api.go | 18 ++++++++++++++++-- github_scripts/citests.sh | 35 +++++++++++++++++++++++++++++++++++ 5 files changed, 75 insertions(+), 8 deletions(-) diff --git a/client/handle_http.go b/client/handle_http.go index b6222c6ab..02db6d26f 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -759,13 +759,13 @@ func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, u for _, option := range options { switch option.Ident() { case identTransferOptionCaches{}: - tc.caches = option.Value().([]*url.URL) + tj.caches = option.Value().([]*url.URL) case identTransferOptionCallback{}: - tc.callback = option.Value().(TransferCallbackFunc) + tj.callback = option.Value().(TransferCallbackFunc) case identTransferOptionTokenLocation{}: - tc.tokenLocation = option.Value().(string) + tj.tokenLocation = option.Value().(string) case identTransferOptionAcquireToken{}: - tc.skipAcquire = !option.Value().(bool) + tj.skipAcquire = !option.Value().(bool) case identTransferOptionToken{}: tj.token = option.Value().(string) } @@ -1922,6 +1922,9 @@ func (te *TransferEngine) walkDirUpload(job *clientTransferJob, transfers []tran func statHttp(ctx context.Context, dest *url.URL, namespace namespaces.Namespace, token string) (size uint64, err error) { // Parse the writeback host as a URL statHost := namespace.WriteBackHost + if len(namespace.SortedDirectorCaches) > 0 { + statHost = namespace.SortedDirectorCaches[0].EndpointUrl + } if statHost == "" { statHost = namespace.DirListHost } diff --git a/cmd/plugin.go b/cmd/plugin.go index 9d0c64783..bc8e3512d 100644 --- a/cmd/plugin.go +++ b/cmd/plugin.go @@ -380,6 +380,19 @@ func runPluginWorker(ctx context.Context, upload bool, workChan <-chan PluginTra err = shutdownErr } }() + + caches := make([]*url.URL, 0, 1) + if nearestCache, ok := os.LookupEnv("NEAREST_CACHE"); ok && nearestCache != "" { + var nearestCacheURL *url.URL + if nearestCacheURL, err = url.Parse(nearestCache); err != nil { + err = errors.Wrapf(err, "unable to parse preferred cache (%s) as URL", nearestCacheURL) + return + } else { + caches = append(caches, nearestCacheURL) + log.Debugln("Setting nearest cache to", nearestCacheURL.String()) + } + } + tc, err := te.NewClient(client.WithAcquireToken(false)) if err != nil { return @@ -417,7 +430,7 @@ func runPluginWorker(ctx context.Context, upload bool, workChan <-chan PluginTra var tj *client.TransferJob urlCopy := *transfer.url - tj, err = tc.NewTransferJob(&urlCopy, transfer.localFile, upload, false, client.WithAcquireToken(false)) + tj, err = tc.NewTransferJob(&urlCopy, transfer.localFile, upload, false, client.WithAcquireToken(false), client.WithCaches(caches...)) jobMap[tj.ID()] = transfer if err != nil { return errors.Wrap(err, "Failed to create new transfer job") @@ -549,7 +562,8 @@ func writeOutfile(err error, resultAds []*classads.ClassAd, outputFile *os.File) // Error code 1 (serr) is ERROR_INVALID_FUNCTION, the expected Windows syscall error // Error code EINVAL is returned on Linux // Error code ENODEV (/dev/null) or ENOTTY (/dev/stdout) is returned on Mac OS X - if errors.As(err, &perr) && errors.As(perr.Unwrap(), &serr) && (int(serr) == 1 || serr == syscall.EINVAL || serr == syscall.ENODEV || serr == syscall.ENOTTY) { + // Error code EBADF is returned on Mac OS X if /dev/stdout is redirected to a pipe in the shell + if errors.As(err, &perr) && errors.As(perr.Unwrap(), &serr) && (int(serr) == 1 || serr == syscall.EINVAL || serr == syscall.ENODEV || serr == syscall.ENOTTY || serr == syscall.EBADF) { log.Debugf("Error when syncing: %s; can be ignored\n", perr) } else { if errors.As(err, &perr) && errors.As(perr.Unwrap(), &serr) { diff --git a/config/config.go b/config/config.go index 298837da3..5bd417c14 100644 --- a/config/config.go +++ b/config/config.go @@ -303,6 +303,7 @@ func (sType *ServerType) SetString(name string) bool { return true case "localcache": *sType |= LocalCacheType + return true case "origin": *sType |= OriginType return true diff --git a/file_cache/cache_api.go b/file_cache/cache_api.go index 4b8cb73dc..b95c12ebe 100644 --- a/file_cache/cache_api.go +++ b/file_cache/cache_api.go @@ -28,6 +28,7 @@ import ( "os" "path" "path/filepath" + "strconv" "strings" "github.com/pelicanplatform/pelican/param" @@ -53,7 +54,7 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { } handler := func(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { + if r.Method != "GET" && r.Method != "HEAD" { w.WriteHeader(http.StatusMethodNotAllowed) return } @@ -74,7 +75,17 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { bearerToken = authzHeader[7:] // len("Bearer ") == 7 } path := path.Clean(r.URL.Path) - reader, err := sc.Get(path, bearerToken) + + var size uint64 + var reader io.ReadCloser + if r.Method == "HEAD" { + size, err = sc.Stat(path, bearerToken) + if err == nil { + w.Header().Set("Content-Length", strconv.FormatUint(size, 10)) + } + } else { + reader, err = sc.Get(path, bearerToken) + } if errors.Is(err, authorizationDenied) { w.WriteHeader(http.StatusForbidden) if _, err = w.Write([]byte("Authorization Denied")); err != nil { @@ -90,6 +101,9 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { return } w.WriteHeader(http.StatusOK) + if r.Method == "HEAD" { + return + } if _, err = io.Copy(w, reader); err != nil && sendTrailer { // TODO: Enumerate more error values w.Header().Set("X-Transfer-Status", fmt.Sprintf("%d: %s", 500, err)) diff --git a/github_scripts/citests.sh b/github_scripts/citests.sh index 5793a6490..f745c6ef0 100755 --- a/github_scripts/citests.sh +++ b/github_scripts/citests.sh @@ -54,7 +54,42 @@ EOF ./stash_plugin -infile $PWD/infile -outfile $PWD/outfile + +##################################### +## Test LocalCache in front of OSDF +##################################### +SOCKET_DIR="`mktemp -d -t pelican-citest`" +export PELICAN_FILECACHE_SOCKET=$SOCKET_DIR/socket +export PELICAN_FILECACHE_DATALOCATION=$SOCKET_DIR/data + +./pelican serve -d -f osg-htc.org --module localcache & +PELICAN_PID=$! + +cleanup() { + rm -rf -- "$SOCKET_DIR" + kill $PELICAN_PID + wait $PELICAN_PID || : +} +trap cleanup EXIT + +sleep 1 + +NEAREST_CACHE="unix://$SOCKET_DIR/socket" ./stash_plugin -d osdf:///ospool/uc-shared/public/OSG-Staff/validation/test.txt /dev/null +exit_status=$? + +if ! [[ "$exit_status" = 0 ]]; then + echo "Cache plugin download failed" + exit 1 +fi + +if [ ! -e "$SOCKET_DIR/data/ospool/uc-shared/public/OSG-Staff/validation/test.txt.DONE" ]; then + echo "Test file not in local cache" + exit 1 +fi + +######################################## # Test we return 0 when HOME is not set +######################################## OLDHOME=$HOME unset HOME ./stash_plugin -classad From 459b7a60c57ec53840813820ace99c10f393f0e6 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Tue, 5 Mar 2024 08:48:28 -0600 Subject: [PATCH 20/45] Mass rename to ensure we consistently use 'localcache' --- config/config.go | 10 ++--- docs/parameters.yaml | 42 +++++++++++-------- launchers/launcher.go | 4 +- {file_cache => local_cache}/cache_api.go | 6 +-- {file_cache => local_cache}/cache_authz.go | 2 +- {file_cache => local_cache}/cache_test.go | 20 ++++----- .../local_cache.go | 42 +++++++++---------- param/parameters.go | 12 +++--- param/parameters_struct.go | 32 +++++++------- 9 files changed, 88 insertions(+), 82 deletions(-) rename {file_cache => local_cache}/cache_api.go (96%) rename {file_cache => local_cache}/cache_authz.go (99%) rename {file_cache => local_cache}/cache_test.go (94%) rename file_cache/simple_cache.go => local_cache/local_cache.go (94%) diff --git a/config/config.go b/config/config.go index 5bd417c14..3a73addfb 100644 --- a/config/config.go +++ b/config/config.go @@ -824,7 +824,7 @@ func InitServer(ctx context.Context, currentServers ServerType) error { viper.SetDefault("Cache.RunLocation", filepath.Join("/run", "pelican", "xrootd", "cache")) } viper.SetDefault("Cache.DataLocation", "/run/pelican/xcache") - viper.SetDefault("FileCache.RunLocation", filepath.Join("/run", "pelican", "filecache")) + viper.SetDefault("LocalCache.RunLocation", filepath.Join("/run", "pelican", "localcache")) viper.SetDefault("Origin.Multiuser", true) viper.SetDefault("Director.GeoIPLocation", "/var/cache/pelican/maxmind/GeoLite2-City.mmdb") @@ -864,12 +864,12 @@ func InitServer(ctx context.Context, currentServers ServerType) error { cleanupDirOnShutdown(ctx, runtimeDir) } viper.SetDefault("Cache.DataLocation", filepath.Join(runtimeDir, "xcache")) - viper.SetDefault("FileCache.RunLocation", filepath.Join(runtimeDir, "cache")) + viper.SetDefault("LocalCache.RunLocation", filepath.Join(runtimeDir, "cache")) viper.SetDefault("Origin.Multiuser", false) } - fcRunLocation := viper.GetString("FileCache.RunLocation") - viper.SetDefault("FileCache.Socket", filepath.Join(fcRunLocation, "cache.sock")) - viper.SetDefault("FileCache.DataLocation", filepath.Join(fcRunLocation, "cache")) + fcRunLocation := viper.GetString("LocalCache.RunLocation") + viper.SetDefault("LocalCache.Socket", filepath.Join(fcRunLocation, "cache.sock")) + viper.SetDefault("LocalCache.DataLocation", filepath.Join(fcRunLocation, "cache")) // Any platform-specific paths should go here err := InitServerOSDefaults() diff --git a/docs/parameters.yaml b/docs/parameters.yaml index c9d08175b..32e5b26f0 100644 --- a/docs/parameters.yaml +++ b/docs/parameters.yaml @@ -618,49 +618,55 @@ default: path components: ["origin"] --- ############################ -# File-cache configs # +# Local cache configs # ############################ -name: FileCache.RunLocation +name: LocalCache.RunLocation description: >- - The directory for the runtime files of the file cache -type: string -root_default: /run/pelican/filecache -default: $XDG_RUNTIME_DIR/pelican/filecache + The directory for the runtime files of the local cache +type: filename +root_default: /run/pelican/localcache +default: $XDG_RUNTIME_DIR/pelican/localcache +components: ["localcache"] --- -name: FileCache.DataLocation +name: LocalCache.DataLocation description: >- The directory for the location of the cache data files - this is where the actual data in the cache is stored - for the file cache. -type: string -default: $PELICAN_FILECACHE_RUNLOCATION/cache + for the local cache. +type: filename +default: $PELICAN_LOCALCACHE_RUNLOCATION/cache +components: ["localcache"] --- -name: FileCache.Socket +name: LocalCache.Socket description: >- - The location of the socket used for client communication for the file cache -type: string -default: $PELICAN_FILECACHE_RUNLOCATION/cache.sock + The location of the socket used for client communication for the local cache +type: filename +default: $PELICAN_LOCALCACHE_RUNLOCATION/cache.sock +components: ["localcache"] --- -name: FileCache.Size +name: LocalCache.Size description: >- - The maximum size of the file cache. If not set, it is assumed the entire device can be used. + The maximum size of the local cache. If not set, it is assumed the entire device can be used. type: string default: 0 +components: ["localcache"] --- -name: FileCache.HighWaterMarkPercentage +name: LocalCache.HighWaterMarkPercentage description: >- A percentage value where the cache cleanup routines will triggered. Once the cache usage of completed files hits the high water mark, files will be deleted until the usage hits the low water mark. type: int default: 95 +components: ["localcache"] --- -name: FileCache.LowWaterMarkPercentage +name: LocalCache.LowWaterMarkPercentage description: >- A percentage value where the cache cleanup routines will complete. Once the cache usage of completed files hits the high water mark, files will be deleted until the usage hits the low water mark. type: int default: 85 +components: ["localcache"] --- ############################ # Cache-level configs # diff --git a/launchers/launcher.go b/launchers/launcher.go index 6e5f37b05..2b76455c7 100644 --- a/launchers/launcher.go +++ b/launchers/launcher.go @@ -34,7 +34,7 @@ import ( "github.com/pelicanplatform/pelican/broker" "github.com/pelicanplatform/pelican/config" - simple_cache "github.com/pelicanplatform/pelican/file_cache" + "github.com/pelicanplatform/pelican/local_cache" "github.com/pelicanplatform/pelican/origin_ui" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/server_ui" @@ -251,7 +251,7 @@ func LaunchModules(ctx context.Context, modules config.ServerType) (context.Canc if modules.IsEnabled(config.LocalCacheType) { log.Debugln("Starting local cache listener") - if err := simple_cache.LaunchListener(ctx, egrp); err != nil { + if err := local_cache.LaunchListener(ctx, egrp); err != nil { log.Errorln("Failure when starting the local cache listener:", err) return shutdownCancel, err } diff --git a/file_cache/cache_api.go b/local_cache/cache_api.go similarity index 96% rename from file_cache/cache_api.go rename to local_cache/cache_api.go index b95c12ebe..3d861353e 100644 --- a/file_cache/cache_api.go +++ b/local_cache/cache_api.go @@ -16,7 +16,7 @@ * ***************************************************************/ -package simple_cache +package local_cache import ( "context" @@ -39,7 +39,7 @@ import ( // Launch the unix socket listener as a separate goroutine func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { - socketName := param.FileCache_Socket.GetString() + socketName := param.LocalCache_Socket.GetString() if err := os.MkdirAll(filepath.Dir(socketName), fs.FileMode(0755)); err != nil { return errors.Wrap(err, "failed to create socket directory") } @@ -48,7 +48,7 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { if err != nil { return err } - sc, err := NewSimpleCache(ctx, egrp) + sc, err := NewLocalCache(ctx, egrp) if err != nil { return err } diff --git a/file_cache/cache_authz.go b/local_cache/cache_authz.go similarity index 99% rename from file_cache/cache_authz.go rename to local_cache/cache_authz.go index cbdaa932a..19804b9ac 100644 --- a/file_cache/cache_authz.go +++ b/local_cache/cache_authz.go @@ -16,7 +16,7 @@ * ***************************************************************/ -package simple_cache +package local_cache import ( "context" diff --git a/file_cache/cache_test.go b/local_cache/cache_test.go similarity index 94% rename from file_cache/cache_test.go rename to local_cache/cache_test.go index 55a518465..261c6ac91 100644 --- a/file_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -16,7 +16,7 @@ * ***************************************************************/ -package simple_cache_test +package local_cache_test import ( "context" @@ -33,8 +33,8 @@ import ( "github.com/pelicanplatform/pelican/client" "github.com/pelicanplatform/pelican/config" - simple_cache "github.com/pelicanplatform/pelican/file_cache" "github.com/pelicanplatform/pelican/launchers" + local_cache "github.com/pelicanplatform/pelican/local_cache" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/test_utils" "github.com/pelicanplatform/pelican/token_scopes" @@ -154,10 +154,10 @@ func TestFedPublicGet(t *testing.T) { ft := fedTest{} ft.spinup(t, ctx, egrp) - sc, err := simple_cache.NewSimpleCache(ctx, egrp) + lc, err := local_cache.NewLocalCache(ctx, egrp) require.NoError(t, err) - reader, err := sc.Get("/test/hello_world.txt", "") + reader, err := lc.Get("/test/hello_world.txt", "") require.NoError(t, err) byteBuff, err := io.ReadAll(reader) @@ -165,7 +165,7 @@ func TestFedPublicGet(t *testing.T) { assert.Equal(t, "Hello, World!", string(byteBuff)) // Query again -- cache hit case - reader, err = sc.Get("/test/hello_world.txt", "") + reader, err = lc.Get("/test/hello_world.txt", "") require.NoError(t, err) assert.Equal(t, "*os.File", fmt.Sprintf("%T", reader)) @@ -183,7 +183,7 @@ func TestFedAuthGet(t *testing.T) { ft := fedTest{} ft.spinup(t, ctx, egrp) - lc, err := simple_cache.NewSimpleCache(ctx, egrp) + lc, err := local_cache.NewLocalCache(ctx, egrp) require.NoError(t, err) reader, err := lc.Get("/test/hello_world.txt", ft.token) @@ -223,7 +223,7 @@ func TestHttpReq(t *testing.T) { transport := config.GetTransport().Clone() transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", param.FileCache_Socket.GetString()) + return net.Dial("unix", param.LocalCache_Socket.GetString()) } client := &http.Client{Transport: transport} @@ -251,7 +251,7 @@ func TestClient(t *testing.T) { cacheUrl := &url.URL{ Scheme: "unix", - Path: param.FileCache_Socket.GetString(), + Path: param.LocalCache_Socket.GetString(), } discoveryHost := param.Federation_DiscoveryUrl.GetString() @@ -278,7 +278,7 @@ func TestStat(t *testing.T) { ft := fedTest{} ft.spinup(t, ctx, egrp) - lc, err := simple_cache.NewSimpleCache(ctx, egrp) + lc, err := local_cache.NewLocalCache(ctx, egrp) require.NoError(t, err) size, err := lc.Stat("/test/hello_world.txt", "") @@ -309,7 +309,7 @@ func TestLargeFile(t *testing.T) { cacheUrl := &url.URL{ Scheme: "unix", - Path: param.FileCache_Socket.GetString(), + Path: param.LocalCache_Socket.GetString(), } fp, err := os.OpenFile(filepath.Join(ft.originDir, "hello_world.txt"), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) diff --git a/file_cache/simple_cache.go b/local_cache/local_cache.go similarity index 94% rename from file_cache/simple_cache.go rename to local_cache/local_cache.go index b3b45bfc2..2557e81e5 100644 --- a/file_cache/simple_cache.go +++ b/local_cache/local_cache.go @@ -16,7 +16,7 @@ * ***************************************************************/ -package simple_cache +package local_cache import ( "container/heap" @@ -48,7 +48,7 @@ import ( ) type ( - SimpleCache struct { + LocalCache struct { ctx context.Context egrp *errgroup.Group te *client.TransferEngine @@ -104,7 +104,7 @@ type ( } cacheReader struct { - sc *SimpleCache + sc *LocalCache offset int64 path string token string @@ -209,15 +209,15 @@ func (ds *downloadStatus) String() string { } } -// Create a simple cache object +// Create a local cache object // // Launches background goroutines associated with the cache -func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, err error) { +func NewLocalCache(ctx context.Context, egrp *errgroup.Group) (sc *LocalCache, err error) { // Setup cache on disk - cacheDir := param.FileCache_DataLocation.GetString() + cacheDir := param.LocalCache_DataLocation.GetString() if cacheDir == "" { - err = errors.New("FileCache.DataLocation is not set; cannot determine where to place file cache's data") + err = errors.New("LocalCache.DataLocation is not set; cannot determine where to place file cache's data") return } if err = os.RemoveAll(cacheDir); err != nil { @@ -227,7 +227,7 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, return } - sizeStr := param.FileCache_Size.GetString() + sizeStr := param.LocalCache_Size.GetString() var cacheSize uint64 if sizeStr == "" || sizeStr == "0" { var stat syscall.Statfs_t @@ -238,21 +238,21 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, cacheSize = stat.Bavail * uint64(stat.Bsize) } else { var signedCacheSize int64 - signedCacheSize, err = units.ParseStrictBytes(param.FileCache_Size.GetString()) + signedCacheSize, err = units.ParseStrictBytes(param.LocalCache_Size.GetString()) if err != nil { return } cacheSize = uint64(signedCacheSize) } - highWaterPercentage := param.FileCache_HighWaterMarkPercentage.GetInt() - lowWaterPercentage := param.FileCache_LowWaterMarkPercentage.GetInt() + highWaterPercentage := param.LocalCache_HighWaterMarkPercentage.GetInt() + lowWaterPercentage := param.LocalCache_LowWaterMarkPercentage.GetInt() directorUrl, err := url.Parse(param.Federation_DirectorUrl.GetString()) if err != nil { return } - sc = &SimpleCache{ + sc = &LocalCache{ ctx: ctx, egrp: egrp, te: client.NewTransferEngine(ctx), @@ -291,7 +291,7 @@ func NewSimpleCache(ctx context.Context, egrp *errgroup.Group) (sc *SimpleCache, // // The TransferClient will invoke the callback as it progresses; // the callback info will be used to help the waiters progress. -func (sc *SimpleCache) callback(path string, downloaded int64, size int64, completed bool) { +func (sc *LocalCache) callback(path string, downloaded int64, size int64, completed bool) { ds := func() (ds *downloadStatus) { sc.mutex.RLock() defer sc.mutex.RUnlock() @@ -309,7 +309,7 @@ func (sc *SimpleCache) callback(path string, downloaded int64, size int64, compl } // The main goroutine for managing the cache and its requests -func (sc *SimpleCache) runMux() error { +func (sc *LocalCache) runMux() error { results := sc.tc.Results() type result struct { @@ -563,7 +563,7 @@ func (sc *SimpleCache) runMux() error { } } -func (sc *SimpleCache) purge() { +func (sc *LocalCache) purge() { heap.Init(&sc.lru) start := time.Now() for sc.cacheSize > sc.lowWater { @@ -587,7 +587,7 @@ func (sc *SimpleCache) purge() { // Given a URL, return a reader from the disk cache // // If there is no sentinal $NAME.DONE file, then returns nil -func (sc *SimpleCache) getFromDisk(localPath string) *os.File { +func (sc *LocalCache) getFromDisk(localPath string) *os.File { localPath = filepath.Join(sc.basePath, path.Clean(localPath)) fp, err := os.Open(localPath + ".DONE") if err != nil { @@ -600,7 +600,7 @@ func (sc *SimpleCache) getFromDisk(localPath string) *os.File { return nil } -func (sc *SimpleCache) newCacheReader(path, token string) (reader *cacheReader, err error) { +func (sc *LocalCache) newCacheReader(path, token string) (reader *cacheReader, err error) { reader = &cacheReader{ path: path, token: token, @@ -612,7 +612,7 @@ func (sc *SimpleCache) newCacheReader(path, token string) (reader *cacheReader, } // Get path from the cache -func (sc *SimpleCache) Get(path, token string) (io.ReadCloser, error) { +func (sc *LocalCache) Get(path, token string) (io.ReadCloser, error) { if !sc.ac.authorize(token_scopes.Storage_Read, path, token) { return nil, authorizationDenied } @@ -625,7 +625,7 @@ func (sc *SimpleCache) Get(path, token string) (io.ReadCloser, error) { } -func (lc *SimpleCache) Stat(path, token string) (uint64, error) { +func (lc *LocalCache) Stat(path, token string) (uint64, error) { if !lc.ac.authorize(token_scopes.Storage_Read, path, token) { return 0, authorizationDenied } @@ -645,7 +645,7 @@ func (lc *SimpleCache) Stat(path, token string) (uint64, error) { return client.DoStat(context.Background(), dUrl.String(), client.WithToken(token)) } -func (sc *SimpleCache) updateConfig() error { +func (sc *LocalCache) updateConfig() error { // Get the endpoint of the director var respNS []common.NamespaceAdV2 @@ -679,7 +679,7 @@ func (sc *SimpleCache) updateConfig() error { } // Periodically update the cache configuration from the registry -func (sc *SimpleCache) periodicUpdateConfig() error { +func (sc *LocalCache) periodicUpdateConfig() error { ticker := time.NewTicker(time.Minute) for { select { diff --git a/param/parameters.go b/param/parameters.go index 5b18ec302..e48427fab 100644 --- a/param/parameters.go +++ b/param/parameters.go @@ -90,10 +90,6 @@ var ( Federation_RegistryUrl = StringParam{"Federation.RegistryUrl"} Federation_TopologyNamespaceUrl = StringParam{"Federation.TopologyNamespaceUrl"} Federation_TopologyUrl = StringParam{"Federation.TopologyUrl"} - FileCache_DataLocation = StringParam{"FileCache.DataLocation"} - FileCache_RunLocation = StringParam{"FileCache.RunLocation"} - FileCache_Size = StringParam{"FileCache.Size"} - FileCache_Socket = StringParam{"FileCache.Socket"} IssuerKey = StringParam{"IssuerKey"} Issuer_AuthenticationSource = StringParam{"Issuer.AuthenticationSource"} Issuer_GroupFile = StringParam{"Issuer.GroupFile"} @@ -102,6 +98,10 @@ var ( Issuer_QDLLocation = StringParam{"Issuer.QDLLocation"} Issuer_ScitokensServerLocation = StringParam{"Issuer.ScitokensServerLocation"} Issuer_TomcatLocation = StringParam{"Issuer.TomcatLocation"} + LocalCache_DataLocation = StringParam{"LocalCache.DataLocation"} + LocalCache_RunLocation = StringParam{"LocalCache.RunLocation"} + LocalCache_Size = StringParam{"LocalCache.Size"} + LocalCache_Socket = StringParam{"LocalCache.Socket"} Logging_Cache_Ofs = StringParam{"Logging.Cache.Ofs"} Logging_Cache_Pss = StringParam{"Logging.Cache.Pss"} Logging_Cache_Scitokens = StringParam{"Logging.Cache.Scitokens"} @@ -206,8 +206,8 @@ var ( Director_MaxStatResponse = IntParam{"Director.MaxStatResponse"} Director_MinStatResponse = IntParam{"Director.MinStatResponse"} Director_StatConcurrencyLimit = IntParam{"Director.StatConcurrencyLimit"} - FileCache_HighWaterMarkPercentage = IntParam{"FileCache.HighWaterMarkPercentage"} - FileCache_LowWaterMarkPercentage = IntParam{"FileCache.LowWaterMarkPercentage"} + LocalCache_HighWaterMarkPercentage = IntParam{"LocalCache.HighWaterMarkPercentage"} + LocalCache_LowWaterMarkPercentage = IntParam{"LocalCache.LowWaterMarkPercentage"} MinimumDownloadSpeed = IntParam{"MinimumDownloadSpeed"} Monitoring_PortHigher = IntParam{"Monitoring.PortHigher"} Monitoring_PortLower = IntParam{"Monitoring.PortLower"} diff --git a/param/parameters_struct.go b/param/parameters_struct.go index a5d48cdb8..6531d90bc 100644 --- a/param/parameters_struct.go +++ b/param/parameters_struct.go @@ -73,14 +73,6 @@ type Config struct { TopologyReloadInterval time.Duration TopologyUrl string } - FileCache struct { - DataLocation string - HighWaterMarkPercentage int - LowWaterMarkPercentage int - RunLocation string - Size string - Socket string - } GeoIPOverrides interface{} Issuer struct { AuthenticationSource string @@ -95,6 +87,14 @@ type Config struct { TomcatLocation string } IssuerKey string + LocalCache struct { + DataLocation string + HighWaterMarkPercentage int + LowWaterMarkPercentage int + RunLocation string + Size string + Socket string + } Logging struct { Cache struct { Ofs string @@ -306,14 +306,6 @@ type configWithType struct { TopologyReloadInterval struct { Type string; Value time.Duration } TopologyUrl struct { Type string; Value string } } - FileCache struct { - DataLocation struct { Type string; Value string } - HighWaterMarkPercentage struct { Type string; Value int } - LowWaterMarkPercentage struct { Type string; Value int } - RunLocation struct { Type string; Value string } - Size struct { Type string; Value string } - Socket struct { Type string; Value string } - } GeoIPOverrides struct { Type string; Value interface{} } Issuer struct { AuthenticationSource struct { Type string; Value string } @@ -328,6 +320,14 @@ type configWithType struct { TomcatLocation struct { Type string; Value string } } IssuerKey struct { Type string; Value string } + LocalCache struct { + DataLocation struct { Type string; Value string } + HighWaterMarkPercentage struct { Type string; Value int } + LowWaterMarkPercentage struct { Type string; Value int } + RunLocation struct { Type string; Value string } + Size struct { Type string; Value string } + Socket struct { Type string; Value string } + } Logging struct { Cache struct { Ofs struct { Type string; Value string } From 9abbdd5298b5703eb0e9b626eeb4c4e964bdd37e Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Tue, 5 Mar 2024 09:52:58 -0600 Subject: [PATCH 21/45] Fix compilation on Windows --- local_cache/cache_size_unix.go | 31 +++++++++++++++++++++++++++++++ local_cache/cache_size_windows.go | 21 +++++++++++++++++++++ local_cache/local_cache.go | 21 +++------------------ 3 files changed, 55 insertions(+), 18 deletions(-) create mode 100644 local_cache/cache_size_unix.go create mode 100644 local_cache/cache_size_windows.go diff --git a/local_cache/cache_size_unix.go b/local_cache/cache_size_unix.go new file mode 100644 index 000000000..b802ea57d --- /dev/null +++ b/local_cache/cache_size_unix.go @@ -0,0 +1,31 @@ +//go:build !windows + +package local_cache + +import ( + "syscall" + + "github.com/alecthomas/units" + "github.com/pelicanplatform/pelican/param" + "github.com/pkg/errors" +) + +func getCacheSize(cacheDir string) (cacheSize uint64, err error) { + sizeStr := param.LocalCache_Size.GetString() + if sizeStr == "" || sizeStr == "0" { + var stat syscall.Statfs_t + if err = syscall.Statfs(cacheDir, &stat); err != nil { + err = errors.Wrapf(err, "unable to determine free space for cache directory %s", cacheDir) + return + } + cacheSize = stat.Bavail * uint64(stat.Bsize) + } else { + var signedCacheSize int64 + signedCacheSize, err = units.ParseStrictBytes(param.LocalCache_Size.GetString()) + if err != nil { + return + } + cacheSize = uint64(signedCacheSize) + } + return +} diff --git a/local_cache/cache_size_windows.go b/local_cache/cache_size_windows.go new file mode 100644 index 000000000..63d9371b7 --- /dev/null +++ b/local_cache/cache_size_windows.go @@ -0,0 +1,21 @@ +//go:build windows + +package local_cache + +import ( + "github.com/alecthomas/units" + "github.com/pelicanplatform/pelican/param" + log "github.com/sirupsen/logrus" +) + +func getCacheSize(cacheDir string) (cacheSize uint64, err error) { + sizeStr := param.LocalCache_Size.GetString() + if sizeStr == "" || sizeStr == "0" { + log.Warningln("Cache size is unset and Pelican is unable to determine filesystem size; using 10GB as the default") + sizeStr = "10GB" + } + if signedCacheSize, err := units.ParseStrictBytes(param.LocalCache_Size.GetString()); err == nil { + cacheSize = uint64(signedCacheSize) + } + return +} diff --git a/local_cache/local_cache.go b/local_cache/local_cache.go index 2557e81e5..fc5729ce7 100644 --- a/local_cache/local_cache.go +++ b/local_cache/local_cache.go @@ -32,10 +32,8 @@ import ( "slices" "sync" "sync/atomic" - "syscall" "time" - "github.com/alecthomas/units" "github.com/google/uuid" "github.com/pelicanplatform/pelican/client" "github.com/pelicanplatform/pelican/common" @@ -227,22 +225,9 @@ func NewLocalCache(ctx context.Context, egrp *errgroup.Group) (sc *LocalCache, e return } - sizeStr := param.LocalCache_Size.GetString() - var cacheSize uint64 - if sizeStr == "" || sizeStr == "0" { - var stat syscall.Statfs_t - if err = syscall.Statfs(cacheDir, &stat); err != nil { - err = errors.Wrapf(err, "unable to determine free space for cache directory %s", cacheDir) - return - } - cacheSize = stat.Bavail * uint64(stat.Bsize) - } else { - var signedCacheSize int64 - signedCacheSize, err = units.ParseStrictBytes(param.LocalCache_Size.GetString()) - if err != nil { - return - } - cacheSize = uint64(signedCacheSize) + cacheSize, err := getCacheSize(cacheDir) + if err != nil { + return } highWaterPercentage := param.LocalCache_HighWaterMarkPercentage.GetInt() lowWaterPercentage := param.LocalCache_LowWaterMarkPercentage.GetInt() From d08c46b70049ae7a55ec6f42bdb45e134ecefc35 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Tue, 5 Mar 2024 14:34:02 -0800 Subject: [PATCH 22/45] Fix error handling around the client Make sure that `DoGet` / `DoCopy` / `DoPut` goes through the transfer results and promote any errors. Also touch up exported and internal functions in the client to follow go naming practices. --- client/get_best_cache.go | 2 +- client/handle_http.go | 2 +- client/main.go | 37 +++++++++++++++++++++++++++---------- client/main_test.go | 2 +- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/client/get_best_cache.go b/client/get_best_cache.go index 6e042e6e0..89c5ebd2d 100644 --- a/client/get_best_cache.go +++ b/client/get_best_cache.go @@ -74,7 +74,7 @@ func GetBestCache(cacheListName string) ([]string, error) { headers.Host = cur_site log.Debugf("Trying server site of %s", cur_site) - for _, ip := range get_ips(cur_site) { + for _, ip := range getIPs(cur_site) { GeoIpUrl.Host = ip GeoIpUrl.Scheme = "http" diff --git a/client/handle_http.go b/client/handle_http.go index 02db6d26f..8c0affc80 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1158,7 +1158,7 @@ func runTransferWorker(ctx context.Context, workChan <-chan *clientTransferFile, } transferResults.jobId = file.jobId if err != nil { - log.Errorf("Error when attempting to transfer object %s for client %s", file.file.remoteURL, file.uuid.String()) + log.Errorf("Error when attempting to transfer object %s for client %s: %v", file.file.remoteURL, file.uuid.String(), err) transferResults = newTransferResults(file.file.job) transferResults.Error = err } else if transferResults.Error == nil { diff --git a/client/main.go b/client/main.go index 9e362cb5b..efc5fc4e4 100644 --- a/client/main.go +++ b/client/main.go @@ -197,7 +197,7 @@ func DoStat(ctx context.Context, destination string, options ...TransferOption) understoodSchemes := []string{"osdf", "pelican", ""} - _, foundSource := Find(understoodSchemes, destUri.Scheme) + _, foundSource := find(understoodSchemes, destUri.Scheme) if !foundSource { log.Errorln("Unknown schema provided:", destUri.Scheme) return 0, errors.New("Unsupported scheme requested") @@ -490,7 +490,7 @@ func DoPut(ctx context.Context, localObject string, remoteDestination string, re understoodSchemes := []string{"file", "osdf", "pelican", ""} - _, foundDest := Find(understoodSchemes, remoteDestScheme) + _, foundDest := find(understoodSchemes, remoteDestScheme) if !foundDest { return nil, fmt.Errorf("Do not understand the destination scheme: %s. Permitted values are %s", remoteDestUrl.Scheme, strings.Join(understoodSchemes, ", ")) @@ -517,6 +517,11 @@ func DoPut(ctx context.Context, localObject string, remoteDestination string, re if tj.lookupErr != nil { err = tj.lookupErr } + for _, result := range transferResults { + if err == nil && result.Error != nil { + err = result.Error + } + } return } @@ -575,7 +580,7 @@ func DoGet(ctx context.Context, remoteObject string, localDestination string, re understoodSchemes := []string{"file", "osdf", "pelican", ""} - _, foundSource := Find(understoodSchemes, remoteObjectScheme) + _, foundSource := find(understoodSchemes, remoteObjectScheme) if !foundSource { return nil, fmt.Errorf("Do not understand the source scheme: %s. Permitted values are %s", remoteObjectUrl.Scheme, strings.Join(understoodSchemes, ", ")) @@ -646,6 +651,10 @@ func DoGet(ctx context.Context, remoteObject string, localDestination string, re var downloaded int64 = 0 for _, result := range transferResults { downloaded += result.TransferredBytes + if err == nil && result.Error != nil { + success = false + err = result.Error + } } payload.end1 = end.Unix() @@ -741,13 +750,13 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv understoodSchemes := []string{"stash", "file", "osdf", "pelican", ""} - _, foundSource := Find(understoodSchemes, sourceScheme) + _, foundSource := find(understoodSchemes, sourceScheme) if !foundSource { log.Errorln("Do not understand source scheme:", sourceURL.Scheme) return nil, errors.New("Do not understand source scheme") } - _, foundDest := Find(understoodSchemes, destScheme) + _, foundDest := find(understoodSchemes, destScheme) if !foundDest { log.Errorln("Do not understand destination scheme:", sourceURL.Scheme) return nil, errors.New("Do not understand destination scheme") @@ -827,13 +836,21 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv } transferResults, err = tc.Shutdown() if err == nil { - success = true + if tj.lookupErr == nil { + success = true + } else { + err = tj.lookupErr + } } end := time.Now() for _, result := range transferResults { downloaded += result.TransferredBytes + if err == nil && result.Error != nil { + success = false + err = result.Error + } } payload.end1 = end.Unix() @@ -854,10 +871,10 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv } } -// Find takes a slice and looks for an element in it. If found it will +// find takes a slice and looks for an element in it. If found it will // return it's key, otherwise it will return -1 and a bool of false. // From https://golangcode.com/check-if-element-exists-in-slice/ -func Find(slice []string, val string) (int, bool) { +func find(slice []string, val string) (int, bool) { for i, item := range slice { if item == val { return i, true @@ -866,10 +883,10 @@ func Find(slice []string, val string) (int, bool) { return -1, false } -// get_ips will resolve a hostname and return all corresponding IP addresses +// getIPs will resolve a hostname and return all corresponding IP addresses // in DNS. This can be used to randomly pick an IP when DNS round robin // is used -func get_ips(name string) []string { +func getIPs(name string) []string { var ipv4s []string var ipv6s []string diff --git a/client/main_test.go b/client/main_test.go index 8f28c61ac..435a22e9b 100644 --- a/client/main_test.go +++ b/client/main_test.go @@ -39,7 +39,7 @@ import ( func TestGetIps(t *testing.T) { t.Parallel() - ips := get_ips("wlcg-wpad.fnal.gov") + ips := getIPs("wlcg-wpad.fnal.gov") for _, ip := range ips { parsedIP := net.ParseIP(ip) if parsedIP.To4() != nil { From f68e9407126304e055eb1ed50ef71da6cf0281ee Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Tue, 5 Mar 2024 15:13:16 -0800 Subject: [PATCH 23/45] Add support for on-demand purging and corresponding unit tests Ensure that purging works on a full cache. --- config/resources/defaults.yaml | 2 +- docs/scopes.yaml | 9 ++ launchers/launcher.go | 17 ++- local_cache/cache_api.go | 69 +++++++-- local_cache/cache_size_windows.go | 2 +- local_cache/cache_test.go | 242 ++++++++++++++++++++++++++---- local_cache/local_cache.go | 156 +++++++++++++------ token_scopes/token_scopes.go | 1 + 8 files changed, 404 insertions(+), 94 deletions(-) diff --git a/config/resources/defaults.yaml b/config/resources/defaults.yaml index ec40165c3..3326ae57c 100644 --- a/config/resources/defaults.yaml +++ b/config/resources/defaults.yaml @@ -46,7 +46,7 @@ Director: EnableBroker: true Cache: Port: 8442 -FileCache: +LocalCache: HighWaterMarkPercentage: 95 LowWaterMarkPercentage: 85 Origin: diff --git a/docs/scopes.yaml b/docs/scopes.yaml index 212b9b83c..458427ec1 100644 --- a/docs/scopes.yaml +++ b/docs/scopes.yaml @@ -94,6 +94,15 @@ issuedBy: ["origin"] acceptedBy: [cache"] --- ############################ +# LocalCache Scopes # +############################ +name: localcache.purge +description: >- + Permits invocation of the purge routine in a local cache +issuedBy: ["localcache"] +acceptedBy: ["localcache"] +--- +############################ # Storage Scopes # ############################ name: "storage.read" diff --git a/launchers/launcher.go b/launchers/launcher.go index 2b76455c7..e93da229d 100644 --- a/launchers/launcher.go +++ b/launchers/launcher.go @@ -183,6 +183,17 @@ func LaunchModules(ctx context.Context, modules config.ServerType) (context.Canc } } + var lc *local_cache.LocalCache + if modules.IsEnabled(config.LocalCacheType) { + // Create and register the cache routines before the web interface is up + lc, err = local_cache.NewLocalCache(ctx, egrp, local_cache.WithDeferConfig(true)) + if err != nil { + return shutdownCancel, err + } + rootGroup := engine.Group("/") + lc.Register(ctx, rootGroup) + } + log.Info("Starting web engine...") lnReference = nil egrp.Go(func() error { @@ -251,10 +262,14 @@ func LaunchModules(ctx context.Context, modules config.ServerType) (context.Canc if modules.IsEnabled(config.LocalCacheType) { log.Debugln("Starting local cache listener") - if err := local_cache.LaunchListener(ctx, egrp); err != nil { + if err := lc.Config(egrp); err != nil { + log.Warning("Failure when configuring the local cache; cache may incorrectly generate 403 errors until reconfiguration runs") + } + if err := lc.LaunchListener(ctx, egrp); err != nil { log.Errorln("Failure when starting the local cache listener:", err) return shutdownCancel, err } + } if param.Server_EnableUI.GetBool() { diff --git a/local_cache/cache_api.go b/local_cache/cache_api.go index 3d861353e..d4e9b772a 100644 --- a/local_cache/cache_api.go +++ b/local_cache/cache_api.go @@ -31,26 +31,41 @@ import ( "strconv" "strings" + "github.com/gin-gonic/gin" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/param" + "github.com/pelicanplatform/pelican/token_scopes" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" ) +type ( + localCacheResp struct { + Status localCacheResponseStatus `json:"status"` + Msg string `json:"msg,omitempty"` + } + + localCacheResponseStatus string +) + +const ( + responseOk localCacheResponseStatus = "success" + responseFailed localCacheResponseStatus = "error" +) + // Launch the unix socket listener as a separate goroutine -func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { +func (lc *LocalCache) LaunchListener(ctx context.Context, egrp *errgroup.Group) (err error) { socketName := param.LocalCache_Socket.GetString() - if err := os.MkdirAll(filepath.Dir(socketName), fs.FileMode(0755)); err != nil { - return errors.Wrap(err, "failed to create socket directory") + if err = os.MkdirAll(filepath.Dir(socketName), fs.FileMode(0755)); err != nil { + err = errors.Wrap(err, "failed to create socket directory") + return } listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: socketName, Net: "unix"}) if err != nil { - return err - } - sc, err := NewLocalCache(ctx, egrp) - if err != nil { - return err + return } handler := func(w http.ResponseWriter, r *http.Request) { @@ -64,6 +79,7 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { for _, encoding := range r.Header.Values("TE") { if encoding == "trailers" { sendTrailer = true + w.Header().Set("Trailer", "X-Transfer-Status") break } } @@ -79,12 +95,12 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { var size uint64 var reader io.ReadCloser if r.Method == "HEAD" { - size, err = sc.Stat(path, bearerToken) + size, err = lc.Stat(path, bearerToken) if err == nil { w.Header().Set("Content-Length", strconv.FormatUint(size, 10)) } } else { - reader, err = sc.Get(path, bearerToken) + reader, err = lc.Get(path, bearerToken) } if errors.Is(err, authorizationDenied) { w.WriteHeader(http.StatusForbidden) @@ -121,5 +137,36 @@ func LaunchListener(ctx context.Context, egrp *errgroup.Group) error { <-ctx.Done() return srv.Shutdown(ctx) }) - return nil + return +} + +// Register the control & monitoring routines with Gin +func (lc *LocalCache) Register(ctx context.Context, router *gin.RouterGroup) { + router.POST("/api/v1.0/localcache/purge", func(ginCtx *gin.Context) { lc.purgeCmd(ginCtx) }) +} + +// Authorize the request then trigger the purge routine +func (lc *LocalCache) purgeCmd(ginCtx *gin.Context) { + token := ginCtx.GetHeader("Authorization") + var hasPrefix bool + if token, hasPrefix = strings.CutPrefix(token, "Bearer "); !hasPrefix { + ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, localCacheResp{responseFailed, "Bearer token required to authenticate"}) + return + } + + jwks, err := config.GetIssuerPublicJWKS() + if err != nil { + ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, localCacheResp{responseFailed, "Unable to get local server token issuer"}) + return + } + tok, err := jwt.Parse([]byte(token), jwt.WithKeySet(jwks)) + if err != nil { + ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, localCacheResp{responseFailed, "Authorization token cannot be verified"}) + } + scopeValidator := token_scopes.CreateScopeValidator([]token_scopes.TokenScope{token_scopes.Localcache_Purge}, true) + if err = jwt.Validate(tok, jwt.WithValidator(scopeValidator)); err != nil { + ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, localCacheResp{responseFailed, "Authorization token is not valid: " + err.Error()}) + return + } + lc.purge() } diff --git a/local_cache/cache_size_windows.go b/local_cache/cache_size_windows.go index 63d9371b7..f8327765e 100644 --- a/local_cache/cache_size_windows.go +++ b/local_cache/cache_size_windows.go @@ -8,7 +8,7 @@ import ( log "github.com/sirupsen/logrus" ) -func getCacheSize(cacheDir string) (cacheSize uint64, err error) { +func getCacheSize(string) (cacheSize uint64, err error) { sizeStr := param.LocalCache_Size.GetString() if sizeStr == "" || sizeStr == "0" { log.Warningln("Cache size is unset and Pelican is unable to determine filesystem size; using 10GB as the default") diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index 261c6ac91..9345770a1 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -20,6 +20,7 @@ package local_cache_test import ( "context" + "errors" "fmt" "io" "net" @@ -39,6 +40,7 @@ import ( "github.com/pelicanplatform/pelican/test_utils" "github.com/pelicanplatform/pelican/token_scopes" "github.com/pelicanplatform/pelican/utils" + log "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -174,6 +176,7 @@ func TestFedPublicGet(t *testing.T) { assert.Equal(t, "Hello, World!", string(byteBuff)) } +// Test the local cache library on an authenticated GET. func TestFedAuthGet(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() @@ -212,6 +215,7 @@ func TestFedAuthGet(t *testing.T) { assert.Equal(t, "authorization denied", err.Error()) } +// Test a raw HTTP request (no Pelican client) works with the local cache func TestHttpReq(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() @@ -239,6 +243,7 @@ func TestHttpReq(t *testing.T) { assert.Equal(t, "Hello, World!", string(body)) } +// Test that the client library (with authentication) works with the local cache func TestClient(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() @@ -254,21 +259,54 @@ func TestClient(t *testing.T) { Path: param.LocalCache_Socket.GetString(), } - discoveryHost := param.Federation_DiscoveryUrl.GetString() - discoveryUrl, err := url.Parse(discoveryHost) - require.NoError(t, err) - tr, err := client.DoGet(ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, - client.WithToken(ft.token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) - assert.NoError(t, err) - require.Equal(t, 1, len(tr)) - assert.Equal(t, int64(13), tr[0].TransferredBytes) - assert.NoError(t, tr[0].Error) + t.Run("correct-auth", func(t *testing.T) { + discoveryHost := param.Federation_DiscoveryUrl.GetString() + discoveryUrl, err := url.Parse(discoveryHost) + require.NoError(t, err) + tr, err := client.DoGet(ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, + client.WithToken(ft.token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + assert.NoError(t, err) + require.Equal(t, 1, len(tr)) + assert.Equal(t, int64(13), tr[0].TransferredBytes) + assert.NoError(t, tr[0].Error) + + byteBuff, err := os.ReadFile(filepath.Join(tmpDir, "hello_world.txt")) + assert.NoError(t, err) + assert.Equal(t, "Hello, World!", string(byteBuff)) + }) + t.Run("incorrect-auth", func(t *testing.T) { + _, err := client.DoGet(ctx, "pelican:///test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, + client.WithToken("badtoken"), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + assert.Error(t, err) + assert.ErrorIs(t, err, &client.ConnectionSetupError{}) + var cse *client.ConnectionSetupError + assert.True(t, errors.As(err, &cse)) + assert.Equal(t, "failed connection setup: server returned 403 Forbidden", cse.Error()) + }) - byteBuff, err := os.ReadFile(filepath.Join(tmpDir, "hello_world.txt")) - assert.NoError(t, err) - assert.Equal(t, "Hello, World!", string(byteBuff)) + t.Run("file-not-found", func(t *testing.T) { + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + tokConf := utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Duration(time.Minute), + Issuer: issuer, + Subject: "test", + Audience: []string{utils.WLCGAny}, + } + tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt.1")) + + token, err := tokConf.CreateToken() + require.NoError(t, err) + + _, err = client.DoGet(ctx, "pelican:///test/hello_world.txt.1", filepath.Join(tmpDir, "hello_world.txt.1"), false, + client.WithToken(token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + assert.Error(t, err) + assert.Equal(t, "failed to download file: transfer error: failed connection setup: server returned 404 Not Found", err.Error()) + }) } +// Test that HEAD requests to the local cache return the correct result func TestStat(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() @@ -296,6 +334,37 @@ func TestStat(t *testing.T) { assert.Equal(t, uint64(13), size) } +// Creates a buffer of at least 1MB +func makeBigBuffer() []byte { + byteBuff := []byte("Hello, World!") + for { + byteBuff = append(byteBuff, []byte("Hello, World!")...) + if len(byteBuff) > 1024*1024 { + break + } + } + return byteBuff +} + +// Writes a file at least the specified size in MB +func writeBigBuffer(t *testing.T, fp io.WriteCloser, sizeMB int) (size int) { + defer fp.Close() + byteBuff := makeBigBuffer() + size = 0 + for { + n, err := fp.Write(byteBuff) + require.NoError(t, err) + size += n + if size > sizeMB*1024*1024 { + break + } + } + return +} + +// Create a 100MB file in the origin. Download it (slowly) via the local cache. +// +// This triggers multiple internal requests to wait on the slow download func TestLargeFile(t *testing.T) { ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) defer cancel() @@ -314,24 +383,7 @@ func TestLargeFile(t *testing.T) { fp, err := os.OpenFile(filepath.Join(ft.originDir, "hello_world.txt"), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) require.NoError(t, err) - - byteBuff := []byte("Hello, World!") - for { - byteBuff = append(byteBuff, []byte("Hello, World!")...) - if len(byteBuff) > 4096 { - break - } - } - size := 0 - for { - n, err := fp.Write(byteBuff) - require.NoError(t, err) - size += n - if size > 100*1024*1024 { - break - } - } - fp.Close() + size := writeBigBuffer(t, fp, 100) discoveryHost := param.Federation_DiscoveryUrl.GetString() discoveryUrl, err := url.Parse(discoveryHost) @@ -344,3 +396,133 @@ func TestLargeFile(t *testing.T) { assert.NoError(t, tr[0].Error) } + +// Create five 1MB files. Trigger a purge, ensuring that the cleanup is +// done according to LRU +func TestPurge(t *testing.T) { + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer cancel() + tmpDir := t.TempDir() + + viper.Reset() + viper.Set("Origin.EnablePublicReads", true) + viper.Set("LocalCache.Size", "5MB") + ft := fedTest{} + ft.spinup(t, ctx, egrp) + + cacheUrl := &url.URL{ + Scheme: "unix", + Path: param.LocalCache_Socket.GetString(), + } + + size := 0 + for idx := 0; idx < 5; idx++ { + log.Debugln("Will write origin file", filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx))) + fp, err := os.OpenFile(filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + require.NoError(t, err) + size = writeBigBuffer(t, fp, 1) + } + require.NotEqual(t, 0, size) + + for idx := 0; idx < 5; idx++ { + tr, err := client.DoGet(ctx, fmt.Sprintf("pelican:///test/hello_world.txt.%d", idx), filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, + client.WithCaches(cacheUrl)) + assert.NoError(t, err) + require.Equal(t, 1, len(tr)) + assert.Equal(t, int64(size), tr[0].TransferredBytes) + assert.NoError(t, tr[0].Error) + } + + // Size of the cache should be just small enough that the 5th file triggers LRU deletion of the first. + for idx := 0; idx < 5; idx++ { + func() { + fp, err := os.Open(filepath.Join(param.LocalCache_DataLocation.GetString(), "test", fmt.Sprintf("hello_world.txt.%d.DONE", idx))) + if idx == 0 { + log.Errorln("Error:", err) + assert.ErrorIs(t, err, os.ErrNotExist) + } else { + assert.NoError(t, err) + } + defer fp.Close() + }() + } +} + +// Create four 1MB files (above low-water mark). Force a purge, ensuring that the cleanup is +// done according to LRU +func TestForcePurge(t *testing.T) { + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + defer cancel() + tmpDir := t.TempDir() + + viper.Reset() + viper.Set("Origin.EnablePublicReads", true) + viper.Set("LocalCache.Size", "5MB") + // Decrease the low water mark so invoking purge will result in 3 files in the cache. + viper.Set("LocalCache.LowWaterMarkPercentage", "80") + ft := fedTest{} + ft.spinup(t, ctx, egrp) + + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + tokConf := utils.TokenConfig{ + TokenProfile: utils.WLCG, + Lifetime: time.Duration(time.Minute), + Issuer: issuer, + Subject: "test", + Audience: []string{utils.WLCGAny}, + } + tokConf.AddScopes(token_scopes.Localcache_Purge) + + token, err := tokConf.CreateToken() + require.NoError(t, err) + + cacheUrl := &url.URL{ + Scheme: "unix", + Path: param.LocalCache_Socket.GetString(), + } + + // Populate the cache with our test files + size := 0 + for idx := 0; idx < 4; idx++ { + log.Debugln("Will write origin file", filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx))) + fp, err := os.OpenFile(filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + require.NoError(t, err) + size = writeBigBuffer(t, fp, 1) + } + require.NotEqual(t, 0, size) + + for idx := 0; idx < 4; idx++ { + tr, err := client.DoGet(ctx, fmt.Sprintf("pelican:///test/hello_world.txt.%d", idx), filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, + client.WithCaches(cacheUrl)) + assert.NoError(t, err) + require.Equal(t, 1, len(tr)) + assert.Equal(t, int64(size), tr[0].TransferredBytes) + assert.NoError(t, tr[0].Error) + } + + // Size of the cache should be large enough that purge hasn't fired yet. + for idx := 0; idx < 4; idx++ { + func() { + fp, err := os.Open(filepath.Join(param.LocalCache_DataLocation.GetString(), "test", fmt.Sprintf("hello_world.txt.%d.DONE", idx))) + assert.NoError(t, err) + defer fp.Close() + }() + } + + _, err = utils.MakeRequest(ctx, param.Server_ExternalWebUrl.GetString()+"/api/v1.0/localcache/purge", "POST", nil, map[string]string{"Authorization": "Bearer " + token}) + require.NoError(t, err) + + // Low water mark is small enough that a force purge will delete a file. + for idx := 0; idx < 4; idx++ { + func() { + fp, err := os.Open(filepath.Join(param.LocalCache_DataLocation.GetString(), "test", fmt.Sprintf("hello_world.txt.%d.DONE", idx))) + if idx == 0 { + assert.ErrorIs(t, err, os.ErrNotExist) + } else { + assert.NoError(t, err) + } + defer fp.Close() + }() + } +} diff --git a/local_cache/local_cache.go b/local_cache/local_cache.go index fc5729ce7..78af8c501 100644 --- a/local_cache/local_cache.go +++ b/local_cache/local_cache.go @@ -35,6 +35,7 @@ import ( "time" "github.com/google/uuid" + "github.com/lestrrat-go/option" "github.com/pelicanplatform/pelican/client" "github.com/pelicanplatform/pelican/common" "github.com/pelicanplatform/pelican/param" @@ -47,17 +48,19 @@ import ( type ( LocalCache struct { - ctx context.Context - egrp *errgroup.Group - te *client.TransferEngine - tc *client.TransferClient - cancelReq chan cancelReq - basePath string - sizeReq chan availSizeReq - mutex sync.RWMutex - downloads map[string]*activeDownload - directorURL *url.URL - ac *authConfig + ctx context.Context + egrp *errgroup.Group + te *client.TransferEngine + tc *client.TransferClient + cancelReq chan cancelReq + basePath string + sizeReq chan availSizeReq + mutex sync.RWMutex + purgeMutex sync.Mutex + downloads map[string]*activeDownload + directorURL *url.URL + ac *authConfig + wasConfigured bool // Cache static configuration highWater uint64 @@ -130,6 +133,9 @@ type ( size int64 results chan *downloadStatus } + + LocalCacheOption = option.Interface + identLocalCacheOptionDeferConfig struct{} ) var ( @@ -207,10 +213,25 @@ func (ds *downloadStatus) String() string { } } +// Create an option to defer the configuration of the local cache +// +// Useful in cases where the cache should be created before the web interface +// is up -- but the web interface is needed to complete configuration. +func WithDeferConfig(deferConfig bool) LocalCacheOption { + return option.New(identLocalCacheOptionDeferConfig{}, deferConfig) +} + // Create a local cache object // // Launches background goroutines associated with the cache -func NewLocalCache(ctx context.Context, egrp *errgroup.Group) (sc *LocalCache, err error) { +func NewLocalCache(ctx context.Context, egrp *errgroup.Group, options ...LocalCacheOption) (lc *LocalCache, err error) { + deferConfig := false + for _, option := range options { + switch option.Ident() { + case identLocalCacheOptionDeferConfig{}: + deferConfig = option.Value().(bool) + } + } // Setup cache on disk cacheDir := param.LocalCache_DataLocation.GetString() @@ -231,13 +252,16 @@ func NewLocalCache(ctx context.Context, egrp *errgroup.Group) (sc *LocalCache, e } highWaterPercentage := param.LocalCache_HighWaterMarkPercentage.GetInt() lowWaterPercentage := param.LocalCache_LowWaterMarkPercentage.GetInt() + highWater := (cacheSize / 100) * uint64(highWaterPercentage) + lowWater := (cacheSize / 100) * uint64(lowWaterPercentage) + log.Infof("Cache size is %d bytes; for purge, high water mark is %d bytes, low water mark is %d bytes", cacheSize, highWater, lowWater) directorUrl, err := url.Parse(param.Federation_DirectorUrl.GetString()) if err != nil { return } - sc = &LocalCache{ + lc = &LocalCache{ ctx: ctx, egrp: egrp, te: client.NewTransferEngine(ctx), @@ -245,7 +269,7 @@ func NewLocalCache(ctx context.Context, egrp *errgroup.Group) (sc *LocalCache, e hitChan: make(chan lruEntry, 64), highWater: (cacheSize / 100) * uint64(highWaterPercentage), lowWater: (cacheSize / 100) * uint64(lowWaterPercentage), - cacheSize: cacheSize, + cacheSize: 0, basePath: cacheDir, ac: newAuthConfig(ctx, egrp), sizeReq: make(chan availSizeReq), @@ -253,25 +277,39 @@ func NewLocalCache(ctx context.Context, egrp *errgroup.Group) (sc *LocalCache, e lruLookup: make(map[string]*lruEntry), } - sc.tc, err = sc.te.NewClient(client.WithAcquireToken(false), client.WithCallback(sc.callback)) + lc.tc, err = lc.te.NewClient(client.WithAcquireToken(false), client.WithCallback(lc.callback)) if err != nil { - shutdownErr := sc.te.Shutdown() + shutdownErr := lc.te.Shutdown() if shutdownErr != nil { log.Errorln("Failed to shutdown transfer engine") } return } - if err = sc.updateConfig(); err != nil { - log.Warningln("First attempt to update cache's authorization failed:", err) + if !deferConfig { + if err = lc.Config(egrp); err != nil { + log.Warningln("First attempt to update cache's authorization failed:", err) + } } - egrp.Go(sc.runMux) - egrp.Go(sc.periodicUpdateConfig) + egrp.Go(lc.runMux) log.Debugln("Successfully created a new local cache object") return } +// Try to configure the local cache and launch the reconfigure goroutine +func (lc *LocalCache) Config(egrp *errgroup.Group) (err error) { + if lc.wasConfigured { + return + } + lc.wasConfigured = true + if err = lc.updateConfig(); err != nil { + log.Warningln("First attempt to update cache's authorization failed:", err) + } + egrp.Go(lc.periodicUpdateConfig) + return +} + // Callback for in-progress transfers // // The TransferClient will invoke the callback as it progresses; @@ -396,16 +434,7 @@ func (sc *LocalCache) runMux() error { } else { fp.Close() } - entry := sc.lruLookup[reqPath] - if entry == nil { - entry = &lruEntry{} - sc.lruLookup[reqPath] = entry - entry.size = results.TransferredBytes - sc.cacheSize += uint64(entry.size) - sc.lru = append(sc.lru, entry) - } else { - entry.lastUse = time.Now() - } + sc.lruHit(lruEntry{lastUse: time.Now(), path: reqPath, size: results.TransferredBytes}) } } else if chosen == lenChan+2 { // Ticker has fired - update progress @@ -470,6 +499,7 @@ func (sc *LocalCache) runMux() error { channel: req.results, ds: ds, }) + sc.lruHit(lruEntry{lastUse: time.Now(), path: req.request.path, size: fi.Size()}) } } @@ -532,35 +562,56 @@ func (sc *LocalCache) runMux() error { } else if chosen == lenChan+5 { // Notification there was a cache hit. hit := recv.Interface().(lruEntry) - entry := sc.lruLookup[hit.path] - if entry == nil { - entry = &lruEntry{} - sc.lruLookup[hit.path] = entry - entry.size = hit.size - sc.lru = append(sc.lru, entry) - sc.cacheSize += uint64(hit.size) - if sc.cacheSize > sc.highWater { - sc.purge() - } - } - entry.lastUse = hit.lastUse + sc.lruHit(hit) + } + } +} + +func (lc *LocalCache) lruHit(hit lruEntry) { + entry := lc.lruLookup[hit.path] + if entry == nil { + entry = &hit + lc.lruLookup[hit.path] = entry + lc.lru = append(lc.lru, entry) + lc.cacheSize += uint64(hit.size) + if lc.cacheSize > lc.highWater { + lc.purge() } } + entry.lastUse = hit.lastUse + if hit.size > entry.size { + entry.size = hit.size + } } -func (sc *LocalCache) purge() { - heap.Init(&sc.lru) +func (lc *LocalCache) purge() { + log.Debugln("Starting purge routine") + lc.purgeMutex.Lock() + defer lc.purgeMutex.Unlock() + heap.Init(&lc.lru) start := time.Now() - for sc.cacheSize > sc.lowWater { - entry := heap.Pop(&sc.lru).(*lruEntry) - localPath := path.Join(sc.basePath, path.Clean(entry.path)) + for lc.cacheSize > lc.lowWater { + if len(lc.lru) == 0 { + log.Warningln("Potential consistency error: purge ran until cache was empty") + break + } + entry := heap.Pop(&lc.lru).(*lruEntry) + if entry == nil { + log.Warningln("Consistency error: purge run but no entry provided") + continue + } + if entry.path == "" { + log.Warningln("Consistency error: purge ran on an empty path") + continue + } + localPath := path.Join(lc.basePath, path.Clean(entry.path)) if err := os.Remove(localPath + ".DONE"); err != nil { log.Warningln("Failed to purge DONE file:", err) } if err := os.Remove(localPath); err != nil { log.Warningln("Failed to purge file:", err) } - sc.cacheSize -= uint64(entry.size) + lc.cacheSize -= uint64(entry.size) // Since purge is called from the mux thread, blocking can cause // other failures; do a time-based break even if we've not hit the low-water if time.Since(start) > 3*time.Second { @@ -603,6 +654,11 @@ func (sc *LocalCache) Get(path, token string) (io.ReadCloser, error) { } if fp := sc.getFromDisk(path); fp != nil { + finfo, err := fp.Stat() + if err != nil { + log.Warningf("Able to open %s in cache but unable to stat it: %v", path, err) + } + sc.hitChan <- lruEntry{lastUse: time.Now(), path: path, size: finfo.Size()} return fp, nil } @@ -626,8 +682,8 @@ func (lc *LocalCache) Stat(path, token string) (uint64, error) { dUrl := *lc.directorURL dUrl.Path = path dUrl.Scheme = "pelican" - log.Debugln("LocalCache doing Stat:", dUrl.String()) - return client.DoStat(context.Background(), dUrl.String(), client.WithToken(token)) + size, err := client.DoStat(context.Background(), dUrl.String(), client.WithToken(token)) + return size, err } func (sc *LocalCache) updateConfig() error { diff --git a/token_scopes/token_scopes.go b/token_scopes/token_scopes.go index c39bc7ed5..b20ef764f 100644 --- a/token_scopes/token_scopes.go +++ b/token_scopes/token_scopes.go @@ -36,6 +36,7 @@ const ( Broker_Reverse TokenScope = "broker.reverse" Broker_Retrieve TokenScope = "broker.retrieve" Broker_Callback TokenScope = "broker.callback" + Localcache_Purge TokenScope = "localcache.purge" // Storage Scopes Storage_Read TokenScope = "storage.read" From 242851e2fff47d86b94a27e40406849a4bd7b100 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Tue, 5 Mar 2024 15:31:36 -0800 Subject: [PATCH 24/45] Disable caches for all the unit test platforms There are currently sequencing issues in registering the cache for enabling immediate downloads like we've setup. Disable for now, revisit later. --- local_cache/cache_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index 9345770a1..da1b26a99 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -60,11 +60,16 @@ func (ft *fedTest) spinup(t *testing.T, ctx context.Context, egrp *errgroup.Grou modules.Set(config.OriginType) modules.Set(config.DirectorType) modules.Set(config.RegistryType) + // TODO: the cache startup routines not sequenced correctly for the downloads + // to immediately work through the cache. For now, unit tests will just use the origin. + viper.Set("Origin.EnableFallbackRead", true) + /* if runtime.GOOS == "darwin" { viper.Set("Origin.EnableFallbackRead", true) } else { modules.Set(config.CacheType) } + */ modules.Set(config.LocalCacheType) tmpPathPattern := "XRootD-Test_Origin*" From 5edb5cde9748a35b8c25542c9ffff4d3a903c978 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Wed, 6 Mar 2024 05:59:11 -0800 Subject: [PATCH 25/45] For consistency, accept `stash` protocol for `DoStat` It appears that, unlike the other client APIs, `DoStat` has never accepted `stash`; add it for now. --- client/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/main.go b/client/main.go index efc5fc4e4..fb18fdc3e 100644 --- a/client/main.go +++ b/client/main.go @@ -195,7 +195,7 @@ func DoStat(ctx context.Context, destination string, options ...TransferOption) return 0, err } - understoodSchemes := []string{"osdf", "pelican", ""} + understoodSchemes := []string{"osdf", "pelican", "stash", ""} _, foundSource := find(understoodSchemes, destUri.Scheme) if !foundSource { From 5383e2190afcd58f2f8f0b6f63d3d7327639510a Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Wed, 6 Mar 2024 06:15:00 -0800 Subject: [PATCH 26/45] Tweak cache test - Do not test on Windows (origin is unavailable). - Ensure `xrootd` user can read the created test files. --- local_cache/cache_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index 0ad38ba7d..23a0c80c4 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -1,3 +1,5 @@ +//go:build !windows + /*************************************************************** * * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research @@ -96,7 +98,7 @@ func (ft *fedTest) spinup(t *testing.T, ctx context.Context, egrp *errgroup.Grou }) // Change the permissions of the temporary origin directory - permissions = os.FileMode(0777) + permissions = os.FileMode(0755) err = os.Chmod(originDir, permissions) require.NoError(t, err) @@ -386,7 +388,7 @@ func TestLargeFile(t *testing.T) { Path: param.LocalCache_Socket.GetString(), } - fp, err := os.OpenFile(filepath.Join(ft.originDir, "hello_world.txt"), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + fp, err := os.OpenFile(filepath.Join(ft.originDir, "hello_world.txt"), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) require.NoError(t, err) size := writeBigBuffer(t, fp, 100) @@ -423,7 +425,7 @@ func TestPurge(t *testing.T) { size := 0 for idx := 0; idx < 5; idx++ { log.Debugln("Will write origin file", filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx))) - fp, err := os.OpenFile(filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + fp, err := os.OpenFile(filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) require.NoError(t, err) size = writeBigBuffer(t, fp, 1) } From a2159ac84236fc54cb7d7c93ca81076e6f89cfd7 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Wed, 6 Mar 2024 06:37:48 -0800 Subject: [PATCH 27/45] Correct missed file permission --- local_cache/cache_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index 23a0c80c4..feae745c4 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -493,7 +493,7 @@ func TestForcePurge(t *testing.T) { size := 0 for idx := 0; idx < 4; idx++ { log.Debugln("Will write origin file", filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx))) - fp, err := os.OpenFile(filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + fp, err := os.OpenFile(filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) require.NoError(t, err) size = writeBigBuffer(t, fp, 1) } From fb467a53d333b84a5e85631635732e6da2ea45e2 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Wed, 6 Mar 2024 07:14:40 -0800 Subject: [PATCH 28/45] Bugfix: recursive downloads were not incorporating directory structure The recursive downloads repeatedly overwrote the same file; this adds the directory structure to the download and ensures in the unit test that the expected file exists. --- client/fed_test.go | 11 ++++++++--- client/handle_http.go | 12 ++++++------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/client/fed_test.go b/client/fed_test.go index 196622494..bf6cb17bd 100644 --- a/client/fed_test.go +++ b/client/fed_test.go @@ -26,6 +26,7 @@ import ( "io" "net/http" "os" + "path" "path/filepath" "strconv" "testing" @@ -738,7 +739,8 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { } // Download the files we just uploaded - transferDetailsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) + tmpDir := t.TempDir() + transferDetailsDownload, err := client.DoGet(ctx, uploadURL, tmpDir, true, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil && len(transferDetailsUpload) == 2 { countBytesUploadIdx0 := 0 @@ -762,9 +764,12 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { // We would hit this case if 1 counter got hit twice for some reason t.Fatal("One of the files was not downloaded correctly") - } else if len(transferDetailsDownload) != 2 { - t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) } + contents, err := os.ReadFile(filepath.Join(tmpDir, path.Join(dirName, path.Base(tempFile2.Name())))) + assert.NoError(t, err) + assert.Equal(t, testFileContent2, string(contents)) + } else if err == nil && len(transferDetailsDownload) != 2 { + t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) } }) diff --git a/client/handle_http.go b/client/handle_http.go index 4e9bfc432..5512837ed 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1185,7 +1185,7 @@ func runTransferWorker(ctx context.Context, workChan <-chan *clientTransferFile, } func downloadObject(transfer *transferFile) (transferResults TransferResults, err error) { - log.Debugln("Downloading file from", transfer.remoteURL) + log.Debugln("Downloading file from", transfer.remoteURL, "to", transfer.localPath) // Remove the source from the file path directory := path.Dir(transfer.localPath) var downloaded int64 @@ -1849,18 +1849,18 @@ func (te *TransferEngine) walkDirDownload(job *clientTransferJob, transfers []tr // // Recursively walks through the remote server directory, emitting transfer files // for the engine to process. -func (te *TransferEngine) walkDirDownloadHelper(job *clientTransferJob, transfers []transferAttemptDetails, files chan *clientTransferFile, path string, client *gowebdav.Client) error { - log.Debugln("Reading directory: ", path) +func (te *TransferEngine) walkDirDownloadHelper(job *clientTransferJob, transfers []transferAttemptDetails, files chan *clientTransferFile, remotePath string, client *gowebdav.Client) error { // Check for cancelation since the client does not respect the context if err := job.job.ctx.Err(); err != nil { return err } - infos, err := client.ReadDir(path) + infos, err := client.ReadDir(remotePath) if err != nil { return err } + localBase := strings.TrimPrefix(remotePath, job.job.remoteURL.Path) for _, info := range infos { - newPath := path + "/" + info.Name() + newPath := remotePath + "/" + info.Name() if info.IsDir() { err := te.walkDirDownloadHelper(job, transfers, files, newPath, client) if err != nil { @@ -1881,7 +1881,7 @@ func (te *TransferEngine) walkDirDownloadHelper(job *clientTransferJob, transfer engine: te, remoteURL: &url.URL{Path: newPath}, packOption: transfers[0].PackOption, - localPath: job.job.localPath, + localPath: path.Join(job.job.localPath, localBase, info.Name()), upload: job.job.upload, token: job.job.token, attempts: transfers, From 59a4ef86698ef37670f65ba415e88d61525212bb Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Wed, 6 Mar 2024 13:50:52 -0800 Subject: [PATCH 29/45] Add directory template characters; fails on Linux otherwise --- github_scripts/citests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/github_scripts/citests.sh b/github_scripts/citests.sh index f745c6ef0..1b962abcb 100755 --- a/github_scripts/citests.sh +++ b/github_scripts/citests.sh @@ -58,7 +58,7 @@ EOF ##################################### ## Test LocalCache in front of OSDF ##################################### -SOCKET_DIR="`mktemp -d -t pelican-citest`" +SOCKET_DIR="`mktemp -d -t pelican-citest-XXXXXX`" export PELICAN_FILECACHE_SOCKET=$SOCKET_DIR/socket export PELICAN_FILECACHE_DATALOCATION=$SOCKET_DIR/data From dceedad4dcd818dfe200c28bb7ab1ab6b85560c9 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Wed, 6 Mar 2024 19:03:21 -0800 Subject: [PATCH 30/45] Make integration test more robust Correct parameter name and wait for the new socket to drop. --- github_scripts/citests.sh | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/github_scripts/citests.sh b/github_scripts/citests.sh index 1b962abcb..1ab31712d 100755 --- a/github_scripts/citests.sh +++ b/github_scripts/citests.sh @@ -59,8 +59,9 @@ EOF ## Test LocalCache in front of OSDF ##################################### SOCKET_DIR="`mktemp -d -t pelican-citest-XXXXXX`" -export PELICAN_FILECACHE_SOCKET=$SOCKET_DIR/socket -export PELICAN_FILECACHE_DATALOCATION=$SOCKET_DIR/data +export PELICAN_LOCALCACHE_SOCKET=$SOCKET_DIR/socket +export PELICAN_LOCALCACHE_DATALOCATION=$SOCKET_DIR/data +export PELICAN_SERVER_ENABLEUI=false ./pelican serve -d -f osg-htc.org --module localcache & PELICAN_PID=$! @@ -72,7 +73,17 @@ cleanup() { } trap cleanup EXIT -sleep 1 +for idx in {1..20}; do + if [ -e "$SOCKET_DIR/socket" ]; then + break + fi + sleep 0.3 +done +if [ ! -e "$SOCKET_DIR/socket" ]; then + echo "pelican serve never dropped localcache socket" + exit 1 +fi + NEAREST_CACHE="unix://$SOCKET_DIR/socket" ./stash_plugin -d osdf:///ospool/uc-shared/public/OSG-Staff/validation/test.txt /dev/null exit_status=$? From aa8d10c4e7ed6007d892ec53ced95e37865b5d85 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Thu, 7 Mar 2024 05:49:27 -0800 Subject: [PATCH 31/45] Fix overwrite of URL path via pointer The endpoint is a *url.URL with all workers during a download using the same pointer. Thus, when one worker modified the URL's path, the modification was used by others. This caused a race condition where the wrong contents of the file would be downloaded. Also fixed up the unit test which (a) tested the wrong results for size (looking at the sizes of the upload slice instead of download) and (b) didn't check the contents of the other file. --- client/fed_test.go | 9 ++++++--- client/handle_http.go | 6 +++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/client/fed_test.go b/client/fed_test.go index bf6cb17bd..f5320a85a 100644 --- a/client/fed_test.go +++ b/client/fed_test.go @@ -742,12 +742,12 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { tmpDir := t.TempDir() transferDetailsDownload, err := client.DoGet(ctx, uploadURL, tmpDir, true, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { + if err == nil && len(transferDetailsDownload) == 2 { countBytesUploadIdx0 := 0 countBytesUploadIdx1 := 0 // Verify we got the correct files back (have to do this since files upload in different orders at times) // In this case, we want to match them to the sizes of the uploaded files - for _, transfer := range transferDetailsUpload { + for _, transfer := range transferDetailsDownload { transferredBytes := transfer.TransferredBytes switch transferredBytes { case transferDetailsUpload[0].TransferredBytes: @@ -768,8 +768,11 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { contents, err := os.ReadFile(filepath.Join(tmpDir, path.Join(dirName, path.Base(tempFile2.Name())))) assert.NoError(t, err) assert.Equal(t, testFileContent2, string(contents)) + contents, err = os.ReadFile(filepath.Join(tmpDir, path.Join(dirName, path.Base(tempFile1.Name())))) + assert.NoError(t, err) + assert.Equal(t, testFileContent1, string(contents)) } else if err == nil && len(transferDetailsDownload) != 2 { - t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) + t.Fatalf("Number of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) } }) diff --git a/client/handle_http.go b/client/handle_http.go index 5512837ed..0a0f38c74 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1200,7 +1200,11 @@ func downloadObject(transfer *transferFile) (transferResults TransferResults, er var serverVersion string attempt.Number = idx // Start with 0 attempt.Endpoint = transferEndpoint.Url.Host - transferEndpoint.Url.Path = transfer.remoteURL.Path + // Make a copy of the transfer endpoint URL; otherwise, when we mutate the pointer, other parallel + // workers might download from the wrong path. + transferEndpointUrl := *transferEndpoint.Url + transferEndpointUrl.Path = transfer.remoteURL.Path + transferEndpoint.Url = &transferEndpointUrl transferStartTime := time.Now() if downloaded, timeToFirstByte, serverVersion, err = downloadHTTP(transfer.ctx, transfer.engine, transfer.callback, transferEndpoint, transfer.localPath, transfer.token, &transfer.accounting); err != nil { log.Debugln("Failed to download:", err) From b5ae63173c6afb5123ee7302821bc29247dcb979 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 9 Mar 2024 07:33:48 -0800 Subject: [PATCH 32/45] Refactor response object to be in common between cache and broker --- broker/broker_test.go | 3 ++- broker/client.go | 15 +++++++------- broker/server_apis.go | 27 +++++++----------------- common/api_resp.go | 45 ++++++++++++++++++++++++++++++++++++++++ local_cache/cache_api.go | 25 +++++++--------------- 5 files changed, 69 insertions(+), 46 deletions(-) create mode 100644 common/api_resp.go diff --git a/broker/broker_test.go b/broker/broker_test.go index 306cddf9a..d7b918f58 100644 --- a/broker/broker_test.go +++ b/broker/broker_test.go @@ -32,6 +32,7 @@ import ( "testing" "time" + "github.com/pelicanplatform/pelican/common" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/registry" @@ -307,7 +308,7 @@ func TestRetrieveTimeout(t *testing.T) { err = json.Unmarshal(responseBytes, &brokerResp) require.NoError(t, err) - assert.Equal(t, brokerReponseStatusTimeout, brokerResp.Status) + assert.Equal(t, common.RespPollTimeout, brokerResp.Status) ctx, cancelFunc := context.WithTimeout(ctx, 50*time.Millisecond) defer cancelFunc() diff --git a/broker/client.go b/broker/client.go index bb4d4962b..6b31fff65 100644 --- a/broker/client.go +++ b/broker/client.go @@ -43,6 +43,7 @@ import ( "sync/atomic" "time" + "github.com/pelicanplatform/pelican/common" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/token_scopes" @@ -239,7 +240,7 @@ func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string) err = errors.Wrap(err, "Failure when reading response from broker response") } if resp.StatusCode >= 400 { - errResp := brokerMsgResp{} + errResp := common.SimpleApiResp{} log.Errorf("Failure (status code %d) when invoking the broker: %s", resp.StatusCode, string(responseBytes)) if err = json.Unmarshal(responseBytes, &errResp); err != nil { err = errors.Errorf("Failure when invoking the broker (status code %d); unable to parse error message", resp.StatusCode) @@ -298,7 +299,7 @@ func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string) hj, ok := writer.(http.Hijacker) if !ok { log.Debug("Not able to hijack underlying TCP connection from server") - resp := brokerMsgResp{ + resp := common.SimpleApiResp{ Msg: "Unable to reverse TCP connection; HTTP/2 in use", Status: "error", } @@ -469,7 +470,7 @@ func doCallback(ctx context.Context, brokerResp reversalRequest) (listener net.L } if resp.StatusCode >= 400 { - errResp := brokerMsgResp{} + errResp := common.SimpleApiResp{} if err = json.Unmarshal(responseBytes, &errResp); err != nil { err = errors.Errorf("Failure when invoking cache %s callback (status code %d); unable to parse error message", brokerResp.CallbackUrl, resp.StatusCode) } else { @@ -616,7 +617,7 @@ func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, resultChan break } if resp.StatusCode >= 400 { - errResp := brokerMsgResp{} + errResp := common.SimpleApiResp{} if err = json.Unmarshal(responseBytes, &errResp); err != nil { log.Errorf("Failure when invoking the broker (status code %d); unable to parse error message", resp.StatusCode) } else { @@ -633,7 +634,7 @@ func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, resultChan break } - if brokerResp.Status == brokerResponseStatusOK { + if brokerResp.Status == common.RespOK { listener, err := doCallback(ctx, brokerResp.Request) if err != nil { log.Errorln("Failed to callback to the cache:", err) @@ -641,9 +642,9 @@ func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, resultChan break } resultChan <- listener - } else if brokerResp.Status == brokerReponseStatusFailed { + } else if brokerResp.Status == common.RespFailed { log.Errorln("Broker responded to origin retrieve with an error:", brokerResp.Msg) - } else if brokerResp.Status != brokerReponseStatusTimeout { // We expect timeouts; do not log them. + } else if brokerResp.Status != common.RespPollTimeout { // We expect timeouts; do not log them. if brokerResp.Msg != "" { log.Errorf("Broker responded with unknown status (%s); msg: %s", brokerResp.Status, brokerResp.Msg) } else { diff --git a/broker/server_apis.go b/broker/server_apis.go index a9338314d..a09f912c3 100644 --- a/broker/server_apis.go +++ b/broker/server_apis.go @@ -25,6 +25,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/pelicanplatform/pelican/common" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/token_scopes" "github.com/pkg/errors" @@ -39,17 +40,9 @@ type ( Prefix string `json:"prefix"` } - brokerResponseStatus string - - // Base response for a request or retrieval - brokerMsgResp struct { - Status brokerResponseStatus `json:"status"` - Msg string `json:"msg,omitempty"` - } - // Response for a successful retrieval brokerRetrievalResp struct { - brokerMsgResp + common.SimpleApiResp Request reversalRequest `json:"req"` } @@ -59,27 +52,21 @@ type ( } ) -const ( - brokerResponseStatusOK brokerResponseStatus = "success" - brokerReponseStatusFailed brokerResponseStatus = "error" - brokerReponseStatusTimeout brokerResponseStatus = "timeout" -) - func newBrokerReqResp(req reversalRequest) (result brokerRetrievalResp) { result.Request = req - result.brokerMsgResp.Status = brokerResponseStatusOK + result.SimpleApiResp.Status = common.RespOK return } -func newBrokerRespFail(msg string) brokerMsgResp { - return brokerMsgResp{ - Status: brokerReponseStatusFailed, +func newBrokerRespFail(msg string) common.SimpleApiResp { + return common.SimpleApiResp{ + Status: common.RespFailed, Msg: msg, } } func newBrokerRespTimeout() (result brokerRetrievalResp) { - result.brokerMsgResp.Status = brokerReponseStatusTimeout + result.SimpleApiResp.Status = common.RespPollTimeout return } diff --git a/common/api_resp.go b/common/api_resp.go new file mode 100644 index 000000000..f677f38d4 --- /dev/null +++ b/common/api_resp.go @@ -0,0 +1,45 @@ +/*************************************************************** + * + * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package common + +type ( + + // A short response object, meant for the result from most + // of the Pelican APIs. Will generate a JSON of the form: + // {"status": "error", "msg": "Some Error Message"} + // or + // {"status": "success"} + SimpleApiResp struct { + Status SimpleRespStatus `json:"status"` + Msg string `json:"msg,omitempty"` + } + + // The standardized status message for the API response + SimpleRespStatus string +) + +const ( + // Indicates the API succeeded. + RespOK SimpleRespStatus = "success" + // Indicates the API call failed; the SimpleApiResp Msg should be non-empty in this case + RespFailed SimpleRespStatus = "error" + // For long-polling APIs, indicates the requested timeout was hit without any response generated. + // Should not be considered an error or success but rather indication the long-poll should be retried. + RespPollTimeout SimpleRespStatus = "timeout" +) diff --git a/local_cache/cache_api.go b/local_cache/cache_api.go index d4e9b772a..a7ac89a5c 100644 --- a/local_cache/cache_api.go +++ b/local_cache/cache_api.go @@ -33,6 +33,7 @@ import ( "github.com/gin-gonic/gin" "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/pelicanplatform/pelican/common" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/token_scopes" @@ -41,20 +42,6 @@ import ( "golang.org/x/sync/errgroup" ) -type ( - localCacheResp struct { - Status localCacheResponseStatus `json:"status"` - Msg string `json:"msg,omitempty"` - } - - localCacheResponseStatus string -) - -const ( - responseOk localCacheResponseStatus = "success" - responseFailed localCacheResponseStatus = "error" -) - // Launch the unix socket listener as a separate goroutine func (lc *LocalCache) LaunchListener(ctx context.Context, egrp *errgroup.Group) (err error) { socketName := param.LocalCache_Socket.GetString() @@ -150,22 +137,24 @@ func (lc *LocalCache) purgeCmd(ginCtx *gin.Context) { token := ginCtx.GetHeader("Authorization") var hasPrefix bool if token, hasPrefix = strings.CutPrefix(token, "Bearer "); !hasPrefix { - ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, localCacheResp{responseFailed, "Bearer token required to authenticate"}) + ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, common.SimpleApiResp{Status: common.RespFailed, Msg: "Bearer token required to authenticate"}) return } jwks, err := config.GetIssuerPublicJWKS() if err != nil { - ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, localCacheResp{responseFailed, "Unable to get local server token issuer"}) + ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, common.SimpleApiResp{Status: common.RespFailed, Msg: "Unable to get local server token issuer"}) return } tok, err := jwt.Parse([]byte(token), jwt.WithKeySet(jwks)) if err != nil { - ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, localCacheResp{responseFailed, "Authorization token cannot be verified"}) + ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, common.SimpleApiResp{Status: common.RespFailed, Msg: "Authorization token cannot be verified"}) } scopeValidator := token_scopes.CreateScopeValidator([]token_scopes.TokenScope{token_scopes.Localcache_Purge}, true) if err = jwt.Validate(tok, jwt.WithValidator(scopeValidator)); err != nil { - ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, localCacheResp{responseFailed, "Authorization token is not valid: " + err.Error()}) + ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, common.SimpleApiResp{Status: common.RespFailed, Msg: "Authorization token is not valid: " + err.Error()}) + return + } return } lc.purge() From 3572efcaf395a00d3e9be9928c5a50596ec4ef4d Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 9 Mar 2024 07:40:51 -0800 Subject: [PATCH 33/45] Various fixups from code review Most significant is ensuring the `purge` method returns an error object so the corresponding REST API can have an appropriate status code. --- client/fed_test.go | 2 +- client/handle_http.go | 2 ++ docs/parameters.yaml | 4 +++- github_scripts/citests.sh | 2 +- launchers/launcher.go | 2 +- local_cache/cache_api.go | 11 ++++++++++- local_cache/cache_authz.go | 20 +++++++++++++++++++- local_cache/cache_test.go | 7 ------- local_cache/local_cache.go | 21 ++++++++++++++++----- 9 files changed, 53 insertions(+), 18 deletions(-) diff --git a/client/fed_test.go b/client/fed_test.go index f5320a85a..7dd59302a 100644 --- a/client/fed_test.go +++ b/client/fed_test.go @@ -2,7 +2,7 @@ /*************************************************************** * - * Copyright (C) 2023, University of Nebraska-Lincoln + * Copyright (C) 2024, University of Nebraska-Lincoln * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may diff --git a/client/handle_http.go b/client/handle_http.go index 0a0f38c74..5f9f9d61c 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1317,6 +1317,8 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall return dialer.DialContext(ctx, "unix", transfer.UnixSocket) } transferUrl.Scheme = "http" + // The host is ignored since we override the dial function; however, I find it useful + // in debug messages to see that this went to the local cache. transferUrl.Host = "localhost" } httpClient, ok := client.HTTPClient.(*http.Client) diff --git a/docs/parameters.yaml b/docs/parameters.yaml index 935190a06..517627697 100644 --- a/docs/parameters.yaml +++ b/docs/parameters.yaml @@ -392,7 +392,7 @@ description: >- type: int default: 0 components: ["client"] -deprecated: true +hidden: true --- ############################ # Origin-level Configs # @@ -659,6 +659,8 @@ components: ["localcache"] name: LocalCache.Size description: >- The maximum size of the local cache. If not set, it is assumed the entire device can be used. + This parameter can be provided with units (e.g., 20GB, 150MB); if no unit is provided, then + it is assumed to be in bytes. type: string default: 0 components: ["localcache"] diff --git a/github_scripts/citests.sh b/github_scripts/citests.sh index 1ab31712d..172855ec1 100755 --- a/github_scripts/citests.sh +++ b/github_scripts/citests.sh @@ -56,7 +56,7 @@ EOF ##################################### -## Test LocalCache in front of OSDF +## Test LocalCache in front of OSDF # ##################################### SOCKET_DIR="`mktemp -d -t pelican-citest-XXXXXX`" export PELICAN_LOCALCACHE_SOCKET=$SOCKET_DIR/socket diff --git a/launchers/launcher.go b/launchers/launcher.go index e93da229d..20d8aa2b6 100644 --- a/launchers/launcher.go +++ b/launchers/launcher.go @@ -261,7 +261,7 @@ func LaunchModules(ctx context.Context, modules config.ServerType) (context.Canc } if modules.IsEnabled(config.LocalCacheType) { - log.Debugln("Starting local cache listener") + log.Debugln("Starting local cache listener at", param.LocalCache_Socket.GetString()) if err := lc.Config(egrp); err != nil { log.Warning("Failure when configuring the local cache; cache may incorrectly generate 403 errors until reconfiguration runs") } diff --git a/local_cache/cache_api.go b/local_cache/cache_api.go index a7ac89a5c..dcc8ae364 100644 --- a/local_cache/cache_api.go +++ b/local_cache/cache_api.go @@ -155,7 +155,16 @@ func (lc *LocalCache) purgeCmd(ginCtx *gin.Context) { ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, common.SimpleApiResp{Status: common.RespFailed, Msg: "Authorization token is not valid: " + err.Error()}) return } + err = lc.purge() + if err != nil { + if err == purgeTimeout { + // Note we don't use common.RespTimeout here; that is reserved for a long-poll timeout. + ginCtx.AbortWithStatusJSON(http.StatusRequestTimeout, common.SimpleApiResp{Status: common.RespFailed, Msg: err.Error()}) + } else { + // Note we don't pass uncategorized errors to the user to avoid leaking potentially sensitive information. + ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, common.SimpleApiResp{Status: common.RespFailed, Msg: "Failed to successfully run purge"}) + } return } - lc.purge() + ginCtx.JSON(http.StatusOK, common.SimpleApiResp{Status: common.RespOK}) } diff --git a/local_cache/cache_authz.go b/local_cache/cache_authz.go index 07324bbfa..c5d569c43 100644 --- a/local_cache/cache_authz.go +++ b/local_cache/cache_authz.go @@ -94,11 +94,20 @@ func newAuthConfig(ctx context.Context, egrp *errgroup.Group) (ac *authConfig) { ttlcache.WithLoader[string, acls](ttlcache.LoaderFunc[string, acls](ac.loader)), ) - go ac.issuerKeys.Start() + egrp.Go(func() error { + ac.issuerKeys.Start() + return nil + }) + egrp.Go(func() error { + ac.tokenAuthz.Start() + return nil + }) egrp.Go(func() error { <-ctx.Done() ac.issuerKeys.Stop() ac.issuerKeys.DeleteAll() + ac.tokenAuthz.Stop() + ac.tokenAuthz.DeleteAll() return nil }) @@ -193,6 +202,15 @@ func calcResourceScopes(rs token_scopes.ResourceScope, basePaths []string, restr return } +// Given a token, calculate the corresponding access control list +// +// The returned ACLs indicate what the bearer of the token is authorized (read, write) +// to do with respect to the root of the cache. For example, if /foo is +// labeled as a public prefix by the director, then one ACL returned will +// be read:/foo. +// +// If the token verification fails then an error will be returned; no authorization +// should be given. func (ac *authConfig) getAcls(token string) (newAcls acls, err error) { namespaces := ac.ns.Load() if namespaces == nil { diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index feae745c4..bb4890139 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -65,13 +65,6 @@ func (ft *fedTest) spinup(t *testing.T, ctx context.Context, egrp *errgroup.Grou // TODO: the cache startup routines not sequenced correctly for the downloads // to immediately work through the cache. For now, unit tests will just use the origin. viper.Set("Origin.EnableFallbackRead", true) - /* - if runtime.GOOS == "darwin" { - viper.Set("Origin.EnableFallbackRead", true) - } else { - modules.Set(config.CacheType) - } - */ modules.Set(config.LocalCacheType) tmpPathPattern := "XRootD-Test_Origin*" diff --git a/local_cache/local_cache.go b/local_cache/local_cache.go index 78af8c501..a0e021598 100644 --- a/local_cache/local_cache.go +++ b/local_cache/local_cache.go @@ -140,6 +140,7 @@ type ( var ( authorizationDenied error = errors.New("authorization denied") + purgeTimeout error = errors.New("purge attempt has timed out") ) const ( @@ -584,14 +585,16 @@ func (lc *LocalCache) lruHit(hit lruEntry) { } } -func (lc *LocalCache) purge() { +func (lc *LocalCache) purge() (err error) { log.Debugln("Starting purge routine") lc.purgeMutex.Lock() defer lc.purgeMutex.Unlock() heap.Init(&lc.lru) start := time.Now() + log.Debugf("Purge running with cache size %d and low watermark of %d", lc.cacheSize, lc.lowWater) for lc.cacheSize > lc.lowWater { if len(lc.lru) == 0 { + err = errors.New("purge ran until cache was empty") log.Warningln("Potential consistency error: purge ran until cache was empty") break } @@ -605,19 +608,27 @@ func (lc *LocalCache) purge() { continue } localPath := path.Join(lc.basePath, path.Clean(entry.path)) - if err := os.Remove(localPath + ".DONE"); err != nil { - log.Warningln("Failed to purge DONE file:", err) + if rmErr := os.Remove(localPath + ".DONE"); rmErr != nil { + log.Warningln("Failed to purge DONE file:", rmErr) + if err == nil { + err = rmErr + } } - if err := os.Remove(localPath); err != nil { - log.Warningln("Failed to purge file:", err) + if rmErr := os.Remove(localPath); rmErr != nil { + log.Warningln("Failed to purge file:", rmErr) + if err == nil { + err = rmErr + } } lc.cacheSize -= uint64(entry.size) // Since purge is called from the mux thread, blocking can cause // other failures; do a time-based break even if we've not hit the low-water if time.Since(start) > 3*time.Second { + err = purgeTimeout break } } + return } // Given a URL, return a reader from the disk cache From 5850086201fc0ab4a7f4b4bbd7b4db0dd83e9ac3 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 9 Mar 2024 09:21:00 -0800 Subject: [PATCH 34/45] Refactor TokenConfig so setting audience is type-safe Adds type-safety to various fields of `TokenConfig`: - Version can no longer be set to something invalid without generating an error - Audience "Any" is now a helper function with the correct behavior based on the profile - There are `New` methods that force setting a valid profile. --- broker/token_utils.go | 13 +-- client/fed_test.go | 37 +++---- cmd/origin_token.go | 26 +++-- cmd/origin_token_test.go | 14 +-- director/origin_api_test.go | 41 +++----- director/origin_monitor.go | 13 +-- director/redirect_test.go | 15 ++- director/stat.go | 17 ++- local_cache/cache_test.go | 49 ++++----- local_cache/local_cache.go | 4 +- registry/client_commands.go | 15 ++- registry/registry_ui_test.go | 5 +- server_ui/advertise.go | 13 +-- server_utils/test_file_transfer.go | 15 ++- token/token_create.go | 159 ++++++++++++++++++++--------- token/token_create_test.go | 56 +++++----- web_ui/authentication.go | 13 +-- web_ui/authorization.go | 15 ++- web_ui/prometheus.go | 26 ++--- web_ui/prometheus_test.go | 31 +++--- 20 files changed, 297 insertions(+), 280 deletions(-) diff --git a/broker/token_utils.go b/broker/token_utils.go index 453be70c9..b660e9ee8 100644 --- a/broker/token_utils.go +++ b/broker/token_utils.go @@ -123,14 +123,11 @@ func createToken(namespace, subject, audience string, desiredScope token_scopes. return } - tokenCfg := token.TokenConfig{ - Lifetime: time.Minute, - TokenProfile: token.WLCG, - Audience: []string{audience}, - Issuer: issuerUrl, - Version: "1.0", - Subject: subject, - } + tokenCfg := token.NewWLCGToken() + tokenCfg.Lifetime = time.Minute + tokenCfg.Issuer = issuerUrl + tokenCfg.Subject = subject + tokenCfg.AddAudiences(audience) tokenCfg.AddScopes(desiredScope) tokenStr, err = tokenCfg.CreateToken() diff --git a/client/fed_test.go b/client/fed_test.go index 7dd59302a..068f3e9ab 100644 --- a/client/fed_test.go +++ b/client/fed_test.go @@ -59,14 +59,11 @@ func generateFileTestScitoken() (string, error) { return "", errors.New("Failed to create token: Invalid iss, Server_ExternalWebUrl is empty") } - fTestTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Minute, - Issuer: issuerUrl, - Audience: []string{config.GetServerAudience()}, - Version: "1.0", - Subject: "origin", - } + fTestTokenCfg := token.NewWLCGToken() + fTestTokenCfg.Lifetime = time.Minute + fTestTokenCfg.Issuer = issuerUrl + fTestTokenCfg.Subject = "origin" + fTestTokenCfg.AddAudiences(config.GetServerAudience()) fTestTokenCfg.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/"), token_scopes.NewResourceScope(token_scopes.Storage_Modify, "/")) @@ -345,13 +342,11 @@ func TestGetAndPutAuth(t *testing.T) { audience := config.GetServerAudience() // Create a token file - tokenConfig := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Minute, - Issuer: issuer, - Audience: []string{audience}, - Subject: "origin", - } + tokenConfig := token.NewWLCGToken() + tokenConfig.Lifetime = time.Minute + tokenConfig.Issuer = issuer + tokenConfig.Subject = "origin" + tokenConfig.AddAudiences(audience) scopes := []token_scopes.TokenScope{} readScope, err := token_scopes.Storage_Read.Path("/") @@ -516,13 +511,11 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { require.NoError(t, err) audience := config.GetServerAudience() - tokenConfig := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Minute, - Issuer: issuer, - Audience: []string{audience}, - Subject: "origin", - } + tokenConfig := token.NewWLCGToken() + tokenConfig.Lifetime = time.Minute + tokenConfig.Issuer = issuer + tokenConfig.Subject = "origin" + tokenConfig.AddAudiences(audience) tokenConfig.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/"), token_scopes.NewResourceScope(token_scopes.Storage_Modify, "/")) token, err := tokenConfig.CreateToken() diff --git a/cmd/origin_token.go b/cmd/origin_token.go index 49105d035..eff9c3a20 100644 --- a/cmd/origin_token.go +++ b/cmd/origin_token.go @@ -46,8 +46,11 @@ func parseInputSlice(rawSlice *[]string, claimPrefix string) []string { // Parse claims to tokenConfig, excluding "sub". `claims` should be in the form of // = -func parseClaimsToTokenConfig(claims []string) (*token.TokenConfig, error) { - tokenConfig := token.TokenConfig{} +func parseClaimsToTokenConfig(profile string, claims []string) (*token.TokenConfig, error) { + tokenConfig, err := token.NewTokenConfig(token.TokenProfile(profile)) + if err != nil { + return nil, err + } for _, claim := range claims { // Split by the first "=" delimiter parts := strings.SplitN(claim, "=", 2) @@ -60,13 +63,15 @@ func parseClaimsToTokenConfig(claims []string) (*token.TokenConfig, error) { switch key { case "aud": - tokenConfig.Audience = append(tokenConfig.Audience, val) + tokenConfig.AddAudiences(val) case "scope": tokenConfig.AddRawScope(val) case "ver": - tokenConfig.Version = val + fallthrough case "wlcg.ver": - tokenConfig.Version = val + if err = tokenConfig.SetVersion(val); err != nil { + return nil, err + } case "iss": tokenConfig.Issuer = val default: @@ -124,17 +129,16 @@ func cliTokenCreate(cmd *cobra.Command, args []string) error { args = append(args, audSlice...) } - tokenConfig, err := parseClaimsToTokenConfig(args) - if err != nil { - return errors.Wrap(err, "Failed to parse token claims") - } - // Get flags used for auxiliary parts of token creation that can't be fed directly to claimsMap profile, err := cmd.Flags().GetString("profile") if err != nil { return errors.Wrapf(err, "Failed to get profile '%s' from input", profile) } - tokenConfig.TokenProfile = token.TokenProfile(profile) + + tokenConfig, err := parseClaimsToTokenConfig(profile, args) + if err != nil { + return errors.Wrap(err, "Failed to parse token claims") + } lifetime, err := cmd.Flags().GetInt("lifetime") if err != nil { diff --git a/cmd/origin_token_test.go b/cmd/origin_token_test.go index a1ec7c806..e97e5effc 100644 --- a/cmd/origin_token_test.go +++ b/cmd/origin_token_test.go @@ -21,22 +21,24 @@ package main import ( "testing" + "github.com/pelicanplatform/pelican/token" + "github.com/stretchr/testify/assert" ) func TestParseClaimsToTokenConfig(t *testing.T) { // Should parse basic fields correctly claims := []string{"aud=foo", "scope=baz", "ver=1.0", "iss=http://random.org"} - tokenConfig, err := parseClaimsToTokenConfig(claims) + tokenConfig, err := parseClaimsToTokenConfig(token.TokenProfileWLCG.String(), claims) assert.NoError(t, err) assert.Equal(t, "http://random.org", tokenConfig.Issuer) - assert.Equal(t, []string{"foo"}, tokenConfig.Audience) + assert.Equal(t, []string{"foo"}, tokenConfig.GetAudiences()) assert.Equal(t, "baz", tokenConfig.GetScope()) - assert.Equal(t, "1.0", tokenConfig.Version) + assert.Equal(t, "1.0", tokenConfig.GetVersion()) // Give it something valid claims = []string{"foo=boo", "bar=baz"} - tokenConfig, err = parseClaimsToTokenConfig(claims) + tokenConfig, err = parseClaimsToTokenConfig(token.TokenProfileWLCG.String(), claims) assert.NoError(t, err) assert.Equal(t, "boo", tokenConfig.Claims["foo"]) assert.Equal(t, "baz", tokenConfig.Claims["bar"]) @@ -44,14 +46,14 @@ func TestParseClaimsToTokenConfig(t *testing.T) { // Give it something with multiple of the same claim key claims = []string{"foo=boo", "foo=baz"} - tokenConfig, err = parseClaimsToTokenConfig(claims) + tokenConfig, err = parseClaimsToTokenConfig(token.TokenProfileWLCG.String(), claims) assert.NoError(t, err) assert.Equal(t, "boo baz", tokenConfig.Claims["foo"]) assert.Equal(t, 1, len(tokenConfig.Claims)) // Give it something without = delimiter claims = []string{"foo=boo", "barbaz"} - _, err = parseClaimsToTokenConfig(claims) + _, err = parseClaimsToTokenConfig(token.TokenProfileWLCG.String(), claims) assert.EqualError(t, err, "The claim 'barbaz' is invalid. Did you forget an '='?") } diff --git a/director/origin_api_test.go b/director/origin_api_test.go index 86e15af81..a14861755 100644 --- a/director/origin_api_test.go +++ b/director/origin_api_test.go @@ -110,14 +110,11 @@ func TestVerifyAdvertiseToken(t *testing.T) { issuerUrl, err := server_utils.GetNSIssuerURL("/test-namespace") assert.NoError(t, err) - advTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Version: "1.0", - Lifetime: time.Minute, - Issuer: issuerUrl, - Audience: []string{"https://director-url.org"}, - Subject: "origin", - } + advTokenCfg := token.NewWLCGToken() + advTokenCfg.Lifetime = time.Minute + advTokenCfg.Issuer = issuerUrl + advTokenCfg.Subject = "origin" + advTokenCfg.AddAudiences("https://director-url.org") advTokenCfg.AddScopes(token_scopes.Pelican_Advertise) // CreateToken also handles validation for us @@ -129,14 +126,11 @@ func TestVerifyAdvertiseToken(t *testing.T) { assert.Equal(t, true, ok, "Expected scope to be 'pelican.advertise'") //Create token without a scope - should return an error upon validation - scopelessTokCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Minute, - Issuer: "https://get-your-tokens.org", - Audience: []string{"director.test"}, - Version: "1.0", - Subject: "origin", - } + scopelessTokCfg := token.NewWLCGToken() + scopelessTokCfg.Lifetime = time.Minute + scopelessTokCfg.Issuer = "https://get-your-tokens.org" + scopelessTokCfg.Subject = "origin" + scopelessTokCfg.AddAudiences("director.test") tok, err = scopelessTokCfg.CreateToken() assert.NoError(t, err, "error creating scopeless token. Should have succeeded") @@ -146,15 +140,12 @@ func TestVerifyAdvertiseToken(t *testing.T) { assert.Equal(t, "No scope is present; required to advertise to director", err.Error()) // Create a token with a bad scope - should return an error upon validation - wrongScopeTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Minute, - Issuer: "https://get-your-tokens.org", - Audience: []string{"director.test"}, - Version: "1.0", - Subject: "origin", - Claims: map[string]string{"scope": "wrong.scope"}, - } + wrongScopeTokenCfg := token.NewWLCGToken() + wrongScopeTokenCfg.Lifetime = time.Minute + wrongScopeTokenCfg.Issuer = "https://get-your-tokens.org" + wrongScopeTokenCfg.AddAudiences("director.test") + wrongScopeTokenCfg.Subject = "origin" + wrongScopeTokenCfg.Claims = map[string]string{"scope": "wrong.scope"} tok, err = wrongScopeTokenCfg.CreateToken() assert.NoError(t, err, "error creating wrong-scope token. Should have succeeded") diff --git a/director/origin_monitor.go b/director/origin_monitor.go index 97b154b1d..10f7ac034 100644 --- a/director/origin_monitor.go +++ b/director/origin_monitor.go @@ -48,14 +48,11 @@ func reportStatusToOrigin(ctx context.Context, originWebUrl string, status strin return errors.Wrapf(err, "failed to parse external URL %v", param.Server_ExternalWebUrl.GetString()) } - testTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Version: "1.0", - Lifetime: time.Minute, - Issuer: directorUrl.String(), - Audience: []string{originWebUrl}, - Subject: "director", - } + testTokenCfg := token.NewWLCGToken() + testTokenCfg.Lifetime = time.Minute + testTokenCfg.Issuer = directorUrl.String() + testTokenCfg.AddAudiences(originWebUrl) + testTokenCfg.Subject = "director" testTokenCfg.AddScopes(token_scopes.Pelican_DirectorTestReport) tok, err := testTokenCfg.CreateToken() diff --git a/director/redirect_test.go b/director/redirect_test.go index 4a486f3ac..476c567b4 100644 --- a/director/redirect_test.go +++ b/director/redirect_test.go @@ -228,15 +228,12 @@ func TestDirectorRegistration(t *testing.T) { } generateReadToken := func(key jwk.Key, object, issuer string) string { - tc := token.TokenConfig{ - TokenProfile: token.WLCG, - Version: "1.0", - Lifetime: time.Minute, - Issuer: issuer, - Audience: []string{"director"}, - Subject: "test", - Claims: map[string]string{"scope": "storage.read:" + object}, - } + tc := token.NewWLCGToken() + tc.Lifetime = time.Minute + tc.Issuer = issuer + tc.AddAudiences("director") + tc.Subject = "test" + tc.Claims = map[string]string{"scope": "storage.read:" + object} tok, err := tc.CreateTokenWithKey(key) require.NoError(t, err) return tok diff --git a/director/stat.go b/director/stat.go index 99f537cf9..682917854 100644 --- a/director/stat.go +++ b/director/stat.go @@ -31,6 +31,7 @@ import ( "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/token" + "github.com/pelicanplatform/pelican/token_scopes" "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -99,15 +100,13 @@ func NewObjectStat() *ObjectStat { // Implementation of sending a HEAD request to an origin for an object func (stat *ObjectStat) sendHeadReqToOrigin(objectName string, dataUrl url.URL, timeout time.Duration, ctx context.Context) (*objectMetadata, error) { - tokenConf := token.TokenConfig{ - Lifetime: time.Minute, - TokenProfile: token.WLCG, - Audience: []string{dataUrl.String()}, - Subject: dataUrl.String(), - // Federation as the issuer - Issuer: param.Server_ExternalWebUrl.GetString(), - } - tokenConf.AddRawScope("storage.read:/") + tokenConf := token.NewWLCGToken() + tokenConf.Lifetime = time.Minute + tokenConf.AddAudiences(dataUrl.String()) + tokenConf.Subject = dataUrl.String() + // Federation as the issuer + tokenConf.Issuer = param.Server_ExternalWebUrl.GetString() + tokenConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/")) token, err := tokenConf.CreateToken() if err != nil { return nil, errors.Wrap(err, "Failed to create token") diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index bb4890139..3787d2206 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -127,13 +127,11 @@ func (ft *fedTest) spinup(t *testing.T, ctx context.Context, egrp *errgroup.Grou issuer, err := config.GetServerIssuerURL() require.NoError(t, err) - tokConf := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Duration(time.Minute), - Issuer: issuer, - Subject: "test", - Audience: []string{token.WLCGAny}, - } + tokConf := token.NewWLCGToken() + tokConf.Lifetime = time.Duration(time.Minute) + tokConf.Issuer = issuer + tokConf.Subject = "test" + tokConf.AddAudienceAny() tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) token, err := tokConf.CreateToken() @@ -198,13 +196,11 @@ func TestFedAuthGet(t *testing.T) { issuer, err := config.GetServerIssuerURL() require.NoError(t, err) - tokConf := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Duration(time.Minute), - Issuer: issuer, - Subject: "test", - Audience: []string{token.WLCGAny}, - } + tokConf := token.NewWLCGToken() + tokConf.Lifetime = time.Duration(time.Minute) + tokConf.Issuer = issuer + tokConf.Subject = "test" + tokConf.AddAudienceAny() tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/not_correct")) token, err := tokConf.CreateToken() @@ -287,13 +283,12 @@ func TestClient(t *testing.T) { t.Run("file-not-found", func(t *testing.T) { issuer, err := config.GetServerIssuerURL() require.NoError(t, err) - tokConf := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Duration(time.Minute), - Issuer: issuer, - Subject: "test", - Audience: []string{token.WLCGAny}, - } + tokConf := token.NewWLCGToken() + + tokConf.Lifetime = time.Duration(time.Minute) + tokConf.Issuer = issuer + tokConf.Subject = "test" + tokConf.AddAudienceAny() tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt.1")) token, err := tokConf.CreateToken() @@ -465,13 +460,11 @@ func TestForcePurge(t *testing.T) { issuer, err := config.GetServerIssuerURL() require.NoError(t, err) - tokConf := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Duration(time.Minute), - Issuer: issuer, - Subject: "test", - Audience: []string{token.WLCGAny}, - } + tokConf := token.NewWLCGToken() + tokConf.Lifetime = time.Duration(time.Minute) + tokConf.Issuer = issuer + tokConf.Subject = "test" + tokConf.AddAudienceAny() tokConf.AddScopes(token_scopes.Localcache_Purge) token, err := tokConf.CreateToken() diff --git a/local_cache/local_cache.go b/local_cache/local_cache.go index a0e021598..a765a0a6d 100644 --- a/local_cache/local_cache.go +++ b/local_cache/local_cache.go @@ -576,7 +576,9 @@ func (lc *LocalCache) lruHit(hit lruEntry) { lc.lru = append(lc.lru, entry) lc.cacheSize += uint64(hit.size) if lc.cacheSize > lc.highWater { - lc.purge() + if err := lc.purge(); err != nil { + log.Warningln("Failure when purging cache:", err) + } } } entry.lastUse = hit.lastUse diff --git a/registry/client_commands.go b/registry/client_commands.go index 725ffc1a3..375763b2c 100644 --- a/registry/client_commands.go +++ b/registry/client_commands.go @@ -238,15 +238,12 @@ func NamespaceDelete(endpoint string, prefix string) error { // including an audience with these tokens. // TODO: Investigate whether 1 min is a good expiration interval // or whether this should be altered. - delTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Minute, - Issuer: issuerURL, - Audience: []string{"registry"}, - Version: "1.0", - Subject: "origin", - Claims: map[string]string{"scope": token_scopes.Pelican_NamespaceDelete.String()}, - } + delTokenCfg := token.NewWLCGToken() + delTokenCfg.Lifetime = time.Minute + delTokenCfg.Issuer = issuerURL + delTokenCfg.AddAudiences("registry") + delTokenCfg.Subject = "origin" + delTokenCfg.AddScopes(token_scopes.Pelican_NamespaceDelete) // CreateToken also handles validation for us tok, err := delTokenCfg.CreateToken() diff --git a/registry/registry_ui_test.go b/registry/registry_ui_test.go index 9f4944dde..52029e73f 100644 --- a/registry/registry_ui_test.go +++ b/registry/registry_ui_test.go @@ -239,7 +239,10 @@ func TestListNamespaces(t *testing.T) { requestURL := "/namespaces?server_type=" + tc.serverType + "&status=" + tc.status req, _ := http.NewRequest("GET", requestURL, nil) if tc.authUser { - tokenCfg := token.TokenConfig{Issuer: "https://mock-server.com", Lifetime: time.Minute, Subject: "admin", TokenProfile: token.None} + tokenCfg := token.NewWLCGToken() + tokenCfg.Issuer = "https://mock-server.com" + tokenCfg.Lifetime = time.Minute + tokenCfg.Subject = "admin" tokenCfg.AddScopes(token_scopes.WebUi_Access) token, err := tokenCfg.CreateToken() require.NoError(t, err) diff --git a/server_ui/advertise.go b/server_ui/advertise.go index e5c304b54..367e6c789 100644 --- a/server_ui/advertise.go +++ b/server_ui/advertise.go @@ -141,14 +141,11 @@ func advertiseInternal(ctx context.Context, server server_utils.XRootDServer) er return err } - advTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Version: "1.0", - Lifetime: time.Minute, - Issuer: issuerUrl, - Audience: []string{param.Federation_DirectorUrl.GetString()}, - Subject: "origin", - } + advTokenCfg := token.NewWLCGToken() + advTokenCfg.Lifetime = time.Minute + advTokenCfg.Issuer = issuerUrl + advTokenCfg.AddAudiences(param.Federation_DirectorUrl.GetString()) + advTokenCfg.Subject = "origin" advTokenCfg.AddScopes(token_scopes.Pelican_Advertise) // CreateToken also handles validation for us diff --git a/server_utils/test_file_transfer.go b/server_utils/test_file_transfer.go index 050969ecb..420d79cd8 100644 --- a/server_utils/test_file_transfer.go +++ b/server_utils/test_file_transfer.go @@ -77,15 +77,12 @@ func (t TestFileTransferImpl) generateFileTestScitoken() (string, error) { return "", errors.New("Failed to create token: Invalid iss, Server_ExternalWebUrl is empty") } - fTestTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: time.Minute, - Issuer: issuerUrl, - Audience: t.audiences, - Version: "1.0", - Subject: "origin", - Claims: map[string]string{"scope": "storage.read:/ storage.modify:/"}, - } + fTestTokenCfg := token.NewWLCGToken() + fTestTokenCfg.Lifetime = time.Minute + fTestTokenCfg.Issuer = issuerUrl + fTestTokenCfg.Subject = "origin" + fTestTokenCfg.Claims = map[string]string{"scope": "storage.read:/ storage.modify:/"} + fTestTokenCfg.AddAudiences(t.audiences...) // CreateToken also handles validation for us tok, err := fTestTokenCfg.CreateToken() diff --git a/token/token_create.go b/token/token_create.go index 3063078c8..b2e19776d 100644 --- a/token/token_create.go +++ b/token/token_create.go @@ -43,11 +43,11 @@ import ( type ( TokenProfile string TokenConfig struct { - TokenProfile TokenProfile + tokenProfile TokenProfile Lifetime time.Duration // Lifetime is used to set 'exp' claim from now Issuer string // Issuer is 'iss' claim - Audience []string // Audience is 'aud' claim - Version string // Version is the version for different profiles. 'wlcg.ver' for WLCG profile and 'ver' for scitokens2 + audience []string // Audience is 'aud' claim + version string // Version is the version for different profiles. 'wlcg.ver' for WLCG profile and 'ver' for scitokens2 Subject string // Subject is 'sub' claim Claims map[string]string // Additional claims scope string // scope is a string with space-delimited list of scopes. To enforce type check, use AddRawScope or AddScopes to add scopes to your token @@ -59,13 +59,19 @@ type ( } ) +var ( + scitokensVerPattern *regexp.Regexp = regexp.MustCompile(`^scitokens:2\.[0-9]+$`) + wlcgVerPattern *regexp.Regexp = regexp.MustCompile(`^1\.[0-9]+$`) +) + const ( - WLCG TokenProfile = "wlcg" - Scitokens2 TokenProfile = "scitokens2" - None TokenProfile = "none" + TokenProfileWLCG TokenProfile = "wlcg" + TokenProfileScitokens2 TokenProfile = "scitokens2" + TokenProfileNone TokenProfile = "none" + tokenProfileEmpty TokenProfile = "" - WLCGAny string = "https://wlcg.cern.ch/jwt/v1/any" - ScitokensAny string = "ANY" + wlcgAny string = "https://wlcg.cern.ch/jwt/v1/any" + scitokensAny string = "ANY" ) func (p TokenProfile) String() string { @@ -81,47 +87,111 @@ func (config *TokenConfig) Validate() (bool, error) { if _, err := url.Parse(config.Issuer); err != nil { return false, errors.Wrap(err, "Invalid issuer, issuer is not a valid Url") } - switch config.TokenProfile { - case Scitokens2: + switch config.tokenProfile { + case TokenProfileScitokens2: if err := config.verifyCreateSciTokens2(); err != nil { return false, err } - case WLCG: + case TokenProfileWLCG: if err := config.verifyCreateWLCG(); err != nil { return false, err } - case None: + case TokenProfileNone: return true, nil // we don't have profile specific check for None type + case tokenProfileEmpty: + return false, errors.New("token profile is not set") default: - return false, errors.New(fmt.Sprint("Unsupported token profile: ", config.TokenProfile.String())) + return false, errors.Errorf("unsupported token profile: %s", config.tokenProfile.String()) } return true, nil } +func NewTokenConfig(tokenProfile TokenProfile) (tc TokenConfig, err error) { + switch tokenProfile { + case TokenProfileScitokens2: + fallthrough + case TokenProfileWLCG: + fallthrough + case TokenProfileNone: + tc.tokenProfile = tokenProfile + case tokenProfileEmpty: + err = errors.New("token profile is not set") + default: + err = errors.Errorf("unsupported token profile: %s", tokenProfile.String()) + } + return +} + +func NewWLCGToken() (tc TokenConfig) { + tc.tokenProfile = TokenProfileWLCG + return +} + +func NewScitoken() (tc TokenConfig) { + tc.tokenProfile = TokenProfileScitokens2 + return +} + +func (config *TokenConfig) GetVersion() string { + return config.version +} + +func (config *TokenConfig) SetVersion(ver string) error { + if config.tokenProfile == TokenProfileScitokens2 { + if ver == "" { + ver = "scitokens:2.0" + } else if !scitokensVerPattern.MatchString(ver) { + return errors.New("the provided version '" + config.version + + "' is not valid. It must match 'scitokens:', where version is of the form 2.x") + } + } else if config.tokenProfile == TokenProfileWLCG { + if ver == "" { + ver = "1.0" + } else if !wlcgVerPattern.MatchString(config.version) { + return errors.New("the provided version '" + config.version + "' is not valid. It must be of the form '1.x'") + } + } + config.version = ver + return nil +} + +func (config *TokenConfig) AddAudienceAny() { + newAud := "" + switch config.tokenProfile { + case TokenProfileScitokens2: + newAud = string(scitokensAny) + case TokenProfileWLCG: + newAud = string(wlcgAny) + } + if newAud != "" { + config.audience = append(config.audience, newAud) + } +} + +func (config *TokenConfig) AddAudiences(audiences ...string) { + config.audience = append(config.audience, audiences...) +} + +func (config *TokenConfig) GetAudiences() []string { + return config.audience +} + // Verify if the token matches scitoken2 profile requirement func (config *TokenConfig) verifyCreateSciTokens2() error { // required fields: aud, ver, scope - if len(config.Audience) == 0 { - errMsg := "The 'audience' claim is required for the scitokens2 profile, but it could not be found." - return errors.New(errMsg) + if len(config.audience) == 0 { + return errors.New("the 'audience' claim is required for the scitokens2 profile, but it could not be found") } if config.scope == "" { - errMsg := "The 'scope' claim is required for the scitokens2 profile, but it could not be found." - return errors.New(errMsg) + return errors.New("the 'scope' claim is required for the scitokens2 profile, but it could not be found") } - if config.Version == "" { - config.Version = "scitokens:2.0" - } else { - verPattern := `^scitokens:2\.[0-9]+$` - re := regexp.MustCompile(verPattern) - - if !re.MatchString(config.Version) { - errMsg := "The provided version '" + config.Version + - "' is not valid. It must match 'scitokens:', where version is of the form 2.x" - return errors.New(errMsg) - } + if config.version == "" { + config.version = "scitokens:2.0" + } else if !scitokensVerPattern.MatchString(config.version) { + return errors.New("the provided version '" + config.version + + "' is not valid. It must match 'scitokens:', where version is of the form 2.x") } return nil } @@ -129,25 +199,20 @@ func (config *TokenConfig) verifyCreateSciTokens2() error { // Verify if the token matches WLCG profile requirement func (config *TokenConfig) verifyCreateWLCG() error { // required fields: sub, wlcg.ver, aud - if len(config.Audience) == 0 { - errMsg := "The 'audience' claim is required for the scitokens2 profile, but it could not be found." + if len(config.audience) == 0 { + errMsg := "the 'audience' claim is required for the WLCG profile, but it could not be found" return errors.New(errMsg) } if config.Subject == "" { - errMsg := "The 'subject' claim is required for the scitokens2 profile, but it could not be found." + errMsg := "the 'subject' claim is required for the WLCG profile, but it could not be found" return errors.New(errMsg) } - if config.Version == "" { - config.Version = "1.0" - } else { - verPattern := `^1\.[0-9]+$` - re := regexp.MustCompile(verPattern) - if !re.MatchString(config.Version) { - errMsg := "The provided version '" + config.Version + "' is not valid. It must be of the form '1.x'" - return errors.New(errMsg) - } + if config.version == "" { + config.version = "1.0" + } else if !wlcgVerPattern.MatchString(config.version) { + return errors.New("the provided version '" + config.version + "' is not valid. It must be of the form '1.x'") } return nil } @@ -198,7 +263,7 @@ func (tokenConfig *TokenConfig) CreateToken() (string, error) { // Variant of CreateToken with a JWT provided by the caller func (tokenConfig *TokenConfig) CreateTokenWithKey(key jwk.Key) (string, error) { if ok, err := tokenConfig.Validate(); !ok || err != nil { - return "", errors.Wrap(err, "Invalid tokenConfig") + return "", errors.Wrap(err, "invalid tokenConfig") } jti_bytes := make([]byte, 16) @@ -236,7 +301,7 @@ func (tokenConfig *TokenConfig) CreateTokenWithKey(key jwk.Key) (string, error) IssuedAt(now). Expiration(now.Add(tokenConfig.Lifetime)). NotBefore(now). - Audience(tokenConfig.Audience). + Audience(tokenConfig.audience). Subject(tokenConfig.Subject). JwtID(jti) @@ -244,10 +309,10 @@ func (tokenConfig *TokenConfig) CreateTokenWithKey(key jwk.Key) (string, error) builder.Claim("scope", tokenConfig.scope) } - if tokenConfig.TokenProfile == Scitokens2 { - builder.Claim("ver", tokenConfig.Version) - } else if tokenConfig.TokenProfile == WLCG { - builder.Claim("wlcg.ver", tokenConfig.Version) + if tokenConfig.tokenProfile == TokenProfileScitokens2 { + builder.Claim("ver", tokenConfig.version) + } else if tokenConfig.tokenProfile == TokenProfileWLCG { + builder.Claim("wlcg.ver", tokenConfig.version) } if tokenConfig.Claims != nil { diff --git a/token/token_create_test.go b/token/token_create_test.go index 16f85daba..076a23bb6 100644 --- a/token/token_create_test.go +++ b/token/token_create_test.go @@ -39,58 +39,58 @@ import ( func TestVerifyCreateSciTokens2(t *testing.T) { // Start by feeding it a valid claims map - tokenConfig := TokenConfig{TokenProfile: Scitokens2, Audience: []string{"foo"}, Version: "scitokens:2.0", scope: "read:/storage"} + tokenConfig := TokenConfig{tokenProfile: TokenProfileScitokens2, audience: []string{"foo"}, version: "scitokens:2.0", scope: "read:/storage"} err := tokenConfig.verifyCreateSciTokens2() assert.NoError(t, err) // Fail to give it audience - tokenConfig = TokenConfig{TokenProfile: Scitokens2, Version: "scitokens:2.0", scope: "read:/storage"} + tokenConfig = TokenConfig{tokenProfile: TokenProfileScitokens2, version: "scitokens:2.0", scope: "read:/storage"} err = tokenConfig.verifyCreateSciTokens2() - assert.EqualError(t, err, "The 'audience' claim is required for the scitokens2 profile, but it could not be found.") + assert.EqualError(t, err, "the 'audience' claim is required for the scitokens2 profile, but it could not be found") // Fail to give it scope - tokenConfig = TokenConfig{TokenProfile: Scitokens2, Audience: []string{"foo"}, Version: "scitokens:2.0"} + tokenConfig = TokenConfig{tokenProfile: TokenProfileScitokens2, audience: []string{"foo"}, version: "scitokens:2.0"} err = tokenConfig.verifyCreateSciTokens2() - assert.EqualError(t, err, "The 'scope' claim is required for the scitokens2 profile, but it could not be found.") + assert.EqualError(t, err, "the 'scope' claim is required for the scitokens2 profile, but it could not be found") // Give it bad version - tokenConfig = TokenConfig{TokenProfile: Scitokens2, Audience: []string{"foo"}, Version: "scitokens:2.xxxx", scope: "read:/storage"} + tokenConfig = TokenConfig{tokenProfile: TokenProfileScitokens2, audience: []string{"foo"}, version: "scitokens:2.xxxx", scope: "read:/storage"} err = tokenConfig.verifyCreateSciTokens2() - assert.EqualError(t, err, "The provided version 'scitokens:2.xxxx' is not valid. It must match 'scitokens:', where version is of the form 2.x") + assert.EqualError(t, err, "the provided version 'scitokens:2.xxxx' is not valid. It must match 'scitokens:', where version is of the form 2.x") // Don't give it a version and make sure it gets set correctly - tokenConfig = TokenConfig{TokenProfile: Scitokens2, Audience: []string{"foo"}, scope: "read:/storage"} + tokenConfig = TokenConfig{tokenProfile: TokenProfileScitokens2, audience: []string{"foo"}, scope: "read:/storage"} err = tokenConfig.verifyCreateSciTokens2() assert.NoError(t, err) - assert.Equal(t, tokenConfig.Version, "scitokens:2.0") + assert.Equal(t, tokenConfig.version, "scitokens:2.0") } func TestVerifyCreateWLCG(t *testing.T) { // Start by feeding it a valid claims map - tokenConfig := TokenConfig{TokenProfile: WLCG, Audience: []string{"director"}, Version: "1.0", Subject: "foo"} + tokenConfig := TokenConfig{tokenProfile: TokenProfileWLCG, audience: []string{"director"}, version: "1.0", Subject: "foo"} err := tokenConfig.verifyCreateWLCG() assert.NoError(t, err) // Fail to give it a sub - tokenConfig = TokenConfig{TokenProfile: WLCG, Audience: []string{"director"}, Version: "1.0"} + tokenConfig = TokenConfig{tokenProfile: TokenProfileWLCG, audience: []string{"director"}, version: "1.0"} err = tokenConfig.verifyCreateWLCG() - assert.EqualError(t, err, "The 'subject' claim is required for the scitokens2 profile, but it could not be found.") + assert.EqualError(t, err, "the 'subject' claim is required for the WLCG profile, but it could not be found") // Fail to give it an aud - tokenConfig = TokenConfig{TokenProfile: WLCG, Version: "1.0", Subject: "foo"} + tokenConfig = TokenConfig{tokenProfile: TokenProfileWLCG, version: "1.0", Subject: "foo"} err = tokenConfig.verifyCreateWLCG() - assert.EqualError(t, err, "The 'audience' claim is required for the scitokens2 profile, but it could not be found.") + assert.EqualError(t, err, "the 'audience' claim is required for the WLCG profile, but it could not be found") // Give it bad version - tokenConfig = TokenConfig{TokenProfile: WLCG, Audience: []string{"director"}, Version: "1.xxxx", Subject: "foo"} + tokenConfig = TokenConfig{tokenProfile: TokenProfileWLCG, audience: []string{"director"}, version: "1.xxxx", Subject: "foo"} err = tokenConfig.verifyCreateWLCG() - assert.EqualError(t, err, "The provided version '1.xxxx' is not valid. It must be of the form '1.x'") + assert.EqualError(t, err, "the provided version '1.xxxx' is not valid. It must be of the form '1.x'") // Don't give it a version and make sure it gets set correctly - tokenConfig = TokenConfig{TokenProfile: WLCG, Audience: []string{"director"}, Subject: "foo"} + tokenConfig = TokenConfig{tokenProfile: TokenProfileWLCG, audience: []string{"director"}, Subject: "foo"} err = tokenConfig.verifyCreateWLCG() assert.NoError(t, err) - assert.Equal(t, tokenConfig.Version, "1.0") + assert.Equal(t, tokenConfig.version, "1.0") } // TestAddScopes tests the AddScopes method of TokenConfig @@ -217,33 +217,33 @@ func TestCreateToken(t *testing.T) { assert.NoError(t, err) // Test that the wlcg profile works - tokenConfig := TokenConfig{TokenProfile: WLCG, Audience: []string{"foo"}, Subject: "bar", Lifetime: time.Minute * 10} + tokenConfig := TokenConfig{tokenProfile: TokenProfileWLCG, audience: []string{"foo"}, Subject: "bar", Lifetime: time.Minute * 10} _, err = tokenConfig.CreateToken() assert.NoError(t, err) // Test that the wlcg profile fails if neither sub or aud not found - tokenConfig = TokenConfig{TokenProfile: WLCG, Lifetime: time.Minute * 10} + tokenConfig = TokenConfig{tokenProfile: TokenProfileWLCG, Lifetime: time.Minute * 10} _, err = tokenConfig.CreateToken() - assert.EqualError(t, err, "Invalid tokenConfig: The 'audience' claim is required for the scitokens2 profile, but it could not be found.") + assert.EqualError(t, err, "invalid tokenConfig: the 'audience' claim is required for the WLCG profile, but it could not be found") // Test that the scitokens2 profile works - tokenConfig = TokenConfig{TokenProfile: Scitokens2, Audience: []string{"foo"}, scope: "bar", Lifetime: time.Minute * 10} + tokenConfig = TokenConfig{tokenProfile: TokenProfileScitokens2, audience: []string{"foo"}, scope: "bar", Lifetime: time.Minute * 10} _, err = tokenConfig.CreateToken() assert.NoError(t, err) // Test that the scitokens2 profile fails if claims not found - tokenConfig = TokenConfig{TokenProfile: Scitokens2, Lifetime: time.Minute * 10} + tokenConfig = TokenConfig{tokenProfile: TokenProfileScitokens2, Lifetime: time.Minute * 10} _, err = tokenConfig.CreateToken() - assert.EqualError(t, err, "Invalid tokenConfig: The 'audience' claim is required for the scitokens2 profile, but it could not be found.") + assert.EqualError(t, err, "invalid tokenConfig: the 'audience' claim is required for the scitokens2 profile, but it could not be found") // Test an unrecognized profile - tokenConfig = TokenConfig{TokenProfile: TokenProfile("unknown"), Lifetime: time.Minute * 10} + tokenConfig = TokenConfig{tokenProfile: TokenProfile("unknown"), Lifetime: time.Minute * 10} _, err = tokenConfig.CreateToken() - assert.EqualError(t, err, "Invalid tokenConfig: Unsupported token profile: unknown") + assert.EqualError(t, err, "invalid tokenConfig: unsupported token profile: unknown") // Test that additional claims can be passed into the token - tokenConfig = TokenConfig{TokenProfile: WLCG, Audience: []string{"foo"}, Subject: "bar", Lifetime: time.Minute * 10, Claims: map[string]string{"foo": "bar"}} + tokenConfig = TokenConfig{tokenProfile: TokenProfileWLCG, audience: []string{"foo"}, Subject: "bar", Lifetime: time.Minute * 10, Claims: map[string]string{"foo": "bar"}} token, err := tokenConfig.CreateToken() require.NoError(t, err) jwt, err := jwt.ParseString(token, jwt.WithVerify(false)) @@ -254,7 +254,7 @@ func TestCreateToken(t *testing.T) { // Test providing issuer via claim viper.Set("IssuerUrl", "") - tokenConfig = TokenConfig{TokenProfile: WLCG, Audience: []string{"foo"}, Subject: "bar", Issuer: "https://localhost:9999", Lifetime: time.Minute * 10} + tokenConfig = TokenConfig{tokenProfile: TokenProfileWLCG, audience: []string{"foo"}, Subject: "bar", Issuer: "https://localhost:9999", Lifetime: time.Minute * 10} _, err = tokenConfig.CreateToken() assert.NoError(t, err) diff --git a/web_ui/authentication.go b/web_ui/authentication.go index 05a8a4bf5..c3a6481ed 100644 --- a/web_ui/authentication.go +++ b/web_ui/authentication.go @@ -153,14 +153,11 @@ func GetUser(ctx *gin.Context) (string, error) { // Create a JWT and set the "login" cookie to store that JWT func setLoginCookie(ctx *gin.Context, user string) { - loginCookieTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: 30 * time.Minute, - Issuer: param.Server_ExternalWebUrl.GetString(), - Audience: []string{param.Server_ExternalWebUrl.GetString()}, - Version: "1.0", - Subject: user, - } + loginCookieTokenCfg := token.NewWLCGToken() + loginCookieTokenCfg.Lifetime = 30 * time.Minute + loginCookieTokenCfg.Issuer = param.Server_ExternalWebUrl.GetString() + loginCookieTokenCfg.AddAudiences(param.Server_ExternalWebUrl.GetString()) + loginCookieTokenCfg.Subject = user loginCookieTokenCfg.AddScopes(token_scopes.WebUi_Access, token_scopes.Monitoring_Query, token_scopes.Monitoring_Scrape) // CreateToken also handles validation for us diff --git a/web_ui/authorization.go b/web_ui/authorization.go index e4ab89c9d..eb68c9491 100644 --- a/web_ui/authorization.go +++ b/web_ui/authorization.go @@ -35,15 +35,12 @@ import ( // the server itself func createPromMetricToken() (string, error) { serverUrl := param.Server_ExternalWebUrl.GetString() - promMetricTokCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: param.Monitoring_TokenExpiresIn.GetDuration(), - Issuer: serverUrl, - Audience: []string{serverUrl}, - Version: "1.0", - Subject: serverUrl, - Claims: map[string]string{"scope": token_scopes.Monitoring_Scrape.String()}, - } + promMetricTokCfg := token.NewWLCGToken() + promMetricTokCfg.Lifetime = param.Monitoring_TokenExpiresIn.GetDuration() + promMetricTokCfg.Issuer = serverUrl + promMetricTokCfg.AddAudiences(serverUrl) + promMetricTokCfg.Subject = serverUrl + promMetricTokCfg.Claims = map[string]string{"scope": token_scopes.Monitoring_Scrape.String()} // CreateToken also handles validation for us tok, err := promMetricTokCfg.CreateToken() diff --git a/web_ui/prometheus.go b/web_ui/prometheus.go index 34ad3fcc6..2eecc9d45 100644 --- a/web_ui/prometheus.go +++ b/web_ui/prometheus.go @@ -152,14 +152,11 @@ func configDirectorPromScraper(ctx context.Context) (*config.ScrapeConfig, error return nil, fmt.Errorf("parse external URL %v: %w", param.Server_ExternalWebUrl.GetString(), err) } - promTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: param.Monitoring_TokenExpiresIn.GetDuration(), - Issuer: directorBaseUrl.String(), - Audience: []string{directorBaseUrl.String()}, - Version: "1.0", - Subject: "director", - } + promTokenCfg := token.NewWLCGToken() + promTokenCfg.Lifetime = param.Monitoring_TokenExpiresIn.GetDuration() + promTokenCfg.Issuer = directorBaseUrl.String() + promTokenCfg.AddAudiences(directorBaseUrl.String()) + promTokenCfg.Subject = "director" promTokenCfg.AddScopes(token_scopes.Pelican_DirectorServiceDiscovery) // CreateToken also handles validation for us @@ -168,14 +165,11 @@ func configDirectorPromScraper(ctx context.Context) (*config.ScrapeConfig, error return nil, errors.Wrap(err, "failed to create director prometheus token") } - scrapeTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Version: "1.0", - Lifetime: param.Monitoring_TokenExpiresIn.GetDuration(), - Issuer: directorBaseUrl.String(), - Audience: []string{"prometheus"}, - Subject: "director", - } + scrapeTokenCfg := token.NewWLCGToken() + scrapeTokenCfg.Lifetime = param.Monitoring_TokenExpiresIn.GetDuration() + scrapeTokenCfg.Issuer = directorBaseUrl.String() + scrapeTokenCfg.AddAudiences("prometheus") + scrapeTokenCfg.Subject = "director" scrapeTokenCfg.AddScopes(token_scopes.Monitoring_Scrape) scraperToken, err := scrapeTokenCfg.CreateToken() diff --git a/web_ui/prometheus_test.go b/web_ui/prometheus_test.go index 1a875e6c1..a5fc62bdb 100644 --- a/web_ui/prometheus_test.go +++ b/web_ui/prometheus_test.go @@ -73,15 +73,13 @@ func TestPrometheusProtectionCookieAuth(t *testing.T) { } issuerUrl := param.Server_ExternalWebUrl.GetString() - promTokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: 10 * time.Minute, - Issuer: issuerUrl, - Audience: []string{issuerUrl}, - Version: "1.0", - Subject: "sub", - Claims: map[string]string{"scope": token_scopes.Monitoring_Query.String()}, - } + promTokenCfg := token.NewWLCGToken() + + promTokenCfg.Lifetime = 10 * time.Minute + promTokenCfg.Issuer = issuerUrl + promTokenCfg.AddAudiences(issuerUrl) + promTokenCfg.Subject = "sub" + promTokenCfg.Claims = map[string]string{"scope": token_scopes.Monitoring_Query.String()} tok, err := promTokenCfg.CreateToken() assert.NoError(t, err, "failed to create prometheus token") @@ -128,15 +126,12 @@ func TestPrometheusProtectionOriginHeaderScope(t *testing.T) { // Shared function to create a token createToken := func(scope, aud string) string { - tokenCfg := token.TokenConfig{ - TokenProfile: token.WLCG, - Lifetime: param.Monitoring_TokenExpiresIn.GetDuration(), - Issuer: issuerUrl, - Audience: []string{aud}, - Version: "1.0", - Subject: "sub", - Claims: map[string]string{"scope": scope}, - } + tokenCfg := token.NewWLCGToken() + tokenCfg.Lifetime = param.Monitoring_TokenExpiresIn.GetDuration() + tokenCfg.Issuer = issuerUrl + tokenCfg.AddAudiences(aud) + tokenCfg.Subject = "sub" + tokenCfg.Claims = map[string]string{"scope": scope} tok, err := tokenCfg.CreateToken() assert.NoError(t, err, "failed to create prometheus test token") From 3ab83898a7372ea66cc945ce5ac98d1b1d095aaa Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 9 Mar 2024 09:41:12 -0800 Subject: [PATCH 35/45] Add unit test for calculating resource scopes --- local_cache/cache_internal_test.go | 82 ++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 local_cache/cache_internal_test.go diff --git a/local_cache/cache_internal_test.go b/local_cache/cache_internal_test.go new file mode 100644 index 000000000..840223af9 --- /dev/null +++ b/local_cache/cache_internal_test.go @@ -0,0 +1,82 @@ +//go:build !windows + +/*************************************************************** + * + * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package local_cache + +import ( + "testing" + + "github.com/pelicanplatform/pelican/token_scopes" + "github.com/stretchr/testify/assert" +) + +func TestCalcResources(t *testing.T) { + tests := []struct { + scopes token_scopes.ResourceScope + basePaths []string + restrictedPaths []string + result []token_scopes.ResourceScope + }{ + { + scopes: token_scopes.NewResourceScope(token_scopes.Storage_Read, "/"), + basePaths: []string{"/foo", "/bar"}, + result: []token_scopes.ResourceScope{ + token_scopes.NewResourceScope(token_scopes.Storage_Read, "/foo"), + token_scopes.NewResourceScope(token_scopes.Storage_Read, "/bar"), + }, + }, + { + scopes: token_scopes.NewResourceScope(token_scopes.Storage_Read, "/"), + basePaths: []string{"/foo", "/bar"}, + restrictedPaths: []string{"/baz"}, + result: []token_scopes.ResourceScope{ + token_scopes.NewResourceScope(token_scopes.Storage_Read, "/foo/baz"), + token_scopes.NewResourceScope(token_scopes.Storage_Read, "/bar/baz"), + }, + }, + { + scopes: token_scopes.NewResourceScope(token_scopes.Storage_Read, "/"), + basePaths: []string{"/"}, + restrictedPaths: []string{"/foo", "/bar"}, + result: []token_scopes.ResourceScope{ + token_scopes.NewResourceScope(token_scopes.Storage_Read, "/foo"), + token_scopes.NewResourceScope(token_scopes.Storage_Read, "/bar"), + }, + }, + { + scopes: token_scopes.NewResourceScope(token_scopes.Storage_Read, "/baz"), + basePaths: []string{"/foo"}, + restrictedPaths: []string{"/bar"}, + result: []token_scopes.ResourceScope{}, + }, + { + scopes: token_scopes.NewResourceScope(token_scopes.Storage_Read, "/bar/baz"), + basePaths: []string{"/foo"}, + restrictedPaths: []string{"/bar"}, + result: []token_scopes.ResourceScope{ + token_scopes.NewResourceScope(token_scopes.Storage_Read, "/foo/bar/baz"), + }, + }, + } + for _, test := range tests { + result := calcResourceScopes(test.scopes, test.basePaths, test.restrictedPaths) + assert.Equal(t, test.result, result) + } +} From 224b4f7f11945e56ef610f992e556331e8cbcff0 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 9 Mar 2024 10:14:38 -0800 Subject: [PATCH 36/45] Refactor FedTest to be shared across client and local_cache Removes some nearly-duplicate code between the two modules' tests. --- client/fed_test.go | 150 ++++---------------------------- fed_test_utils/fed.go | 167 ++++++++++++++++++++++++++++++++++++ local_cache/cache_test.go | 176 ++++++-------------------------------- 3 files changed, 210 insertions(+), 283 deletions(-) create mode 100644 fed_test_utils/fed.go diff --git a/client/fed_test.go b/client/fed_test.go index 068f3e9ab..ba729f661 100644 --- a/client/fed_test.go +++ b/client/fed_test.go @@ -37,10 +37,10 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" "github.com/pelicanplatform/pelican/client" "github.com/pelicanplatform/pelican/config" + "github.com/pelicanplatform/pelican/fed_test_utils" "github.com/pelicanplatform/pelican/launchers" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/server_utils" @@ -140,7 +140,7 @@ func TestFullUpload(t *testing.T) { require.NoError(t, err) } - desiredURL := param.Server_ExternalWebUrl.GetString() + "/.well-known/openid-configuration" + desiredURL := param.Server_ExternalWebUrl.GetString() + "/api/v1.0/health" err = server_utils.WaitUntilWorking(ctx, "GET", desiredURL, "director", 200) require.NoError(t, err) @@ -155,12 +155,12 @@ func TestFullUpload(t *testing.T) { responseBody, err := io.ReadAll(resp.Body) require.NoError(t, err) expectedResponse := struct { - JwksUri string `json:"jwks_uri"` + Msg string `json:"message"` }{} err = json.Unmarshal(responseBody, &expectedResponse) require.NoError(t, err) - assert.NotEmpty(t, expectedResponse.JwksUri) + assert.NotEmpty(t, expectedResponse.Msg) t.Run("testFullUpload", func(t *testing.T) { testFileContent := "test file content" @@ -204,128 +204,13 @@ func TestFullUpload(t *testing.T) { os.RemoveAll(originDir) }) - cancel() - fedCancel() - assert.NoError(t, egrp.Wait()) - viper.Reset() -} - -type FedTest struct { - T *testing.T - TmpPath string - OriginDir string - Output *os.File - Cancel context.CancelFunc - FedCancel context.CancelFunc - ErrGroup *errgroup.Group -} - -func (f *FedTest) Spinup() { - //////////////////////////////Setup our test federation////////////////////////////////////////// - ctx, cancel, egrp := test_utils.TestContext(context.Background(), f.T) - - modules := config.ServerType(0) - modules.Set(config.OriginType) - modules.Set(config.DirectorType) - modules.Set(config.RegistryType) - - // Create our own temp directory (for some reason t.TempDir() does not play well with xrootd) - tmpPathPattern := "XRootD-Test_Origin*" - tmpPath, err := os.MkdirTemp("", tmpPathPattern) - require.NoError(f.T, err) - f.TmpPath = tmpPath - - // Need to set permissions or the xrootd process we spawn won't be able to write PID/UID files - permissions := os.FileMode(0755) - err = os.Chmod(tmpPath, permissions) - require.NoError(f.T, err) - - viper.Set("ConfigDir", tmpPath) - - config.InitConfig() - // Create a file to capture output from commands - output, err := os.CreateTemp(f.T.TempDir(), "output") - assert.NoError(f.T, err) - f.Output = output - viper.Set("Logging.LogLocation", output.Name()) - - originDir, err := os.MkdirTemp("", "Origin") - assert.NoError(f.T, err) - f.OriginDir = originDir - - // Change the permissions of the temporary origin directory - permissions = os.FileMode(0777) - err = os.Chmod(originDir, permissions) - require.NoError(f.T, err) - - viper.Set("Origin.ExportVolume", originDir+":/test") - viper.Set("Origin.Mode", "posix") - viper.Set("Origin.EnableFallbackRead", true) - // Disable functionality we're not using (and is difficult to make work on Mac) - viper.Set("Origin.EnableCmsd", false) - viper.Set("Origin.EnableMacaroons", false) - viper.Set("Origin.EnableVoms", false) - viper.Set("Origin.EnableWrite", true) - viper.Set("TLSSkipVerify", true) - viper.Set("Server.EnableUI", false) - viper.Set("Registry.DbLocation", filepath.Join(f.T.TempDir(), "ns-registry.sqlite")) - viper.Set("Origin.Port", 0) - viper.Set("Server.WebPort", 0) - viper.Set("Origin.RunLocation", tmpPath) - - err = config.InitServer(ctx, modules) - require.NoError(f.T, err) - - viper.Set("Registry.RequireOriginApproval", false) - viper.Set("Registry.RequireCacheApproval", false) - - f.FedCancel, err = launchers.LaunchModules(ctx, modules) - if err != nil { - f.T.Fatalf("Failure in fedServeInternal: %v", err) - } - - desiredURL := param.Server_ExternalWebUrl.GetString() + "/.well-known/openid-configuration" - err = server_utils.WaitUntilWorking(ctx, "GET", desiredURL, "director", 200) - require.NoError(f.T, err) - - httpc := http.Client{ - Transport: config.GetTransport(), - } - resp, err := httpc.Get(desiredURL) - require.NoError(f.T, err) - - assert.Equal(f.T, resp.StatusCode, http.StatusOK) - - responseBody, err := io.ReadAll(resp.Body) - require.NoError(f.T, err) - expectedResponse := struct { - JwksUri string `json:"jwks_uri"` - }{} - err = json.Unmarshal(responseBody, &expectedResponse) - require.NoError(f.T, err) - - f.Cancel = cancel - f.ErrGroup = egrp -} - -func (f *FedTest) Teardown() { - os.RemoveAll(f.TmpPath) - os.RemoveAll(f.OriginDir) - f.Cancel() - f.FedCancel() - assert.NoError(f.T, f.ErrGroup.Wait()) viper.Reset() } // A test that spins up a federation, and tests object get and put func TestGetAndPutAuth(t *testing.T) { - // Create instance of test federation - ctx, _, _ := test_utils.TestContext(context.Background(), t) - viper.Reset() - fed := FedTest{T: t} - fed.Spinup() - defer fed.Teardown() + fed := fed_test_utils.NewFedTest(t) // Other set-up items: testFileContent := "test file content" @@ -376,14 +261,14 @@ func TestGetAndPutAuth(t *testing.T) { uploadURL := "pelican:///test/" + fileName // Upload the file with PUT - transferResultsUpload, err := client.DoPut(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) } // Download that same file with GET - transferResultsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) @@ -400,14 +285,14 @@ func TestGetAndPutAuth(t *testing.T) { uploadURL := "pelican:///test/" + fileName // Upload the file with PUT - transferResultsUpload, err := client.DoPut(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) } // Download that same file with GET - transferResultsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) @@ -423,14 +308,14 @@ func TestGetAndPutAuth(t *testing.T) { uploadURL := "pelican:///test/" + fileName // Upload the file with PUT - transferResultsUpload, err := client.DoPut(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) } // Download that same file with GET - transferResultsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) @@ -447,14 +332,14 @@ func TestGetAndPutAuth(t *testing.T) { uploadURL := "pelican:///test/" + fileName // Upload the file with PUT - transferResultsUpload, err := client.DoPut(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) } // Download that same file with GET - transferResultsDownload, err := client.DoGet(ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) @@ -467,9 +352,8 @@ func TestGetPublicRead(t *testing.T) { ctx, _, _ := test_utils.TestContext(context.Background(), t) viper.Reset() viper.Set("Origin.EnablePublicReads", true) - fed := FedTest{T: t} - fed.Spinup() - defer fed.Teardown() + fed := fed_test_utils.NewFedTest(t) + t.Run("testPubObjGet", func(t *testing.T) { testFileContent := "test file content" // Drop the testFileContent into the origin directory @@ -501,9 +385,7 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { ctx, _, _ := test_utils.TestContext(context.Background(), t) viper.Reset() - fed := FedTest{T: t} - fed.Spinup() - defer fed.Teardown() + fed_test_utils.NewFedTest(t) //////////////////////////SETUP/////////////////////////// // Create a token file diff --git a/fed_test_utils/fed.go b/fed_test_utils/fed.go new file mode 100644 index 000000000..0c96616e1 --- /dev/null +++ b/fed_test_utils/fed.go @@ -0,0 +1,167 @@ +//go:build !windows + +/*************************************************************** + * + * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package fed_test_utils + +import ( + "context" + "encoding/json" + "io" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/pelicanplatform/pelican/config" + "github.com/pelicanplatform/pelican/launchers" + "github.com/pelicanplatform/pelican/param" + "github.com/pelicanplatform/pelican/server_utils" + "github.com/pelicanplatform/pelican/test_utils" + "github.com/pelicanplatform/pelican/token" + "github.com/pelicanplatform/pelican/token_scopes" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +type ( + FedTest struct { + OriginDir string + Token string + Ctx context.Context + Egrp *errgroup.Group + } +) + +func NewFedTest(t *testing.T) (ft *FedTest) { + ft = &FedTest{} + + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + t.Cleanup(func() { + cancel() + if err := egrp.Wait(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + require.NoError(t, err) + } + }) + + ft.Ctx = ctx + ft.Egrp = egrp + + modules := config.ServerType(0) + modules.Set(config.OriginType) + modules.Set(config.DirectorType) + modules.Set(config.RegistryType) + // TODO: the cache startup routines not sequenced correctly for the downloads + // to immediately work through the cache. For now, unit tests will just use the origin. + viper.Set("Origin.EnableFallbackRead", true) + modules.Set(config.LocalCacheType) + + tmpPathPattern := "PelicanOrigin-FedTest*" + tmpPath, err := os.MkdirTemp("", tmpPathPattern) + require.NoError(t, err) + + permissions := os.FileMode(0755) + err = os.Chmod(tmpPath, permissions) + require.NoError(t, err) + t.Cleanup(func() { + err := os.RemoveAll(tmpPath) + require.NoError(t, err) + }) + + viper.Set("ConfigDir", tmpPath) + + config.InitConfig() + + originDir, err := os.MkdirTemp("", "Origin") + assert.NoError(t, err) + t.Cleanup(func() { + err := os.RemoveAll(originDir) + require.NoError(t, err) + }) + + // Change the permissions of the temporary origin directory + permissions = os.FileMode(0755) + err = os.Chmod(originDir, permissions) + require.NoError(t, err) + + viper.Set("Origin.ExportVolume", originDir+":/test") + viper.Set("Origin.Mode", "posix") + viper.Set("Origin.EnableFallbackRead", true) + // Disable functionality we're not using (and is difficult to make work on Mac) + viper.Set("Origin.EnableCmsd", false) + viper.Set("Origin.EnableMacaroons", false) + viper.Set("Origin.EnableVoms", false) + viper.Set("Server.EnableUI", false) + viper.Set("Registry.DbLocation", filepath.Join(t.TempDir(), "ns-registry.sqlite")) + viper.Set("Origin.Port", 0) + viper.Set("Server.WebPort", 0) + viper.Set("Origin.RunLocation", tmpPath) + viper.Set("Registry.RequireOriginApproval", false) + viper.Set("Registry.RequireCacheApproval", false) + + err = config.InitServer(ctx, modules) + require.NoError(t, err) + + _, err = launchers.LaunchModules(ctx, modules) + require.NoError(t, err) + + desiredURL := param.Server_ExternalWebUrl.GetString() + "/api/v1.0/health" + err = server_utils.WaitUntilWorking(ctx, "GET", desiredURL, "director", 200) + require.NoError(t, err) + + httpc := http.Client{ + Transport: config.GetTransport(), + } + resp, err := httpc.Get(desiredURL) + require.NoError(t, err) + + assert.Equal(t, resp.StatusCode, http.StatusOK) + + responseBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + expectedResponse := struct { + Msg string `json:"message"` + }{} + err = json.Unmarshal(responseBody, &expectedResponse) + require.NoError(t, err) + assert.NotEmpty(t, expectedResponse.Msg) + + err = os.WriteFile(filepath.Join(originDir, "hello_world.txt"), []byte("Hello, World!"), os.FileMode(0644)) + require.NoError(t, err) + + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + tokConf := token.NewWLCGToken() + tokConf.Lifetime = time.Duration(time.Minute) + tokConf.Issuer = issuer + tokConf.Subject = "test" + tokConf.AddAudienceAny() + tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) + + token, err := tokConf.CreateToken() + require.NoError(t, err) + + ft.OriginDir = originDir + ft.Token = token + + return +} diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index 3787d2206..f11dfdc9c 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -35,10 +35,9 @@ import ( "github.com/pelicanplatform/pelican/client" "github.com/pelicanplatform/pelican/config" - "github.com/pelicanplatform/pelican/launchers" + "github.com/pelicanplatform/pelican/fed_test_utils" local_cache "github.com/pelicanplatform/pelican/local_cache" "github.com/pelicanplatform/pelican/param" - "github.com/pelicanplatform/pelican/test_utils" "github.com/pelicanplatform/pelican/token" "github.com/pelicanplatform/pelican/token_scopes" "github.com/pelicanplatform/pelican/utils" @@ -46,115 +45,18 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" ) -type ( - fedTest struct { - originDir string - token string - } -) - -func (ft *fedTest) spinup(t *testing.T, ctx context.Context, egrp *errgroup.Group) { - - modules := config.ServerType(0) - modules.Set(config.OriginType) - modules.Set(config.DirectorType) - modules.Set(config.RegistryType) - // TODO: the cache startup routines not sequenced correctly for the downloads - // to immediately work through the cache. For now, unit tests will just use the origin. - viper.Set("Origin.EnableFallbackRead", true) - modules.Set(config.LocalCacheType) - - tmpPathPattern := "XRootD-Test_Origin*" - tmpPath, err := os.MkdirTemp("", tmpPathPattern) - require.NoError(t, err) - - permissions := os.FileMode(0755) - err = os.Chmod(tmpPath, permissions) - require.NoError(t, err) - t.Cleanup(func() { - err := os.RemoveAll(tmpPath) - require.NoError(t, err) - }) - - viper.Set("ConfigDir", tmpPath) - - config.InitConfig() - - originDir, err := os.MkdirTemp("", "Origin") - assert.NoError(t, err) - t.Cleanup(func() { - err := os.RemoveAll(originDir) - require.NoError(t, err) - }) - - // Change the permissions of the temporary origin directory - permissions = os.FileMode(0755) - err = os.Chmod(originDir, permissions) - require.NoError(t, err) - - viper.Set("Origin.ExportVolume", originDir+":/test") - viper.Set("Origin.Mode", "posix") - // Disable functionality we're not using (and is difficult to make work on Mac) - viper.Set("Origin.EnableCmsd", false) - viper.Set("Origin.EnableMacaroons", false) - viper.Set("Origin.EnableVoms", false) - viper.Set("Server.EnableUI", false) - viper.Set("Registry.DbLocation", filepath.Join(t.TempDir(), "ns-registry.sqlite")) - viper.Set("Origin.Port", 0) - viper.Set("Server.WebPort", 0) - viper.Set("Origin.RunLocation", tmpPath) - viper.Set("Cache.RunLocation", tmpPath) - viper.Set("Registry.RequireOriginApproval", false) - viper.Set("Registry.RequireCacheApproval", false) - - err = config.InitServer(ctx, modules) - require.NoError(t, err) - - cancel, err := launchers.LaunchModules(ctx, modules) - require.NoError(t, err) - t.Cleanup(func() { - cancel() - if err = egrp.Wait(); err != nil && err != context.Canceled && err != http.ErrServerClosed { - require.NoError(t, err) - } - }) - - err = os.WriteFile(filepath.Join(originDir, "hello_world.txt"), []byte("Hello, World!"), os.FileMode(0644)) - require.NoError(t, err) - - issuer, err := config.GetServerIssuerURL() - require.NoError(t, err) - tokConf := token.NewWLCGToken() - tokConf.Lifetime = time.Duration(time.Minute) - tokConf.Issuer = issuer - tokConf.Subject = "test" - tokConf.AddAudienceAny() - tokConf.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/hello_world.txt")) - - token, err := tokConf.CreateToken() - require.NoError(t, err) - - ft.originDir = originDir - ft.token = token -} - // Setup a federation, invoke "get" through the local cache module // // The download is done twice -- once to verify functionality and once // as a cache hit. func TestFedPublicGet(t *testing.T) { - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer cancel() - viper.Reset() viper.Set("Origin.EnablePublicReads", true) - ft := fedTest{} - ft.spinup(t, ctx, egrp) + ft := fed_test_utils.NewFedTest(t) - lc, err := local_cache.NewLocalCache(ctx, egrp) + lc, err := local_cache.NewLocalCache(ft.Ctx, ft.Egrp) require.NoError(t, err) reader, err := lc.Get("/test/hello_world.txt", "") @@ -176,18 +78,14 @@ func TestFedPublicGet(t *testing.T) { // Test the local cache library on an authenticated GET. func TestFedAuthGet(t *testing.T) { - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer cancel() - viper.Reset() viper.Set("Origin.EnablePublicReads", false) - ft := fedTest{} - ft.spinup(t, ctx, egrp) + ft := fed_test_utils.NewFedTest(t) - lc, err := local_cache.NewLocalCache(ctx, egrp) + lc, err := local_cache.NewLocalCache(ft.Ctx, ft.Egrp) require.NoError(t, err) - reader, err := lc.Get("/test/hello_world.txt", ft.token) + reader, err := lc.Get("/test/hello_world.txt", ft.Token) require.NoError(t, err) byteBuff, err := io.ReadAll(reader) @@ -213,13 +111,9 @@ func TestFedAuthGet(t *testing.T) { // Test a raw HTTP request (no Pelican client) works with the local cache func TestHttpReq(t *testing.T) { - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer cancel() - viper.Reset() viper.Set("Origin.EnablePublicReads", false) - ft := fedTest{} - ft.spinup(t, ctx, egrp) + ft := fed_test_utils.NewFedTest(t) transport := config.GetTransport().Clone() transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { @@ -229,7 +123,7 @@ func TestHttpReq(t *testing.T) { client := &http.Client{Transport: transport} req, err := http.NewRequest("GET", "http://localhost/test/hello_world.txt", nil) require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+ft.token) + req.Header.Set("Authorization", "Bearer "+ft.Token) resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -241,14 +135,11 @@ func TestHttpReq(t *testing.T) { // Test that the client library (with authentication) works with the local cache func TestClient(t *testing.T) { - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer cancel() tmpDir := t.TempDir() viper.Reset() viper.Set("Origin.EnablePublicReads", false) - ft := fedTest{} - ft.spinup(t, ctx, egrp) + ft := fed_test_utils.NewFedTest(t) cacheUrl := &url.URL{ Scheme: "unix", @@ -259,8 +150,8 @@ func TestClient(t *testing.T) { discoveryHost := param.Federation_DiscoveryUrl.GetString() discoveryUrl, err := url.Parse(discoveryHost) require.NoError(t, err) - tr, err := client.DoGet(ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, - client.WithToken(ft.token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + tr, err := client.DoGet(ft.Ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, + client.WithToken(ft.Token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) assert.Equal(t, int64(13), tr[0].TransferredBytes) @@ -271,7 +162,7 @@ func TestClient(t *testing.T) { assert.Equal(t, "Hello, World!", string(byteBuff)) }) t.Run("incorrect-auth", func(t *testing.T) { - _, err := client.DoGet(ctx, "pelican:///test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, + _, err := client.DoGet(ft.Ctx, "pelican:///test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, client.WithToken("badtoken"), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) assert.Error(t, err) assert.ErrorIs(t, err, &client.ConnectionSetupError{}) @@ -294,7 +185,7 @@ func TestClient(t *testing.T) { token, err := tokConf.CreateToken() require.NoError(t, err) - _, err = client.DoGet(ctx, "pelican:///test/hello_world.txt.1", filepath.Join(tmpDir, "hello_world.txt.1"), false, + _, err = client.DoGet(ft.Ctx, "pelican:///test/hello_world.txt.1", filepath.Join(tmpDir, "hello_world.txt.1"), false, client.WithToken(token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) assert.Error(t, err) assert.Equal(t, "failed to download file: transfer error: failed connection setup: server returned 404 Not Found", err.Error()) @@ -303,15 +194,11 @@ func TestClient(t *testing.T) { // Test that HEAD requests to the local cache return the correct result func TestStat(t *testing.T) { - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer cancel() - viper.Reset() viper.Set("Origin.EnablePublicReads", true) - ft := fedTest{} - ft.spinup(t, ctx, egrp) + ft := fed_test_utils.NewFedTest(t) - lc, err := local_cache.NewLocalCache(ctx, egrp) + lc, err := local_cache.NewLocalCache(ft.Ctx, ft.Egrp) require.NoError(t, err) size, err := lc.Stat("/test/hello_world.txt", "") @@ -361,29 +248,26 @@ func writeBigBuffer(t *testing.T, fp io.WriteCloser, sizeMB int) (size int) { // // This triggers multiple internal requests to wait on the slow download func TestLargeFile(t *testing.T) { - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer cancel() tmpDir := t.TempDir() viper.Reset() viper.Set("Origin.EnablePublicReads", true) viper.Set("Client.MaximumDownloadSpeed", 40*1024*1024) - ft := fedTest{} - ft.spinup(t, ctx, egrp) + ft := fed_test_utils.NewFedTest(t) cacheUrl := &url.URL{ Scheme: "unix", Path: param.LocalCache_Socket.GetString(), } - fp, err := os.OpenFile(filepath.Join(ft.originDir, "hello_world.txt"), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + fp, err := os.OpenFile(filepath.Join(ft.OriginDir, "hello_world.txt"), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) require.NoError(t, err) size := writeBigBuffer(t, fp, 100) discoveryHost := param.Federation_DiscoveryUrl.GetString() discoveryUrl, err := url.Parse(discoveryHost) require.NoError(t, err) - tr, err := client.DoGet(ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, + tr, err := client.DoGet(ft.Ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, client.WithCaches(cacheUrl)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) @@ -395,15 +279,12 @@ func TestLargeFile(t *testing.T) { // Create five 1MB files. Trigger a purge, ensuring that the cleanup is // done according to LRU func TestPurge(t *testing.T) { - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer cancel() tmpDir := t.TempDir() viper.Reset() viper.Set("Origin.EnablePublicReads", true) viper.Set("LocalCache.Size", "5MB") - ft := fedTest{} - ft.spinup(t, ctx, egrp) + ft := fed_test_utils.NewFedTest(t) cacheUrl := &url.URL{ Scheme: "unix", @@ -412,15 +293,15 @@ func TestPurge(t *testing.T) { size := 0 for idx := 0; idx < 5; idx++ { - log.Debugln("Will write origin file", filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx))) - fp, err := os.OpenFile(filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + log.Debugln("Will write origin file", filepath.Join(ft.OriginDir, fmt.Sprintf("hello_world.txt.%d", idx))) + fp, err := os.OpenFile(filepath.Join(ft.OriginDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) require.NoError(t, err) size = writeBigBuffer(t, fp, 1) } require.NotEqual(t, 0, size) for idx := 0; idx < 5; idx++ { - tr, err := client.DoGet(ctx, fmt.Sprintf("pelican:///test/hello_world.txt.%d", idx), filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, + tr, err := client.DoGet(ft.Ctx, fmt.Sprintf("pelican:///test/hello_world.txt.%d", idx), filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, client.WithCaches(cacheUrl)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) @@ -446,8 +327,6 @@ func TestPurge(t *testing.T) { // Create four 1MB files (above low-water mark). Force a purge, ensuring that the cleanup is // done according to LRU func TestForcePurge(t *testing.T) { - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer cancel() tmpDir := t.TempDir() viper.Reset() @@ -455,8 +334,7 @@ func TestForcePurge(t *testing.T) { viper.Set("LocalCache.Size", "5MB") // Decrease the low water mark so invoking purge will result in 3 files in the cache. viper.Set("LocalCache.LowWaterMarkPercentage", "80") - ft := fedTest{} - ft.spinup(t, ctx, egrp) + ft := fed_test_utils.NewFedTest(t) issuer, err := config.GetServerIssuerURL() require.NoError(t, err) @@ -478,15 +356,15 @@ func TestForcePurge(t *testing.T) { // Populate the cache with our test files size := 0 for idx := 0; idx < 4; idx++ { - log.Debugln("Will write origin file", filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx))) - fp, err := os.OpenFile(filepath.Join(ft.originDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + log.Debugln("Will write origin file", filepath.Join(ft.OriginDir, fmt.Sprintf("hello_world.txt.%d", idx))) + fp, err := os.OpenFile(filepath.Join(ft.OriginDir, fmt.Sprintf("hello_world.txt.%d", idx)), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) require.NoError(t, err) size = writeBigBuffer(t, fp, 1) } require.NotEqual(t, 0, size) for idx := 0; idx < 4; idx++ { - tr, err := client.DoGet(ctx, fmt.Sprintf("pelican:///test/hello_world.txt.%d", idx), filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, + tr, err := client.DoGet(ft.Ctx, fmt.Sprintf("pelican:///test/hello_world.txt.%d", idx), filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, client.WithCaches(cacheUrl)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) @@ -503,7 +381,7 @@ func TestForcePurge(t *testing.T) { }() } - _, err = utils.MakeRequest(ctx, param.Server_ExternalWebUrl.GetString()+"/api/v1.0/localcache/purge", "POST", nil, map[string]string{"Authorization": "Bearer " + token}) + _, err = utils.MakeRequest(ft.Ctx, param.Server_ExternalWebUrl.GetString()+"/api/v1.0/localcache/purge", "POST", nil, map[string]string{"Authorization": "Bearer " + token}) require.NoError(t, err) // Low water mark is small enough that a force purge will delete a file. From b1dc96cd7b9de2ad71be243daaab6459a5883b90 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 9 Mar 2024 13:45:39 -0600 Subject: [PATCH 37/45] Switch to utility-based verification routine --- local_cache/cache_api.go | 31 +++++++++++++------------------ local_cache/cache_test.go | 4 ++++ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/local_cache/cache_api.go b/local_cache/cache_api.go index dcc8ae364..a5f46eb0a 100644 --- a/local_cache/cache_api.go +++ b/local_cache/cache_api.go @@ -32,10 +32,9 @@ import ( "strings" "github.com/gin-gonic/gin" - "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pelicanplatform/pelican/common" - "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/param" + "github.com/pelicanplatform/pelican/token" "github.com/pelicanplatform/pelican/token_scopes" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -134,27 +133,23 @@ func (lc *LocalCache) Register(ctx context.Context, router *gin.RouterGroup) { // Authorize the request then trigger the purge routine func (lc *LocalCache) purgeCmd(ginCtx *gin.Context) { - token := ginCtx.GetHeader("Authorization") - var hasPrefix bool - if token, hasPrefix = strings.CutPrefix(token, "Bearer "); !hasPrefix { - ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, common.SimpleApiResp{Status: common.RespFailed, Msg: "Bearer token required to authenticate"}) - return - } - jwks, err := config.GetIssuerPublicJWKS() + status, verified, err := token.Verify(ginCtx, token.AuthOption{ + Sources: []token.TokenSource{token.Header}, + Issuers: []token.TokenIssuer{token.LocalIssuer}, + Scopes: []token_scopes.TokenScope{token_scopes.Localcache_Purge}, + }) if err != nil { - ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, common.SimpleApiResp{Status: common.RespFailed, Msg: "Unable to get local server token issuer"}) + if status == http.StatusOK { + status = http.StatusInternalServerError + } + ginCtx.AbortWithStatusJSON(status, common.SimpleApiResp{Status: common.RespFailed, Msg: err.Error()}) return - } - tok, err := jwt.Parse([]byte(token), jwt.WithKeySet(jwks)) - if err != nil { - ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, common.SimpleApiResp{Status: common.RespFailed, Msg: "Authorization token cannot be verified"}) - } - scopeValidator := token_scopes.CreateScopeValidator([]token_scopes.TokenScope{token_scopes.Localcache_Purge}, true) - if err = jwt.Validate(tok, jwt.WithValidator(scopeValidator)); err != nil { - ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, common.SimpleApiResp{Status: common.RespFailed, Msg: "Authorization token is not valid: " + err.Error()}) + } else if !verified { + ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, common.SimpleApiResp{Status: common.RespFailed, Msg: "Unknown verification error"}) return } + err = lc.purge() if err != nil { if err == purgeTimeout { diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index f11dfdc9c..10f2dcdc5 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -353,6 +353,10 @@ func TestForcePurge(t *testing.T) { Path: param.LocalCache_Socket.GetString(), } + _, err = utils.MakeRequest(ft.Ctx, param.Server_ExternalWebUrl.GetString()+"/api/v1.0/localcache/purge", "POST", nil, map[string]string{"Authorization": "Bearer abcd"}) + assert.Error(t, err) + require.Equal(t, fmt.Sprintf("The POST attempt to %s/api/v1.0/localcache/purge resulted in status code 403", param.Server_ExternalWebUrl.GetString()), err.Error()) + // Populate the cache with our test files size := 0 for idx := 0; idx < 4; idx++ { From fe957ea6ecff4fe6a1cbe3c09ce93faf8887c7f8 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 10 Mar 2024 09:41:07 -0500 Subject: [PATCH 38/45] Fix tests to have audience --- registry/registry_ui_test.go | 1 + token/token_verify.go | 14 ++++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/registry/registry_ui_test.go b/registry/registry_ui_test.go index 52029e73f..d59485708 100644 --- a/registry/registry_ui_test.go +++ b/registry/registry_ui_test.go @@ -244,6 +244,7 @@ func TestListNamespaces(t *testing.T) { tokenCfg.Lifetime = time.Minute tokenCfg.Subject = "admin" tokenCfg.AddScopes(token_scopes.WebUi_Access) + tokenCfg.AddAudienceAny() token, err := tokenCfg.CreateToken() require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "login", Value: token, Path: "/"}) diff --git a/token/token_verify.go b/token/token_verify.go index 99f3a7839..e47c886e4 100644 --- a/token/token_verify.go +++ b/token/token_verify.go @@ -161,7 +161,7 @@ func (a AuthCheckImpl) localIssuerCheck(c *gin.Context, strToken string, expecte // authOption.Scopes, return true and set "User" context to the issuer if any of the issuer check succeed // // Scope check will pass if your token has ANY of the scopes in authOption.Scopes -func Verify(ctx *gin.Context, authOption AuthOption) (status int, verfied bool, err error) { +func Verify(ctx *gin.Context, authOption AuthOption) (status int, verified bool, err error) { token := "" // Find token from the provided sources list, stop when found the first token tokenFound := false @@ -177,16 +177,17 @@ func Verify(ctx *gin.Context, authOption AuthOption) (status int, verfied bool, } else { token = cookieToken tokenFound = true - break } case Header: headerToken := ctx.Request.Header["Authorization"] if len(headerToken) <= 0 { continue } else { - token = strings.TrimPrefix(headerToken[0], "Bearer ") - tokenFound = true - break + var found bool + token, found = strings.CutPrefix(headerToken[0], "Bearer ") + if found { + tokenFound = true + } } case Authz: authzToken := ctx.Request.URL.Query()["authz"] @@ -195,7 +196,6 @@ func Verify(ctx *gin.Context, authOption AuthOption) (status int, verfied bool, } else { token = authzToken[0] tokenFound = true - break } default: log.Error("Invalid/unsupported token source") @@ -214,14 +214,12 @@ func Verify(ctx *gin.Context, authOption AuthOption) (status int, verfied bool, case FederationIssuer: if err := authChecker.federationIssuerCheck(ctx, token, authOption.Scopes, authOption.AllScopes); err != nil { errMsg += fmt.Sprintln("Cannot verify token with federation issuer: ", err) - break } else { return http.StatusOK, true, nil } case LocalIssuer: if err := authChecker.localIssuerCheck(ctx, token, authOption.Scopes, authOption.AllScopes); err != nil { errMsg += fmt.Sprintln("Cannot verify token with server issuer: ", err) - break } else { return http.StatusOK, true, nil } From d2b775f9a2b0a3a0191b73d2c2d878ab57224ee0 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 10 Mar 2024 09:41:54 -0500 Subject: [PATCH 39/45] Ensure registry tests can run without a network connection --- registry/client_commands_test.go | 5 +++++ registry/registry_ui_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/registry/client_commands_test.go b/registry/client_commands_test.go index ee8bd0a03..873bcea1f 100644 --- a/registry/client_commands_test.go +++ b/registry/client_commands_test.go @@ -131,6 +131,11 @@ func TestRegistryKeyChainingOSDF(t *testing.T) { viper.Reset() _ = config.SetPreferredPrefix("OSDF") + viper.Set("Federation.DirectorUrl", "https://osdf-director.osg-htc.org") + viper.Set("Federation.RegistryUrl", "https://osdf-registry.osg-htc.org") + viper.Set("Federation.JwkUrl", "https://osg-htc.org/osdf/public_signing_key.jwks") + viper.Set("Federation.BrokerUrl", "https://osdf-director.osg-htc.org") + // On by default, but just to make things explicit viper.Set("Registry.RequireKeyChaining", true) diff --git a/registry/registry_ui_test.go b/registry/registry_ui_test.go index d59485708..87de3d0e8 100644 --- a/registry/registry_ui_test.go +++ b/registry/registry_ui_test.go @@ -27,6 +27,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/http/httptest" "net/url" @@ -1245,6 +1246,36 @@ func TestPopulateRegistrationFields(t *testing.T) { } func TestGetCachedInstitutions(t *testing.T) { + svr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.URL.Path == "/institution_ids" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`[{"id": "https://osg-htc.org/iid/05ejpqr48", "name": "Worcester Polytechnic Institute", "ror_id": "https://ror.org/05ejpqr48"}, {"id": "https://osg-htc.org/iid/017t4sb47", "name": "Wright Institute", "ror_id": "https://ror.org/017t4sb47"}, {"id": "https://osg-htc.org/iid/03v76x132", "name": "Yale University", "ror_id": "https://ror.org/03v76x132"}]`)) + require.NoError(t, err) + return + } + w.WriteHeader(http.StatusNotFound) + })) + + // Hijack the common transport used by Pelican, forcing all connections to go to our test server + transport := config.GetTransport() + oldDial := transport.DialContext + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, svr.Listener.Addr().Network(), svr.Listener.Addr().String()) + } + oldConfig := transport.TLSClientConfig + transport.TLSClientConfig = svr.TLS.Clone() + transport.TLSClientConfig.InsecureSkipVerify = true + t.Cleanup(func() { + transport.DialContext = oldDial + transport.TLSClientConfig = oldConfig + }) + t.Run("nil-cache-returns-error", func(t *testing.T) { func() { institutionsCacheMutex.Lock() From d98f5ade6cf52dd9563e7d09f5fbf9b1b4f4a195 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 10 Mar 2024 10:08:13 -0500 Subject: [PATCH 40/45] Fix SetVersion to use correct version --- token/token_create.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/token/token_create.go b/token/token_create.go index b2e19776d..eaaf6da40 100644 --- a/token/token_create.go +++ b/token/token_create.go @@ -141,14 +141,14 @@ func (config *TokenConfig) SetVersion(ver string) error { if ver == "" { ver = "scitokens:2.0" } else if !scitokensVerPattern.MatchString(ver) { - return errors.New("the provided version '" + config.version + + return errors.New("the provided version '" + ver + "' is not valid. It must match 'scitokens:', where version is of the form 2.x") } } else if config.tokenProfile == TokenProfileWLCG { if ver == "" { ver = "1.0" - } else if !wlcgVerPattern.MatchString(config.version) { - return errors.New("the provided version '" + config.version + "' is not valid. It must be of the form '1.x'") + } else if !wlcgVerPattern.MatchString(ver) { + return errors.New("the provided version '" + ver + "' is not valid. It must be of the form '1.x'") } } config.version = ver From cd3d90629d7e0c477c46d4f0b074c276ec8b2a21 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 10 Mar 2024 11:26:30 -0500 Subject: [PATCH 41/45] Correct ownership of the origin directory when run as root --- fed_test_utils/fed.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fed_test_utils/fed.go b/fed_test_utils/fed.go index 0c96616e1..b3be7ac1d 100644 --- a/fed_test_utils/fed.go +++ b/fed_test_utils/fed.go @@ -103,6 +103,11 @@ func NewFedTest(t *testing.T) (ft *FedTest) { err = os.Chmod(originDir, permissions) require.NoError(t, err) + // Change ownership on the temporary origin directory so files can be uploaded + uinfo, err := config.GetDaemonUserInfo() + require.NoError(t, err) + require.NoError(t, os.Chown(originDir, uinfo.Uid, uinfo.Gid)) + viper.Set("Origin.ExportVolume", originDir+":/test") viper.Set("Origin.Mode", "posix") viper.Set("Origin.EnableFallbackRead", true) From 316cc2e818a468e4e1f21bc3c7fffba240d0b4ce Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 10 Mar 2024 21:13:17 -0500 Subject: [PATCH 42/45] Sort the caches based on responsiveness If we cannot get a HEAD request through a cache within a second, it's a strong signal that it won't be responsive later on when we try to download. Do simultaneous HEAD queries and sort those without a response after 1 second to the end. The intent is to reduce the time cost of a totally unresponsive cache. --- client/handle_http.go | 111 ++++++++++++++++++++++++++++++++++--- client/handle_http_test.go | 71 ++++++++++++++++++++++-- 2 files changed, 171 insertions(+), 11 deletions(-) diff --git a/client/handle_http.go b/client/handle_http.go index 5f9f9d61c..31a4cdf6a 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1184,6 +1184,98 @@ func runTransferWorker(ctx context.Context, workChan <-chan *clientTransferFile, } } +// If there are multiple potential attempts, try to see if we can quickly eliminate some of them +// +// Attempts a HEAD against all the endpoints simultaneously. Put any that don't respond within +// a second behind those that do respond. +func sortAttempts(ctx context.Context, path string, attempts []transferAttemptDetails) (size int64, results []transferAttemptDetails) { + size = -1 + if len(attempts) < 2 { + results = attempts + return + } + transport := config.GetTransport() + headChan := make(chan struct { + idx int + size uint64 + err error + }) + defer close(headChan) + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + for idx, transferEndpoint := range attempts { + tUrl := *transferEndpoint.Url + tUrl.Path = path + + go func(idx int, tUrl string) { + headClient := &http.Client{Transport: transport} + headRequest, _ := http.NewRequestWithContext(ctx, "HEAD", tUrl, nil) + var headResponse *http.Response + headResponse, err := headClient.Do(headRequest) + if err != nil { + headChan <- struct { + idx int + size uint64 + err error + }{idx, 0, err} + return + } + headResponse.Body.Close() + contentLengthStr := headResponse.Header.Get("Content-Length") + size := int64(0) + if contentLengthStr != "" { + size, err = strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + log.Errorln("problem converting content-length to an int:", err) + } + } + headChan <- struct { + idx int + size uint64 + err error + }{idx, uint64(size), nil} + }(idx, tUrl.String()) + } + finished := make(map[int]bool) + for ctr := 0; ctr != len(attempts); ctr++ { + result := <-headChan + if result.err != nil { + if result.err != context.Canceled { + log.Debugf("Failure when doing a HEAD request against %s: %s", attempts[result.idx].Url.String(), result.err.Error()) + } + } else { + finished[result.idx] = true + if result.idx == 0 { + cancel() + } + if size <= int64(result.size) { + size = int64(result.size) + } + } + } + // Sort all the successful attempts first; use stable sort so the original ordering + // is preserved if the two entries are both successful or both unsuccessful. + type sorter struct { + good bool + attempt transferAttemptDetails + } + tmpResults := make([]sorter, len(attempts)) + for idx, attempt := range attempts { + tmpResults[idx] = sorter{finished[idx], attempt} + } + results = make([]transferAttemptDetails, len(attempts)) + slices.SortStableFunc(tmpResults, func(left sorter, right sorter) int { + if left.good && !right.good { + return -1 + } + return 0 + }) + for idx, val := range tmpResults { + results[idx] = val.attempt + } + return +} + func downloadObject(transfer *transferFile) (transferResults TransferResults, err error) { log.Debugln("Downloading file from", transfer.remoteURL, "to", transfer.localPath) // Remove the source from the file path @@ -1192,9 +1284,12 @@ func downloadObject(transfer *transferFile) (transferResults TransferResults, er if err = os.MkdirAll(directory, 0700); err != nil { return } + + size, attempts := sortAttempts(transfer.job.ctx, transfer.remoteURL.Path, transfer.attempts) + transferResults = newTransferResults(transfer.job) success := false - for idx, transferEndpoint := range transfer.attempts { // For each transfer (usually 3), populate each attempt given + for idx, transferEndpoint := range attempts { // For each transfer attempt (usually 3), try to download via HTTP var attempt TransferResult var timeToFirstByte float64 var serverVersion string @@ -1206,7 +1301,7 @@ func downloadObject(transfer *transferFile) (transferResults TransferResults, er transferEndpointUrl.Path = transfer.remoteURL.Path transferEndpoint.Url = &transferEndpointUrl transferStartTime := time.Now() - if downloaded, timeToFirstByte, serverVersion, err = downloadHTTP(transfer.ctx, transfer.engine, transfer.callback, transferEndpoint, transfer.localPath, transfer.token, &transfer.accounting); err != nil { + if downloaded, timeToFirstByte, serverVersion, err = downloadHTTP(transfer.ctx, transfer.engine, transfer.callback, transferEndpoint, transfer.localPath, size, transfer.token, &transfer.accounting); err != nil { log.Debugln("Failed to download:", err) transferEndTime := time.Now() transferTime := transferEndTime.Unix() - transferStartTime.Unix() @@ -1278,7 +1373,7 @@ func parseTransferStatus(status string) (int, string) { // Perform the actual download of the file // // Returns the downloaded size, time to 1st byte downloaded, serverVersion and an error if there is one -func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCallbackFunc, transfer transferAttemptDetails, dest string, token string, payload *payloadStruct) (downloaded int64, timeToFirstByte float64, serverVersion string, err error) { +func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCallbackFunc, transfer transferAttemptDetails, dest string, totalSize int64, token string, payload *payloadStruct) (downloaded int64, timeToFirstByte float64, serverVersion string, err error) { defer func() { if r := recover(); r != nil { log.Errorln("Panic occurred in downloadHTTP:", r) @@ -1287,15 +1382,17 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall } }() - var totalSize int64 = 0 - lastUpdate := time.Now() if callback != nil { callback(dest, 0, 0, false) } defer func() { if callback != nil { - callback(dest, downloaded, totalSize, true) + finalSize := int64(0) + if totalSize >= 0 { + finalSize = totalSize + } + callback(dest, downloaded, finalSize, true) } if te != nil { te.ewmaCtr.Add(int64(time.Since(lastUpdate))) @@ -1396,7 +1493,7 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall // Size of the download totalSize = resp.Size() // Do a head request for content length if resp.Size is unknown - if totalSize <= 0 { + if totalSize <= 0 && !resp.IsComplete() { headClient := &http.Client{Transport: transport} headRequest, _ := http.NewRequest("HEAD", transferUrl.String(), nil) var headResponse *http.Response diff --git a/client/handle_http_test.go b/client/handle_http_test.go index 32992df1a..6eecf3f57 100644 --- a/client/handle_http_test.go +++ b/client/handle_http_test.go @@ -36,6 +36,7 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/namespaces" @@ -187,7 +188,7 @@ func TestSlowTransfers(t *testing.T) { var err error // Do a quick timeout go func() { - _, _, _, err = downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), "", nil) + _, _, _, err = downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", nil) finishedChannel <- true }() @@ -258,7 +259,7 @@ func TestStoppedTransfer(t *testing.T) { var err error go func() { - _, _, _, err = downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), "", nil) + _, _, _, err = downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", nil) finishedChannel <- true }() @@ -290,7 +291,7 @@ func TestConnectionError(t *testing.T) { addr := l.Addr().String() l.Close() - _, _, _, err = downloadHTTP(ctx, nil, nil, transferAttemptDetails{Url: &url.URL{Host: addr, Scheme: "http"}, Proxy: false}, filepath.Join(t.TempDir(), "test.txt"), "", nil) + _, _, _, err = downloadHTTP(ctx, nil, nil, transferAttemptDetails{Url: &url.URL{Host: addr, Scheme: "http"}, Proxy: false}, filepath.Join(t.TempDir(), "test.txt"), -1, "", nil) assert.IsType(t, &ConnectionSetupError{}, err) @@ -325,7 +326,7 @@ func TestTrailerError(t *testing.T) { assert.Equal(t, svr.URL, transfers[0].Url.String()) // Call DownloadHTTP and check if the error is returned correctly - _, _, _, err := downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), "", nil) + _, _, _, err := downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", nil) assert.NotNil(t, err) assert.EqualError(t, err, "transfer error: Unable to read test.txt; input/output error") @@ -388,3 +389,65 @@ func TestFailedUpload(t *testing.T) { assert.Fail(t, "Timeout while waiting for response") } } + +func TestSortAttempts(t *testing.T) { + ctx, cancel, _ := test_utils.TestContext(context.Background(), t) + + neverRespond := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + select { + case <-ctx.Done(): + case <-ticker.C: + } + }) + alwaysRespond := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Content-Length", "42") + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + }) + svr1 := httptest.NewServer(neverRespond) + defer svr1.Close() + url1, err := url.Parse(svr1.URL) + require.NoError(t, err) + attempt1 := transferAttemptDetails{Url: url1} + + svr2 := httptest.NewServer(alwaysRespond) + defer svr2.Close() + url2, err := url.Parse(svr2.URL) + require.NoError(t, err) + attempt2 := transferAttemptDetails{Url: url2} + + svr3 := httptest.NewServer(alwaysRespond) + defer svr3.Close() + url3, err := url.Parse(svr3.URL) + require.NoError(t, err) + attempt3 := transferAttemptDetails{Url: url3} + + defer cancel() + + size, results := sortAttempts(ctx, "/path", []transferAttemptDetails{attempt1, attempt2, attempt3}) + assert.Equal(t, int64(42), size) + assert.Equal(t, svr2.URL, results[0].Url.String()) + assert.Equal(t, svr3.URL, results[1].Url.String()) + assert.Equal(t, svr1.URL, results[2].Url.String()) + + size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt2, attempt3, attempt1}) + assert.Equal(t, int64(42), size) + assert.Equal(t, svr2.URL, results[0].Url.String()) + assert.Equal(t, svr3.URL, results[1].Url.String()) + assert.Equal(t, svr1.URL, results[2].Url.String()) + + size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt1, attempt1}) + assert.Equal(t, int64(-1), size) + assert.Equal(t, svr1.URL, results[0].Url.String()) + assert.Equal(t, svr1.URL, results[1].Url.String()) + + size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt2, attempt3}) + assert.Equal(t, int64(42), size) + assert.Equal(t, svr2.URL, results[0].Url.String()) + assert.Equal(t, svr3.URL, results[1].Url.String()) +} From 223125e4fa39760b619fcd3d0d11dfcda101bd00 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sun, 10 Mar 2024 22:15:21 -0500 Subject: [PATCH 43/45] Perform HEAD against multiple caches if provided When we need to do stat, if there are only caches provided, then perform HEAD against multiple simultaneously. Meant to reduce the impact of one bad cache. --- client/handle_http.go | 196 +++++++++++++++++++++++++++--------------- 1 file changed, 126 insertions(+), 70 deletions(-) diff --git a/client/handle_http.go b/client/handle_http.go index 31a4cdf6a..b28a9beac 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -2045,86 +2045,142 @@ func (te *TransferEngine) walkDirUpload(job *clientTransferJob, transfers []tran return err } +// Invoke HEAD against a remote URL, using the provided namespace information +// +// If a "dirlist host" is given, then that is used for the namespace info. +// Otherwise, the first three caches are queried simultaneously. +// For any of the queries, if the attempt with the proxy fails, a second attempt +// is made without. func statHttp(ctx context.Context, dest *url.URL, namespace namespaces.Namespace, token string) (size uint64, err error) { - // Parse the writeback host as a URL - statHost := namespace.WriteBackHost - if len(namespace.SortedDirectorCaches) > 0 { - statHost = namespace.SortedDirectorCaches[0].EndpointUrl - } - if statHost == "" { - statHost = namespace.DirListHost + statHosts := make([]url.URL, 0, 3) + if namespace.DirListHost != "" { + var endpoint *url.URL + endpoint, err = url.Parse(namespace.DirListHost) + if err != nil { + return + } + statHosts = append(statHosts, *endpoint) + } else if len(namespace.SortedDirectorCaches) > 0 { + for idx, cache := range namespace.SortedDirectorCaches { + if idx > 2 { + break + } + var endpoint *url.URL + endpoint, err = url.Parse(cache.EndpointUrl) + if err != nil { + return + } + statHosts = append(statHosts, *endpoint) + } + } else if namespace.WriteBackHost != "" { + var endpoint *url.URL + endpoint, err = url.Parse(namespace.WriteBackHost) + if err != nil { + return + } + statHosts = append(statHosts, *endpoint) } - writebackhostUrl, err := url.Parse(statHost) - if err != nil { - return + + type statResults struct { + size uint64 + err error } - dest.Host = writebackhostUrl.Host - dest.Scheme = "https" + resultsChan := make(chan statResults) + transport := config.GetTransport() + client := &http.Client{Transport: transport} + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() - canDisableProxy := CanDisableProxy() - disableProxy := !isProxyEnabled() + for _, statUrl := range statHosts { + destCopy := *dest + destCopy.Host = statUrl.Host + destCopy.Scheme = statUrl.Scheme - var resp *http.Response - for { - transport := config.GetTransport() - if disableProxy { - log.Debugln("Performing HEAD (without proxy)", dest.String()) - transport.Proxy = nil - } else { - log.Debugln("Performing HEAD", dest.String()) - } + go func(endpoint *url.URL) { + canDisableProxy := CanDisableProxy() + disableProxy := !isProxyEnabled() - client := &http.Client{Transport: transport} - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "HEAD", dest.String(), nil) - if err != nil { - log.Errorln("Failed to create HTTP request:", err) - return - } + var resp *http.Response + for { + if disableProxy { + log.Debugln("Performing HEAD (without proxy)", endpoint.String()) + transport.Proxy = nil + } else { + log.Debugln("Performing HEAD", endpoint.String()) + } - if token != "" { - req.Header.Set("Authorization", "Bearer "+token) - } + var req *http.Request + req, err = http.NewRequestWithContext(ctx, "HEAD", endpoint.String(), nil) + if err != nil { + log.Errorln("Failed to create HTTP request:", err) + resultsChan <- statResults{0, err} + return + } - resp, err = client.Do(req) - if err == nil { - break - } - if urle, ok := err.(*url.Error); canDisableProxy && !disableProxy && ok && urle.Unwrap() != nil { - if ope, ok := urle.Unwrap().(*net.OpError); ok && ope.Op == "proxyconnect" { - log.Warnln("Failed to connect to proxy; will retry without:", ope) - disableProxy = true - continue + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err = client.Do(req) + if err == nil { + break + } + if urle, ok := err.(*url.Error); canDisableProxy && !disableProxy && ok && urle.Unwrap() != nil { + if ope, ok := urle.Unwrap().(*net.OpError); ok && ope.Op == "proxyconnect" { + log.Warnln("Failed to connect to proxy; will retry without:", ope) + disableProxy = true + continue + } + } + log.Errorln("Failed to get HTTP response:", err) + resultsChan <- statResults{0, err} + return } - } - log.Errorln("Failed to get HTTP response:", err) - return - } - if resp.StatusCode == 200 { - defer resp.Body.Close() - contentLengthStr := resp.Header.Get("Content-Length") - if len(contentLengthStr) == 0 { - log.Errorln("HEAD response did not include Content-Length header") - err = errors.New("HEAD response did not include Content-Length header") - return - } - var contentLength int64 - contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) - if err != nil { - log.Errorf("Unable to parse Content-Length header value (%s) as integer: %s", contentLengthStr, err) - return - } - return uint64(contentLength), nil - } else { - var respB []byte - respB, err = io.ReadAll(resp.Body) - if err != nil { - log.Errorln("Failed to read error message:", err) - return + defer resp.Body.Close() + if resp.StatusCode == 200 { + contentLengthStr := resp.Header.Get("Content-Length") + if len(contentLengthStr) == 0 { + log.Errorln("HEAD response did not include Content-Length header") + err = errors.New("HEAD response did not include Content-Length header") + resultsChan <- statResults{0, err} + return + } + var contentLength int64 + contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + log.Errorf("Unable to parse Content-Length header value (%s) as integer: %s", contentLengthStr, err) + resultsChan <- statResults{0, err} + return + } + resultsChan <- statResults{uint64(contentLength), nil} + } else { + var respB []byte + respB, err = io.ReadAll(resp.Body) + if err != nil { + log.Errorln("Failed to read error message:", err) + return + } + err = &HttpErrResp{resp.StatusCode, fmt.Sprintf("Request failed (HTTP status %d): %s", resp.StatusCode, string(respB))} + resultsChan <- statResults{0, err} + } + }(&destCopy) + } + success := false + for ctr := 0; ctr < len(statHosts); ctr++ { + result := <-resultsChan + if result.err == nil { + if !success { + cancel() + success = true + size = result.size + } + } else if err == nil && result.err != context.Canceled { + err = result.err } - defer resp.Body.Close() - err = &HttpErrResp{resp.StatusCode, fmt.Sprintf("Request failed (HTTP status %d): %s", resp.StatusCode, string(respB))} - return } + if success { + err = nil + } + return } From f7d47c989b9599ef3bc2bdfdccef7cf67d3f8316 Mon Sep 17 00:00:00 2001 From: Haoming Meng Date: Mon, 11 Mar 2024 17:14:47 +0000 Subject: [PATCH 44/45] Add comment to token audience function --- token/token_create.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/token/token_create.go b/token/token_create.go index eaaf6da40..0cb44d609 100644 --- a/token/token_create.go +++ b/token/token_create.go @@ -155,6 +155,10 @@ func (config *TokenConfig) SetVersion(ver string) error { return nil } +// Add audience="any" to the config based on the token profile. +// +// For WLCG profile, it will be "https://wlcg.cern.ch/jwt/v1/any". +// For Scitokens profile, it will be "ANY" func (config *TokenConfig) AddAudienceAny() { newAud := "" switch config.tokenProfile { From fffac65e6c8c833795bdf442eba099a640d56454 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Mon, 11 Mar 2024 13:35:22 -0500 Subject: [PATCH 45/45] Improve sorting logic for HEAD requests If the first cache returns back successfully, sort all the error'd out cases to the back of the array and the pending (canceled) ones to the middle. --- client/handle_http.go | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/client/handle_http.go b/client/handle_http.go index b28a9beac..fc49b4be0 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -1226,27 +1226,40 @@ func sortAttempts(ctx context.Context, path string, attempts []transferAttemptDe if contentLengthStr != "" { size, err = strconv.ParseInt(contentLengthStr, 10, 64) if err != nil { - log.Errorln("problem converting content-length to an int:", err) + err = errors.Wrap(err, "problem converting Content-Length in response to an int") + log.Errorln(err.Error()) + } } headChan <- struct { idx int size uint64 err error - }{idx, uint64(size), nil} + }{idx, uint64(size), err} }(idx, tUrl.String()) } - finished := make(map[int]bool) + // 1 -> success. + // 0 -> pending. + // -1 -> error. + finished := make(map[int]int) for ctr := 0; ctr != len(attempts); ctr++ { result := <-headChan if result.err != nil { if result.err != context.Canceled { log.Debugf("Failure when doing a HEAD request against %s: %s", attempts[result.idx].Url.String(), result.err.Error()) + finished[result.idx] = -1 } } else { - finished[result.idx] = true + finished[result.idx] = 1 if result.idx == 0 { cancel() + // If the first responds successfully, we want to return immediately instead of giving + // the other caches time to respond - the result is "good enough". + // - Any cache with confirmed errors (-1) is sorted to the back. + // - Any cache which is still pending (0) is kept in place. + for ctr := 0; ctr < len(attempts); ctr++ { + finished[ctr] = 1 + } } if size <= int64(result.size) { size = int64(result.size) @@ -1256,7 +1269,7 @@ func sortAttempts(ctx context.Context, path string, attempts []transferAttemptDe // Sort all the successful attempts first; use stable sort so the original ordering // is preserved if the two entries are both successful or both unsuccessful. type sorter struct { - good bool + good int attempt transferAttemptDetails } tmpResults := make([]sorter, len(attempts)) @@ -1265,7 +1278,7 @@ func sortAttempts(ctx context.Context, path string, attempts []transferAttemptDe } results = make([]transferAttemptDetails, len(attempts)) slices.SortStableFunc(tmpResults, func(left sorter, right sorter) int { - if left.good && !right.good { + if left.good > right.good { return -1 } return 0