Skip to content

Commit

Permalink
Merge branch 'pi-classifier' of https://github.com/TheCacophonyProjec…
Browse files Browse the repository at this point in the history
…t/classifier-pipeline into pi-classifier
  • Loading branch information
gferraro committed Sep 30, 2024
2 parents d6daa4b + b4d53a7 commit def00bd
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 72 deletions.
8 changes: 7 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,10 @@ deploy:
script: bash ./pypi_push.sh
on:
branch: pi-classifier
skip_cleanup: 'true'
skip_cleanup: 'true'
- provider: script
script: bash ./pypi_push.sh
on:
branch: add-event
skip_cleanup: 'true'

2 changes: 1 addition & 1 deletion pirequirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ tables==3.8.0
h5py==3.8.0
pyyaml==6.0
pillow==10.0.1
attrs==19.2.0
attrs==24.2.0
filelock==3.0.12
Astral==1.10.1
timezonefinder==4.1.0
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "hatchling.build"

[project]
name = "classifier_pipeline"
version = "0.0.16"
version = "0.0.17"
authors = [
{ name="Giampaolo Ferraro", email="[email protected]" },
]
Expand All @@ -35,7 +35,7 @@ dependencies = [
"h5py==3.8.0",
"pyyaml==6.0",
"pillow==10.0.1",
"attrs==19.2.0",
"attrs==24.2.0",
"filelock==3.0.12",
"Astral==1.10.1",
"timezonefinder==4.1.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ tables~=3.8.0
h5py~=3.9.0
pyyaml>=4.2b1
pillow~=10.0.1
attrs~=19.1
attrs~=24.2.0
filelock~=3.0.12
Astral~=1.10.1
timezonefinder~=6.2.0
Expand Down
92 changes: 62 additions & 30 deletions src/classify/trackprediction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import attr
import logging
import numpy as np
import time
from attrs import define, field

# uniform prior stats start with uniform distribution. This is the safest bet, but means that
# it takes a while to make predictions. When off the first prediction is used instead causing
Expand Down Expand Up @@ -55,6 +57,29 @@ def classify_time(self):
return np.sum(classify_time)


@define
class Prediction:
prediction = field()
smoothed_prediction = field()
frames = field()
predicted_at_frame = field()
mass = field()
predicted_time = field(init=False)

def __attrs_post_init__(self):
self.predicted_time = time.time()

def get_metadata(self):
meta = attr.asdict(self)
meta["smoothed_prediction"] = np.uint32(np.round(self.smoothed_prediction))
meta["prediction"] = np.uint8(np.round(100 * self.prediction))
return meta

def clarity(self):
best = np.argsort(self.prediction)
return self.prediction[best[-1]] - self.prediction[best[-2]]


class TrackPrediction:
"""
Class to hold the information about the predicted class of a track.
Expand All @@ -69,9 +94,7 @@ def __init__(self, track_id, labels, keep_all=True, start_frame=None):
fp_index = None
self.track_id = track_id
self.predictions = []
self.prediction_frames = []
self.fp_index = fp_index
self.smoothed_predictions = []
self.class_best_score = None
self.start_frame = start_frame

Expand All @@ -87,15 +110,20 @@ def classified_clip(
self, predictions, smoothed_predictions, prediction_frames, top_score=None
):
self.num_frames_classified = len(predictions)
if smoothed_predictions is None:
self.smoothed_predictions = predictions
else:
self.smoothed_predictions = smoothed_predictions
self.predictions = predictions
self.prediction_frames = prediction_frames
for prediction, smoothed_prediction, frames in zip(
predictions, smoothed_predictions, prediction_frames
):
prediction = Prediction(
prediction,
smoothed_prediction,
frames,
np.amax(frames),
None,
)
self.predictions.append(prediction)

if self.num_frames_classified > 0:
self.class_best_score = np.sum(self.smoothed_predictions, axis=0)
self.class_best_score = np.sum(smoothed_predictions, axis=0)
# normalize so it sums to 1
if top_score is None:
self.class_best_score = self.class_best_score / np.sum(
Expand All @@ -112,36 +140,45 @@ def normalize_score(self):
self.class_best_score
)

def classified_frames(self, frame_numbers, prediction, mass):
def classified_frames(self, frame_numbers, predictions, mass):
self.num_frames_classified += len(frame_numbers)
smoothed_prediction = prediction**2 * mass
self.last_frame_classified = np.max(frame_numbers)
smoothed_prediction = predictions**2 * mass
prediction = Prediction(
predictions,
smoothed_prediction,
frame_numbers,
self.last_frame_classified,
mass,
)
if self.keep_all:
self.prediction_frames.append(frame_numbers)
self.predictions.append(prediction)
self.smoothed_predictions.append(smoothed_prediction)

else:
self.prediction_frames = [frame_numbers]
self.predictions = [prediction]
self.smoothed_predictions = [smoothed_prediction]

if self.class_best_score is None:
self.class_best_score = smoothed_prediction.copy()
else:
self.class_best_score += smoothed_prediction

def classified_frame(self, frame_number, prediction, mass):
def classified_frame(self, frame_number, predictions, mass):
self.prediction_frames.append([frame_number])
self.last_frame_classified = frame_number
self.num_frames_classified += 1
self.masses.append(mass)
smoothed_prediction = prediction * prediction * mass

prediction = Prediction(
predictions,
smoothed_prediction,
frame_number,
self.last_frame_classified,
mass,
)
if self.keep_all:
self.predictions.append(prediction)
self.smoothed_predictions.append(smoothed_prediction)

else:
self.predictions = [prediction]
self.smoothed_predictions = [smoothed_prediction]

if self.class_best_score is None:
self.class_best_score = smoothed_prediction
Expand Down Expand Up @@ -257,9 +294,7 @@ def max_score(self):
return float(np.amax(self.class_best_score))

def clarity_at(self, frame):
pred = self.predictions[frame]
best = np.argsort(pred)
return pred[best[-1]] - pred[best[-2]]
return self.predictions[frame].clarity

@property
def clarity(self):
Expand Down Expand Up @@ -359,13 +394,10 @@ def get_metadata(self):
)
prediction_meta["clarity"] = round(self.clarity, 3) if self.clarity else 0
prediction_meta["all_class_confidences"] = {}
if self.prediction_frames is not None:
prediction_meta["prediction_frames"] = self.prediction_frames

# for ease always multiply by 100, depending on smoothing applied values might be large
prediction_meta["predictions"] = np.uint32(
np.round(100 * self.smoothed_predictions)
)
preds = []
for p in self.predictions:
preds.append(p.get_metadata())
prediction_meta["predictions"] = preds
if self.class_best_score is not None:
for i, value in enumerate(self.class_best_score):
label = self.labels[i]
Expand Down
35 changes: 32 additions & 3 deletions src/dbustest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,42 @@
import threading
from datetime import datetime

labels = []

def catchall_tracking_signals_handler(what, confidence, region, tracking):

def catchall_tracking_signals_handler(
prediction,
what,
confidence,
region,
frame,
mass,
blank,
tracking,
last_prediction_frame,
):
print(
"Received a trackng signal and it says " + what,
confidence,
"% at ",
region,
" tracking?",
tracking,
"prediction",
prediction,
"frame",
frame,
"mass",
mass,
"blank",
blank,
"last prediction frame",
last_prediction_frame,
)
index = 0
for x in prediction:
print("For ", labels[index], " have confidence ", int(x), "%")
index += 1


def catchall_rec_signals_handler(dt, is_recording):
Expand All @@ -49,13 +75,16 @@ def quit(self):
self.loop.quit()

def run_server(self):
dbus_object = None
try:
bus = dbus.SystemBus()
object = bus.get_object(DBUS_NAME, DBUS_PATH)
dbus_object = bus.get_object(DBUS_NAME, DBUS_PATH)
except dbus.exceptions.DBusException as e:
print("Failed to initialize D-Bus object: '%s'" % str(e))
sys.exit(2)

global labels
labels = dbus_object.ClassificationLabels()
print("Labels are ", labels)
bus.add_signal_receiver(
self.callback,
dbus_interface=DBUS_NAME,
Expand Down
27 changes: 15 additions & 12 deletions src/piclassifier/piclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self.constant_recorder = None
self._output_dir = thermal_config.recorder.output_dir
self.headers = headers
super().__init__()
self.classifier = None
self.frame_num = 0
self.clip = None
self.enable_per_track_information = False
Expand Down Expand Up @@ -277,6 +277,7 @@ def __init__(
except ValueError:
self.fp_index = None
self.startup_classifier()
super().__init__()

def new_clip(self, preview_frames):
self.clip = Clip(
Expand Down Expand Up @@ -438,20 +439,20 @@ def identify_last_frame(self):
self.tracking = track
track_prediction.normalize_score()
self.service.tracking(
track_prediction.predicted_tag(),
track_prediction.max_score,
track.bounds_history[-1].to_ltrb(),
track_prediction.class_best_score,
track.bounds_history[-1],
True,
track_prediction.last_frame_classified,
)
elif track_prediction.tracking:
track_prediction.tracking = False
self.tracking = None
track_prediction.normalize_score()
self.service.tracking(
track_prediction.predicted_tag(),
track_prediction.max_score,
track.bounds_history[-1].to_ltrb(),
track_prediction.class_best_score,
track.bounds_history[-1],
False,
track_prediction.last_frame_classified,
)

new_prediction = True
Expand Down Expand Up @@ -584,17 +585,19 @@ def process_frame(self, lepton_frame, received_at):
tracking = self.tracking in self.clip.active_tracks
score = 0
prediction = ""
all_scores = None
last_prediction = 0
if self.classify:
track_prediction = self.predictions.prediction_for(
self.tracking.get_id()
)
prediction = track_prediction.predicted_tag()
score = track_prediction.max_score
all_scores = track_prediction.class_best_score
last_prediction = track_prediction.last_frame_classified
self.service.tracking(
prediction,
score,
self.tracking.bounds_history[-1].to_ltrb(),
all_scores,
self.tracking.bounds_history[-1],
tracking,
last_prediction,
)

if not tracking:
Expand Down
Loading

0 comments on commit def00bd

Please sign in to comment.