diff --git a/charidotella/__init__.py b/charidotella/__init__.py index 1e6d0f2..15a76de 100644 --- a/charidotella/__init__.py +++ b/charidotella/__init__.py @@ -3,22 +3,20 @@ import argparse import copy import functools -import importlib.resources import json import pathlib import re -import shutil import sys import tempfile import typing import uuid -import event_stream import jsonschema import toml from . import animals as animals from . import filters as filters +from . import formats as formats from . import tasks as tasks from . import utilities as utilities from .version import __version__ as __version__ @@ -79,8 +77,9 @@ def main(): init_parser.add_argument( "--glob", "-g", - default="recordings/*.es", - help="Glob pattern used to search for Event Stream files", + nargs="*", + default=["recordings/*.es", "recordings/*.aedat4"], + help="Glob pattern used to search for Event Stream and AEDAT4 files", ) init_parser.add_argument( "--configuration", @@ -332,14 +331,14 @@ def run_generators(configuration: dict[str, typing.Any]): utilities.error( f'"{configuration_path}" already exists (use --force to override it)' ) - paths = [ - path.resolve() - for path in pathlib.Path(".").glob(args.glob) - if path.is_file() and path.suffix == ".es" - ] + paths = [] + for glob in args.glob: + for path in pathlib.Path(".").glob(glob): + if path.is_file(): + paths.append(path.resolve()) paths.sort(key=lambda path: (path.stem, path.parent)) if len(paths) == 0: - utilities.error(f'no .es files match "{args.glob}"') + utilities.error(f'no files match "{args.glob}"') if args.new_names: names = animals.generate_names(len(paths)) else: @@ -349,23 +348,9 @@ def run_generators(configuration: dict[str, typing.Any]): for path in paths: if path.stem in name_to_path: utilities.error( - f'two files have the same name ("{name_to_path[path.stem]}" and "{path}"), rename one or do *not* use the flag --preserve-name' + f'two files have the same name ("{name_to_path[path.stem]}" and "{path}"), rename one or do use the flag --new-names' ) name_to_path[path.stem] = path - attachments: dict[str, list[dict[str, str]]] = {} - for name, path in zip(names, paths): - for sibling in path.parent.iterdir(): - if sibling != path and sibling.stem == path.stem: - if not name in attachments: - attachments[name] = [] - attachments[name].append( - { - "source": str( - sibling.relative_to(configuration_path.parent) - ), - "target": f"{name}{sibling.suffix}", - } - ) jobs = [] for index, (name, path) in enumerate(zip(names, paths)): utilities.info( @@ -374,7 +359,7 @@ def run_generators(configuration: dict[str, typing.Any]): ) begin: typing.Optional[int] = None end: typing.Optional[int] = None - with event_stream.Decoder(path) as decoder: + with formats.Decoder(path) as decoder: for packet in decoder: if begin is None: begin = int(packet["t"][0]) @@ -737,14 +722,6 @@ def run_generators(configuration: dict[str, typing.Any]): configuration_file, encoder=Encoder(), ) - configuration_file.write( - "\n\n# attachments are copied in target directories, algonside generated files \n" - ) - toml.dump( - {"attachments": attachments}, - configuration_file, - encoder=Encoder(), - ) with open( utilities.with_suffix(configuration_path, ".part"), "r", @@ -768,8 +745,6 @@ def run_generators(configuration: dict[str, typing.Any]): configuration["tasks"] = {} if not "jobs" in configuration: configuration["jobs"] = [] - if not "attachments" in configuration: - configuration["attachments"] = {} run_generators(configuration) jsonschema.validate(configuration, configuration_schema()) if len(configuration["filters"]) == 0: @@ -808,12 +783,6 @@ def run_generators(configuration: dict[str, typing.Any]): utilities.error( f"parsing \"end\" ({job['end']}) in \"{job['name']}\" failed ({exception})" ) - for name, attachment in configuration["attachments"].items(): - targets = [file["target"] for file in attachment] - if len(targets) != len(set(targets)): - utilities.error( - f'two or more attachments share the same target in "{name}"' - ) configuration["filters"] = { name: { "type": filter["type"], @@ -872,37 +841,6 @@ def run_generators(configuration: dict[str, typing.Any]): parameters["filters"] = {} if not "tasks" in parameters: parameters["tasks"] = {} - if not "attachments" in parameters: - parameters["attachments"] = {} - if job["name"] in configuration["attachments"]: - for attachment in configuration["attachments"][job["name"]]: - if ( - not args.force - and attachment["target"] in parameters["attachments"] - and (directory / name / attachment["target"]).is_file() - ): - utilities.info( - "⏭ ", - f"skip copy {pathlib.Path(configuration_path.parent) / attachment['source']} → {attachment['target']}", - ) - else: - utilities.info( - "🗃 ", - f"copy {pathlib.Path(configuration_path.parent) / attachment['source']} → {attachment['target']}", - ) - shutil.copy2( - pathlib.Path(configuration_path.parent) - / attachment["source"], - utilities.with_suffix( - directory / name / attachment["target"], ".part" - ), - ) - utilities.with_suffix( - directory / name / attachment["target"], ".part" - ).replace(directory / name / attachment["target"]) - parameters["attachments"][attachment["target"]] = attachment[ - "source" - ] if len(job["filters"]) == 1: filter_name = job["filters"][0] filter = configuration["filters"][filter_name] @@ -1017,8 +955,6 @@ def run_generators(configuration: dict[str, typing.Any]): configuration["tasks"] = {} if not "jobs" in configuration: configuration["jobs"] = [] - if not "attachments" in configuration: - configuration["attachments"] = [] run_generators(configuration) jsonschema.validate(configuration, configuration_schema()) with open(pathlib.Path(args.output), "w", encoding="utf-8") as output_file: diff --git a/charidotella/filters/arbiter_saturation.py b/charidotella/filters/arbiter_saturation.py index 10461ff..ebdea27 100644 --- a/charidotella/filters/arbiter_saturation.py +++ b/charidotella/filters/arbiter_saturation.py @@ -6,6 +6,8 @@ import event_stream import numpy +from .. import formats + def consume_packets( events_packets: list[numpy.ndarray], @@ -53,7 +55,7 @@ def apply( parameters: dict[str, typing.Any], ) -> None: events_packets = [] - with event_stream.Decoder(input) as decoder: + with formats.Decoder(input) as decoder: with event_stream.Encoder( output, "dvs", diff --git a/charidotella/filters/default.py b/charidotella/filters/default.py index 4ac2b0e..63815ed 100644 --- a/charidotella/filters/default.py +++ b/charidotella/filters/default.py @@ -6,6 +6,8 @@ import event_stream import numpy +from .. import formats + def apply( input: pathlib.Path, @@ -14,7 +16,7 @@ def apply( end: int, parameters: dict[str, typing.Any], ) -> None: - with event_stream.Decoder(input) as decoder: + with formats.Decoder(input) as decoder: with event_stream.Encoder( output, "dvs", diff --git a/charidotella/filters/hot_pixels.py b/charidotella/filters/hot_pixels.py index 3b6325d..a00f1ff 100644 --- a/charidotella/filters/hot_pixels.py +++ b/charidotella/filters/hot_pixels.py @@ -7,6 +7,8 @@ import numpy import scipy.ndimage +from .. import formats + def apply( input: pathlib.Path, @@ -15,7 +17,8 @@ def apply( end: int, parameters: dict[str, typing.Any], ) -> None: - with event_stream.Decoder(input) as decoder: + count = None + with formats.Decoder(input) as decoder: count = numpy.zeros((decoder.width, decoder.height), dtype=numpy.uint64) for packet in decoder: if packet["t"][-1] < begin: @@ -33,6 +36,7 @@ def apply( else: events = packet numpy.add.at(count, (events["x"], events["y"]), 1) + assert count is not None shifted: list[numpy.ndarray] = [] for x, y in ((1, 0), (0, 1), (1, 2), (2, 1)): kernel = numpy.zeros((3, 3)) @@ -47,7 +51,7 @@ def apply( ) ratios = numpy.divide(count, numpy.maximum.reduce(shifted) + 1.0) mask = ratios < parameters["ratio"] - with event_stream.Decoder(input) as decoder: + with formats.Decoder(input) as decoder: with event_stream.Encoder( output, "dvs", diff --git a/charidotella/filters/refractory.py b/charidotella/filters/refractory.py index f57cfb5..0e7ffef 100644 --- a/charidotella/filters/refractory.py +++ b/charidotella/filters/refractory.py @@ -6,7 +6,7 @@ import event_stream import numpy -from .. import utilities +from .. import formats, utilities def apply( @@ -16,7 +16,7 @@ def apply( end: int, parameters: dict[str, typing.Any], ) -> None: - with event_stream.Decoder(input) as decoder: + with formats.Decoder(input) as decoder: refractory = numpy.uint64(utilities.timecode(parameters["refractory"])) threshold_t = numpy.zeros((decoder.width, decoder.height), dtype=numpy.uint64) with event_stream.Encoder( diff --git a/charidotella/filters/transpose.py b/charidotella/filters/transpose.py index 089536c..cb5bdb5 100644 --- a/charidotella/filters/transpose.py +++ b/charidotella/filters/transpose.py @@ -6,6 +6,8 @@ import event_stream import numpy +from .. import formats + def apply( input: pathlib.Path, @@ -14,7 +16,7 @@ def apply( end: int, parameters: dict[str, typing.Any], ) -> None: - with event_stream.Decoder(input) as decoder: + with formats.Decoder(input) as decoder: if parameters["method"] == "flip_left_right": width, height = decoder.width, decoder.height method = 0 diff --git a/charidotella/formats.py b/charidotella/formats.py new file mode 100644 index 0000000..0db35db --- /dev/null +++ b/charidotella/formats.py @@ -0,0 +1,81 @@ +import pathlib +import types +import typing + +import aedat +import event_stream +import numpy + +TYPE_EVENT_STREAM: int = 0 +TYPE_AEDAT: int = 1 + + +class Decoder: + def __init__(self, path: pathlib.Path): + self.width: int + self.height: int + if path.suffix == ".es": + self.type = TYPE_EVENT_STREAM + elif path.suffix == ".aedat4": + self.type = TYPE_AEDAT + else: + with open(path, "rb") as file: + magic = file.read(12) + if magic == b"Event Stream": + self.type = TYPE_EVENT_STREAM + elif magic == b"#!AER-DAT4.0": + self.type = TYPE_AEDAT + else: + raise Exception(f"unsupported file {path}") + self.t0 = None + if self.type == TYPE_EVENT_STREAM: + self.decoder = event_stream.Decoder(path) + assert self.decoder.type == "dvs" + self.width = self.decoder.width + self.height = self.decoder.height + else: + self.decoder = aedat.Decoder(path) # type: ignore + found = False + for stream in self.decoder.id_to_stream().values(): + if stream["type"] == "events": + self.width = stream["width"] + self.height = stream["height"] + found = True + break + if not found: + raise Exception(f"the file {path} contains no events") + + def __iter__(self): + return self + + def __next__(self) -> numpy.ndarray: + assert self.decoder is not None + if self.type == TYPE_EVENT_STREAM: + return self.decoder.__next__() + while True: + packet = self.decoder.__next__() + if "events" in packet: + events = packet["events"] + if len(events) > 0: + if self.t0 is None: + self.t0 = events["t"][0] + events["t"] -= self.t0 + events["y"] = self.height - 1 - events["y"] + return events + + def __enter__(self) -> "Decoder": + return self + + def __exit__( + self, + exception_type: typing.Optional[typing.Type[BaseException]], + value: typing.Optional[BaseException], + traceback: typing.Optional[types.TracebackType], + ) -> bool: + assert self.decoder is not None + if self.type == TYPE_EVENT_STREAM: + result = self.decoder.__exit__(exception_type, value, traceback) + else: + result = False + self.decoder = None + return result diff --git a/charidotella/tasks/colourtime.py b/charidotella/tasks/colourtime.py index ac86c87..582ebd2 100644 --- a/charidotella/tasks/colourtime.py +++ b/charidotella/tasks/colourtime.py @@ -9,6 +9,8 @@ import matplotlib.colors import PIL.Image +from .. import formats + EXTENSION = ".png" @@ -26,13 +28,13 @@ def run( ) else: time_mapping = colourtime.generate_linear_time_mapping(begin=begin, end=end) - with event_stream.Decoder(input) as decoder: + with formats.Decoder(input) as decoder: image = colourtime.convert( begin=begin, end=end, width=decoder.width, height=decoder.height, - decoder=decoder, + decoder=decoder, # type: ignore colormap=matplotlib.colormaps[parameters["colormap"]], # type: ignore time_mapping=time_mapping, alpha=parameters["alpha"], diff --git a/charidotella/version.py b/charidotella/version.py index df4be5e..528787c 100644 --- a/charidotella/version.py +++ b/charidotella/version.py @@ -1 +1 @@ -__version__ = "2.1.4" +__version__ = "3.0.0" diff --git a/configuration-schema.json b/configuration-schema.json index 65e25ec..1794e02 100644 --- a/configuration-schema.json +++ b/configuration-schema.json @@ -1,23 +1,6 @@ { "additionalProperties": false, "properties": { - "attachments": { - "additionalProperties": { - "items": { - "properties": { - "source": { - "type": "string" - }, - "target": { - "type": "string" - } - }, - "type": "object" - }, - "type": "array" - }, - "type": "object" - }, "directory": { "type": "string" }, diff --git a/setup.py b/setup.py index afb4730..41b12c0 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ include_package_data=True, package_data={"": ["charidotella/assets/*"]}, install_requires=[ + "aedat", "colourtime", "coolname", "event_stream",