Skip to content

Commit

Permalink
Add support for AEDAT4 files
Browse files Browse the repository at this point in the history
  • Loading branch information
aMarcireau committed Dec 9, 2023
1 parent 75a72a7 commit 0638f55
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 103 deletions.
88 changes: 12 additions & 76 deletions charidotella/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@
import argparse
import copy
import functools
import importlib.resources
import json
import pathlib
import re
import shutil
import sys
import tempfile
import typing
import uuid

import event_stream
import jsonschema
import toml

from . import animals as animals
from . import filters as filters
from . import formats as formats
from . import tasks as tasks
from . import utilities as utilities
from .version import __version__ as __version__
Expand Down Expand Up @@ -79,8 +77,9 @@ def main():
init_parser.add_argument(
"--glob",
"-g",
default="recordings/*.es",
help="Glob pattern used to search for Event Stream files",
nargs="*",
default=["recordings/*.es", "recordings/*.aedat4"],
help="Glob pattern used to search for Event Stream and AEDAT4 files",
)
init_parser.add_argument(
"--configuration",
Expand Down Expand Up @@ -332,14 +331,14 @@ def run_generators(configuration: dict[str, typing.Any]):
utilities.error(
f'"{configuration_path}" already exists (use --force to override it)'
)
paths = [
path.resolve()
for path in pathlib.Path(".").glob(args.glob)
if path.is_file() and path.suffix == ".es"
]
paths = []
for glob in args.glob:
for path in pathlib.Path(".").glob(glob):
if path.is_file():
paths.append(path.resolve())
paths.sort(key=lambda path: (path.stem, path.parent))
if len(paths) == 0:
utilities.error(f'no .es files match "{args.glob}"')
utilities.error(f'no files match "{args.glob}"')
if args.new_names:
names = animals.generate_names(len(paths))
else:
Expand All @@ -349,23 +348,9 @@ def run_generators(configuration: dict[str, typing.Any]):
for path in paths:
if path.stem in name_to_path:
utilities.error(
f'two files have the same name ("{name_to_path[path.stem]}" and "{path}"), rename one or do *not* use the flag --preserve-name'
f'two files have the same name ("{name_to_path[path.stem]}" and "{path}"), rename one or do use the flag --new-names'
)
name_to_path[path.stem] = path
attachments: dict[str, list[dict[str, str]]] = {}
for name, path in zip(names, paths):
for sibling in path.parent.iterdir():
if sibling != path and sibling.stem == path.stem:
if not name in attachments:
attachments[name] = []
attachments[name].append(
{
"source": str(
sibling.relative_to(configuration_path.parent)
),
"target": f"{name}{sibling.suffix}",
}
)
jobs = []
for index, (name, path) in enumerate(zip(names, paths)):
utilities.info(
Expand All @@ -374,7 +359,7 @@ def run_generators(configuration: dict[str, typing.Any]):
)
begin: typing.Optional[int] = None
end: typing.Optional[int] = None
with event_stream.Decoder(path) as decoder:
with formats.Decoder(path) as decoder:
for packet in decoder:
if begin is None:
begin = int(packet["t"][0])
Expand Down Expand Up @@ -737,14 +722,6 @@ def run_generators(configuration: dict[str, typing.Any]):
configuration_file,
encoder=Encoder(),
)
configuration_file.write(
"\n\n# attachments are copied in target directories, algonside generated files \n"
)
toml.dump(
{"attachments": attachments},
configuration_file,
encoder=Encoder(),
)
with open(
utilities.with_suffix(configuration_path, ".part"),
"r",
Expand All @@ -768,8 +745,6 @@ def run_generators(configuration: dict[str, typing.Any]):
configuration["tasks"] = {}
if not "jobs" in configuration:
configuration["jobs"] = []
if not "attachments" in configuration:
configuration["attachments"] = {}
run_generators(configuration)
jsonschema.validate(configuration, configuration_schema())
if len(configuration["filters"]) == 0:
Expand Down Expand Up @@ -808,12 +783,6 @@ def run_generators(configuration: dict[str, typing.Any]):
utilities.error(
f"parsing \"end\" ({job['end']}) in \"{job['name']}\" failed ({exception})"
)
for name, attachment in configuration["attachments"].items():
targets = [file["target"] for file in attachment]
if len(targets) != len(set(targets)):
utilities.error(
f'two or more attachments share the same target in "{name}"'
)
configuration["filters"] = {
name: {
"type": filter["type"],
Expand Down Expand Up @@ -872,37 +841,6 @@ def run_generators(configuration: dict[str, typing.Any]):
parameters["filters"] = {}
if not "tasks" in parameters:
parameters["tasks"] = {}
if not "attachments" in parameters:
parameters["attachments"] = {}
if job["name"] in configuration["attachments"]:
for attachment in configuration["attachments"][job["name"]]:
if (
not args.force
and attachment["target"] in parameters["attachments"]
and (directory / name / attachment["target"]).is_file()
):
utilities.info(
"⏭ ",
f"skip copy {pathlib.Path(configuration_path.parent) / attachment['source']}{attachment['target']}",
)
else:
utilities.info(
"🗃 ",
f"copy {pathlib.Path(configuration_path.parent) / attachment['source']}{attachment['target']}",
)
shutil.copy2(
pathlib.Path(configuration_path.parent)
/ attachment["source"],
utilities.with_suffix(
directory / name / attachment["target"], ".part"
),
)
utilities.with_suffix(
directory / name / attachment["target"], ".part"
).replace(directory / name / attachment["target"])
parameters["attachments"][attachment["target"]] = attachment[
"source"
]
if len(job["filters"]) == 1:
filter_name = job["filters"][0]
filter = configuration["filters"][filter_name]
Expand Down Expand Up @@ -1017,8 +955,6 @@ def run_generators(configuration: dict[str, typing.Any]):
configuration["tasks"] = {}
if not "jobs" in configuration:
configuration["jobs"] = []
if not "attachments" in configuration:
configuration["attachments"] = []
run_generators(configuration)
jsonschema.validate(configuration, configuration_schema())
with open(pathlib.Path(args.output), "w", encoding="utf-8") as output_file:
Expand Down
4 changes: 3 additions & 1 deletion charidotella/filters/arbiter_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import event_stream
import numpy

from .. import formats


def consume_packets(
events_packets: list[numpy.ndarray],
Expand Down Expand Up @@ -53,7 +55,7 @@ def apply(
parameters: dict[str, typing.Any],
) -> None:
events_packets = []
with event_stream.Decoder(input) as decoder:
with formats.Decoder(input) as decoder:
with event_stream.Encoder(
output,
"dvs",
Expand Down
4 changes: 3 additions & 1 deletion charidotella/filters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import event_stream
import numpy

from .. import formats


def apply(
input: pathlib.Path,
Expand All @@ -14,7 +16,7 @@ def apply(
end: int,
parameters: dict[str, typing.Any],
) -> None:
with event_stream.Decoder(input) as decoder:
with formats.Decoder(input) as decoder:
with event_stream.Encoder(
output,
"dvs",
Expand Down
8 changes: 6 additions & 2 deletions charidotella/filters/hot_pixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy
import scipy.ndimage

from .. import formats


def apply(
input: pathlib.Path,
Expand All @@ -15,7 +17,8 @@ def apply(
end: int,
parameters: dict[str, typing.Any],
) -> None:
with event_stream.Decoder(input) as decoder:
count = None
with formats.Decoder(input) as decoder:
count = numpy.zeros((decoder.width, decoder.height), dtype=numpy.uint64)
for packet in decoder:
if packet["t"][-1] < begin:
Expand All @@ -33,6 +36,7 @@ def apply(
else:
events = packet
numpy.add.at(count, (events["x"], events["y"]), 1)
assert count is not None
shifted: list[numpy.ndarray] = []
for x, y in ((1, 0), (0, 1), (1, 2), (2, 1)):
kernel = numpy.zeros((3, 3))
Expand All @@ -47,7 +51,7 @@ def apply(
)
ratios = numpy.divide(count, numpy.maximum.reduce(shifted) + 1.0)
mask = ratios < parameters["ratio"]
with event_stream.Decoder(input) as decoder:
with formats.Decoder(input) as decoder:
with event_stream.Encoder(
output,
"dvs",
Expand Down
4 changes: 2 additions & 2 deletions charidotella/filters/refractory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import event_stream
import numpy

from .. import utilities
from .. import formats, utilities


def apply(
Expand All @@ -16,7 +16,7 @@ def apply(
end: int,
parameters: dict[str, typing.Any],
) -> None:
with event_stream.Decoder(input) as decoder:
with formats.Decoder(input) as decoder:
refractory = numpy.uint64(utilities.timecode(parameters["refractory"]))
threshold_t = numpy.zeros((decoder.width, decoder.height), dtype=numpy.uint64)
with event_stream.Encoder(
Expand Down
4 changes: 3 additions & 1 deletion charidotella/filters/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import event_stream
import numpy

from .. import formats


def apply(
input: pathlib.Path,
Expand All @@ -14,7 +16,7 @@ def apply(
end: int,
parameters: dict[str, typing.Any],
) -> None:
with event_stream.Decoder(input) as decoder:
with formats.Decoder(input) as decoder:
if parameters["method"] == "flip_left_right":
width, height = decoder.width, decoder.height
method = 0
Expand Down
81 changes: 81 additions & 0 deletions charidotella/formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pathlib
import types
import typing

import aedat
import event_stream
import numpy

TYPE_EVENT_STREAM: int = 0
TYPE_AEDAT: int = 1


class Decoder:
def __init__(self, path: pathlib.Path):
self.width: int
self.height: int
if path.suffix == ".es":
self.type = TYPE_EVENT_STREAM
elif path.suffix == ".aedat4":
self.type = TYPE_AEDAT
else:
with open(path, "rb") as file:
magic = file.read(12)
if magic == b"Event Stream":
self.type = TYPE_EVENT_STREAM
elif magic == b"#!AER-DAT4.0":
self.type = TYPE_AEDAT
else:
raise Exception(f"unsupported file {path}")
self.t0 = None
if self.type == TYPE_EVENT_STREAM:
self.decoder = event_stream.Decoder(path)
assert self.decoder.type == "dvs"
self.width = self.decoder.width
self.height = self.decoder.height
else:
self.decoder = aedat.Decoder(path) # type: ignore
found = False
for stream in self.decoder.id_to_stream().values():
if stream["type"] == "events":
self.width = stream["width"]
self.height = stream["height"]
found = True
break
if not found:
raise Exception(f"the file {path} contains no events")

def __iter__(self):
return self

def __next__(self) -> numpy.ndarray:
assert self.decoder is not None
if self.type == TYPE_EVENT_STREAM:
return self.decoder.__next__()
while True:
packet = self.decoder.__next__()
if "events" in packet:
events = packet["events"]
if len(events) > 0:
if self.t0 is None:
self.t0 = events["t"][0]
events["t"] -= self.t0
events["y"] = self.height - 1 - events["y"]
return events

def __enter__(self) -> "Decoder":
return self

def __exit__(
self,
exception_type: typing.Optional[typing.Type[BaseException]],
value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> bool:
assert self.decoder is not None
if self.type == TYPE_EVENT_STREAM:
result = self.decoder.__exit__(exception_type, value, traceback)
else:
result = False
self.decoder = None
return result
6 changes: 4 additions & 2 deletions charidotella/tasks/colourtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import matplotlib.colors
import PIL.Image

from .. import formats

EXTENSION = ".png"


Expand All @@ -26,13 +28,13 @@ def run(
)
else:
time_mapping = colourtime.generate_linear_time_mapping(begin=begin, end=end)
with event_stream.Decoder(input) as decoder:
with formats.Decoder(input) as decoder:
image = colourtime.convert(
begin=begin,
end=end,
width=decoder.width,
height=decoder.height,
decoder=decoder,
decoder=decoder, # type: ignore
colormap=matplotlib.colormaps[parameters["colormap"]], # type: ignore
time_mapping=time_mapping,
alpha=parameters["alpha"],
Expand Down
2 changes: 1 addition & 1 deletion charidotella/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.1.4"
__version__ = "3.0.0"
Loading

0 comments on commit 0638f55

Please sign in to comment.