diff --git a/pkg/oci/pull.go b/pkg/oci/pull.go index 6618262a7..c9f851462 100644 --- a/pkg/oci/pull.go +++ b/pkg/oci/pull.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" "path/filepath" "strings" @@ -11,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/replicatedhq/troubleshoot/internal/util" "github.com/replicatedhq/troubleshoot/pkg/version" + "k8s.io/klog/v2" "oras.land/oras-go/pkg/auth" dockerauth "oras.land/oras-go/pkg/auth/docker" "oras.land/oras-go/pkg/content" @@ -27,14 +29,39 @@ var ( ) func PullPreflightFromOCI(uri string) ([]byte, error) { - return pullFromOCI(uri, "replicated.preflight.spec", "replicated-preflight") + return pullFromOCI(context.Background(), uri, "replicated.preflight.spec", "replicated-preflight") } func PullSupportBundleFromOCI(uri string) ([]byte, error) { - return pullFromOCI(uri, "replicated.supportbundle.spec", "replicated-supportbundle") + return pullFromOCI(context.Background(), uri, "replicated.supportbundle.spec", "replicated-supportbundle") } -func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error) { +func PullSpecsFromOCI(ctx context.Context, uri string) ([]string, error) { + rawSpecs := []string{} + + // First try to pull the preflight spec + rawPreflight, err := pullFromOCI(ctx, uri, "replicated.preflight.spec", "replicated-preflight") + if err != nil { + // Ignore "not found" error and continue fetching the support bundle spec + if !errors.Is(err, ErrNoRelease) { + return nil, err + } + } else { + rawSpecs = append(rawSpecs, string(rawPreflight)) + } + + // Then try to pull the support bundle spec + rawSupportBundle, err := pullFromOCI(ctx, uri, "replicated.supportbundle.spec", "replicated-supportbundle") + // If we had found a preflight spec, do not return an error + if err != nil && len(rawSpecs) == 0 { + return nil, err + } + rawSpecs = append(rawSpecs, string(rawSupportBundle)) + + return rawSpecs, nil +} + +func pullFromOCI(ctx context.Context, uri string, mediaType string, imageName string) ([]byte, error) { // helm credentials helmCredentialsFile := filepath.Join(util.HomeDir(), HelmCredentialsFileBasename) dockerauthClient, err := dockerauth.NewClientWithDockerFallback(helmCredentialsFile) @@ -52,6 +79,7 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error) return nil, errors.Wrap(err, "failed to create resolver") } + // TODO: How do we handle "not found" cases? memoryStore := content.NewMemory() allowedMediaTypes := []string{ mediaType, @@ -60,24 +88,13 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error) var descriptors, layers []ocispec.Descriptor registryStore := content.Registry{Resolver: resolver} - // remove the oci:// - uri = strings.TrimPrefix(uri, "oci://") - - uriParts := strings.Split(uri, ":") - uri = fmt.Sprintf("%s/%s", uriParts[0], imageName) - - if len(uriParts) > 1 { - uri = fmt.Sprintf("%s:%s", uri, uriParts[1]) - } else { - uri = fmt.Sprintf("%s:latest", uri) - } - - parsedRef, err := registry.ParseReference(uri) + parsedRef, err := toRegistryRef(uri) if err != nil { - return nil, errors.Wrap(err, "failed to parse reference") + return nil, err } + klog.V(1).Infof("Pulling OCI image from %q", parsedRef.String()) - manifest, err := oras.Copy(context.TODO(), registryStore, parsedRef.String(), memoryStore, "", + manifest, err := oras.Copy(ctx, registryStore, parsedRef.String(), memoryStore, "", oras.WithPullEmptyNameAllowed(), oras.WithAllowedMediaTypes(allowedMediaTypes), oras.WithLayerDescriptors(func(l []ocispec.Descriptor) { @@ -94,7 +111,7 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error) descriptors = append(descriptors, manifest) descriptors = append(descriptors, layers...) - // expect 1 descriptor + // expect 2 descriptors if len(descriptors) != 2 { return nil, fmt.Errorf("expected 2 descriptor, got %d", len(descriptors)) } @@ -120,3 +137,26 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error) return matchingSpec, nil } + +func toRegistryRef(raw string) (registry.Reference, error) { + u, err := url.Parse(raw) + if err != nil { + return registry.Reference{}, err + } + + // Always check the scheme. If more schemes need to be supported + // we need to compare u.Scheme against a list of supported schemes. + // url.Parse(raw) will not return an error is a scheme is not present. + if u.Scheme != "oci" { + return registry.Reference{}, fmt.Errorf("%q is an invalid OCI registry scheme", u.Scheme) + } + + parts := strings.Split(u.EscapedPath(), ":") + tag := "latest" + if len(parts) > 1 { + tag = parts[1] + } + // remove the oci:// + uri := fmt.Sprintf("%s%s:%s", u.Host, parts[0], tag) + return registry.ParseReference(uri) +} diff --git a/pkg/oci/pull_test.go b/pkg/oci/pull_test.go new file mode 100644 index 000000000..e8d98c578 --- /dev/null +++ b/pkg/oci/pull_test.go @@ -0,0 +1,56 @@ +package oci + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_toRegistryRef(t *testing.T) { + tests := []struct { + name string + uri string + want string + wantErr bool + }{ + { + name: "valid uri", + uri: "oci://localhost/replicated-preflight", + want: "localhost/replicated-preflight:latest", + }, + { + name: "valid uri with port", + uri: "oci://localhost:5000/replicated-preflight", + want: "localhost:5000/replicated-preflight:latest", + }, + { + name: "valid uri with tag", + uri: "oci://localhost:5000/replicated-preflight:v4", + want: "localhost:5000/replicated-preflight:v4", + }, + { + name: "invalid uri - missing scheme", + uri: "localhost:5000/replicated-preflight:v4", + wantErr: true, + }, + { + name: "invalid uri - wrong scheme", + uri: "https://localhost:5000/replicated-preflight:v4", + wantErr: true, + }, + { + name: "empty uri", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := toRegistryRef(tt.uri) + require.Equalf(t, (err != nil), tt.wantErr, "toRegistryRef() error = %v, wantErr %v", err, tt.wantErr) + + gotStr := got.String() + assert.Equalf(t, tt.want, gotStr, "toRegistryRef() = %v, want %v", gotStr, tt.want) + }) + } +}