Skip to content

Commit

Permalink
Bulkwriter set row group for parquet (#1836)
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo authored Dec 28, 2023
1 parent a38a502 commit f67c8c9
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 56 deletions.
21 changes: 16 additions & 5 deletions examples/example_bulkwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def build_simple_collection():
print(f"Collection '{collection.name}' created")
return collection.schema

def build_all_type_schema(bin_vec: bool):
def build_all_type_schema(bin_vec: bool, has_array: bool):
print(f"\n===================== build all types schema ====================")
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False),
Expand All @@ -93,6 +93,11 @@ def build_all_type_schema(bin_vec: bool):
FieldSchema(name="json", dtype=DataType.JSON),
FieldSchema(name="vector", dtype=DataType.BINARY_VECTOR, dim=DIM) if bin_vec else FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=DIM),
]

if has_array:
fields.append(FieldSchema(name="array_str", dtype=DataType.ARRAY, max_capacity=100, element_type=DataType.VARCHAR, max_length=128))
fields.append(FieldSchema(name="array_int", dtype=DataType.ARRAY, max_capacity=100, element_type=DataType.INT64))

schema = CollectionSchema(fields=fields, enable_dynamic_field=True)
return schema

Expand All @@ -118,8 +123,6 @@ def local_writer(schema: CollectionSchema, file_type: BulkFileType):
segment_size=128*1024*1024,
file_type=file_type,
) as local_writer:
# read data from csv
read_sample_data("./data/train_embeddings.csv", local_writer)

# append rows
for i in range(100000):
Expand Down Expand Up @@ -245,6 +248,9 @@ def all_types_writer(bin_vec: bool, schema: CollectionSchema, file_type: BulkFil
"json": {"dummy": i, "ok": f"name_{i}"},
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
f"dynamic_{i}": i,
# bulkinsert doesn't support import npy with array field, the below values will be stored into dynamic field
"array_str": [f"str_{k}" for k in range(5)],
"array_int": [k for k in range(10)],
}
remote_writer.append_row(row)

Expand All @@ -263,6 +269,9 @@ def all_types_writer(bin_vec: bool, schema: CollectionSchema, file_type: BulkFil
"json": json.dumps({"dummy": i, "ok": f"name_{i}"}),
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
f"dynamic_{i}": i,
# bulkinsert doesn't support import npy with array field, the below values will be stored into dynamic field
"array_str": np.array([f"str_{k}" for k in range(5)], np.dtype("str")),
"array_int": np.array([k for k in range(10)], np.dtype("int64")),
})

print(f"{remote_writer.total_row_count} rows appends")
Expand Down Expand Up @@ -383,15 +392,17 @@ def cloud_bulkinsert():
parallel_append(schema)

# float vectors + all scalar types
schema = build_all_type_schema(bin_vec=False)
for file_type in file_types:
# Note: bulkinsert doesn't support import npy with array field
schema = build_all_type_schema(bin_vec=False, has_array=False if file_type==BulkFileType.NPY else True)
batch_files = all_types_writer(bin_vec=False, schema=schema, file_type=file_type)
call_bulkinsert(schema, batch_files)
retrieve_imported_data(bin_vec=False)

# binary vectors + all scalar types
schema = build_all_type_schema(bin_vec=True)
for file_type in file_types:
# Note: bulkinsert doesn't support import npy with array field
schema = build_all_type_schema(bin_vec=True, has_array=False if file_type == BulkFileType.NPY else True)
batch_files = all_types_writer(bin_vec=True, schema=schema, file_type=file_type)
call_bulkinsert(schema, batch_files)
retrieve_imported_data(bin_vec=True)
Expand Down
70 changes: 54 additions & 16 deletions pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from .constants import (
DYNAMIC_FIELD_NAME,
MB,
NUMPY_TYPE_CREATOR,
BulkFileType,
)
Expand Down Expand Up @@ -74,6 +75,14 @@ def _throw(self, msg: str):
logger.error(msg)
raise MilvusException(message=msg)

def _raw_obj(self, x: object):
if isinstance(x, np.ndarray):
return x.tolist()
if isinstance(x, np.generic):
return x.item()

return x

def append_row(self, row: dict):
dynamic_values = {}
if DYNAMIC_FIELD_NAME in row and not isinstance(row[DYNAMIC_FIELD_NAME], dict):
Expand All @@ -85,14 +94,14 @@ def append_row(self, row: dict):
continue

if k not in self._buffer:
dynamic_values[k] = row[k]
dynamic_values[k] = self._raw_obj(row[k])
else:
self._buffer[k].append(row[k])

if DYNAMIC_FIELD_NAME in self._buffer:
self._buffer[DYNAMIC_FIELD_NAME].append(dynamic_values)

def persist(self, local_path: str) -> list:
def persist(self, local_path: str, **kwargs) -> list:
# verify row count of fields are equal
row_count = -1
for k in self._buffer:
Expand All @@ -107,17 +116,18 @@ def persist(self, local_path: str) -> list:

# output files
if self._file_type == BulkFileType.NPY:
return self._persist_npy(local_path)
return self._persist_npy(local_path, **kwargs)
if self._file_type == BulkFileType.JSON_RB:
return self._persist_json_rows(local_path)
return self._persist_json_rows(local_path, **kwargs)
if self._file_type == BulkFileType.PARQUET:
return self._persist_parquet(local_path)
return self._persist_parquet(local_path, **kwargs)

self._throw(f"Unsupported file tpye: {self._file_type}")
return []

def _persist_npy(self, local_path: str):
def _persist_npy(self, local_path: str, **kwargs):
file_list = []
row_count = len(next(iter(self._buffer.values())))
for k in self._buffer:
full_file_name = Path(local_path).joinpath(k + ".npy")
file_list.append(str(full_file_name))
Expand All @@ -127,7 +137,10 @@ def _persist_npy(self, local_path: str):
# numpy data type specify
dt = None
field_schema = self._fields[k]
if field_schema.dtype.name in NUMPY_TYPE_CREATOR:
if field_schema.dtype == DataType.ARRAY:
element_type = field_schema.element_type
dt = NUMPY_TYPE_CREATOR[element_type.name]
elif field_schema.dtype.name in NUMPY_TYPE_CREATOR:
dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name]

# for JSON field, convert to string array
Expand All @@ -140,9 +153,9 @@ def _persist_npy(self, local_path: str):
arr = np.array(self._buffer[k], dtype=dt)
np.save(str(full_file_name), arr)
except Exception as e:
self._throw(f"Failed to persist column-based file {full_file_name}, error: {e}")
self._throw(f"Failed to persist file {full_file_name}, error: {e}")

logger.info(f"Successfully persist column-based file {full_file_name}")
logger.info(f"Successfully persist file {full_file_name}, row count: {row_count}")

if len(file_list) != len(self._buffer):
logger.error("Some of fields were not persisted successfully, abort the files")
Expand All @@ -154,7 +167,7 @@ def _persist_npy(self, local_path: str):

return file_list

def _persist_json_rows(self, local_path: str):
def _persist_json_rows(self, local_path: str, **kwargs):
rows = []
row_count = len(next(iter(self._buffer.values())))
row_index = 0
Expand All @@ -173,12 +186,12 @@ def _persist_json_rows(self, local_path: str):
with file_path.open("w") as json_file:
json.dump(data, json_file, indent=2)
except Exception as e:
self._throw(f"Failed to persist row-based file {file_path}, error: {e}")
self._throw(f"Failed to persist file {file_path}, error: {e}")

logger.info(f"Successfully persist row-based file {file_path}")
logger.info(f"Successfully persist file {file_path}, row count: {len(rows)}")
return [str(file_path)]

def _persist_parquet(self, local_path: str):
def _persist_parquet(self, local_path: str, **kwargs):
file_path = Path(local_path + ".parquet")

data = {}
Expand All @@ -203,10 +216,35 @@ def _persist_parquet(self, local_path: str):
elif field_schema.dtype.name in NUMPY_TYPE_CREATOR:
dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name]
data[k] = pd.Series(self._buffer[k], dtype=dt)
else:
# dtype is null, let pandas deduce the type, might not work
data[k] = pd.Series(self._buffer[k])

# calculate a proper row group size
row_group_size_min = 1000
row_group_size = 10000
row_group_size_max = 1000000
if "buffer_size" in kwargs and "buffer_row_count" in kwargs:
row_group_bytes = kwargs.get(
"row_group_bytes", 32 * MB
) # 32MB is an experience value that avoid high memory usage of parquet reader on server-side
buffer_size = kwargs.get("buffer_size", 1)
buffer_row_count = kwargs.get("buffer_row_count", 1)
size_per_row = int(buffer_size / buffer_row_count) + 1
row_group_size = int(row_group_bytes / size_per_row)
if row_group_size < row_group_size_min:
row_group_size = row_group_size_min
if row_group_size > row_group_size_max:
row_group_size = row_group_size_max

# write to Parquet file
data_frame = pd.DataFrame(data=data)
data_frame.to_parquet(file_path, engine="pyarrow") # don't use fastparquet

logger.info(f"Successfully persist parquet file {file_path}")
data_frame.to_parquet(
file_path, row_group_size=row_group_size, engine="pyarrow"
) # don't use fastparquet

logger.info(
f"Successfully persist file {file_path}, total size: {buffer_size},"
f" row count: {buffer_row_count}, row group size: {row_group_size}"
)
return [str(file_path)]
94 changes: 68 additions & 26 deletions pymilvus/bulk_writer/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pymilvus.client.types import DataType
from pymilvus.exceptions import MilvusException
from pymilvus.orm.schema import CollectionSchema
from pymilvus.orm.schema import CollectionSchema, FieldSchema

from .buffer import (
Buffer,
Expand All @@ -39,6 +39,7 @@ def __init__(
schema: CollectionSchema,
segment_size: int,
file_type: BulkFileType = BulkFileType.NPY,
**kwargs,
):
self._schema = schema
self._buffer_size = 0
Expand Down Expand Up @@ -107,6 +108,62 @@ def _throw(self, msg: str):
logger.error(msg)
raise MilvusException(message=msg)

def _verify_vector(self, x: object, field: FieldSchema):
dtype = DataType(field.dtype)
validator = TYPE_VALIDATOR[dtype.name]
dim = field.params["dim"]
if not validator(x, dim):
self._throw(
f"Illegal vector data for vector field: '{field.name}',"
f" dim is not {dim} or type mismatch"
)

return len(x) * 4 if dtype == DataType.FLOAT_VECTOR else len(x) / 8

def _verify_json(self, x: object, field: FieldSchema):
size = 0
validator = TYPE_VALIDATOR[DataType.JSON.name]
if isinstance(x, str):
size = len(x)
x = self._try_convert_json(field.name, x)
elif validator(x):
size = len(json.dumps(x))
else:
self._throw(f"Illegal JSON value for field '{field.name}', type mismatch")

return x, size

def _verify_varchar(self, x: object, field: FieldSchema):
max_len = field.params["max_length"]
validator = TYPE_VALIDATOR[DataType.VARCHAR.name]
if not validator(x, max_len):
self._throw(
f"Illegal varchar value for field '{field.name}',"
f" length exceeds {max_len} or type mismatch"
)

return len(x)

def _verify_array(self, x: object, field: FieldSchema):
max_capacity = field.params["max_capacity"]
element_type = field.element_type
validator = TYPE_VALIDATOR[DataType.ARRAY.name]
if not validator(x, max_capacity):
self._throw(
f"Illegal array value for field '{field.name}', length exceeds capacity or type mismatch"
)

row_size = 0
if element_type.name in TYPE_SIZE:
row_size = TYPE_SIZE[element_type.name] * len(x)
elif element_type == DataType.VARCHAR:
for ele in x:
row_size = row_size + self._verify_varchar(ele, field)
else:
self._throw(f"Unsupported element type for array field '{field.name}'")

return row_size

def _verify_row(self, row: dict):
if not isinstance(row, dict):
self._throw("The input row must be a dict object")
Expand All @@ -125,41 +182,26 @@ def _verify_row(self, row: dict):
self._throw(f"The field '{field.name}' is missed in the row")

dtype = DataType(field.dtype)
validator = TYPE_VALIDATOR[dtype.name]
if dtype in {DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR}:
if isinstance(row[field.name], np.ndarray):
row[field.name] = row[field.name].tolist()
dim = field.params["dim"]
if not validator(row[field.name], dim):
self._throw(
f"Illegal vector data for vector field: '{field.name}',"
f" dim is not {dim} or type mismatch"
)

vec_size = (
len(row[field.name]) * 4
if dtype == DataType.FLOAT_VECTOR
else len(row[field.name]) / 8
)
row_size = row_size + vec_size
row_size = row_size + self._verify_vector(row[field.name], field)
elif dtype == DataType.VARCHAR:
max_len = field.params["max_length"]
if not validator(row[field.name], max_len):
self._throw(
f"Illegal varchar value for field '{field.name}',"
f" length exceeds {max_len} or type mismatch"
)

row_size = row_size + len(row[field.name])
row_size = row_size + self._verify_varchar(row[field.name], field)
elif dtype == DataType.JSON:
row[field.name] = self._try_convert_json(field.name, row[field.name])
if not validator(row[field.name]):
self._throw(f"Illegal JSON value for field '{field.name}', type mismatch")
row[field.name], size = self._verify_json(row[field.name], field)
row_size = row_size + size
elif dtype == DataType.ARRAY:
if isinstance(row[field.name], np.ndarray):
row[field.name] = row[field.name].tolist()

row_size = row_size + len(row[field.name])
row_size = row_size + self._verify_array(row[field.name], field)
else:
if isinstance(row[field.name], np.generic):
row[field.name] = row[field.name].item()

validator = TYPE_VALIDATOR[dtype.name]
if not validator(row[field.name]):
self._throw(
f"Illegal scalar value for field '{field.name}', value overflow or type mismatch"
Expand Down
12 changes: 6 additions & 6 deletions pymilvus/bulk_writer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@

TYPE_SIZE = {
DataType.BOOL.name: 1,
DataType.INT8.name: 8,
DataType.INT16.name: 8,
DataType.INT32.name: 8,
DataType.INT8.name: 1,
DataType.INT16.name: 2,
DataType.INT32.name: 4,
DataType.INT64.name: 8,
DataType.FLOAT.name: 8,
DataType.FLOAT.name: 4,
DataType.DOUBLE.name: 8,
}

Expand All @@ -43,10 +43,10 @@
DataType.FLOAT.name: lambda x: isinstance(x, float),
DataType.DOUBLE.name: lambda x: isinstance(x, float),
DataType.VARCHAR.name: lambda x, max_len: isinstance(x, str) and len(x) <= max_len,
DataType.JSON.name: lambda x: isinstance(x, dict) and len(x) <= 65535,
DataType.JSON.name: lambda x: isinstance(x, (dict, list)),
DataType.FLOAT_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) == dim,
DataType.BINARY_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) * 8 == dim,
DataType.ARRAY.name: lambda x: isinstance(x, list),
DataType.ARRAY.name: lambda x, cap: isinstance(x, list) and len(x) <= cap,
}

NUMPY_TYPE_CREATOR = {
Expand Down
Loading

0 comments on commit f67c8c9

Please sign in to comment.