From 156a0dd4501733e167c64015289ebabd3096e6ec Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Sun, 7 Jan 2024 19:38:49 +0800 Subject: [PATCH] feat: Add import reader for Parquet (#29618) This PR implements a Parquet reader for import. issue: https://github.com/milvus-io/milvus/issues/28521 --------- Signed-off-by: bigsheeper --- .../util/importutilv2/parquet/field_reader.go | 568 ++++++++++++++++++ internal/util/importutilv2/parquet/reader.go | 113 ++++ .../util/importutilv2/parquet/reader_test.go | 447 ++++++++++++++ internal/util/importutilv2/parquet/util.go | 175 ++++++ 4 files changed, 1303 insertions(+) create mode 100644 internal/util/importutilv2/parquet/field_reader.go create mode 100644 internal/util/importutilv2/parquet/reader.go create mode 100644 internal/util/importutilv2/parquet/reader_test.go create mode 100644 internal/util/importutilv2/parquet/util.go diff --git a/internal/util/importutilv2/parquet/field_reader.go b/internal/util/importutilv2/parquet/field_reader.go new file mode 100644 index 0000000000000..162ea59e92ad3 --- /dev/null +++ b/internal/util/importutilv2/parquet/field_reader.go @@ -0,0 +1,568 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parquet + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "github.com/samber/lo" + "golang.org/x/exp/constraints" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type FieldReader struct { + columnIndex int + columnReader *pqarrow.ColumnReader + + dim int + field *schemapb.FieldSchema +} + +func NewFieldReader(reader *pqarrow.FileReader, columnIndex int, field *schemapb.FieldSchema) (*FieldReader, error) { + columnReader, err := reader.GetColumn(context.Background(), columnIndex) // TODO: dyh, resolve context + if err != nil { + return nil, err + } + + var dim int64 = 1 + if typeutil.IsVectorType(field.GetDataType()) { + dim, err = typeutil.GetDim(field) + if err != nil { + return nil, err + } + } + + cr := &FieldReader{ + columnIndex: columnIndex, + columnReader: columnReader, + dim: int(dim), + field: field, + } + return cr, nil +} + +func (c *FieldReader) Next(count int64) (any, error) { + switch c.field.GetDataType() { + case schemapb.DataType_Bool: + return ReadBoolData(c, count) + case schemapb.DataType_Int8: + return ReadIntegerOrFloatData[int8](c, count) + case schemapb.DataType_Int16: + return ReadIntegerOrFloatData[int16](c, count) + case schemapb.DataType_Int32: + return ReadIntegerOrFloatData[int32](c, count) + case schemapb.DataType_Int64: + return ReadIntegerOrFloatData[int64](c, count) + case schemapb.DataType_Float: + data, err := ReadIntegerOrFloatData[float32](c, count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + return data, typeutil.VerifyFloats32(data.([]float32)) + case schemapb.DataType_Double: + data, err := ReadIntegerOrFloatData[float64](c, count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + return data, typeutil.VerifyFloats64(data.([]float64)) + case schemapb.DataType_VarChar, schemapb.DataType_String: + return ReadStringData(c, count) + case schemapb.DataType_JSON: + // JSON field read data from string array Parquet + data, err := ReadStringData(c, count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + byteArr := make([][]byte, 0) + for _, str := range data.([]string) { + var dummy interface{} + err = json.Unmarshal([]byte(str), &dummy) + if err != nil { + return nil, err + } + byteArr = append(byteArr, []byte(str)) + } + return byteArr, nil + case schemapb.DataType_BinaryVector: + return ReadBinaryData(c, count) + case schemapb.DataType_FloatVector: + arrayData, err := ReadIntegerOrFloatArrayData[float32](c, count) + if err != nil { + return nil, err + } + if arrayData == nil { + return nil, nil + } + vectors := lo.Flatten(arrayData.([][]float32)) + return vectors, nil + case schemapb.DataType_Array: + data := make([]*schemapb.ScalarField, 0, count) + elementType := c.field.GetElementType() + switch elementType { + case schemapb.DataType_Bool: + boolArray, err := ReadBoolArrayData(c, count) + if err != nil { + return nil, err + } + if boolArray == nil { + return nil, nil + } + for _, elementArray := range boolArray.([][]bool) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int8: + int8Array, err := ReadIntegerOrFloatArrayData[int32](c, count) + if err != nil { + return nil, err + } + if int8Array == nil { + return nil, nil + } + for _, elementArray := range int8Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int16: + int16Array, err := ReadIntegerOrFloatArrayData[int32](c, count) + if err != nil { + return nil, err + } + if int16Array == nil { + return nil, nil + } + for _, elementArray := range int16Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int32: + int32Array, err := ReadIntegerOrFloatArrayData[int32](c, count) + if err != nil { + return nil, err + } + if int32Array == nil { + return nil, nil + } + for _, elementArray := range int32Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int64: + int64Array, err := ReadIntegerOrFloatArrayData[int64](c, count) + if err != nil { + return nil, err + } + if int64Array == nil { + return nil, nil + } + for _, elementArray := range int64Array.([][]int64) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Float: + float32Array, err := ReadIntegerOrFloatArrayData[float32](c, count) + if err != nil { + return nil, err + } + if float32Array == nil { + return nil, nil + } + for _, elementArray := range float32Array.([][]float32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Double: + float64Array, err := ReadIntegerOrFloatArrayData[float64](c, count) + if err != nil { + return nil, err + } + if float64Array == nil { + return nil, nil + } + for _, elementArray := range float64Array.([][]float64) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_VarChar, schemapb.DataType_String: + stringArray, err := ReadStringArrayData(c, count) + if err != nil { + return nil, err + } + if stringArray == nil { + return nil, nil + } + for _, elementArray := range stringArray.([][]string) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: elementArray, + }, + }, + }) + } + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type '%s' for array field '%s'", + elementType.String(), c.field.GetName())) + } + return data, nil + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type '%s' for field '%s'", + c.field.GetDataType().String(), c.field.GetName())) + } +} + +func (c *FieldReader) Close() {} + +func ReadBoolData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]bool, 0, count) + for _, chunk := range chunked.Chunks() { + dataNums := chunk.Data().Len() + chunkData := make([]bool, dataNums) + boolReader, ok := chunk.(*array.Boolean) + if !ok { + return nil, WrapTypeErr("bool", chunk.DataType().Name(), pcr.field) + } + for i := 0; i < dataNums; i++ { + chunkData[i] = boolReader.Value(i) + } + data = append(data, chunkData...) + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadIntegerOrFloatData[T constraints.Integer | constraints.Float](pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]T, 0, count) + for _, chunk := range chunked.Chunks() { + dataNums := chunk.Data().Len() + chunkData := make([]T, dataNums) + switch chunk.DataType().ID() { + case arrow.INT8: + int8Reader := chunk.(*array.Int8) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(int8Reader.Value(i)) + } + case arrow.INT16: + int16Reader := chunk.(*array.Int16) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(int16Reader.Value(i)) + } + case arrow.INT32: + int32Reader := chunk.(*array.Int32) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(int32Reader.Value(i)) + } + case arrow.INT64: + int64Reader := chunk.(*array.Int64) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(int64Reader.Value(i)) + } + case arrow.FLOAT32: + float32Reader := chunk.(*array.Float32) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(float32Reader.Value(i)) + } + case arrow.FLOAT64: + float64Reader := chunk.(*array.Float64) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(float64Reader.Value(i)) + } + default: + return nil, WrapTypeErr("integer|float", chunk.DataType().Name(), pcr.field) + } + data = append(data, chunkData...) + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadStringData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]string, 0, count) + for _, chunk := range chunked.Chunks() { + dataNums := chunk.Data().Len() + chunkData := make([]string, dataNums) + stringReader, ok := chunk.(*array.String) + if !ok { + return nil, WrapTypeErr("string", chunk.DataType().Name(), pcr.field) + } + for i := 0; i < dataNums; i++ { + chunkData[i] = stringReader.Value(i) + } + data = append(data, chunkData...) + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadBinaryData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]byte, 0, count) + for _, chunk := range chunked.Chunks() { + dataNums := chunk.Data().Len() + switch chunk.DataType().ID() { + case arrow.BINARY: + binaryReader := chunk.(*array.Binary) + for i := 0; i < dataNums; i++ { + data = append(data, binaryReader.Value(i)...) + } + case arrow.LIST: + listReader := chunk.(*array.List) + if !isRegularVector(listReader.Offsets(), pcr.dim, true) { + return nil, merr.WrapErrImportFailed("binary vector is irregular") + } + uint8Reader, ok := listReader.ListValues().(*array.Uint8) + if !ok { + return nil, WrapTypeErr("binary", listReader.ListValues().DataType().Name(), pcr.field) + } + for i := 0; i < uint8Reader.Len(); i++ { + data = append(data, uint8Reader.Value(i)) + } + default: + return nil, WrapTypeErr("binary", chunk.DataType().Name(), pcr.field) + } + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func isRegularVector(offsets []int32, dim int, isBinary bool) bool { + if len(offsets) < 1 { + return false + } + if isBinary { + dim = dim / 8 + } + for i := 1; i < len(offsets); i++ { + if offsets[i]-offsets[i-1] != int32(dim) { + return false + } + } + return true +} + +func ReadBoolArrayData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([][]bool, 0, count) + for _, chunk := range chunked.Chunks() { + listReader, ok := chunk.(*array.List) + if !ok { + return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) + } + boolReader, ok := listReader.ListValues().(*array.Boolean) + if !ok { + return nil, WrapTypeErr("boolArray", chunk.DataType().Name(), pcr.field) + } + offsets := listReader.Offsets() + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]bool, 0, end-start) + for j := start; j < end; j++ { + elementData = append(elementData, boolReader.Value(int(j))) + } + data = append(data, elementData) + } + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadIntegerOrFloatArrayData[T constraints.Integer | constraints.Float](pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([][]T, 0, count) + + getDataFunc := func(offsets []int32, getValue func(int) T) { + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]T, 0, end-start) + for j := start; j < end; j++ { + elementData = append(elementData, getValue(int(j))) + } + data = append(data, elementData) + } + } + for _, chunk := range chunked.Chunks() { + listReader, ok := chunk.(*array.List) + if !ok { + return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) + } + offsets := listReader.Offsets() + if typeutil.IsVectorType(pcr.field.GetDataType()) && + !isRegularVector(offsets, pcr.dim, pcr.field.GetDataType() == schemapb.DataType_BinaryVector) { + return nil, merr.WrapErrImportFailed("float vector is irregular") + } + valueReader := listReader.ListValues() + switch valueReader.DataType().ID() { + case arrow.INT8: + int8Reader := valueReader.(*array.Int8) + getDataFunc(offsets, func(i int) T { + return T(int8Reader.Value(i)) + }) + case arrow.INT16: + int16Reader := valueReader.(*array.Int16) + getDataFunc(offsets, func(i int) T { + return T(int16Reader.Value(i)) + }) + case arrow.INT32: + int32Reader := valueReader.(*array.Int32) + getDataFunc(offsets, func(i int) T { + return T(int32Reader.Value(i)) + }) + case arrow.INT64: + int64Reader := valueReader.(*array.Int64) + getDataFunc(offsets, func(i int) T { + return T(int64Reader.Value(i)) + }) + case arrow.FLOAT32: + float32Reader := valueReader.(*array.Float32) + getDataFunc(offsets, func(i int) T { + return T(float32Reader.Value(i)) + }) + case arrow.FLOAT64: + float64Reader := valueReader.(*array.Float64) + getDataFunc(offsets, func(i int) T { + return T(float64Reader.Value(i)) + }) + default: + return nil, WrapTypeErr("integerArray|floatArray", chunk.DataType().Name(), pcr.field) + } + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadStringArrayData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([][]string, 0, count) + for _, chunk := range chunked.Chunks() { + listReader, ok := chunk.(*array.List) + if !ok { + return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) + } + stringReader, ok := listReader.ListValues().(*array.String) + if !ok { + return nil, WrapTypeErr("stringArray", chunk.DataType().Name(), pcr.field) + } + offsets := listReader.Offsets() + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]string, 0, end-start) + for j := start; j < end; j++ { + elementData = append(elementData, stringReader.Value(int(j))) + } + data = append(data, elementData) + } + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} diff --git a/internal/util/importutilv2/parquet/reader.go b/internal/util/importutilv2/parquet/reader.go new file mode 100644 index 0000000000000..ba2c4f9d21c28 --- /dev/null +++ b/internal/util/importutilv2/parquet/reader.go @@ -0,0 +1,113 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parquet + +import ( + "fmt" + + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet" + "github.com/apache/arrow/go/v12/parquet/file" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type Reader struct { + reader *file.Reader + + bufferSize int + + schema *schemapb.CollectionSchema + frs map[int64]*FieldReader // fieldID -> FieldReader +} + +func NewReader(schema *schemapb.CollectionSchema, cmReader storage.FileReader, bufferSize int) (*Reader, error) { + const pqBufSize = 32 * 1024 * 1024 // TODO: dyh, make if configurable + size := calcBufferSize(pqBufSize, schema) + reader, err := file.NewParquetReader(cmReader, file.WithReadProps(&parquet.ReaderProperties{ + BufferSize: int64(size), + BufferedStreamEnabled: true, + })) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("new parquet reader failed, err=%v", err)) + } + log.Info("create parquet reader done", zap.Int("row group num", reader.NumRowGroups()), + zap.Int64("num rows", reader.NumRows())) + + fileReader, err := pqarrow.NewFileReader(reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("new parquet file reader failed, err=%v", err)) + } + + crs, err := CreateFieldReaders(fileReader, schema) + if err != nil { + return nil, err + } + return &Reader{ + reader: reader, + bufferSize: bufferSize, + schema: schema, + frs: crs, + }, nil +} + +func (r *Reader) Read() (*storage.InsertData, error) { + insertData, err := storage.NewInsertData(r.schema) + if err != nil { + return nil, err + } +OUTER: + for { + for fieldID, cr := range r.frs { + data, err := cr.Next(1) + if err != nil { + return nil, err + } + if data == nil { + break OUTER + } + err = insertData.Data[fieldID].AppendRows(data) + if err != nil { + return nil, err + } + } + if insertData.GetMemorySize() >= r.bufferSize { + break + } + } + for fieldID := range r.frs { + if insertData.Data[fieldID].RowNum() == 0 { + return nil, nil + } + } + return insertData, nil +} + +func (r *Reader) Close() { + for _, cr := range r.frs { + cr.Close() + } + err := r.reader.Close() + if err != nil { + log.Warn("close parquet reader failed", zap.Error(err)) + } +} diff --git a/internal/util/importutilv2/parquet/reader_test.go b/internal/util/importutilv2/parquet/reader_test.go new file mode 100644 index 0000000000000..1b0e1a0639493 --- /dev/null +++ b/internal/util/importutilv2/parquet/reader_test.go @@ -0,0 +1,447 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parquet + +import ( + "context" + "fmt" + "io" + "math" + "math/rand" + "os" + "testing" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type ReaderSuite struct { + suite.Suite + + numRows int + pkDataType schemapb.DataType + vecDataType schemapb.DataType +} + +func (s *ReaderSuite) SetupSuite() { + paramtable.Get().Init(paramtable.NewBaseTable()) +} + +func (s *ReaderSuite) SetupTest() { + // default suite params + s.numRows = 100 + s.pkDataType = schemapb.DataType_Int64 + s.vecDataType = schemapb.DataType_FloatVector +} + +func milvusDataTypeToArrowType(dataType schemapb.DataType, isBinary bool) arrow.DataType { + switch dataType { + case schemapb.DataType_Bool: + return &arrow.BooleanType{} + case schemapb.DataType_Int8: + return &arrow.Int8Type{} + case schemapb.DataType_Int16: + return &arrow.Int16Type{} + case schemapb.DataType_Int32: + return &arrow.Int32Type{} + case schemapb.DataType_Int64: + return &arrow.Int64Type{} + case schemapb.DataType_Float: + return &arrow.Float32Type{} + case schemapb.DataType_Double: + return &arrow.Float64Type{} + case schemapb.DataType_VarChar, schemapb.DataType_String: + return &arrow.StringType{} + case schemapb.DataType_Array: + return &arrow.ListType{} + case schemapb.DataType_JSON: + return &arrow.StringType{} + case schemapb.DataType_FloatVector: + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Float32Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + case schemapb.DataType_BinaryVector: + if isBinary { + return &arrow.BinaryType{} + } + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Uint8Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + case schemapb.DataType_Float16Vector: + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Float16Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + default: + panic("unsupported data type") + } +} + +func convertMilvusSchemaToArrowSchema(schema *schemapb.CollectionSchema) *arrow.Schema { + fields := make([]arrow.Field, 0) + for _, field := range schema.GetFields() { + if field.GetDataType() == schemapb.DataType_Array { + fields = append(fields, arrow.Field{ + Name: field.GetName(), + Type: arrow.ListOfField(arrow.Field{ + Name: "item", + Type: milvusDataTypeToArrowType(field.GetElementType(), false), + Nullable: true, + Metadata: arrow.Metadata{}, + }), + Nullable: true, + Metadata: arrow.Metadata{}, + }) + continue + } + fields = append(fields, arrow.Field{ + Name: field.GetName(), + Type: milvusDataTypeToArrowType(field.GetDataType(), field.Name == "FieldBinaryVector2"), + Nullable: true, + Metadata: arrow.Metadata{}, + }) + } + return arrow.NewSchema(fields, nil) +} + +func randomString(length int) string { + letterRunes := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + b := make([]rune, length) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + +func buildArrayData(dataType, elementType schemapb.DataType, dim, rows int, isBinary bool) arrow.Array { + mem := memory.NewGoAllocator() + switch dataType { + case schemapb.DataType_Bool: + builder := array.NewBooleanBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(i%2 == 0) + } + return builder.NewBooleanArray() + case schemapb.DataType_Int8: + builder := array.NewInt8Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int8(i)) + } + return builder.NewInt8Array() + case schemapb.DataType_Int16: + builder := array.NewInt16Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int16(i)) + } + return builder.NewInt16Array() + case schemapb.DataType_Int32: + builder := array.NewInt32Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int32(i)) + } + return builder.NewInt32Array() + case schemapb.DataType_Int64: + builder := array.NewInt64Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int64(i)) + } + return builder.NewInt64Array() + case schemapb.DataType_Float: + builder := array.NewFloat32Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(float32(i) * 0.1) + } + return builder.NewFloat32Array() + case schemapb.DataType_Double: + builder := array.NewFloat64Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(float64(i) * 0.02) + } + return builder.NewFloat64Array() + case schemapb.DataType_VarChar, schemapb.DataType_String: + builder := array.NewStringBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(randomString(10)) + } + return builder.NewStringArray() + case schemapb.DataType_FloatVector: + builder := array.NewListBuilder(mem, &arrow.Float32Type{}) + offsets := make([]int32, 0, rows) + valid := make([]bool, 0, rows) + for i := 0; i < dim*rows; i++ { + builder.ValueBuilder().(*array.Float32Builder).Append(float32(i)) + } + for i := 0; i < rows; i++ { + offsets = append(offsets, int32(i*dim)) + valid = append(valid, true) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_BinaryVector: + if isBinary { + builder := array.NewBinaryBuilder(mem, &arrow.BinaryType{}) + for i := 0; i < rows; i++ { + element := make([]byte, dim/8) + for j := 0; j < dim/8; j++ { + element[j] = randomString(1)[0] + } + builder.Append(element) + } + return builder.NewBinaryArray() + } + builder := array.NewListBuilder(mem, &arrow.Uint8Type{}) + offsets := make([]int32, 0, rows) + valid := make([]bool, 0) + for i := 0; i < dim*rows/8; i++ { + builder.ValueBuilder().(*array.Uint8Builder).Append(uint8(i)) + } + for i := 0; i < rows; i++ { + offsets = append(offsets, int32(dim*i/8)) + valid = append(valid, true) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_JSON: + builder := array.NewStringBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(fmt.Sprintf("{\"a\": \"%s\", \"b\": %d}", randomString(3), i)) + } + return builder.NewStringArray() + case schemapb.DataType_Array: + offsets := make([]int32, 0, rows) + valid := make([]bool, 0, rows) + index := 0 + for i := 0; i < rows; i++ { + index += i % 10 + offsets = append(offsets, int32(index)) + valid = append(valid, true) + } + switch elementType { + case schemapb.DataType_Bool: + builder := array.NewListBuilder(mem, &arrow.BooleanType{}) + valueBuilder := builder.ValueBuilder().(*array.BooleanBuilder) + for i := 0; i < index; i++ { + valueBuilder.Append(i%2 == 0) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int8: + builder := array.NewListBuilder(mem, &arrow.Int8Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int8Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int8(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int16: + builder := array.NewListBuilder(mem, &arrow.Int16Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int16Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int16(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int32: + builder := array.NewListBuilder(mem, &arrow.Int32Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int32Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int32(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int64: + builder := array.NewListBuilder(mem, &arrow.Int64Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int64Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int64(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Float: + builder := array.NewListBuilder(mem, &arrow.Float32Type{}) + valueBuilder := builder.ValueBuilder().(*array.Float32Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(float32(i) * 0.1) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Double: + builder := array.NewListBuilder(mem, &arrow.Float64Type{}) + valueBuilder := builder.ValueBuilder().(*array.Float64Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(float64(i) * 0.02) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_VarChar, schemapb.DataType_String: + builder := array.NewListBuilder(mem, &arrow.StringType{}) + valueBuilder := builder.ValueBuilder().(*array.StringBuilder) + for i := 0; i < index; i++ { + valueBuilder.Append(randomString(5) + "-" + fmt.Sprintf("%d", i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + } + } + return nil +} + +func writeParquet(w io.Writer, schema *schemapb.CollectionSchema, numRows int) error { + pqSchema := convertMilvusSchemaToArrowSchema(schema) + fw, err := pqarrow.NewFileWriter(pqSchema, w, parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(int64(numRows))), pqarrow.DefaultWriterProps()) + if err != nil { + return err + } + defer fw.Close() + + columns := make([]arrow.Array, 0, len(schema.Fields)) + for _, field := range schema.Fields { + var dim int64 = 1 + if typeutil.IsVectorType(field.GetDataType()) { + dim, err = typeutil.GetDim(field) + if err != nil { + return err + } + } + columnData := buildArrayData(field.DataType, field.ElementType, int(dim), numRows, field.Name == "FieldBinaryVector2") + columns = append(columns, columnData) + } + recordBatch := array.NewRecord(pqSchema, columns, int64(numRows)) + err = fw.Write(recordBatch) + if err != nil { + return err + } + + return nil +} + +func (s *ReaderSuite) run(dt schemapb.DataType) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: s.pkDataType, + }, + { + FieldID: 101, + Name: "vec", + DataType: s.vecDataType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + { + FieldID: 102, + Name: dt.String(), + DataType: dt, + ElementType: schemapb.DataType_Int32, + }, + }, + } + + filePath := fmt.Sprintf("/tmp/test_%d_reader.parquet", rand.Int()) + defer os.Remove(filePath) + wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) + assert.NoError(s.T(), err) + err = writeParquet(wf, schema, s.numRows) + assert.NoError(s.T(), err) + + ctx := context.Background() + f := storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus_test/test_parquet_reader/")) + cm, err := f.NewPersistentStorageChunkManager(ctx) + assert.NoError(s.T(), err) + cmReader, err := cm.Reader(ctx, filePath) + assert.NoError(s.T(), err) + reader, err := NewReader(schema, cmReader, math.MaxInt) + s.NoError(err) + + checkFn := func(actualInsertData *storage.InsertData, offsetBegin, expectRows int) { + // expectInsertData := insertData + for _, data := range actualInsertData.Data { + s.Equal(expectRows, data.RowNum()) + // TODO: dyh, check rows + // fieldDataType := typeutil.GetField(schema, fieldID).GetDataType() + // for i := 0; i < expectRows; i++ { + // expect := expectInsertData.Data[fieldID].GetRow(i + offsetBegin) + // actual := data.GetRow(i) + // if fieldDataType == schemapb.DataType_Array { + // s.True(slices.Equal(expect.(*schemapb.ScalarField).GetIntData().GetData(), actual.(*schemapb.ScalarField).GetIntData().GetData())) + // } else { + // s.Equal(expect, actual) + // } + // } + } + } + + res, err := reader.Read() + s.NoError(err) + checkFn(res, 0, s.numRows) +} + +func (s *ReaderSuite) TestReadScalarFields() { + s.run(schemapb.DataType_Bool) + s.run(schemapb.DataType_Int8) + s.run(schemapb.DataType_Int16) + s.run(schemapb.DataType_Int32) + s.run(schemapb.DataType_Int64) + s.run(schemapb.DataType_Float) + s.run(schemapb.DataType_Double) + s.run(schemapb.DataType_VarChar) + s.run(schemapb.DataType_Array) + s.run(schemapb.DataType_JSON) +} + +func (s *ReaderSuite) TestStringPK() { + s.pkDataType = schemapb.DataType_VarChar + s.run(schemapb.DataType_Int32) +} + +func (s *ReaderSuite) TestBinaryAndFloat16Vector() { + s.vecDataType = schemapb.DataType_BinaryVector + s.run(schemapb.DataType_Int32) + // s.vecDataType = schemapb.DataType_Float16Vector + // s.run(schemapb.DataType_Int32) // TODO: dyh, support float16 vector +} + +func TestUtil(t *testing.T) { + suite.Run(t, new(ReaderSuite)) +} diff --git a/internal/util/importutilv2/parquet/util.go b/internal/util/importutilv2/parquet/util.go new file mode 100644 index 0000000000000..3c2e4baac5673 --- /dev/null +++ b/internal/util/importutilv2/parquet/util.go @@ -0,0 +1,175 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parquet + +import ( + "fmt" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func WrapTypeErr(expect string, actual string, field *schemapb.FieldSchema) error { + return merr.WrapErrImportFailed( + fmt.Sprintf("expect '%s' type for field '%s', but got '%s' type", + expect, field.GetName(), actual)) +} + +func calcBufferSize(blockSize int, schema *schemapb.CollectionSchema) int { + if len(schema.GetFields()) <= 0 { + return blockSize + } + return blockSize / len(schema.GetFields()) +} + +func CreateFieldReaders(fileReader *pqarrow.FileReader, schema *schemapb.CollectionSchema) (map[int64]*FieldReader, error) { + nameToField := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) string { + return field.GetName() + }) + + pqSchema, err := fileReader.Schema() + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("get parquet schema failed, err=%v", err)) + } + + crs := make(map[int64]*FieldReader) + for i, pqField := range pqSchema.Fields() { + field, ok := nameToField[pqField.Name] + if !ok { + // TODO @cai.zhang: handle dynamic field + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the field: %s is not in schema, "+ + "if it's a dynamic field, please reformat data by bulk_writer", pqField.Name)) + } + if field.GetIsPrimaryKey() && field.GetAutoID() { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", field.GetName())) + } + + arrowType, isList := convertArrowSchemaToDataType(pqField, false) + dataType := field.GetDataType() + if isList { + if !typeutil.IsVectorType(dataType) && dataType != schemapb.DataType_Array { + return nil, WrapTypeErr(dataType.String(), pqField.Type.Name(), field) + } + if dataType == schemapb.DataType_Array { + dataType = field.GetElementType() + } + } + if !isConvertible(arrowType, dataType, isList) { + return nil, WrapTypeErr(dataType.String(), pqField.Type.Name(), field) + } + + cr, err := NewFieldReader(fileReader, i, field) + if err != nil { + return nil, err + } + if _, ok = crs[field.GetFieldID()]; ok { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("there is multi field with name: %s", field.GetName())) + } + crs[field.GetFieldID()] = cr + } + + for _, field := range nameToField { + if (field.GetIsPrimaryKey() && field.GetAutoID()) || field.GetIsDynamic() { + continue + } + if _, ok := crs[field.GetFieldID()]; !ok { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("no parquet field for milvus file '%s'", field.GetName())) + } + } + return crs, nil +} + +func convertArrowSchemaToDataType(field arrow.Field, isList bool) (schemapb.DataType, bool) { + switch field.Type.ID() { + case arrow.BOOL: + return schemapb.DataType_Bool, false + case arrow.UINT8: + if isList { + return schemapb.DataType_BinaryVector, false + } + return schemapb.DataType_None, false + case arrow.INT8: + return schemapb.DataType_Int8, false + case arrow.INT16: + return schemapb.DataType_Int16, false + case arrow.INT32: + return schemapb.DataType_Int32, false + case arrow.INT64: + return schemapb.DataType_Int64, false + case arrow.FLOAT16: + if isList { + return schemapb.DataType_Float16Vector, false + } + return schemapb.DataType_None, false + case arrow.FLOAT32: + return schemapb.DataType_Float, false + case arrow.FLOAT64: + return schemapb.DataType_Double, false + case arrow.STRING: + return schemapb.DataType_VarChar, false + case arrow.BINARY: + return schemapb.DataType_BinaryVector, false + case arrow.LIST: + elementType, _ := convertArrowSchemaToDataType(field.Type.(*arrow.ListType).ElemField(), true) + return elementType, true + default: + return schemapb.DataType_None, false + } +} + +func isConvertible(src, dst schemapb.DataType, isList bool) bool { + switch src { + case schemapb.DataType_Bool: + return typeutil.IsBoolType(dst) + case schemapb.DataType_Int8: + return typeutil.IsArithmetic(dst) + case schemapb.DataType_Int16: + return typeutil.IsArithmetic(dst) && dst != schemapb.DataType_Int8 + case schemapb.DataType_Int32: + return typeutil.IsArithmetic(dst) && dst != schemapb.DataType_Int8 && dst != schemapb.DataType_Int16 + case schemapb.DataType_Int64: + return typeutil.IsFloatingType(dst) || dst == schemapb.DataType_Int64 + case schemapb.DataType_Float: + if isList && dst == schemapb.DataType_FloatVector { + return true + } + return typeutil.IsFloatingType(dst) + case schemapb.DataType_Double: + if isList && dst == schemapb.DataType_FloatVector { + return true + } + return dst == schemapb.DataType_Double + case schemapb.DataType_String, schemapb.DataType_VarChar: + return typeutil.IsStringType(dst) || typeutil.IsJSONType(dst) + case schemapb.DataType_JSON: + return typeutil.IsJSONType(dst) + case schemapb.DataType_BinaryVector: + return dst == schemapb.DataType_BinaryVector + case schemapb.DataType_Float16Vector: + return dst == schemapb.DataType_Float16Vector + default: + return false + } +}