diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index ac4f7a4a1..6a6cfd740 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -457,7 +457,7 @@ def execute_operations( dropped_count = operations_batch.dropped_operations_count operations_preprocessor = OperationsPreprocessor() - operations_preprocessor.process_batch(operations_batch.operations) + operations_preprocessor.process(operations_batch.operations) preprocessed_operations = operations_preprocessor.get_operations() errors.extend(preprocessed_operations.errors) diff --git a/src/neptune/internal/backends/operations_preprocessor.py b/src/neptune/internal/backends/operations_preprocessor.py index 7e6816a60..b71abc768 100644 --- a/src/neptune/internal/backends/operations_preprocessor.py +++ b/src/neptune/internal/backends/operations_preprocessor.py @@ -21,7 +21,6 @@ from typing import ( Callable, List, - Type, TypeVar, ) @@ -45,7 +44,6 @@ DeleteFiles, LogFloats, LogImages, - LogOperation, LogStrings, Operation, RemoveStrings, @@ -72,41 +70,24 @@ class AccumulatedOperations: errors: List[MetadataInconsistency] = dataclasses.field(default_factory=list) - def all_operations(self) -> List[Operation]: - return self.upload_operations + self.artifact_operations + self.other_operations - class OperationsPreprocessor: def __init__(self): self._accumulators: typing.Dict[str, "_OperationsAccumulator"] = dict() self.processed_ops_count = 0 - self.final_ops_count = 0 - self.final_append_count = 0 - - def process(self, operation: Operation) -> bool: - """Adds a single operation to its processed list. - Returns `False` iff the new operation can't be in queue until one of already enqueued operations gets - synchronized with server first. - """ - try: - self._process_op(operation) - self.processed_ops_count += 1 - return True - except RequiresPreviousCompleted: - return False - - def process_batch(self, operations: List[Operation]) -> None: + + def process(self, operations: List[Operation]): for op in operations: - if not self.process(op): + try: + self._process_op(op) + self.processed_ops_count += 1 + except RequiresPreviousCompleted: return def _process_op(self, op: Operation) -> "_OperationsAccumulator": path_str = path_to_str(op.path) target_acc = self._accumulators.setdefault(path_str, _OperationsAccumulator(op.path)) - old_ops_count, old_append_count = target_acc.get_op_count(), target_acc.get_append_count() target_acc.visit(op) - self.final_ops_count += target_acc.get_op_count() - old_ops_count - self.final_append_count += target_acc.get_append_count() - old_append_count return target_acc @staticmethod @@ -162,8 +143,6 @@ def __init__(self, path: List[str]): self._modify_ops = [] self._config_ops = [] self._errors = [] - self._ops_count = 0 - self._append_count = 0 def get_operations(self) -> List[Operation]: return self._delete_ops + self._modify_ops + self._config_ops @@ -171,12 +150,6 @@ def get_operations(self) -> List[Operation]: def get_errors(self) -> List[MetadataInconsistency]: return self._errors - def get_op_count(self) -> int: - return self._ops_count - - def get_append_count(self) -> int: - return self._append_count - def _check_prerequisites(self, op: Operation): if (OperationsPreprocessor.is_file_op(op) or OperationsPreprocessor.is_artifact_op(op)) and len( self._delete_ops @@ -206,9 +179,7 @@ def _process_modify_op( else: self._check_prerequisites(op) self._type = expected_type - old_op_count = len(self._modify_ops) self._modify_ops = modifier(self._modify_ops, op) - self._ops_count += len(self._modify_ops) - old_op_count def _process_config_op(self, expected_type: _DataType, op: Operation) -> None: @@ -228,9 +199,7 @@ def _process_config_op(self, expected_type: _DataType, op: Operation) -> None: else: self._check_prerequisites(op) self._type = expected_type - old_op_count = len(self._config_ops) self._config_ops = [op] - self._ops_count += len(self._config_ops) - old_op_count def visit_assign_float(self, op: AssignFloat) -> None: self._process_modify_op(_DataType.FLOAT, op, self._assign_modifier()) @@ -326,8 +295,6 @@ def visit_delete_attribute(self, op: DeleteAttribute) -> None: self._modify_ops = [] self._config_ops = [] self._type = None - self._ops_count = len(self._delete_ops) - self._append_count = 0 else: # This case is tricky. There was no delete operation, but some modifications was performed. # We do not know if this attribute exists on server side and we do not want a delete op to fail. @@ -336,8 +303,6 @@ def visit_delete_attribute(self, op: DeleteAttribute) -> None: self._modify_ops = [] self._config_ops = [] self._type = None - self._ops_count = len(self._delete_ops) - self._append_count = 0 else: if self._delete_ops: # Do nothing if there already is a delete operation @@ -347,7 +312,6 @@ def visit_delete_attribute(self, op: DeleteAttribute) -> None: # If value has not been set locally yet and no delete operation was performed, # simply perform single delete operation. self._delete_ops.append(op) - self._ops_count = len(self._delete_ops) @staticmethod def _artifact_log_modifier( @@ -376,30 +340,23 @@ def visit_copy_attribute(self, op: CopyAttribute) -> None: def _assign_modifier(): return lambda ops, new_op: [new_op] - def _clear_modifier(self): - def modifier(ops: List[Operation], new_op: Operation): - for op in ops: - if isinstance(op, LogOperation): - self._append_count -= op.value_count() - return [new_op] - - return modifier + @staticmethod + def _clear_modifier(): + return lambda ops, new_op: [new_op] - def _log_modifier(self, log_op_class: Type[LogOperation], clear_op_class: type, log_combine: Callable[[T, T], T]): - def modifier(ops: List[Operation], new_op: Operation): + @staticmethod + def _log_modifier(log_op_class: type, clear_op_class: type, log_combine: Callable[[T, T], T]): + def modifier(ops, new_op): if len(ops) == 0: - res = [new_op] + return [new_op] elif len(ops) == 1 and isinstance(ops[0], log_op_class): - res = [log_combine(ops[0], new_op)] + return [log_combine(ops[0], new_op)] elif len(ops) == 1 and isinstance(ops[0], clear_op_class): - res = [ops[0], new_op] + return [ops[0], new_op] elif len(ops) == 2: - res = [ops[0], log_combine(ops[1], new_op)] + return [ops[0], log_combine(ops[1], new_op)] else: raise InternalClientError("Preprocessing operations failed: len(ops) == {}".format(len(ops))) - if isinstance(new_op, log_op_class): # Check just so that static typing doesn't complain - self._append_count += new_op.value_count() - return res return modifier diff --git a/src/neptune/internal/operation.py b/src/neptune/internal/operation.py index 02124c73a..b0080617f 100644 --- a/src/neptune/internal/operation.py +++ b/src/neptune/internal/operation.py @@ -292,9 +292,7 @@ def from_dict(data: dict) -> "UploadFileSet": class LogOperation(Operation, abc.ABC): - @abc.abstractmethod - def value_count(self) -> int: - pass + pass @dataclass @@ -334,9 +332,6 @@ def from_dict(data: dict) -> "LogFloats": [LogFloats.ValueType.from_dict(value) for value in data["values"]], ) - def value_count(self) -> int: - return len(self.values) - @dataclass class LogStrings(LogOperation): @@ -360,9 +355,6 @@ def from_dict(data: dict) -> "LogStrings": [LogStrings.ValueType.from_dict(value) for value in data["values"]], ) - def value_count(self) -> int: - return len(self.values) - @dataclass class ImageValue: @@ -408,9 +400,6 @@ def from_dict(data: dict) -> "LogImages": [LogImages.ValueType.from_dict(value, ImageValue.deserializer) for value in data["values"]], ) - def value_count(self) -> int: - return len(self.values) - @dataclass class ClearFloatLog(Operation): diff --git a/src/neptune/internal/operation_processors/async_operation_processor.py b/src/neptune/internal/operation_processors/async_operation_processor.py index 3b6f9a9da..865102e76 100644 --- a/src/neptune/internal/operation_processors/async_operation_processor.py +++ b/src/neptune/internal/operation_processors/async_operation_processor.py @@ -26,32 +26,23 @@ ) from typing import ( Callable, - ClassVar, List, Optional, - Tuple, ) from neptune.constants import ASYNC_DIRECTORY from neptune.envs import NEPTUNE_SYNC_AFTER_STOP_TIMEOUT from neptune.exceptions import NeptuneSynchronizationAlreadyStoppedException from neptune.internal.backends.neptune_backend import NeptuneBackend -from neptune.internal.backends.operations_preprocessor import OperationsPreprocessor from neptune.internal.container_type import ContainerType -from neptune.internal.disk_queue import ( - DiskQueue, - QueueElement, -) +from neptune.internal.disk_queue import DiskQueue from neptune.internal.id_formats import UniqueId from neptune.internal.init.parameters import ( ASYNC_LAG_THRESHOLD, ASYNC_NO_PROGRESS_THRESHOLD, DEFAULT_STOP_TIMEOUT, ) -from neptune.internal.operation import ( - CopyAttribute, - Operation, -) +from neptune.internal.operation import Operation from neptune.internal.operation_processors.operation_processor import OperationProcessor from neptune.internal.operation_processors.operation_storage import ( OperationStorage, @@ -264,10 +255,6 @@ def close(self): self._queue.close() class ConsumerThread(Daemon): - MAX_OPERATIONS_IN_BATCH: ClassVar[int] = 1000 - MAX_APPENDS_IN_BATCH: ClassVar[int] = 100000 - MAX_BATCH_SIZE_BYTES: ClassVar[int] = 100 * 1024 * 1024 - def __init__( self, processor: "AsyncOperationProcessor", @@ -279,7 +266,6 @@ def __init__( self._batch_size = batch_size self._last_flush = 0 self._no_progress_exceeded = False - self._last_disk_record: Optional[QueueElement[Operation]] = None def run(self): try: @@ -296,42 +282,10 @@ def work(self) -> None: self._processor._queue.flush() while True: - batch = self.collect_batch() + batch = self._processor._queue.get_batch(self._batch_size) if not batch: return - operations, version = batch - self.process_batch(operations, version) - - def collect_batch(self) -> Optional[Tuple[List[Operation], int]]: - preprocessor = OperationsPreprocessor() - version: Optional[int] = None - total_bytes = 0 - copy_ops: List[CopyAttribute] = [] - while ( - preprocessor.final_ops_count < self.MAX_OPERATIONS_IN_BATCH - and preprocessor.final_append_count < self.MAX_APPENDS_IN_BATCH - and total_bytes < self.MAX_BATCH_SIZE_BYTES - ): - record: Optional[QueueElement[Operation]] = self._last_disk_record or self._processor._queue.get() - self._last_disk_record = None - if not record: - break - if isinstance(record.obj, CopyAttribute): - # CopyAttribute can be only at the start of a batch. - if copy_ops or preprocessor.final_ops_count: - self._last_disk_record = record - break - else: - version = record.ver - copy_ops.append(record.obj) - total_bytes += record.size - elif preprocessor.process(record.obj): - version = record.ver - total_bytes += record.size - else: - self._last_disk_record = record - break - return (copy_ops + preprocessor.get_operations().all_operations(), version) if version is not None else None + self.process_batch([element.obj for element in batch], batch[-1].ver) def _check_no_progress(self): if not self._no_progress_exceeded: diff --git a/tests/unit/neptune/new/internal/backends/test_operations_preprocessor.py b/tests/unit/neptune/new/internal/backends/test_operations_preprocessor.py index 4d0ab59be..042da27ae 100644 --- a/tests/unit/neptune/new/internal/backends/test_operations_preprocessor.py +++ b/tests/unit/neptune/new/internal/backends/test_operations_preprocessor.py @@ -61,7 +61,7 @@ def test_delete_attribute(self): ] # when - processor.process_batch(operations) + processor.process(operations) # then result = processor.get_operations() @@ -101,7 +101,7 @@ def test_assign(self): ] # when - processor.process_batch(operations) + processor.process(operations) # then result = processor.get_operations() @@ -160,7 +160,7 @@ def test_series(self): ] # when - processor.process_batch(operations) + processor.process(operations) # then result = processor.get_operations() @@ -231,7 +231,7 @@ def test_sets(self): ] # when - processor.process_batch(operations) + processor.process(operations) # then result = processor.get_operations() @@ -283,7 +283,7 @@ def test_file_set(self): ] # when - processor.process_batch(operations) + processor.process(operations) # then result = processor.get_operations() @@ -328,25 +328,25 @@ def test_file_ops_delete(self): ] # when - processor.process_batch(operations) + processor.process(operations) # then: there's a cutoff after DeleteAttribute(["a"]) result = processor.get_operations() self.assertEqual( + result.upload_operations, [ UploadFileSet(["b"], ["abc", "defgh"], reset=True), UploadFileSet(["c"], ["abc", "defgh"], reset=True), UploadFileSet(["c"], ["qqq"], reset=False), UploadFileSet(["d"], ["hhh", "gij"], reset=False), ], - result.upload_operations, ) self.assertEqual(result.artifact_operations, []) self.assertEqual( + result.other_operations, [ DeleteAttribute(["a"]), ], - result.other_operations, ) self.assertEqual(result.errors, []) self.assertEqual(processor.processed_ops_count, 6) @@ -375,7 +375,7 @@ def test_artifacts(self): ] # when - processor.process_batch(operations) + processor.process(operations) # then: there's a cutoff before second TrackFilesToArtifact(["a"]) due to DeleteAttribute(["a"]) result = processor.get_operations()