diff --git a/pkg/controllers/tagging/tagging_controller.go b/pkg/controllers/tagging/tagging_controller.go index 909c8237eb..84d5f07cbc 100644 --- a/pkg/controllers/tagging/tagging_controller.go +++ b/pkg/controllers/tagging/tagging_controller.go @@ -225,7 +225,7 @@ func (tc *Controller) process() bool { recordWorkItemLatencyMetrics(workItemDequeuingTimeWorkItemMetric, timeTaken) klog.Infof("Dequeuing latency %f seconds", timeTaken) - instanceID, err := awsv1.KubernetesInstanceID(workItem.node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, err := awsv1.ParseProviderID(workItem.node.Spec.ProviderID) if err != nil { err = fmt.Errorf("Error in getting instanceID for node %s, error: %v", workItem.node.GetName(), err) utilruntime.HandleError(err) @@ -233,9 +233,9 @@ func (tc *Controller) process() bool { } klog.Infof("Instance ID of work item %s is %s", workItem, instanceID) - if variant.IsVariantNode(string(instanceID)) { + if variant.IsVariantNode(instanceID) { klog.Infof("Skip processing the node %s since it is a %s node", - instanceID, variant.NodeType(string(instanceID))) + instanceID, variant.NodeType(instanceID)) tc.workqueue.Forget(obj) return nil } @@ -297,7 +297,7 @@ func (tc *Controller) tagEc2Instance(node *v1.Node) error { return nil } - instanceID, _ := awsv1.KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, _ := awsv1.ParseProviderID(node.Spec.ProviderID) err := tc.cloud.TagResource(string(instanceID), tc.tags) @@ -349,7 +349,7 @@ func (tc *Controller) untagNodeResources(node *v1.Node) error { // untagEc2Instances deletes the provided tags to each EC2 instances in // the cluster. func (tc *Controller) untagEc2Instance(node *v1.Node) error { - instanceID, _ := awsv1.KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, _ := awsv1.ParseProviderID(node.Spec.ProviderID) err := tc.cloud.UntagResource(string(instanceID), tc.tags) diff --git a/pkg/providers/v1/aws.go b/pkg/providers/v1/aws.go index 5122ecd432..0613c93632 100644 --- a/pkg/providers/v1/aws.go +++ b/pkg/providers/v1/aws.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "net" "regexp" "sort" @@ -418,7 +419,7 @@ func InstanceIDIndexFunc(obj interface{}) ([]string, error) { // provider ID hasn't been populated yet return []string{""}, nil } - instanceID, err := KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(node.Spec.ProviderID) if err != nil { //logging the error as warning as Informer.AddIndexers would panic if there is an error klog.Warningf("error mapping node %q's provider ID %q to instance ID: %v", node.Name, node.Spec.ProviderID, err) @@ -832,16 +833,16 @@ func extractIPv6NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) NodeAddressesByProviderID(ctx context.Context, providerID string) ([]v1.NodeAddress, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return nil, err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.NodeAddresses(string(instanceID), c.vpcID) + if v := variant.GetVariant(instanceID); v != nil { + return v.NodeAddresses(instanceID, c.vpcID) } - instance, err := describeInstance(c.ec2, instanceID) + instance, err := describeInstance(c.ec2, string(instanceID)) if err != nil { return nil, err } @@ -871,17 +872,17 @@ func (c *Cloud) NodeAddressesByProviderID(ctx context.Context, providerID string // InstanceExistsByProviderID returns true if the instance with the given provider id still exists. // If false is returned with no error, the instance will be immediately deleted by the cloud controller manager. func (c *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID string) (bool, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return false, err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.InstanceExists(string(instanceID), c.vpcID) + if v := variant.GetVariant(instanceID); v != nil { + return v.InstanceExists(instanceID, c.vpcID) } request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []*string{instanceID.AwsString()}, } instances, err := c.ec2.DescribeInstances(request) @@ -910,17 +911,17 @@ func (c *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID strin // InstanceShutdownByProviderID returns true if the instance is terminated func (c *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID string) (bool, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return false, err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.InstanceShutdown(string(instanceID), c.vpcID) + if v := variant.GetVariant(instanceID); v != nil { + return v.InstanceShutdown(instanceID, c.vpcID) } request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []*string{instanceID.AwsString()}, } instances, err := c.ec2.DescribeInstances(request) @@ -969,16 +970,16 @@ func (c *Cloud) InstanceID(ctx context.Context, nodeName types.NodeName) (string // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) InstanceTypeByProviderID(ctx context.Context, providerID string) (string, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return "", err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.InstanceTypeByProviderID(string(instanceID)) + if v := variant.GetVariant(instanceID); v != nil { + return v.InstanceTypeByProviderID(instanceID) } - instance, err := describeInstance(c.ec2, instanceID) + instance, err := describeInstance(c.ec2, string(instanceID)) if err != nil { return "", err } @@ -1010,13 +1011,13 @@ func (c *Cloud) GetZone(ctx context.Context) (cloudprovider.Zone, error) { // This is particularly useful in external cloud providers where the kubelet // does not initialize node data. func (c *Cloud) GetZoneByProviderID(ctx context.Context, providerID string) (cloudprovider.Zone, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return cloudprovider.Zone{}, err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.GetZone(string(instanceID), c.vpcID, c.region) + if v := variant.GetVariant(instanceID); v != nil { + return v.GetZone(instanceID, c.vpcID, c.region) } instance, err := c.getInstanceByID(string(instanceID)) @@ -2651,7 +2652,7 @@ func (c *Cloud) getTaggedSecurityGroups() (map[string]*ec2.SecurityGroup, error) // Open security group ingress rules on the instances so that the load balancer can talk to them // Will also remove any security groups ingress rules for the load balancer that are _not_ needed for allInstances -func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancerDescription, instances map[InstanceID]*ec2.Instance, annotations map[string]string) error { +func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancerDescription, instances map[awsnode.NodeID]*ec2.Instance, annotations map[string]string) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } @@ -3228,15 +3229,15 @@ func nodeNameToIPAddress(nodeName string) string { return strings.ReplaceAll(nodeName, "-", ".") } -func (c *Cloud) nodeNameToInstanceID(nodeName types.NodeName) (InstanceID, error) { +func (c *Cloud) nodeNameToInstanceID(nodeName types.NodeName) (awsnode.NodeID, error) { if strings.HasPrefix(string(nodeName), rbnNamePrefix) { // depending on if you use a RHEL (e.g. AL2) or Debian (e.g. standard Ubuntu) based distribution, the // hostname on the machine may be either i-00000000000000001 or i-00000000000000001.region.compute.internal. // This handles both scenarios by returning anything before the first '.' in the node name if it has an RBN prefix. if idx := strings.IndexByte(string(nodeName), '.'); idx != -1 { - return InstanceID(nodeName[0:idx]), nil + return awsnode.NodeID(nodeName[0:idx]), nil } - return InstanceID(nodeName), nil + return awsnode.NodeID(nodeName), nil } if len(nodeName) == 0 { return "", fmt.Errorf("no nodeName provided") @@ -3254,10 +3255,10 @@ func (c *Cloud) nodeNameToInstanceID(nodeName types.NodeName) (InstanceID, error return "", fmt.Errorf("node has no providerID") } - return KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + return ParseProviderID(node.Spec.ProviderID) } -func (c *Cloud) instanceIDToNodeName(instanceID InstanceID) (types.NodeName, error) { +func (c *Cloud) instanceIDToNodeName(instanceID awsnode.NodeID) (types.NodeName, error) { if len(instanceID) == 0 { return "", fmt.Errorf("no instanceID provided") } diff --git a/pkg/providers/v1/aws_instance.go b/pkg/providers/v1/aws_instance.go index e7e8b152a1..01423a6e92 100644 --- a/pkg/providers/v1/aws_instance.go +++ b/pkg/providers/v1/aws_instance.go @@ -20,7 +20,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "k8s.io/apimachinery/pkg/types" - "k8s.io/cloud-provider-aws/pkg/providers/v1/iface" ) @@ -67,5 +66,5 @@ func newAWSInstance(ec2Service iface.EC2, instance *ec2.Instance) *awsInstance { // Gets the full information about this instance from the EC2 API func (i *awsInstance) describeInstance() (*ec2.Instance, error) { - return describeInstance(i.ec2, InstanceID(i.awsID)) + return describeInstance(i.ec2, i.awsID) } diff --git a/pkg/providers/v1/aws_loadbalancer.go b/pkg/providers/v1/aws_loadbalancer.go index c39ea3de37..f149d92ad0 100644 --- a/pkg/providers/v1/aws_loadbalancer.go +++ b/pkg/providers/v1/aws_loadbalancer.go @@ -20,6 +20,7 @@ import ( "crypto/sha1" "encoding/hex" "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "reflect" "regexp" "strconv" @@ -781,7 +782,7 @@ func (c *Cloud) chunkTargetDescriptions(targets []*elbv2.TargetDescription, chun // updateInstanceSecurityGroupsForNLB will adjust securityGroup's settings to allow inbound traffic into instances from clientCIDRs and portMappings. // TIP: if either instances or clientCIDRs or portMappings are nil, then the securityGroup rules for lbName are cleared. -func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[InstanceID]*ec2.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { +func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[awsnode.NodeID]*ec2.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } @@ -1430,7 +1431,7 @@ func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDesc } // Makes sure that exactly the specified hosts are registered as instances with the load balancer -func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances []*elb.Instance, instanceIDs map[InstanceID]*ec2.Instance) error { +func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances []*elb.Instance, instanceIDs map[awsnode.NodeID]*ec2.Instance) error { expected := sets.NewString() for id := range instanceIDs { expected.Insert(string(id)) @@ -1607,7 +1608,7 @@ func proxyProtocolEnabled(backend *elb.BackendServerDescription) bool { // findInstancesForELB gets the EC2 instances corresponding to the Nodes, for setting up an ELB // We ignore Nodes (with a log message) where the instanceid cannot be determined from the provider, // and we ignore instances which are not found -func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]string) (map[InstanceID]*ec2.Instance, error) { +func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]string) (map[awsnode.NodeID]*ec2.Instance, error) { targetNodes := filterTargetNodes(nodes, annotations) diff --git a/pkg/providers/v1/aws_loadbalancer_test.go b/pkg/providers/v1/aws_loadbalancer_test.go index 309d9eb209..a24c077fcb 100644 --- a/pkg/providers/v1/aws_loadbalancer_test.go +++ b/pkg/providers/v1/aws_loadbalancer_test.go @@ -18,6 +18,7 @@ package aws import ( "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "reflect" "testing" "time" @@ -592,7 +593,7 @@ func TestCloud_findInstancesForELB(t *testing.T) { return } - want := map[InstanceID]*ec2.Instance{ + want := map[awsnode.NodeID]*ec2.Instance{ "i-self": awsServices.selfInstance, } got, err := c.findInstancesForELB([]*v1.Node{defaultNode}, nil) @@ -601,9 +602,9 @@ func TestCloud_findInstancesForELB(t *testing.T) { // Add a new EC2 instance awsServices.instances = append(awsServices.instances, newInstance) - want = map[InstanceID]*ec2.Instance{ + want = map[awsnode.NodeID]*ec2.Instance{ "i-self": awsServices.selfInstance, - InstanceID(aws.StringValue(newInstance.InstanceId)): newInstance, + awsnode.NodeID(aws.StringValue(newInstance.InstanceId)): newInstance, } got, err = c.findInstancesForELB([]*v1.Node{defaultNode, newNode}, nil) assert.NoError(t, err) diff --git a/pkg/providers/v1/aws_routes.go b/pkg/providers/v1/aws_routes.go index e3e7c5b7a4..16f6e0e29e 100644 --- a/pkg/providers/v1/aws_routes.go +++ b/pkg/providers/v1/aws_routes.go @@ -19,9 +19,9 @@ package aws import ( "context" "fmt" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "k8s.io/klog/v2" cloudprovider "k8s.io/cloud-provider" @@ -114,7 +114,7 @@ func (c *Cloud) ListRoutes(ctx context.Context, clusterName string) ([]*cloudpro if instanceID != "" { _, found := instances[instanceID] if found { - node, err := c.instanceIDToNodeName(InstanceID(instanceID)) + node, err := c.instanceIDToNodeName(awsnode.NodeID(instanceID)) if err != nil { return nil, err } diff --git a/pkg/providers/v1/aws_test.go b/pkg/providers/v1/aws_test.go index 577f5d72cf..65b462574e 100644 --- a/pkg/providers/v1/aws_test.go +++ b/pkg/providers/v1/aws_test.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "math/rand" "reflect" "sort" @@ -2399,7 +2400,7 @@ func TestNodeNameToInstanceID(t *testing.T) { func TestInstanceIDToNodeName(t *testing.T) { testCases := []struct { name string - instanceID InstanceID + instanceID awsnode.NodeID node *v1.Node expectedNodeName types.NodeName expectedErr error diff --git a/pkg/providers/v1/awsnode/identifier.go b/pkg/providers/v1/awsnode/identifier.go new file mode 100644 index 0000000000..a534308f31 --- /dev/null +++ b/pkg/providers/v1/awsnode/identifier.go @@ -0,0 +1,11 @@ +package awsnode + +import "github.com/aws/aws-sdk-go/aws" + +// NodeID is the ID used to uniquely identify a node within an AWS service +type NodeID string + +// AwsString returns a pointer to the string value of the NodeID. Useful for AWS APIs +func (i NodeID) AwsString() *string { + return aws.String(string(i)) +} diff --git a/pkg/providers/v1/instances.go b/pkg/providers/v1/instances.go index 08ae3aff21..28c21e0b58 100644 --- a/pkg/providers/v1/instances.go +++ b/pkg/providers/v1/instances.go @@ -18,6 +18,7 @@ package aws import ( "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "net/url" "regexp" "strings" @@ -36,29 +37,14 @@ import ( // awsInstanceRegMatch represents Regex Match for AWS instance. var awsInstanceRegMatch = regexp.MustCompile("^i-[^/]*$") -// InstanceID represents the ID of the instance in the AWS API, e.g. i-12345678 -// The "traditional" format is "i-12345678" -// A new longer format is also being introduced: "i-12345678abcdef01" -// We should not assume anything about the length or format, though it seems -// reasonable to assume that instances will continue to start with "i-". -type InstanceID string - -func (i InstanceID) awsString() *string { - return aws.String(string(i)) -} - -// KubernetesInstanceID represents the id for an instance in the kubernetes API; -// the following form +// ParseProviderID turns a Kubernetes ProviderID into an AWS node id +// the following are forms of ProviderIDs that are supported: // - aws://// // - aws://// // - aws:////fargate- // - -type KubernetesInstanceID string - -// MapToAWSInstanceID extracts the InstanceID from the KubernetesInstanceID -func (name KubernetesInstanceID) MapToAWSInstanceID() (InstanceID, error) { - s := string(name) - +func ParseProviderID(providerID string) (awsnode.NodeID, error) { + s := providerID if !strings.HasPrefix(s, "aws://") { // Assume a bare aws instance id (i-1234...) // Build a URL with an empty host (AZ) @@ -66,10 +52,14 @@ func (name KubernetesInstanceID) MapToAWSInstanceID() (InstanceID, error) { } url, err := url.Parse(s) if err != nil { - return "", fmt.Errorf("Invalid instance name (%s): %v", name, err) + return "", fmt.Errorf("Invalid instance name (%s): %v", providerID, err) } if url.Scheme != "aws" { - return "", fmt.Errorf("Invalid scheme for AWS instance (%s)", name) + return "", fmt.Errorf("Invalid scheme for AWS instance (%s)", providerID) + } + + if nodeID := variant.GetNodeID(*url); nodeID != "" { + return nodeID, nil } awsID := "" @@ -81,21 +71,21 @@ func (name KubernetesInstanceID) MapToAWSInstanceID() (InstanceID, error) { // We sanity check the resulting instance ID; the two known formats are // i-12345678 and i-12345678abcdef01 - if awsID == "" || !(awsInstanceRegMatch.MatchString(awsID) || variant.IsVariantNode(awsID)) { - return "", fmt.Errorf("Invalid format for AWS instance (%s)", name) + if awsID == "" || !awsInstanceRegMatch.MatchString(awsID) { + return "", fmt.Errorf("Invalid format for AWS instance (%s)", providerID) } - return InstanceID(awsID), nil + return awsnode.NodeID(awsID), nil } // mapToAWSInstanceID extracts the InstanceIDs from the Nodes, returning an error if a Node cannot be mapped -func mapToAWSInstanceIDs(nodes []*v1.Node) ([]InstanceID, error) { - var instanceIDs []InstanceID +func mapToAWSInstanceIDs(nodes []*v1.Node) ([]awsnode.NodeID, error) { + var instanceIDs []awsnode.NodeID for _, node := range nodes { if node.Spec.ProviderID == "" { return nil, fmt.Errorf("node %q did not have ProviderID set", node.Name) } - instanceID, err := KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(node.Spec.ProviderID) if err != nil { return nil, fmt.Errorf("unable to parse ProviderID %q for node %q", node.Spec.ProviderID, node.Name) } @@ -106,14 +96,14 @@ func mapToAWSInstanceIDs(nodes []*v1.Node) ([]InstanceID, error) { } // mapToAWSInstanceIDsTolerant extracts the InstanceIDs from the Nodes, skipping Nodes that cannot be mapped -func mapToAWSInstanceIDsTolerant(nodes []*v1.Node) []InstanceID { - var instanceIDs []InstanceID +func mapToAWSInstanceIDsTolerant(nodes []*v1.Node) []awsnode.NodeID { + var instanceIDs []awsnode.NodeID for _, node := range nodes { if node.Spec.ProviderID == "" { klog.Warningf("node %q did not have ProviderID set", node.Name) continue } - instanceID, err := KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(node.Spec.ProviderID) if err != nil { klog.Warningf("unable to parse ProviderID %q for node %q", node.Spec.ProviderID, node.Name) continue @@ -125,9 +115,9 @@ func mapToAWSInstanceIDsTolerant(nodes []*v1.Node) []InstanceID { } // Gets the full information about this instance from the EC2 API -func describeInstance(ec2Client iface.EC2, instanceID InstanceID) (*ec2.Instance, error) { +func describeInstance(ec2Client iface.EC2, instanceID string) (*ec2.Instance, error) { request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []*string{&instanceID}, } instances, err := ec2Client.DescribeInstances(request) @@ -165,9 +155,9 @@ func (c *instanceCache) describeAllInstancesUncached() (*allInstancesSnapshot, e return nil, err } - m := make(map[InstanceID]*ec2.Instance) + m := make(map[awsnode.NodeID]*ec2.Instance) for _, i := range instances { - id := InstanceID(aws.StringValue(i.InstanceId)) + id := awsnode.NodeID(aws.StringValue(i.InstanceId)) m[id] = i } @@ -190,7 +180,7 @@ type cacheCriteria struct { // HasInstances is a list of InstanceIDs that must be in a cached snapshot for it to be considered valid. // If an instance is not found in the cached snapshot, the snapshot be ignored and we will re-fetch. - HasInstances []InstanceID + HasInstances []awsnode.NodeID } // describeAllInstancesCached returns all instances, using cached results if applicable @@ -238,12 +228,12 @@ func (s *allInstancesSnapshot) MeetsCriteria(criteria cacheCriteria) bool { // along with the timestamp for cache-invalidation purposes type allInstancesSnapshot struct { timestamp time.Time - instances map[InstanceID]*ec2.Instance + instances map[awsnode.NodeID]*ec2.Instance } // FindInstances returns the instances corresponding to the specified ids. If an id is not found, it is ignored. -func (s *allInstancesSnapshot) FindInstances(ids []InstanceID) map[InstanceID]*ec2.Instance { - m := make(map[InstanceID]*ec2.Instance) +func (s *allInstancesSnapshot) FindInstances(ids []awsnode.NodeID) map[awsnode.NodeID]*ec2.Instance { + m := make(map[awsnode.NodeID]*ec2.Instance) for _, id := range ids { instance := s.instances[id] if instance != nil { diff --git a/pkg/providers/v1/instances_test.go b/pkg/providers/v1/instances_test.go index ac431c6cf6..5ed6318261 100644 --- a/pkg/providers/v1/instances_test.go +++ b/pkg/providers/v1/instances_test.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "testing" "time" @@ -28,8 +29,8 @@ import ( func TestMapToAWSInstanceIDs(t *testing.T) { tests := []struct { - Kubernetes KubernetesInstanceID - Aws InstanceID + Kubernetes string + Aws awsnode.NodeID ExpectError bool }{ { @@ -87,7 +88,7 @@ func TestMapToAWSInstanceIDs(t *testing.T) { } for _, test := range tests { - awsID, err := test.Kubernetes.MapToAWSInstanceID() + awsID, err := ParseProviderID(test.Kubernetes) if err != nil { if !test.ExpectError { t.Errorf("unexpected error parsing %s: %v", test.Kubernetes, err) @@ -146,18 +147,18 @@ func TestSnapshotMeetsCriteria(t *testing.T) { t.Errorf("Snapshot did not honor MaxAge") } - if snapshot.MeetsCriteria(cacheCriteria{HasInstances: []InstanceID{InstanceID("i-12345678")}}) { + if snapshot.MeetsCriteria(cacheCriteria{HasInstances: []awsnode.NodeID{awsnode.NodeID("i-12345678")}}) { t.Errorf("Snapshot did not honor HasInstances with missing instances") } - snapshot.instances = make(map[InstanceID]*ec2.Instance) - snapshot.instances[InstanceID("i-12345678")] = &ec2.Instance{} + snapshot.instances = make(map[awsnode.NodeID]*ec2.Instance) + snapshot.instances[awsnode.NodeID("i-12345678")] = &ec2.Instance{} - if !snapshot.MeetsCriteria(cacheCriteria{HasInstances: []InstanceID{InstanceID("i-12345678")}}) { + if !snapshot.MeetsCriteria(cacheCriteria{HasInstances: []awsnode.NodeID{awsnode.NodeID("i-12345678")}}) { t.Errorf("Snapshot did not honor HasInstances with matching instances") } - if snapshot.MeetsCriteria(cacheCriteria{HasInstances: []InstanceID{InstanceID("i-12345678"), InstanceID("i-00000000")}}) { + if snapshot.MeetsCriteria(cacheCriteria{HasInstances: []awsnode.NodeID{awsnode.NodeID("i-12345678"), awsnode.NodeID("i-00000000")}}) { t.Errorf("Snapshot did not honor HasInstances with partially matching instances") } } @@ -177,22 +178,22 @@ func TestOlderThan(t *testing.T) { func TestSnapshotFindInstances(t *testing.T) { snapshot := &allInstancesSnapshot{} - snapshot.instances = make(map[InstanceID]*ec2.Instance) + snapshot.instances = make(map[awsnode.NodeID]*ec2.Instance) { - id := InstanceID("i-12345678") - snapshot.instances[id] = &ec2.Instance{InstanceId: id.awsString()} + id := awsnode.NodeID("i-12345678") + snapshot.instances[id] = &ec2.Instance{InstanceId: id.AwsString()} } { - id := InstanceID("i-23456789") - snapshot.instances[id] = &ec2.Instance{InstanceId: id.awsString()} + id := awsnode.NodeID("i-23456789") + snapshot.instances[id] = &ec2.Instance{InstanceId: id.AwsString()} } - instances := snapshot.FindInstances([]InstanceID{InstanceID("i-12345678"), InstanceID("i-23456789"), InstanceID("i-00000000")}) + instances := snapshot.FindInstances([]awsnode.NodeID{"i-12345678", "i-23456789", "i-00000000"}) if len(instances) != 2 { t.Errorf("findInstances returned %d results, expected 2", len(instances)) } - for _, id := range []InstanceID{InstanceID("i-12345678"), InstanceID("i-23456789")} { + for _, id := range []awsnode.NodeID{awsnode.NodeID("i-12345678"), awsnode.NodeID("i-23456789")} { i := instances[id] if i == nil { t.Errorf("findInstances did not return %s", id) diff --git a/pkg/providers/v1/variant/fargate/fargate.go b/pkg/providers/v1/variant/fargate/fargate.go index f4d7174603..fd296c0ca8 100644 --- a/pkg/providers/v1/variant/fargate/fargate.go +++ b/pkg/providers/v1/variant/fargate/fargate.go @@ -2,6 +2,8 @@ package fargate import ( "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" + "net/url" "strings" awssdk "github.com/aws/aws-sdk-go/aws" @@ -37,12 +39,12 @@ func (f *fargateVariant) Initialize(cloudConfig *config.CloudConfig, credentials return nil } -func (f *fargateVariant) InstanceTypeByProviderID(instanceID string) (string, error) { +func (f *fargateVariant) InstanceTypeByProviderID(nodeID awsnode.NodeID) (string, error) { return "", nil } -func (f *fargateVariant) GetZone(instanceID, vpcID, region string) (cloudprovider.Zone, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) GetZone(nodeID awsnode.NodeID, vpcID, region string) (cloudprovider.Zone, error) { + eni, err := f.DescribeNetworkInterfaces(f.ec2API, nodeID, vpcID) if eni == nil || err != nil { return cloudprovider.Zone{}, err } @@ -52,12 +54,12 @@ func (f *fargateVariant) GetZone(instanceID, vpcID, region string) (cloudprovide }, nil } -func (f *fargateVariant) IsSupportedNode(nodeName string) bool { - return strings.HasPrefix(nodeName, fargateNodeNamePrefix) +func (f *fargateVariant) IsSupportedNode(nodeID awsnode.NodeID) bool { + return strings.HasPrefix(string(nodeID), fargateNodeNamePrefix) } -func (f *fargateVariant) NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddress, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) NodeAddresses(nodeID awsnode.NodeID, vpcID string) ([]v1.NodeAddress, error) { + eni, err := f.DescribeNetworkInterfaces(f.ec2API, nodeID, vpcID) if eni == nil || err != nil { return nil, err } @@ -83,16 +85,29 @@ func (f *fargateVariant) NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddre return addresses, nil } -func (f *fargateVariant) InstanceExists(instanceID, vpcID string) (bool, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) InstanceExists(nodeID awsnode.NodeID, vpcID string) (bool, error) { + eni, err := f.DescribeNetworkInterfaces(f.ec2API, nodeID, vpcID) return eni != nil, err } -func (f *fargateVariant) InstanceShutdown(instanceID, vpcID string) (bool, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) InstanceShutdown(nodeID awsnode.NodeID, vpcID string) (bool, error) { + eni, err := f.DescribeNetworkInterfaces(f.ec2API, nodeID, vpcID) return eni != nil, err } +func (f *fargateVariant) NodeID(providerID url.URL) awsnode.NodeID { + tokens := strings.Split(strings.Trim(providerID.Path, "/"), "/") + // last token in the providerID is the aws resource ID for Fargate nodes + if len(tokens) == 0 { + return "" + } + nodeName := awsnode.NodeID(tokens[len(tokens)-1]) + if f.IsSupportedNode(nodeName) { + return nodeName + } + return "" +} + func newEc2Filter(name string, values ...string) *ec2.Filter { filter := &ec2.Filter{ Name: awssdk.String(name), @@ -116,8 +131,8 @@ func nodeNameToIPAddress(nodeName string) string { } // DescribeNetworkInterfaces returns network interface information for the given DNS name. -func (f *fargateVariant) DescribeNetworkInterfaces(ec2API iface.EC2, instanceID, vpcID string) (*ec2.NetworkInterface, error) { - eniEndpoint := strings.TrimPrefix(instanceID, fargateNodeNamePrefix) +func (f *fargateVariant) DescribeNetworkInterfaces(ec2API iface.EC2, nodeID awsnode.NodeID, vpcID string) (*ec2.NetworkInterface, error) { + eniEndpoint := strings.TrimPrefix(string(nodeID), fargateNodeNamePrefix) filters := []*ec2.Filter{ newEc2Filter("attachment.status", "attached"), diff --git a/pkg/providers/v1/variant/variant.go b/pkg/providers/v1/variant/variant.go index 39df86b795..e74c62a7cb 100644 --- a/pkg/providers/v1/variant/variant.go +++ b/pkg/providers/v1/variant/variant.go @@ -2,6 +2,8 @@ package variant import ( "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" + "net/url" "sync" v1 "k8s.io/api/core/v1" @@ -20,12 +22,13 @@ var variants = make(map[string]Variant) type Variant interface { Initialize(cloudConfig *config.CloudConfig, credentials *credentials.Credentials, provider config.SDKProvider, ec2API iface.EC2, region string) error - IsSupportedNode(nodeName string) bool - NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddress, error) - GetZone(instanceID, vpcID, region string) (cloudprovider.Zone, error) - InstanceExists(instanceID, vpcID string) (bool, error) - InstanceShutdown(instanceID, vpcID string) (bool, error) - InstanceTypeByProviderID(id string) (string, error) + IsSupportedNode(nodeID awsnode.NodeID) bool + NodeAddresses(nodeID awsnode.NodeID, vpcID string) ([]v1.NodeAddress, error) + GetZone(nodeID awsnode.NodeID, vpcID, region string) (cloudprovider.Zone, error) + InstanceExists(nodeID awsnode.NodeID, vpcID string) (bool, error) + InstanceShutdown(nodeID awsnode.NodeID, vpcID string) (bool, error) + InstanceTypeByProviderID(nodeID awsnode.NodeID) (string, error) + NodeID(providerID url.URL) awsnode.NodeID } // RegisterVariant is used to register code that needs to be called for a specific variant @@ -39,11 +42,11 @@ func RegisterVariant(name string, variant Variant) { } // IsVariantNode helps evaluate if a specific variant handles a given instance -func IsVariantNode(instanceID string) bool { +func IsVariantNode(nodeID awsnode.NodeID) bool { variantsLock.Lock() defer variantsLock.Unlock() for _, v := range variants { - if v.IsSupportedNode(instanceID) { + if v.IsSupportedNode(nodeID) { return true } } @@ -51,11 +54,11 @@ func IsVariantNode(instanceID string) bool { } // NodeType returns the type name example: "fargate" -func NodeType(instanceID string) string { +func NodeType(nodeID awsnode.NodeID) string { variantsLock.Lock() defer variantsLock.Unlock() for key, v := range variants { - if v.IsSupportedNode(instanceID) { + if v.IsSupportedNode(nodeID) { return key } } @@ -63,17 +66,30 @@ func NodeType(instanceID string) string { } // GetVariant returns the interface that can then be used to handle a specific instance -func GetVariant(instanceID string) Variant { +func GetVariant(nodeID awsnode.NodeID) Variant { variantsLock.Lock() defer variantsLock.Unlock() for _, v := range variants { - if v.IsSupportedNode(instanceID) { + if v.IsSupportedNode(nodeID) { return v } } return nil } +// GetNodeID returns the node id of the variant if a variant supports this particular provider id +// A return value of an empty string denotes no variant supported the node with this providerId. +func GetNodeID(providerID url.URL) awsnode.NodeID { + variantsLock.Lock() + defer variantsLock.Unlock() + for _, v := range variants { + if varID := v.NodeID(providerID); varID != "" { + return varID + } + } + return "" +} + // GetVariants returns the names of all the variants registered func GetVariants() []Variant { variantsLock.Lock()