From 82055b3468d77e02443683efa0e32779a7af0b4a Mon Sep 17 00:00:00 2001 From: gouhongshen Date: Wed, 16 Oct 2024 20:16:54 +0800 Subject: [PATCH] fix Intersection2Vector panic (#19383) --- pkg/container/vector/vector.go | 129 ++++++++++++--- pkg/container/vector/vector_test.go | 238 ++++++++++++++++++++++++++++ 2 files changed, 344 insertions(+), 23 deletions(-) diff --git a/pkg/container/vector/vector.go b/pkg/container/vector/vector.go index 5efb6d84595c4..7a507b02873d6 100644 --- a/pkg/container/vector/vector.go +++ b/pkg/container/vector/vector.go @@ -4301,7 +4301,12 @@ func BuildVarlenaFromArray[T types.RealNumbers](vec *Vector, v *types.Varlena, a // Intersection2VectorOrdered does a ∩ b ==> ret, keeps all item unique and sorted // it assumes that a and b all sorted already -func Intersection2VectorOrdered[T types.OrderedT | types.Decimal128](a, b []T, ret *Vector, mp *mpool.MPool, cmp func(x, y T) int) { +func Intersection2VectorOrdered[T types.OrderedT | types.Decimal128]( + a, b []T, + ret *Vector, + mp *mpool.MPool, + cmp func(x, y T) int) (err error) { + var long, short []T if len(a) < len(b) { long = b @@ -4312,44 +4317,76 @@ func Intersection2VectorOrdered[T types.OrderedT | types.Decimal128](a, b []T, r } var lenLong, lenShort = len(long), len(short) - ret.PreExtend(lenLong+lenShort, mp) + if err = ret.PreExtend(lenLong+lenShort, mp); err != nil { + return err + } for i := range short { idx := sort.Search(lenLong, func(j int) bool { return cmp(long[j], short[i]) >= 0 }) + if idx >= lenLong { break } + j := idx if cmp(short[i], long[idx]) == 0 { - AppendFixed(ret, short[i], false, mp) + if err = AppendFixed(ret, short[i], false, mp); err != nil { + return err + } + + j++ + + // skip the same item + for j < lenLong && cmp(long[j], long[j-1]) == 0 { + j++ + } + + if j >= lenLong { + break + } } + idx = j + long = long[idx:] + lenLong = len(long) } + return nil } // Union2VectorOrdered does a ∪ b ==> ret, keeps all item unique and sorted // it assumes that a and b all sorted already -func Union2VectorOrdered[T types.OrderedT | types.Decimal128](a, b []T, ret *Vector, mp *mpool.MPool, cmp func(x, y T) int) { +func Union2VectorOrdered[T types.OrderedT | types.Decimal128]( + a, b []T, + ret *Vector, + mp *mpool.MPool, + cmp func(x, y T) int) (err error) { + var i, j int var prevVal T var lenA, lenB = len(a), len(b) - ret.PreExtend(lenA+lenB, mp) + if err = ret.PreExtend(lenA+lenB, mp); err != nil { + return err + } for i < lenA && j < lenB { if cmp(a[i], b[j]) <= 0 { if (i == 0 && j == 0) || cmp(prevVal, a[i]) != 0 { prevVal = a[i] - AppendFixed(ret, a[i], false, mp) + if err = AppendFixed(ret, a[i], false, mp); err != nil { + return err + } } i++ } else { if (i == 0 && j == 0) || cmp(prevVal, b[j]) != 0 { prevVal = b[j] - AppendFixed(ret, b[j], false, mp) + if err = AppendFixed(ret, b[j], false, mp); err != nil { + return err + } } j++ } @@ -4358,21 +4395,30 @@ func Union2VectorOrdered[T types.OrderedT | types.Decimal128](a, b []T, ret *Vec for ; i < lenA; i++ { if (i == 0 && j == 0) || cmp(prevVal, a[i]) != 0 { prevVal = a[i] - AppendFixed(ret, a[i], false, mp) + if err = AppendFixed(ret, a[i], false, mp); err != nil { + return err + } } } for ; j < lenB; j++ { if (i == 0 && j == 0) || cmp(prevVal, b[j]) != 0 { prevVal = b[j] - AppendFixed(ret, b[j], false, mp) + if err = AppendFixed(ret, b[j], false, mp); err != nil { + return err + } } } + return nil } // Intersection2VectorVarlen does a ∩ b ==> ret, keeps all item unique and sorted // it assumes that va and vb all sorted already -func Intersection2VectorVarlen(va, vb *Vector, ret *Vector, mp *mpool.MPool) { +func Intersection2VectorVarlen( + va, vb *Vector, + ret *Vector, + mp *mpool.MPool) (err error) { + var shortCol, longCol []types.Varlena var shortArea, longArea []byte @@ -4393,28 +4439,53 @@ func Intersection2VectorVarlen(va, vb *Vector, ret *Vector, mp *mpool.MPool) { var lenLong, lenShort = len(longCol), len(shortCol) - ret.PreExtend(lenLong+lenShort, mp) + if err = ret.PreExtend(lenLong+lenShort, mp); err != nil { + return err + } for i := range shortCol { shortBytes := shortCol[i].GetByteSlice(shortArea) idx := sort.Search(lenLong, func(j int) bool { return bytes.Compare(longCol[j].GetByteSlice(longArea), shortBytes) >= 0 }) + if idx >= lenLong { break } + j := idx if bytes.Equal(shortBytes, longCol[idx].GetByteSlice(longArea)) { - AppendBytes(ret, shortBytes, false, mp) + if err = AppendBytes(ret, shortBytes, false, mp); err != nil { + return err + } + + // skip the same item + j++ + for j < lenLong && bytes.Equal( + longCol[j].GetByteSlice(longArea), longCol[j-1].GetByteSlice(longArea)) { + j++ + } + + if j >= lenLong { + break + } } + idx = j + longCol = longCol[idx:] + lenLong = len(longCol) } + return nil } // Union2VectorValen does a ∪ b ==> ret, keeps all item unique and sorted // it assumes that va and vb all sorted already -func Union2VectorValen(va, vb *Vector, ret *Vector, mp *mpool.MPool) { +func Union2VectorValen( + va, vb *Vector, + ret *Vector, + mp *mpool.MPool) (err error) { + var i, j int var prevVal []byte @@ -4423,22 +4494,28 @@ func Union2VectorValen(va, vb *Vector, ret *Vector, mp *mpool.MPool) { var lenA, lenB = len(cola), len(colb) - ret.PreExtend(lenA+lenB, mp) + if err = ret.PreExtend(lenA+lenB, mp); err != nil { + return err + } for i < lenA && j < lenB { - bb := colb[j].GetByteSlice(areab) ba := cola[i].GetByteSlice(areaa) + bb := colb[j].GetByteSlice(areab) if bytes.Compare(ba, bb) <= 0 { - if (i == 0 && j == 0) || bytes.Equal(prevVal, ba) { + if (i == 0 && j == 0) || !bytes.Equal(prevVal, ba) { prevVal = ba - AppendBytes(ret, ba, false, mp) + if err = AppendBytes(ret, ba, false, mp); err != nil { + return err + } } i++ } else { - if (i == 0 && j == 0) || bytes.Equal(prevVal, bb) { + if (i == 0 && j == 0) || !bytes.Equal(prevVal, bb) { prevVal = bb - AppendBytes(ret, bb, false, mp) + if err = AppendBytes(ret, bb, false, mp); err != nil { + return err + } } j++ } @@ -4446,17 +4523,23 @@ func Union2VectorValen(va, vb *Vector, ret *Vector, mp *mpool.MPool) { for ; i < lenA; i++ { ba := cola[i].GetByteSlice(areaa) - if (i == 0 && j == 0) || bytes.Equal(prevVal, ba) { + if (i == 0 && j == 0) || !bytes.Equal(prevVal, ba) { prevVal = ba - AppendBytes(ret, ba, false, mp) + if err = AppendBytes(ret, ba, false, mp); err != nil { + return err + } } } for ; j < lenB; j++ { bb := colb[j].GetByteSlice(areab) - if (i == 0 && j == 0) || bytes.Equal(prevVal, bb) { + if (i == 0 && j == 0) || !bytes.Equal(prevVal, bb) { prevVal = bb - AppendBytes(ret, bb, false, mp) + if err = AppendBytes(ret, bb, false, mp); err != nil { + return err + } } } + + return nil } diff --git a/pkg/container/vector/vector_test.go b/pkg/container/vector/vector_test.go index 71c57791d11a8..66922c09dea31 100644 --- a/pkg/container/vector/vector_test.go +++ b/pkg/container/vector/vector_test.go @@ -15,6 +15,10 @@ package vector import ( + "fmt" + rand2 "math/rand" + "slices" + "strings" "testing" "github.com/matrixorigin/matrixone/pkg/common/mpool" @@ -1935,3 +1939,237 @@ func BenchmarkMustFixedCol(b *testing.B) { MustFixedCol[int8](vec) } } + +func TestIntersection2VectorOrdered(t *testing.T) { + const ll = 10000 + const cnt = 100 + + mp := mpool.MustNewZero() + + lenA := rand2.Intn(ll) + ll/5 + lenB := rand2.Intn(ll) + ll/5 + + for range cnt { + var a []int32 = make([]int32, lenA) + var b []int32 = make([]int32, lenB) + + for i := 0; i < lenA; i++ { + a[i] = rand2.Int31() % (ll / 2) + } + + for i := 0; i < lenB; i++ { + b[i] = rand2.Int31() % (ll / 2) + } + + cmp := func(x, y int32) int { + return int(x) - int(y) + } + + slices.SortFunc(a, cmp) + slices.SortFunc(b, cmp) + + ret := NewVec(types.T_int32.ToType()) + Intersection2VectorOrdered(a, b, ret, mp, cmp) + + mm := make(map[int32]struct{}) + + for i := range a { + for j := range b { + if cmp(a[i], b[j]) == 0 { + mm[a[i]] = struct{}{} + } + } + } + + col := MustFixedCol[int32](ret) + + require.Equal(t, len(mm), len(col)) + + for i := range col { + _, ok := mm[col[i]] + require.True(t, ok) + } + } +} + +func TestIntersection2VectorVarlen(t *testing.T) { + const ll = 5000 + const cnt = 100 + + mp := mpool.MustNewZero() + + lenA := rand2.Intn(ll) + ll/5 + lenB := rand2.Intn(ll) + ll/5 + + for range cnt { + var a = make([]string, lenA) + var b = make([]string, lenB) + + va := NewVec(types.T_text.ToType()) + vb := NewVec(types.T_text.ToType()) + + for i := 0; i < lenA; i++ { + x := rand2.Int31() % (ll / 2) + a[i] = fmt.Sprintf("%d", x) + } + + for i := 0; i < lenB; i++ { + x := rand2.Int31() % (ll / 2) + b[i] = fmt.Sprintf("%d", x) + } + + cmp := func(x, y string) int { + return strings.Compare(string(x), string(y)) + } + + slices.SortFunc(a, cmp) + slices.SortFunc(b, cmp) + + for i := 0; i < lenA; i++ { + AppendBytes(va, []byte(a[i]), false, mp) + } + + for i := 0; i < lenB; i++ { + AppendBytes(vb, []byte(b[i]), false, mp) + } + + ret := NewVec(types.T_text.ToType()) + Intersection2VectorVarlen(va, vb, ret, mp) + + mm := make(map[string]struct{}) + + for i := range a { + for j := range b { + if cmp(a[i], b[j]) == 0 { + mm[a[i]] = struct{}{} + } + } + } + + col, area := MustVarlenaRawData(ret) + + require.Equal(t, len(mm), len(col)) + + for i := range col { + _, ok := mm[col[i].GetString(area)] + require.True(t, ok) + } + } +} + +func TestUnion2VectorOrdered(t *testing.T) { + const ll = 10000 + const cnt = 100 + + mp := mpool.MustNewZero() + + lenA := rand2.Intn(ll) + ll/5 + lenB := rand2.Intn(ll) + ll/5 + + for range cnt { + var a []int32 = make([]int32, lenA) + var b []int32 = make([]int32, lenB) + + for i := 0; i < lenA; i++ { + a[i] = rand2.Int31() % (ll / 2) + } + + for i := 0; i < lenB; i++ { + b[i] = rand2.Int31() % (ll / 2) + } + + cmp := func(x, y int32) int { + return int(x) - int(y) + } + + slices.SortFunc(a, cmp) + slices.SortFunc(b, cmp) + + ret := NewVec(types.T_int32.ToType()) + Union2VectorOrdered(a, b, ret, mp, cmp) + + mm := make(map[int32]struct{}) + + for i := range a { + mm[a[i]] = struct{}{} + } + + for i := range b { + mm[b[i]] = struct{}{} + } + + col := MustFixedCol[int32](ret) + + require.Equal(t, len(mm), len(col)) + + for i := range col { + _, ok := mm[col[i]] + require.True(t, ok) + } + } +} + +func TestUnion2VectorVarlen(t *testing.T) { + const ll = 5000 + const cnt = 100 + + mp := mpool.MustNewZero() + + lenA := rand2.Intn(ll) + ll/5 + lenB := rand2.Intn(ll) + ll/5 + + for range cnt { + var a = make([]string, lenA) + var b = make([]string, lenB) + + va := NewVec(types.T_text.ToType()) + vb := NewVec(types.T_text.ToType()) + + for i := 0; i < lenA; i++ { + x := rand2.Int31() % (ll / 2) + a[i] = fmt.Sprintf("%d", x) + } + + for i := 0; i < lenB; i++ { + x := rand2.Int31() % (ll / 2) + b[i] = fmt.Sprintf("%d", x) + } + + cmp := func(x, y string) int { + return strings.Compare(string(x), string(y)) + } + + slices.SortFunc(a, cmp) + slices.SortFunc(b, cmp) + + for i := 0; i < lenA; i++ { + AppendBytes(va, []byte(a[i]), false, mp) + } + + for i := 0; i < lenB; i++ { + AppendBytes(vb, []byte(b[i]), false, mp) + } + + ret := NewVec(types.T_text.ToType()) + Union2VectorValen(va, vb, ret, mp) + + mm := make(map[string]struct{}) + + for i := range a { + mm[a[i]] = struct{}{} + } + + for i := range b { + mm[b[i]] = struct{}{} + } + + col, area := MustVarlenaRawData(ret) + + require.Equal(t, len(mm), len(col)) + + for i := range col { + _, ok := mm[col[i].GetString(area)] + require.True(t, ok) + } + } +}