Skip to content

Commit

Permalink
Review remarks.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi committed Nov 27, 2024
1 parent b20d2e2 commit 1680a2a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 25 deletions.
10 changes: 7 additions & 3 deletions pkg/controller/tas/topology_ungater.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,20 +448,24 @@ func readRanksForLabels(
) (map[int]*corev1.Pod, error) {
result := make(map[int]*corev1.Pod)
podSetSize := int(*psa.Count)
singleJobSize := podSetSize
if rjInfo != nil {
singleJobSize = podSetSize / rjInfo.replicasCount
}

for _, pod := range pods {
podIndex, err := utilpod.ReadUIntFromLabel(pod, podIndexLabel)
podIndex, err := utilpod.ReadUIntFromLabelBelowBound(pod, podIndexLabel, singleJobSize)
if err != nil {
// the Pod has no rank information - ranks cannot be used
return nil, err
}
rank := *podIndex
if rjInfo != nil {
jobIndex, err := utilpod.ReadUIntFromLabel(pod, rjInfo.jobIndexLabel)
jobIndex, err := utilpod.ReadUIntFromLabelBelowBound(pod, rjInfo.jobIndexLabel, rjInfo.replicasCount)
if err != nil {
// the Pod has no Job index information - ranks cannot be used
return nil, err
}
singleJobSize := podSetSize / rjInfo.replicasCount
if *podIndex >= singleJobSize {
// the pod index exceeds size, this scenario is not
// supported by the rank-based ordering of pods.
Expand Down
24 changes: 13 additions & 11 deletions pkg/util/pod/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,37 @@ func gateIndex(p *corev1.Pod, gateName string) int {
})
}

var (
errLabelNotFound = errors.New("label not found")
errInvalidUInt = errors.New("invalid unsigned integer")
errValidation = errors.New("validation error")
)

func ReadUIntFromLabel(obj client.Object, labelKey string) (*int, error) {
return ReadUIntFromLabelWithMax(obj, labelKey, math.MaxInt)
return ReadUIntFromLabelBelowBound(obj, labelKey, math.MaxInt)
}

func ReadUIntFromLabelWithMax(obj client.Object, labelKey string, max int) (*int, error) {
func ReadUIntFromLabelBelowBound(obj client.Object, labelKey string, bound int) (*int, error) {
value, found := obj.GetLabels()[labelKey]
kind := obj.GetObjectKind().GroupVersionKind().Kind
if !found {
return nil, fmt.Errorf("no label %q for %s %q", labelKey, kind, klog.KObj(obj))
return nil, fmt.Errorf("%w: no label %q for %s %q", errLabelNotFound, labelKey, kind, klog.KObj(obj))
}
intValue, err := readUIntFromStringWithMax(value, max)
intValue, err := readUIntFromStringBelowBound(value, bound)
if err != nil {
return nil, fmt.Errorf("incorrect label value %q for %s %q: %w", value, kind, klog.KObj(obj), err)
}
return intValue, nil
}

var (
errInvalidUInt = errors.New("invalid unsigned integer")
)

func readUIntFromStringWithMax(value string, max int) (*int, error) {
func readUIntFromStringBelowBound(value string, bound int) (*int, error) {
uintValue, err := strconv.ParseUint(value, 10, 0)
if err != nil {
return nil, fmt.Errorf("%w: %s", errInvalidUInt, err.Error())
}
intValue := int(uintValue)
if intValue > max {
return nil, fmt.Errorf("value should be less than or equal to %d", max)
if intValue > bound {
return nil, fmt.Errorf("%w: value should be less than %d", errValidation, bound)
}
return ptr.To(intValue), nil
}
30 changes: 19 additions & 11 deletions pkg/util/pod/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,20 @@ func TestReadUIntFromLabel(t *testing.T) {
label string
max int
wantVal *int
wantErr string
wantErr error
}{
"label not found": {
obj: &corev1.Pod{
TypeMeta: metav1.TypeMeta{Kind: "Pod", APIVersion: ""},
ObjectMeta: metav1.ObjectMeta{
Name: "pod",
Namespace: "ns",
},
},
label: "label",
max: math.MaxInt,
wantErr: errLabelNotFound,
},
"valid label value": {
obj: &corev1.Pod{
TypeMeta: metav1.TypeMeta{Kind: "Pod", APIVersion: ""},
Expand All @@ -233,7 +245,7 @@ func TestReadUIntFromLabel(t *testing.T) {
},
},
label: "label",
wantErr: "incorrect label value \"value\" for Pod \"ns/pod\": invalid unsigned integer: strconv.ParseUint: parsing \"value\": invalid syntax",
wantErr: errInvalidUInt,
},
"less than zero": {
obj: &corev1.Pod{
Expand All @@ -245,7 +257,7 @@ func TestReadUIntFromLabel(t *testing.T) {
},
},
label: "label",
wantErr: "incorrect label value \"-1\" for Pod \"ns/pod\": invalid unsigned integer: strconv.ParseUint: parsing \"-1\": invalid syntax",
wantErr: errInvalidUInt,
},
"greater than max": {
obj: &corev1.Pod{
Expand All @@ -258,24 +270,20 @@ func TestReadUIntFromLabel(t *testing.T) {
},
label: "label",
max: 1000,
wantErr: "incorrect label value \"1001\" for Pod \"ns/pod\": value should be less than or equal to 1000",
wantErr: errValidation,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
gotValue, gotErr := ReadUIntFromLabelWithMax(tc.obj, tc.label, tc.max)
var gotErrStr string
if gotErr != nil {
gotErrStr = gotErr.Error()
}
gotValue, gotErr := ReadUIntFromLabelBelowBound(tc.obj, tc.label, tc.max)

if diff := cmp.Diff(tc.wantVal, gotValue); diff != "" {
t.Errorf("Unexpected value (-want,+got):\n%s", diff)
}

if diff := cmp.Diff(tc.wantErr, gotErrStr); diff != "" {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
if diff := cmp.Diff(tc.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" {
t.Errorf("Reconcile returned error (-want,+got):\n%s", diff)
}
})
}
Expand Down

0 comments on commit 1680a2a

Please sign in to comment.