diff --git a/cmd/allInOne.go b/cmd/allInOne.go index d7e9e6a..132c81d 100644 --- a/cmd/allInOne.go +++ b/cmd/allInOne.go @@ -17,6 +17,7 @@ import ( storage2 "github.com/terrariumcloud/terrarium/internal/module/services/storage" "github.com/terrariumcloud/terrarium/internal/module/services/tag_manager" "github.com/terrariumcloud/terrarium/internal/module/services/version_manager" + providerStorage "github.com/terrariumcloud/terrarium/internal/provider/services/storage" providerVersionManager "github.com/terrariumcloud/terrarium/internal/provider/services/version_manager" "github.com/terrariumcloud/terrarium/internal/release/services/release" "github.com/terrariumcloud/terrarium/internal/restapi/browse" @@ -92,6 +93,12 @@ var allInOneCmd = &cobra.Command{ Schema: providerVersionManager.GetProviderVersionsSchema(providerVersionManager.VersionsTableName), } + providerStorageServiceServer := &providerStorage.StorageService{ + Client: storage.NewS3Client(awsSessionConfig), + BucketName: providerStorage.BucketName, + Region: awsSessionConfig.Region, + } + services := []grpcServices.Service{ dependencyServiceServer, registrarServiceServer, @@ -100,6 +107,7 @@ var allInOneCmd = &cobra.Command{ releaseServiceServer, versionManagerServer, providerVersionManagerServer, + providerStorageServiceServer, } otelShutdown := initOpenTelemetry("all-in-one") @@ -114,6 +122,7 @@ var allInOneCmd = &cobra.Command{ dependency_manager.NewDependencyManagerGrpcClient(allInOneInternalEndpoint), release.NewPublisherGrpcClient(allInOneInternalEndpoint), providerVersionManager.NewVersionManagerGrpcClient(allInOneInternalEndpoint), + providerStorage.NewStorageGrpcClient(allInOneInternalEndpoint), ) startAllInOneGrpcServices([]grpcServices.Service{gatewayServer}, allInOneGrpcGatewayEndpoint) @@ -124,7 +133,7 @@ var allInOneCmd = &cobra.Command{ providerVersionManager.NewVersionManagerGrpcClient(allInOneInternalEndpoint)) modulesAPIServer := modulesv1.New(version_manager.NewVersionManagerGrpcClient(allInOneInternalEndpoint), storage2.NewStorageGrpcClient(allInOneInternalEndpoint)) - providersAPIServer := providersv1.New(providerVersionManager.NewVersionManagerGrpcClient(allInOneInternalEndpoint)) + providersAPIServer := providersv1.New(providerVersionManager.NewVersionManagerGrpcClient(allInOneInternalEndpoint), providerStorage.NewStorageGrpcClient(allInOneInternalEndpoint)) router := mux.NewRouter() router.PathPrefix("/modules").Handler(modulesAPIServer.GetHttpHandler("/modules")) @@ -146,6 +155,7 @@ func init() { allInOneCmd.Flags().StringVar(&dependency_manager.ModuleDependenciesTableName, "module-dependencies-table", dependency_manager.DefaultModuleDependenciesTableName, "Module dependencies table name") allInOneCmd.Flags().StringVar(&dependency_manager.ContainerDependenciesTableName, "container-dependencies-table", dependency_manager.DefaultContainerDependenciesTableName, "Module container dependencies table name") allInOneCmd.Flags().StringVar(&providerVersionManager.VersionsTableName, "provider-table", providerVersionManager.DefaultProviderVersionsTableName, "Provider versions table name") + allInOneCmd.Flags().StringVar(&providerStorage.BucketName, "provider-storage-bucket", providerStorage.DefaultBucketName, "Provider bucket name") } func startAllInOneGrpcServices(services []grpcServices.Service, endpoint string) { diff --git a/cmd/gateway.go b/cmd/gateway.go index a0121e7..2a556ee 100644 --- a/cmd/gateway.go +++ b/cmd/gateway.go @@ -7,6 +7,7 @@ import ( "github.com/terrariumcloud/terrarium/internal/module/services/storage" "github.com/terrariumcloud/terrarium/internal/module/services/tag_manager" "github.com/terrariumcloud/terrarium/internal/module/services/version_manager" + providerStorage "github.com/terrariumcloud/terrarium/internal/provider/services/storage" providerVersionManager "github.com/terrariumcloud/terrarium/internal/provider/services/version_manager" "github.com/terrariumcloud/terrarium/internal/release/services/release" @@ -25,10 +26,11 @@ func init() { gatewayCmd.Flags().StringVarP(®istrar.RegistrarServiceEndpoint, "registrar", "", registrar.DefaultRegistrarServiceEndpoint, "GRPC Endpoint for Registrar Service") gatewayCmd.Flags().StringVarP(&dependency_manager.DependencyManagerEndpoint, "dependency-manager", "", dependency_manager.DefaultDependencyManagerEndpoint, "GRPC Endpoint for Dependency Manager Service") gatewayCmd.Flags().StringVarP(&version_manager.VersionManagerEndpoint, "version-manager", "", version_manager.DefaultVersionManagerEndpoint, "GRPC Endpoint for Module Version Manager Service") - gatewayCmd.Flags().StringVarP(&storage.StorageServiceEndpoint, "storage", "", storage.DefaultStorageServiceDefaultEndpoint, "GRPC Endpoint for Storage Service") + gatewayCmd.Flags().StringVarP(&storage.StorageServiceEndpoint, "storage", "", storage.DefaultStorageServiceDefaultEndpoint, "GRPC Endpoint for Module Storage Service") gatewayCmd.Flags().StringVarP(&tag_manager.TagManagerEndpoint, "tag-manager", "", tag_manager.DefaultTagManagerEndpoint, "GRPC Endpoint for Tag Service") gatewayCmd.Flags().StringVarP(&release.ReleaseServiceEndpoint, "release", "", release.DefaultReleaseServiceEndpoint, "GRPC Endpoint for Release Service") gatewayCmd.Flags().StringVarP(&providerVersionManager.VersionManagerEndpoint, "provider-version-manager", "", providerVersionManager.DefaultProviderVersionManagerEndpoint, "GRPC Endpoint for Provider Version Manager Service") + gatewayCmd.Flags().StringVarP(&providerStorage.StorageServiceEndpoint, "provider-storage", "", providerStorage.DefaultStorageServiceDefaultEndpoint, "GRPC Endpoint for Provider Storage Service") } func runGateway(cmd *cobra.Command, args []string) { @@ -40,6 +42,7 @@ func runGateway(cmd *cobra.Command, args []string) { dependency_manager.NewDependencyManagerGrpcClient(dependency_manager.DependencyManagerEndpoint), release.NewPublisherGrpcClient(release.ReleaseServiceEndpoint), providerVersionManager.NewVersionManagerGrpcClient(providerVersionManager.VersionManagerEndpoint), + providerStorage.NewStorageGrpcClient(providerStorage.StorageServiceEndpoint), ) startGRPCService("api-gateway", gatewayServer) diff --git a/cmd/provider_storage.go b/cmd/provider_storage.go new file mode 100644 index 0000000..9018243 --- /dev/null +++ b/cmd/provider_storage.go @@ -0,0 +1,31 @@ +package cmd + +import ( + providerStorage "github.com/terrariumcloud/terrarium/internal/provider/services/storage" + "github.com/terrariumcloud/terrarium/internal/storage" + + "github.com/spf13/cobra" +) + +var providerStorageServiceCmd = &cobra.Command{ + Use: "provider-storage", + Short: "Starts the Terrarium GRPC Provider Storage service", + Long: "Runs the Terrarium GRPC Provider Storage server.", + Run: runProviderStorageService, +} + +func init() { + rootCmd.AddCommand(providerStorageServiceCmd) + providerStorageServiceCmd.Flags().StringVarP(&providerStorage.BucketName, "bucket", "b", providerStorage.DefaultBucketName, "Provider bucket name") +} + +func runProviderStorageService(cmd *cobra.Command, args []string) { + + storageServiceServer := &providerStorage.StorageService{ + Client: storage.NewS3Client(awsSessionConfig), + BucketName: providerStorage.BucketName, + Region: awsSessionConfig.Region, + } + + startGRPCService("provider-storage-s3", storageServiceServer) +} diff --git a/cmd/rest_providers_v1.go b/cmd/rest_providers_v1.go index 405ca32..e800cdd 100644 --- a/cmd/rest_providers_v1.go +++ b/cmd/rest_providers_v1.go @@ -1,6 +1,7 @@ package cmd import ( + "github.com/terrariumcloud/terrarium/internal/provider/services/storage" "github.com/terrariumcloud/terrarium/internal/provider/services/version_manager" providersv1 "github.com/terrariumcloud/terrarium/internal/restapi/providers/v1" @@ -25,10 +26,15 @@ func init() { "Mount path for the rest API server used to process request relative to a particular URL in a reverse proxy type setup", ) providersV1Cmd.Flags().StringVarP(&version_manager.VersionManagerEndpoint, "provider-version-manager", "", version_manager.DefaultProviderVersionManagerEndpoint, "GRPC Endpoint for Version Manager Service") + providersV1Cmd.Flags().StringVarP(&storage.StorageServiceEndpoint, "provider-storage", "", storage.DefaultStorageServiceDefaultEndpoint, "GRPC Endpoint for Provider Storage Service") + rootCmd.AddCommand(providersV1Cmd) } func runRESTProvidersV1Server(cmd *cobra.Command, args []string) { - restAPIServer := providersv1.New(version_manager.NewVersionManagerGrpcClient(version_manager.VersionManagerEndpoint)) + restAPIServer := providersv1.New( + version_manager.NewVersionManagerGrpcClient(version_manager.VersionManagerEndpoint), + storage.NewStorageGrpcClient(storage.StorageServiceEndpoint), + ) startRESTAPIService("rest-providers-v1", mountPathProviders, restAPIServer) } diff --git a/docker-compose.yaml b/docker-compose.yaml index 880cf74..2ef4237 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -170,6 +170,27 @@ services: - "$AWS_SECRET_ACCESS_KEY" - "--aws-region" - "$AWS_DEFAULT_REGION" + provider-storage: + build: . + image: terrarium:dev + container_name: terrarium-provider-storage-service + environment: + - AWS_ACCESS_KEY_ID + - AWS_SECRET_ACCESS_KEY + - AWS_DEFAULT_REGION + - OTEL_EXPORTER_OTLP_ENDPOINT=jaeger:4317 + ports: + - 50010:3001 + networks: + - terrarium + command: + - provider-storage + - "--aws-access-key-id" + - "$AWS_ACCESS_KEY_ID" + - "--aws-secret-access-key" + - "$AWS_SECRET_ACCESS_KEY" + - "--aws-region" + - "$AWS_DEFAULT_REGION" release: build: . image: terrarium:dev diff --git a/internal/common/gateway/gateway.go b/internal/common/gateway/gateway.go index 8fd7b52..4c657b9 100644 --- a/internal/common/gateway/gateway.go +++ b/internal/common/gateway/gateway.go @@ -44,6 +44,7 @@ type TerrariumGrpcGateway struct { storageClient moduleServices.StorageClient dependencyManagerClient moduleServices.DependencyManagerClient releasePublisherClient release.PublisherClient + providerStorageClient providerServices.StorageClient } func New(registrarClient moduleServices.RegistrarClient, @@ -52,7 +53,8 @@ func New(registrarClient moduleServices.RegistrarClient, storageClient moduleServices.StorageClient, dependencyManagerClient moduleServices.DependencyManagerClient, releasePublisherClient release.PublisherClient, - providerVersionManagerClient providerServices.VersionManagerClient) *TerrariumGrpcGateway { + providerVersionManagerClient providerServices.VersionManagerClient, + providerStorageClient providerServices.StorageClient) *TerrariumGrpcGateway { return &TerrariumGrpcGateway{ registrarClient: registrarClient, tagManagerClient: tagManagerClient, @@ -61,6 +63,7 @@ func New(registrarClient moduleServices.RegistrarClient, dependencyManagerClient: dependencyManagerClient, releasePublisherClient: releasePublisherClient, providerVersionManagerClient: providerVersionManagerClient, + providerStorageClient: providerStorageClient, } } diff --git a/internal/provider/services/mocks/mock_clients.go b/internal/provider/services/mocks/mock_clients.go new file mode 100644 index 0000000..70a6382 --- /dev/null +++ b/internal/provider/services/mocks/mock_clients.go @@ -0,0 +1,94 @@ +package mocks + +import ( + "context" + + providerServices "github.com/terrariumcloud/terrarium/internal/provider/services" + + "google.golang.org/grpc" +) + +type MockProviderStorageClient struct { + providerServices.StorageClient + DownloadSourceZipInvocations int + DownloadSourceZipClient providerServices.Storage_DownloadProviderSourceZipClient + DownloadSourceZipError error + DownloadShasumInvocations int + DownloadShasumClient providerServices.Storage_DownloadShasumClient + DownloadShasumError error + DownloadShasumSignatureInvocations int + DownloadShasumSignatureClient providerServices.Storage_DownloadShasumSignatureClient + DownloadShasumSignatureError error +} + +func (m *MockProviderStorageClient) DownloadProviderSourceZip(ctx context.Context, in *providerServices.DownloadSourceZipRequest, opts ...grpc.CallOption) (providerServices.Storage_DownloadProviderSourceZipClient, error) { + m.DownloadSourceZipInvocations++ + return m.DownloadSourceZipClient, m.DownloadSourceZipError +} + +func (m *MockProviderStorageClient) DownloadShasum(ctx context.Context, in *providerServices.DownloadShasumRequest, opts ...grpc.CallOption) (providerServices.Storage_DownloadShasumClient, error) { + m.DownloadShasumInvocations++ + return m.DownloadShasumClient, m.DownloadShasumError +} + +func (m *MockProviderStorageClient) DownloadShasumSignature(ctx context.Context, in *providerServices.DownloadShasumRequest, opts ...grpc.CallOption) (providerServices.Storage_DownloadShasumSignatureClient, error) { + m.DownloadShasumSignatureInvocations++ + return m.DownloadShasumSignatureClient, m.DownloadShasumSignatureError +} + +type MockStorage_DownloadProviderSourceZipClient struct { + providerServices.Storage_DownloadProviderSourceZipClient + RecvInvocations int + RecvResponse *providerServices.SourceZipResponse + RecvError error + CloseSendInvocations int + CloseSendError error +} + +func (m *MockStorage_DownloadProviderSourceZipClient) Recv() (*providerServices.SourceZipResponse, error) { + m.RecvInvocations++ + return m.RecvResponse, m.RecvError +} + +func (m *MockStorage_DownloadProviderSourceZipClient) CloseSend() error { + m.CloseSendInvocations++ + return m.CloseSendError +} + +type MockStorage_DownloadProviderShasumClient struct { + providerServices.Storage_DownloadShasumClient + RecvInvocations int + RecvResponse *providerServices.DownloadShasumResponse + RecvError error + CloseSendInvocations int + CloseSendError error +} + +func (m *MockStorage_DownloadProviderShasumClient) Recv() (*providerServices.DownloadShasumResponse, error) { + m.RecvInvocations++ + return m.RecvResponse, m.RecvError +} + +func (m *MockStorage_DownloadProviderShasumClient) CloseSend() error { + m.CloseSendInvocations++ + return m.CloseSendError +} + +type MockStorage_DownloadProviderShasumSignatureClient struct { + providerServices.Storage_DownloadShasumSignatureClient + RecvInvocations int + RecvResponse *providerServices.DownloadShasumResponse + RecvError error + CloseSendInvocations int + CloseSendError error +} + +func (m *MockStorage_DownloadProviderShasumSignatureClient) Recv() (*providerServices.DownloadShasumResponse, error) { + m.RecvInvocations++ + return m.RecvResponse, m.RecvError +} + +func (m *MockStorage_DownloadProviderShasumSignatureClient) CloseSend() error { + m.CloseSendInvocations++ + return m.CloseSendError +} diff --git a/internal/provider/services/mocks/mock_servers.go b/internal/provider/services/mocks/mock_servers.go new file mode 100644 index 0000000..10e5c91 --- /dev/null +++ b/internal/provider/services/mocks/mock_servers.go @@ -0,0 +1,64 @@ +package mocks + +import ( + "context" + + providerServices "github.com/terrariumcloud/terrarium/internal/provider/services" +) + +type MockDownloadProviderSourceZipServer struct { + providerServices.Storage_DownloadProviderSourceZipServer + SendInvocations int + SendResponse *providerServices.SourceZipResponse + SendError error + TotalReceived []byte +} + +func (mds *MockDownloadProviderSourceZipServer) Context() context.Context { + return context.TODO() +} + +func (mds *MockDownloadProviderSourceZipServer) Send(res *providerServices.SourceZipResponse) error { + mds.SendInvocations++ + mds.SendResponse = res + mds.TotalReceived = append(mds.TotalReceived, mds.SendResponse.ZipDataChunk...) + return mds.SendError +} + +type MockDownloadProviderShasumServer struct { + providerServices.Storage_DownloadShasumServer + SendInvocations int + SendResponse *providerServices.DownloadShasumResponse + SendError error + TotalReceived []byte +} + +func (mds *MockDownloadProviderShasumServer) Context() context.Context { + return context.TODO() +} + +func (mds *MockDownloadProviderShasumServer) Send(res *providerServices.DownloadShasumResponse) error { + mds.SendInvocations++ + mds.SendResponse = res + mds.TotalReceived = append(mds.TotalReceived, mds.SendResponse.ShasumDataChunk...) + return mds.SendError +} + +type MockDownloadProviderShasumSignatureServer struct { + providerServices.Storage_DownloadShasumSignatureServer + SendInvocations int + SendResponse *providerServices.DownloadShasumResponse + SendError error + TotalReceived []byte +} + +func (mds *MockDownloadProviderShasumSignatureServer) Context() context.Context { + return context.TODO() +} + +func (mds *MockDownloadProviderShasumSignatureServer) Send(res *providerServices.DownloadShasumResponse) error { + mds.SendInvocations++ + mds.SendResponse = res + mds.TotalReceived = append(mds.TotalReceived, mds.SendResponse.ShasumDataChunk...) + return mds.SendError +} diff --git a/internal/provider/services/storage/client.go b/internal/provider/services/storage/client.go new file mode 100644 index 0000000..b850054 --- /dev/null +++ b/internal/provider/services/storage/client.go @@ -0,0 +1,171 @@ +package storage + +import ( + "context" + "github.com/terrariumcloud/terrarium/internal/common/grpc_service" + "github.com/terrariumcloud/terrarium/internal/provider/services" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "io" +) + +type storageGrpcClient struct { + endpoint string +} + +func NewStorageGrpcClient(endpoint string) services.StorageClient { + return &storageGrpcClient{endpoint: endpoint} +} + +func (s storageGrpcClient) DownloadProviderSourceZip(ctx context.Context, in *services.DownloadSourceZipRequest, opts ...grpc.CallOption) (services.Storage_DownloadProviderSourceZipClient, error) { + if conn, err := grpc_service.CreateGRPCConnection(s.endpoint); err != nil { + return nil, err + } else { + client := services.NewStorageClient(conn) + if download, err := client.DownloadProviderSourceZip(ctx, in, opts...); err == nil { + return &downloadSourceZipClient{conn: conn, client: download}, nil + } else { + _ = conn.Close() + return nil, err + } + } +} + +func (s storageGrpcClient) DownloadShasum(ctx context.Context, in *services.DownloadShasumRequest, opts ...grpc.CallOption) (services.Storage_DownloadShasumClient, error) { + if conn, err := grpc_service.CreateGRPCConnection(s.endpoint); err != nil { + return nil, err + } else { + client := services.NewStorageClient(conn) + if download, err := client.DownloadShasum(ctx, in, opts...); err == nil { + return &downloadShasumClient{conn: conn, client: download}, nil + } else { + _ = conn.Close() + return nil, err + } + } +} + +func (s storageGrpcClient) DownloadShasumSignature(ctx context.Context, in *services.DownloadShasumRequest, opts ...grpc.CallOption) (services.Storage_DownloadShasumSignatureClient, error) { + if conn, err := grpc_service.CreateGRPCConnection(s.endpoint); err != nil { + return nil, err + } else { + client := services.NewStorageClient(conn) + if download, err := client.DownloadShasumSignature(ctx, in, opts...); err == nil { + return &downloadShasumSignatureClient{conn: conn, client: download}, nil + } else { + _ = conn.Close() + return nil, err + } + } +} + +type downloadSourceZipClient struct { + conn *grpc.ClientConn + client services.Storage_DownloadProviderSourceZipClient +} + +func (d downloadSourceZipClient) Recv() (*services.SourceZipResponse, error) { + result, err := d.client.Recv() + if err == io.EOF { + _ = d.conn.Close() + } + return result, err +} + +func (d downloadSourceZipClient) Header() (metadata.MD, error) { + return d.client.Header() +} + +func (d downloadSourceZipClient) Trailer() metadata.MD { + return d.client.Trailer() +} + +func (d downloadSourceZipClient) CloseSend() error { + return d.client.CloseSend() +} + +func (d downloadSourceZipClient) Context() context.Context { + return d.client.Context() +} + +func (d downloadSourceZipClient) SendMsg(m any) error { + return d.client.SendMsg(m) +} + +func (d downloadSourceZipClient) RecvMsg(m any) error { + return d.client.RecvMsg(m) +} + +type downloadShasumClient struct { + conn *grpc.ClientConn + client services.Storage_DownloadShasumClient +} + +func (d downloadShasumClient) Recv() (*services.DownloadShasumResponse, error) { + result, err := d.client.Recv() + if err == io.EOF { + _ = d.conn.Close() + } + return result, err +} + +func (d downloadShasumClient) Header() (metadata.MD, error) { + return d.client.Header() +} + +func (d downloadShasumClient) Trailer() metadata.MD { + return d.client.Trailer() +} + +func (d downloadShasumClient) CloseSend() error { + return d.client.CloseSend() +} + +func (d downloadShasumClient) Context() context.Context { + return d.client.Context() +} + +func (d downloadShasumClient) SendMsg(m any) error { + return d.client.SendMsg(m) +} + +func (d downloadShasumClient) RecvMsg(m any) error { + return d.client.RecvMsg(m) +} + +type downloadShasumSignatureClient struct { + conn *grpc.ClientConn + client services.Storage_DownloadShasumSignatureClient +} + +func (d downloadShasumSignatureClient) Recv() (*services.DownloadShasumResponse, error) { + result, err := d.client.Recv() + if err == io.EOF { + _ = d.conn.Close() + } + return result, err +} + +func (d downloadShasumSignatureClient) Header() (metadata.MD, error) { + return d.client.Header() +} + +func (d downloadShasumSignatureClient) Trailer() metadata.MD { + return d.client.Trailer() +} + +func (d downloadShasumSignatureClient) CloseSend() error { + return d.client.CloseSend() +} + +func (d downloadShasumSignatureClient) Context() context.Context { + return d.client.Context() +} + +func (d downloadShasumSignatureClient) SendMsg(m any) error { + return d.client.SendMsg(m) +} + +func (d downloadShasumSignatureClient) RecvMsg(m any) error { + return d.client.RecvMsg(m) +} diff --git a/internal/provider/services/storage/storage.go b/internal/provider/services/storage/storage.go new file mode 100644 index 0000000..ca3e10a --- /dev/null +++ b/internal/provider/services/storage/storage.go @@ -0,0 +1,227 @@ +package storage + +import ( + "fmt" + "io" + "log" + "strings" + + "github.com/terrariumcloud/terrarium/internal/provider/services" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/terrariumcloud/terrarium/internal/storage" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + DefaultBucketName = "terrarium-providers" + DefaultStorageServiceDefaultEndpoint = "storage:3001" + DefaultChunkSize int64 = 64 * 1024 // 64 KB +) + +var ( + BucketName = DefaultBucketName + StorageServiceEndpoint = DefaultStorageServiceDefaultEndpoint + ChunkSize = DefaultChunkSize + + BucketInitializationError = status.Error(codes.Unknown, "Failed to initialize bucket for storage.") + DownloadSourceZipError = status.Error(codes.Unknown, "Failed to download source zip.") + SendSourceZipError = status.Error(codes.Unknown, "Failed to send source zip.") + SendShasumError = status.Error(codes.Unknown, "Failed to send shasum file.") + DownloadShasumError = status.Error(codes.Unknown, "Failed to download shasum.") +) + +type StorageService struct { + services.UnimplementedStorageServer + Client storage.AWSS3BucketClient + BucketName string + Region string +} + +// Registers StorageService with grpc server +func (s *StorageService) RegisterWithServer(grpcServer grpc.ServiceRegistrar) error { + if err := storage.InitializeS3Bucket(s.BucketName, s.Region, s.Client); err != nil { + log.Println("Error initializing S3 bucket for Provider storage", err) + return BucketInitializationError + } + + services.RegisterStorageServer(grpcServer, s) + + return nil +} + +// Download Source Zip from storage +func (s *StorageService) DownloadProviderSourceZip(request *services.DownloadSourceZipRequest, server services.Storage_DownloadProviderSourceZipServer) error { + + log.Println("Downloading source zip.") + + ctx := server.Context() + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("provider.name", request.GetProvider().GetName()), + attribute.String("provider.version", request.GetProvider().GetVersion()), + attribute.String("provider.os", request.GetProvider().GetOs()), + attribute.String("provider.arch", request.GetProvider().GetArch()), + ) + + providerAddress := strings.Split(request.Provider.GetName(), "/") + filename := fmt.Sprintf("terraform-provider-%s_%s_%s_%s.zip", providerAddress[1], request.GetProvider().GetVersion(), request.GetProvider().GetOs(), request.GetProvider().GetArch()) + fileLocation := ResolveS3Locations(request.Provider.GetName(), request.Provider.GetVersion(), filename) + + in := &s3.GetObjectInput{ + Bucket: aws.String(BucketName), + Key: aws.String(fileLocation), + } + + out, err := s.Client.GetObject(ctx, in) + if err != nil { + span.RecordError(err) + log.Println("Error downloading source zip for provider binary", err) + return DownloadSourceZipError + } + + buf := make([]byte, ChunkSize) + res := &services.SourceZipResponse{} + + for { + n, err := out.Body.Read(buf) + if err != nil && err != io.EOF { + span.RecordError(err) + log.Println("Failed to download source zip", err) + return DownloadSourceZipError + } + if n == 0 { + break + } + res.ZipDataChunk = buf[:n] + if err := server.Send(res); err != nil { + span.RecordError(err) + log.Println("Failed to send source zip", err) + return SendSourceZipError + } + } + + log.Println("Source zip downloaded.") + return nil + +} + +// Download Shasum from storage +func (s *StorageService) DownloadShasum(request *services.DownloadShasumRequest, server services.Storage_DownloadShasumServer) error { + + log.Println("Downloading shasum file.") + + ctx := server.Context() + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("provider.name", request.GetProvider().GetName()), + attribute.String("provider.version", request.GetProvider().GetVersion()), + ) + + providerAddress := strings.Split(request.GetProvider().GetName(), "/") + suffix := fmt.Sprintf("terraform-provider-%s_%s_SHA256SUMS", providerAddress[1], request.GetProvider().GetVersion()) + fileLocation := ResolveS3Locations(request.Provider.GetName(), request.Provider.GetVersion(), suffix) + + in := &s3.GetObjectInput{ + Bucket: aws.String(BucketName), + Key: aws.String(fileLocation), + } + + out, err := s.Client.GetObject(ctx, in) + if err != nil { + span.RecordError(err) + log.Println("Error downloading shasum file", err) + return DownloadShasumError + } + + buf := make([]byte, ChunkSize) + res := &services.DownloadShasumResponse{} + + for { + n, err := out.Body.Read(buf) + if err != nil && err != io.EOF { + span.RecordError(err) + log.Println("Failed to download shasum file", err) + return DownloadShasumError + } + if n == 0 { + break + } + + res.ShasumDataChunk = buf[:n] + if err := server.Send(res); err != nil { + span.RecordError(err) + log.Println("Failed to send shasum file", err) + return SendShasumError + } + } + + log.Println("Shasum file downloaded.") + return nil +} + +// Download Shasum Signature from storage +func (s *StorageService) DownloadShasumSignature(request *services.DownloadShasumRequest, server services.Storage_DownloadShasumSignatureServer) error { + + log.Println("Downloading shasum signature file.") + + ctx := server.Context() + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("provider.name", request.GetProvider().GetName()), + attribute.String("provider.version", request.GetProvider().GetVersion()), + ) + + providerAddress := strings.Split(request.GetProvider().GetName(), "/") + suffix := fmt.Sprintf("terraform-provider-%s_%s_SHA256SUMS.sig", providerAddress[1], request.GetProvider().GetVersion()) + fileLocation := ResolveS3Locations(request.Provider.GetName(), request.Provider.GetVersion(), suffix) + + in := &s3.GetObjectInput{ + Bucket: aws.String(BucketName), + Key: aws.String(fileLocation), + } + + out, err := s.Client.GetObject(ctx, in) + if err != nil { + span.RecordError(err) + log.Println("Error downloading shasum signature file", err) + return DownloadShasumError + } + + buf := make([]byte, ChunkSize) + res := &services.DownloadShasumResponse{} + + for { + n, err := out.Body.Read(buf) + if err != nil && err != io.EOF { + span.RecordError(err) + log.Println("Failed to download shasum signature file", err) + return DownloadShasumError + } + if n == 0 { + break + } + + res.ShasumDataChunk = buf[:n] + if err := server.Send(res); err != nil { + span.RecordError(err) + log.Println("Failed to send shasum signature file", err) + return SendShasumError + } + } + + log.Println("Shasum signature file downloaded.") + return nil +} + +func ResolveS3Locations(providerID, providerVersion, value string) string { + fileLocation := fmt.Sprintf("%s/%s/%s", providerID, providerVersion, value) + return fileLocation +} diff --git a/internal/provider/services/storage/storage_test.go b/internal/provider/services/storage/storage_test.go new file mode 100644 index 0000000..79893d2 --- /dev/null +++ b/internal/provider/services/storage/storage_test.go @@ -0,0 +1,370 @@ +package storage + +import ( + "bytes" + "errors" + mocks2 "github.com/terrariumcloud/terrarium/internal/storage/mocks" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/s3" + terrarium "github.com/terrariumcloud/terrarium/internal/provider/services" + "github.com/terrariumcloud/terrarium/internal/provider/services/mocks" + "github.com/terrariumcloud/terrarium/pkg/terrarium/provider" + "google.golang.org/grpc" +) + +type ClosingBuffer struct { + *bytes.Buffer +} + +func (cb *ClosingBuffer) Close() error { + return nil +} + +// Test_RegisterStorageWithServer checks: +// - if there was no error with bucket init +// - if error was returned when bucket init fails +func Test_RegisterStorageWithServer(t *testing.T) { + t.Parallel() + + t.Run("when bucket init is successful", func(t *testing.T) { + s3Client := &mocks2.S3{} + + ss := &StorageService{Client: s3Client} + + s := grpc.NewServer(*new([]grpc.ServerOption)...) + + err := ss.RegisterWithServer(s) + + if err != nil { + t.Errorf("Expected no error, got %v.", err) + } + + if s3Client.HeadBucketInvocations != 1 { + t.Errorf("Expected 1 call to HeadBucket, got %v.", s3Client.HeadBucketInvocations) + } + + if s3Client.CreateBucketInvocations != 0 { + t.Errorf("Expected no calls to CreateBucket, got %v.", s3Client.CreateBucketInvocations) + } + }) + + t.Run("when bucket init fails", func(t *testing.T) { + s3Client := &mocks2.S3{ + HeadBucketError: errors.New("some error"), + CreateBucketError: errors.New("some error"), + } + + vms := &StorageService{Client: s3Client} + + s := grpc.NewServer(*new([]grpc.ServerOption)...) + + err := vms.RegisterWithServer(s) + + if err != BucketInitializationError { + t.Errorf("Expected %v, got %v.", BucketInitializationError, err) + } + + if s3Client.HeadBucketInvocations != 1 { + t.Errorf("Expected 1 call to DescribeTable, got %v.", s3Client.HeadBucketInvocations) + } + + if s3Client.CreateBucketInvocations != 1 { + t.Errorf("Expected 1 calls to CreateTable, got %v.", s3Client.CreateBucketInvocations) + } + }) +} + +// Test_DownloadProviderSourceZip checks: +// - if correct response is returned when source zip is downloaded +// - if error is returned when GetObject fails +// - if error is returned when Send fails +func Test_DownloadProviderSourceZip(t *testing.T) { + t.Parallel() + + t.Run("When source zip is downloaded", func(t *testing.T) { + var length int64 = 70000 + buf := &ClosingBuffer{bytes.NewBuffer(make([]byte, length))} + + s3Client := &mocks2.S3{GetObjectOut: &s3.GetObjectOutput{Body: buf, ContentLength: length}} + + svc := &StorageService{Client: s3Client} + + res := &terrarium.SourceZipResponse{ZipDataChunk: make([]byte, length)} + + mds := &mocks.MockDownloadProviderSourceZipServer{SendResponse: res} + + req := &terrarium.DownloadSourceZipRequest{ + Provider: &terrarium.ProviderRequest{Name: "TestOrg/TestProvider", Version: "v1", Os: "linux", Arch: "amd64"}, + } + + err := svc.DownloadProviderSourceZip(req, mds) + + if err != nil { + t.Errorf("Expected no error, got %v.", err) + } + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 2 { + t.Errorf("Expected 2 call to Send, got %v", mds.SendInvocations) + } + + if !bytes.Equal(mds.TotalReceived, res.ZipDataChunk) { + t.Errorf("Expected same data to be returned.") + } + }) + + t.Run("when GetObject fails", func(t *testing.T) { + s3Client := &mocks2.S3{GetObjectError: errors.New("some error")} + + svc := &StorageService{Client: s3Client} + + mds := &mocks.MockDownloadProviderSourceZipServer{} + + req := &terrarium.DownloadSourceZipRequest{ + Provider: &terrarium.ProviderRequest{Name: "TestOrg/TestProvider", Version: "v1", Os: "linux", Arch: "amd64"}, + } + + err := svc.DownloadProviderSourceZip(req, mds) + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 0 { + t.Errorf("Expected 0 call to Sends, got %v", mds.SendInvocations) + } + + if err != DownloadSourceZipError { + t.Errorf("Expected %v, got %v.", DownloadSourceZipError, err) + } + }) + + t.Run("when Send fails", func(t *testing.T) { + var length int64 = 70000 + buf := &ClosingBuffer{bytes.NewBuffer(make([]byte, length))} + + s3Client := &mocks2.S3{GetObjectOut: &s3.GetObjectOutput{Body: buf, ContentLength: length}} + + svc := &StorageService{Client: s3Client} + + mds := &mocks.MockDownloadProviderSourceZipServer{SendError: errors.New("some error")} + + req := &terrarium.DownloadSourceZipRequest{ + Provider: &terrarium.ProviderRequest{Name: "TestOrg/TestProvider", Version: "v1", Os: "linux", Arch: "amd64"}, + } + + err := svc.DownloadProviderSourceZip(req, mds) + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 1 { + t.Errorf("Expected 1 call to Send, got %v", mds.SendInvocations) + } + + if err != SendSourceZipError { + t.Errorf("Expected %v, got %v.", SendSourceZipError, err) + } + }) +} + +// Test_DownloadShasum checks: +// - if correct response is returned when shasum file is downloaded +// - if error is returned when GetObject fails +// - if error is returned when Send fails +func Test_DownloadShasum(t *testing.T) { + t.Parallel() + + t.Run("when shasum file is downloaded", func(t *testing.T) { + var length int64 = 70000 + buf := &ClosingBuffer{bytes.NewBuffer(make([]byte, length))} + + s3Client := &mocks2.S3{GetObjectOut: &s3.GetObjectOutput{Body: buf, ContentLength: length}} + + svc := &StorageService{Client: s3Client} + + res := &terrarium.DownloadShasumResponse{ShasumDataChunk: make([]byte, length)} + + mds := &mocks.MockDownloadProviderShasumServer{SendResponse: res} + + req := &terrarium.DownloadShasumRequest{ + Provider: &provider.Provider{Name: "TestOrg/TestProvider", Version: "v1"}, + } + + err := svc.DownloadShasum(req, mds) + + if err != nil { + t.Errorf("Expected no error, got %v.", err) + } + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 2 { + t.Errorf("Expected 2 call to Send, got %v", mds.SendInvocations) + } + + if !bytes.Equal(mds.TotalReceived, res.ShasumDataChunk) { + t.Errorf("Expected same data to be returned.") + } + }) + + t.Run("when GetObject fails", func(t *testing.T) { + s3Client := &mocks2.S3{GetObjectError: errors.New("some error")} + + svc := &StorageService{Client: s3Client} + + mds := &mocks.MockDownloadProviderShasumServer{} + + req := &terrarium.DownloadShasumRequest{ + Provider: &provider.Provider{Name: "TestOrg/TestProvider", Version: "v1"}, + } + + err := svc.DownloadShasum(req, mds) + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 0 { + t.Errorf("Expected 0 call to Sends, got %v", mds.SendInvocations) + } + + if err != DownloadShasumError { + t.Errorf("Expected %v, got %v.", DownloadShasumError, err) + } + }) + + t.Run("when Send fails", func(t *testing.T) { + var length int64 = 70000 + buf := &ClosingBuffer{bytes.NewBuffer(make([]byte, length))} + + s3Client := &mocks2.S3{GetObjectOut: &s3.GetObjectOutput{Body: buf, ContentLength: length}} + + svc := &StorageService{Client: s3Client} + + mds := &mocks.MockDownloadProviderShasumServer{SendError: errors.New("some error")} + + req := &terrarium.DownloadShasumRequest{ + Provider: &provider.Provider{Name: "TestOrg/TestProvider", Version: "v1"}, + } + + err := svc.DownloadShasum(req, mds) + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 1 { + t.Errorf("Expected 1 call to Send, got %v", mds.SendInvocations) + } + + if err != SendShasumError { + t.Errorf("Expected %v, got %v.", SendShasumError, err) + } + }) +} + +// Test_DownloadShasumSignature checks: +// - if correct response is returned when shasum signature file is downloaded +// - if error is returned when GetObject fails +// - if error is returned when Send fails +func Test_DownloadShasumSignature(t *testing.T) { + t.Parallel() + + t.Run("when shasum signature file is downloaded", func(t *testing.T) { + var length int64 = 70000 + buf := &ClosingBuffer{bytes.NewBuffer(make([]byte, length))} + + s3Client := &mocks2.S3{GetObjectOut: &s3.GetObjectOutput{Body: buf, ContentLength: length}} + + svc := &StorageService{Client: s3Client} + + res := &terrarium.DownloadShasumResponse{ShasumDataChunk: make([]byte, length)} + + mds := &mocks.MockDownloadProviderShasumSignatureServer{SendResponse: res} + + req := &terrarium.DownloadShasumRequest{ + Provider: &provider.Provider{Name: "TestOrg/TestProvider", Version: "v1"}, + } + + err := svc.DownloadShasumSignature(req, mds) + + if err != nil { + t.Errorf("Expected no error, got %v.", err) + } + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 2 { + t.Errorf("Expected 2 call to Send, got %v", mds.SendInvocations) + } + + if !bytes.Equal(mds.TotalReceived, res.ShasumDataChunk) { + t.Errorf("Expected same data to be returned.") + } + }) + + t.Run("when GetObject fails", func(t *testing.T) { + s3Client := &mocks2.S3{GetObjectError: errors.New("some error")} + + svc := &StorageService{Client: s3Client} + + mds := &mocks.MockDownloadProviderShasumSignatureServer{} + + req := &terrarium.DownloadShasumRequest{ + Provider: &provider.Provider{Name: "TestOrg/TestProvider", Version: "v1"}, + } + + err := svc.DownloadShasumSignature(req, mds) + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 0 { + t.Errorf("Expected 0 call to Sends, got %v", mds.SendInvocations) + } + + if err != DownloadShasumError { + t.Errorf("Expected %v, got %v.", DownloadShasumError, err) + } + }) + + t.Run("when Send fails", func(t *testing.T) { + var length int64 = 70000 + buf := &ClosingBuffer{bytes.NewBuffer(make([]byte, length))} + + s3Client := &mocks2.S3{GetObjectOut: &s3.GetObjectOutput{Body: buf, ContentLength: length}} + + svc := &StorageService{Client: s3Client} + + mds := &mocks.MockDownloadProviderShasumSignatureServer{SendError: errors.New("some error")} + + req := &terrarium.DownloadShasumRequest{ + Provider: &provider.Provider{Name: "TestOrg/TestProvider", Version: "v1"}, + } + + err := svc.DownloadShasumSignature(req, mds) + + if s3Client.GetObjectInvocations != 1 { + t.Errorf("Expected 1 call to GetObject, got %v", s3Client.GetObjectInvocations) + } + + if mds.SendInvocations != 1 { + t.Errorf("Expected 1 call to Send, got %v", mds.SendInvocations) + } + + if err != SendShasumError { + t.Errorf("Expected %v, got %v.", SendShasumError, err) + } + }) +} diff --git a/internal/restapi/providers/v1/handler.go b/internal/restapi/providers/v1/handler.go index c6d7ff0..1cd1f68 100644 --- a/internal/restapi/providers/v1/handler.go +++ b/internal/restapi/providers/v1/handler.go @@ -2,7 +2,9 @@ package v1 import ( "encoding/json" + "errors" "fmt" + "io" "log" "net/http" "os" @@ -20,12 +22,16 @@ import ( type providersV1HttpService struct { versionManagerClient services.VersionManagerClient + storageClient services.StorageClient responseHandler restapi.ResponseHandler errorHandler restapi.ErrorHandler } -func New(versionManagerClient services.VersionManagerClient) *providersV1HttpService { - return &providersV1HttpService{versionManagerClient: versionManagerClient} +func New(versionManagerClient services.VersionManagerClient, storageClient services.StorageClient) *providersV1HttpService { + return &providersV1HttpService{ + versionManagerClient: versionManagerClient, + storageClient: storageClient, + } } func (h *providersV1HttpService) GetHttpHandler(mountPath string) http.Handler { @@ -43,6 +49,9 @@ func (h *providersV1HttpService) createRouter(mountPath string) *mux.Router { sr.StrictSlash(true) sr.Handle("/{organization_name}/{name}/versions", h.getProviderVersionHandler()).Methods(http.MethodGet) sr.Handle("/{organization_name}/{name}/{version}/download/{os}/{arch}", h.downloadProviderHandler()).Methods(http.MethodGet) + sr.Handle("/{organization_name}/{name}/{version}/{os}/{arch}/terraform-provider-{name}_{version}_{os}_{arch}.zip", h.archiveHandler()).Methods(http.MethodGet) + sr.Handle("/{organization_name}/{name}/{version}/terraform-provider-{name}_{version}_SHA256SUMS", h.shasumHandler()).Methods(http.MethodGet) + sr.Handle("/{organization_name}/{name}/{version}/terraform-provider-{name}_{version}_SHA256SUMS.sig", h.shasumSignatureHandler()).Methods(http.MethodGet) return r } @@ -120,3 +129,105 @@ func (h *providersV1HttpService) downloadProviderHandler() http.Handler { _, _ = rw.Write(data) }) } + +// archiveHandler performs a fetch of the provider binary from the chosen backing store and presents it to the client. +func (h *providersV1HttpService) archiveHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + providerName := GetProviderNameFromRequest(r) + + ctx := r.Context() + span := trace.SpanFromContext(ctx) + providerVersion, providerOS, providerArch := GetProviderInputsFromRequest(r) + span.SetAttributes( + attribute.String("provider.name", providerName), + attribute.String("provider.version", providerVersion), + attribute.String("provider.os", providerOS), + attribute.String("provider.arch", providerArch), + ) + downloadStream, err := h.storageClient.DownloadProviderSourceZip(r.Context(), &services.DownloadSourceZipRequest{ + Provider: GetProviderLocationFromRequest(r), + }) + if err != nil { + log.Printf("Failed to connect: %v", err) + span.RecordError(err) + h.errorHandler.Write(rw, errors.New("failed to initiate the download of the archive from storage backend service"), http.StatusInternalServerError) + return + } + r.Header.Set("Content-Type", "application/zip") + for { + chunk, err := downloadStream.Recv() + if err == io.EOF { + return + } + _, _ = rw.Write(chunk.ZipDataChunk) + } + }) +} + +// shasumHandler performs a fetch of the shasum file from the chosen backing store and presents it to the client. +func (h *providersV1HttpService) shasumHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + providerName := GetProviderNameFromRequest(r) + + ctx := r.Context() + span := trace.SpanFromContext(ctx) + providerVersion, _, _ := GetProviderInputsFromRequest(r) + span.SetAttributes( + attribute.String("provider.name", providerName), + attribute.String("provider.version", providerVersion), + ) + + downloadStream, err := h.storageClient.DownloadShasum(r.Context(), &services.DownloadShasumRequest{ + Provider: GetVersionedProviderFromRequest(r), + }) + if err != nil { + log.Printf("Failed to connect: %v", err) + span.RecordError(err) + h.errorHandler.Write(rw, errors.New("failed to initiate the download of the shasum file from storage backend service"), http.StatusInternalServerError) + return + } + + r.Header.Set("Content-Type", "text/plain") + for { + chunk, err := downloadStream.Recv() + if err == io.EOF { + return + } + _, _ = rw.Write(chunk.ShasumDataChunk) + } + }) +} + +// shasumSignatureHandler performs a fetch of the shasum signature file from the chosen backing store and presents it to the client. +func (h *providersV1HttpService) shasumSignatureHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + providerName := GetProviderNameFromRequest(r) + + ctx := r.Context() + span := trace.SpanFromContext(ctx) + providerVersion, _, _ := GetProviderInputsFromRequest(r) + span.SetAttributes( + attribute.String("provider.name", providerName), + attribute.String("provider.version", providerVersion), + ) + + downloadStream, err := h.storageClient.DownloadShasumSignature(r.Context(), &services.DownloadShasumRequest{ + Provider: GetVersionedProviderFromRequest(r), + }) + if err != nil { + log.Printf("Failed to connect: %v", err) + span.RecordError(err) + h.errorHandler.Write(rw, errors.New("failed to initiate the download of the shasum signature file from storage backend service"), http.StatusInternalServerError) + return + } + + r.Header.Set("Content-Type", "text/plain") + for { + chunk, err := downloadStream.Recv() + if err == io.EOF { + return + } + _, _ = rw.Write(chunk.ShasumDataChunk) + } + }) +} diff --git a/internal/restapi/providers/v1/helpers.go b/internal/restapi/providers/v1/helpers.go index 1fb5d22..d92ab71 100644 --- a/internal/restapi/providers/v1/helpers.go +++ b/internal/restapi/providers/v1/helpers.go @@ -5,6 +5,9 @@ import ( "net/http" "github.com/gorilla/mux" + + "github.com/terrariumcloud/terrarium/internal/provider/services" + pb "github.com/terrariumcloud/terrarium/pkg/terrarium/provider" ) func GetProviderNameFromRequest(r *http.Request) string { @@ -18,3 +21,29 @@ func GetProviderInputsFromRequest(r *http.Request) (string, string, string) { params := mux.Vars(r) return params["version"], params["os"], params["arch"] } + +func GetProviderLocationFromRequest(r *http.Request) *services.ProviderRequest { + params := mux.Vars(r) + orgName := params["organization_name"] + providerName := params["name"] + version := params["version"] + os := params["os"] + arch := params["arch"] + return &services.ProviderRequest{ + Name: fmt.Sprintf("%s/%s", orgName, providerName), + Version: version, + Os: os, + Arch: arch, + } +} + +func GetVersionedProviderFromRequest(r *http.Request) *pb.Provider { + params := mux.Vars(r) + orgName := params["organization_name"] + providerName := params["name"] + version := params["version"] + return &pb.Provider{ + Name: fmt.Sprintf("%s/%s", orgName, providerName), + Version: version, + } +} diff --git a/internal/restapi/providers/v1/helpers_test.go b/internal/restapi/providers/v1/helpers_test.go index 5961589..c17f25b 100644 --- a/internal/restapi/providers/v1/helpers_test.go +++ b/internal/restapi/providers/v1/helpers_test.go @@ -8,22 +8,22 @@ import ( ) func Test_GetProviderNameFromRequest(t *testing.T) { - req := httptest.NewRequest("GET", "/providers/v1/hashicorp/random/versions", nil) + req := httptest.NewRequest("GET", "/providers/v1/test-org/test-provider/versions", nil) req = mux.SetURLVars(req, map[string]string{ - "organization_name": "hashicorp", - "name": "random", + "organization_name": "test-org", + "name": "test-provider", }) providerName := GetProviderNameFromRequest(req) - expectedProviderName := "hashicorp/random" + expectedProviderName := "test-org/test-provider" if providerName != expectedProviderName { t.Errorf("Expected provider name to be %s, but got %s", expectedProviderName, providerName) } } func Test_GetProviderInputsFromRequest(t *testing.T) { - req := httptest.NewRequest("GET", "/providers/v1/hashicorp/random/2.0.0/download/linux/amd64", nil) + req := httptest.NewRequest("GET", "/providers/v1/test-org/test-provider/2.0.0/download/linux/amd64", nil) req = mux.SetURLVars(req, map[string]string{ "version": "2.0.0", "os": "linux", @@ -47,3 +47,57 @@ func Test_GetProviderInputsFromRequest(t *testing.T) { t.Errorf("Expected arch. to be %s, but got %s", expectedArch, arch) } } + +func Test_GetProviderLocationFromRequest(t *testing.T) { + req := httptest.NewRequest("GET", "/providers/v1/test-org/test-provider/2.0.0/linux/amd64/terraform-provider-test-provider_2.0.0_linux_amd64.zip", nil) + req = mux.SetURLVars(req, map[string]string{ + "organization_name": "test-org", + "name": "test-provider", + "version": "2.0.0", + "os": "linux", + "arch": "amd64", + }) + + provider := GetProviderLocationFromRequest(req) + + expectedProviderName := "test-org/test-provider" + if provider.Name != expectedProviderName { + t.Errorf("Expected provider name to be %s, but got %s", expectedProviderName, provider.Name) + } + + expectedVersion := "2.0.0" + if provider.Version != expectedVersion { + t.Errorf("Expected version to be %s, but got %s", expectedVersion, provider.Version) + } + + expectedOS := "linux" + if provider.Os != expectedOS { + t.Errorf("Expected OS to be %s, but got %s", expectedOS, provider.Os) + } + + expectedArch := "amd64" + if provider.Arch != expectedArch { + t.Errorf("Expected arch. to be %s, but got %s", expectedArch, provider.Arch) + } +} + +func Test_GetVersionedProviderFromRequest(t *testing.T) { + req := httptest.NewRequest("GET", "/providers/v1/test-org/test-provider/2.0.0/terraform-provider-test-provider_2.0.0_SHA256SUMS", nil) + req = mux.SetURLVars(req, map[string]string{ + "organization_name": "test-org", + "name": "test-provider", + "version": "2.0.0", + }) + + provider := GetVersionedProviderFromRequest(req) + + expectedProviderName := "test-org/test-provider" + if provider.Name != expectedProviderName { + t.Errorf("Expected provider name to be %s, but got %s", expectedProviderName, provider.Name) + } + + expectedVersion := "2.0.0" + if provider.Version != expectedVersion { + t.Errorf("Expected version to be %s, but got %s", expectedVersion, provider.Version) + } +}