Skip to content

Commit

Permalink
Add docstrings and sort imports
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Oct 3, 2023
1 parent 6e000d1 commit ecfd924
Showing 1 changed file with 183 additions and 22 deletions.
205 changes: 183 additions & 22 deletions src/spikegadgets_to_nwb/convert_position.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import logging
import os
import re
import subprocess
from pathlib import Path
from xml.etree import ElementTree

Expand All @@ -11,17 +12,38 @@
from pynwb.image import ImageSeries
from scipy.ndimage import label
from scipy.stats import linregress
import subprocess

from spikegadgets_to_nwb.convert_rec_header import detect_ptp_from_header

NANOSECONDS_PER_SECOND = 1e9


def parse_dtype(fieldstr: str) -> np.dtype:
"""Parses last fields parameter (<time uint32><...>) as a single string
Assumes it is formatted as <name number * type> or <name type>
Returns: np.dtype
"""
Parses the last fields parameter (<time uint32><...>) as a single string.
Assumes it is formatted as <name number * type> or <name type>. Returns a numpy dtype object.
Parameters
----------
fieldstr : str
The string to parse.
Returns
-------
np.dtype
The numpy dtype object.
Raises
------
AttributeError
If the field type is not valid.
Examples
--------
>>> fieldstr = '<time uint32><x float32><y float32><z float32>'
>>> parse_dtype(fieldstr)
dtype([('time', '<u4'), ('x', '<f4'), ('y', '<f4'), ('z', '<f4')])
"""
# Returns np.dtype from field string
sep = " ".join(
Expand Down Expand Up @@ -54,15 +76,23 @@ def parse_dtype(fieldstr: str) -> np.dtype:


def read_trodes_datafile(filename: Path) -> dict:
"""Read trodes binary.
"""
Read trodes binary.
Parameters
----------
filename : str
filename : Path
Path to the trodes binary file.
Returns
-------
data_file : dict
dict
A dictionary containing the settings and data from the trodes binary file.
Raises
------
Exception
If the settings format is not supported.
"""
with open(filename, "rb") as file:
Expand Down Expand Up @@ -91,7 +121,19 @@ def read_trodes_datafile(filename: Path) -> dict:


def get_framerate(timestamps: np.ndarray) -> float:
"""Frames per second"""
"""
Calculates the framerate of a video based on the timestamps of each frame.
Parameters
----------
timestamps : np.ndarray
An array of timestamps for each frame in the video.
Returns
-------
frame_rate: float
The framerate of the video in frames per second.
"""
timestamps = np.asarray(timestamps)
return NANOSECONDS_PER_SECOND / np.median(np.diff(timestamps))

Expand All @@ -102,19 +144,24 @@ def find_acquisition_timing_pause(
max_duration: float = 1.0,
n_search: int = 100,
) -> float:
"""Landmark timing 'gap' (0.5 s pause in video stream) parameters
"""
Find the midpoint time of a timing pause in the video stream.
Parameters
----------
timestamps : int64
min_duration : minimum duratino of gap (in seconds)
max_duration : maximum duratino of gap (in seconds)
n_search : search only the first `n_search` entries
timestamps : np.ndarray
An array of timestamps for each frame in the video.
min_duration : float, optional
The minimum duration of the pause in seconds, by default 0.4.
max_duration : float, optional
The maximum duration of the pause in seconds, by default 1.0.
n_search : int, optional
The number of frames to search for the pause, by default 100.
Returns
-------
pause_mid_time
Midpoint time of timing pause
pause_mid_time : float
The midpoint time of the timing pause.
"""
timestamps = np.asarray(timestamps)
Expand All @@ -135,7 +182,22 @@ def find_acquisition_timing_pause(
def find_large_frame_jumps(
frame_count: np.ndarray, min_frame_jump: int = 15
) -> np.ndarray:
"""Want to avoid regressing over large frame count skips"""
"""
Find large frame jumps in the video.
Parameters
----------
frame_count : np.ndarray
An array of frame counts for each frame in the video.
min_frame_jump : int, optional
The minimum number of frames to consider a jump as large, by default 15.
Returns
-------
np.ndarray
A boolean array indicating whether each frame has a large jump.
"""
logger = logging.getLogger("convert")
frame_count = np.asarray(frame_count)

Expand All @@ -147,14 +209,48 @@ def find_large_frame_jumps(


def detect_repeat_timestamps(timestamps: np.ndarray) -> np.ndarray:
"""
Detects repeated timestamps in an array of timestamps.
Parameters
----------
timestamps : np.ndarray
Array of timestamps.
Returns
-------
np.ndarray
Boolean array indicating whether each timestamp is repeated.
"""
return np.insert(timestamps[:-1] >= timestamps[1:], 0, False)


import numpy as np


def detect_trodes_time_repeats_or_frame_jumps(
trodes_time: np.ndarray, frame_count: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""If a trodes time index repeats, then the Trodes clock has frozen
due to headstage disconnects."""
"""
Detects if a Trodes time index repeats, indicating that the Trodes clock has frozen
due to headstage disconnects. Also detects large frame jumps.
Parameters
----------
trodes_time : np.ndarray
Array of Trodes time indices.
frame_count : np.ndarray
Array of frame counts.
Returns
-------
tuple[np.ndarray, np.ndarray]
A tuple containing two arrays:
- non_repeat_timestamp_labels : np.ndarray
Array of labels for non-repeating timestamps.
- non_repeat_timestamp_labels_id : np.ndarray
Array of unique IDs for non-repeating timestamps.
"""
logger = logging.getLogger("convert")

trodes_time = np.asarray(trodes_time)
Expand Down Expand Up @@ -232,8 +328,37 @@ def remove_acquisition_timing_pause_non_ptp(
frame_count: np.ndarray,
camera_systime: np.ndarray,
is_valid_camera_time: np.ndarray,
pause_mid_time: np.ndarray,
pause_mid_time: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Remove acquisition timing pause non-PTP.
Parameters
----------
dio_systime : np.ndarray
Digital I/O system time.
frame_count : np.ndarray
Frame count.
camera_systime : np.ndarray
Camera system time.
is_valid_camera_time : np.ndarray
Boolean array indicating whether the camera time is valid.
pause_mid_time : float
Midpoint time of the pause.
Returns
-------
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
A tuple containing the following arrays:
- dio_systime : np.ndarray
Digital I/O system time after removing the pause.
- frame_count : np.ndarray
Frame count after removing the pause.
- is_valid_camera_time : np.ndarray
Boolean array indicating whether the camera time is valid after removing the pause.
- camera_systime : np.ndarray
Camera system time after removing the pause.
"""
dio_systime = dio_systime[dio_systime > pause_mid_time]
frame_count = frame_count[is_valid_camera_time][camera_systime > pause_mid_time]
is_valid_camera_time[is_valid_camera_time] = camera_systime > pause_mid_time
Expand Down Expand Up @@ -416,6 +541,24 @@ def add_position(
video_directory: str,
convert_video: bool = False,
):
"""
Add position data to an NWBFile.
Parameters
----------
nwb_file : NWBFile
The NWBFile to add the position data to.
metadata : dict
Metadata about the experiment.
session_df : pd.DataFrame
A DataFrame containing information about the session.
rec_header : ElementTree.ElementTree
The recording header.
video_directory : str
The directory containing the video files.
convert_video : bool, optional
Whether to convert the video files to NWB format, by default False.
"""
logger = logging.getLogger("convert")

LED_POS_NAMES = [
Expand Down Expand Up @@ -581,8 +724,26 @@ def add_position(
nwb_file.processing["video_files"].add(video)


def convert_h264_to_mp4(file: str):
"""Converts h264 file to mp4 file using ffmpeg"""
def convert_h264_to_mp4(file: str) -> str:
"""
Converts h264 file to mp4 file using ffmpeg.
Parameters
----------
file : str
The path to the input h264 file.
Returns
-------
str
The path to the output mp4 file.
Raises
------
subprocess.CalledProcessError
If the ffmpeg command fails.
"""
new_file_name = file.replace(".h264", ".mp4")
logger = logging.getLogger("convert")
if os.path.exists(new_file_name):
Expand Down

0 comments on commit ecfd924

Please sign in to comment.