Skip to content

Commit

Permalink
Training updates (#215)
Browse files Browse the repository at this point in the history
* tweak training

* use rust binding

* repeat frames at random rather than only last frame

* add multiple segment type option

* try random section

* add min path

* fix to small tracks

* added mask segment type as default

* fix segment type load for old meta

---------

Co-authored-by: gferraro <[email protected]>
  • Loading branch information
gferraro and gferraro authored Nov 13, 2024
1 parent d299543 commit ab8cd07
Show file tree
Hide file tree
Showing 34 changed files with 1,341 additions and 785 deletions.
4 changes: 2 additions & 2 deletions pirequirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ scipy==1.9.3
python-dateutil
scikit-learn==1.1.3
tables==3.8.0
h5py==3.8.0
h5py==3.10.0
pyyaml==6.0
pillow==10.0.1
attrs==24.2.0
Expand All @@ -26,4 +26,4 @@ dbus-python==1.3.2
importlib_resources==5.10.2
opencv-python==4.8.0.76
inotify_simple==1.3.5
python-cptv==0.0.3
python-cptv==0.0.5
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies = [
"importlib_resources==5.10.2",
"opencv-python==4.8.0.76",
"inotify_simple==1.3.5",
"python-cptv==0.0.3"
"python-cptv==0.0.5"
]

[project.scripts]
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tensorflow~=2.14.0
tensorflow~=2.17.0
matplotlib~=3.0
pytz
cptv~=1.5.4
Expand All @@ -7,7 +7,7 @@ scipy
python-dateutil
scikit-learn
tables~=3.8.0
h5py~=3.9.0
h5py~=3.10.0
pyyaml>=4.2b1
pillow~=10.0.1
attrs~=24.2.0
Expand All @@ -26,4 +26,4 @@ joblib
#requires sudo apt-get install libopencv-dev used for ir track extraction on server
# pybgs==3.2.0.post1 this was used for ir
inotify_simple==1.3.5
python-cptv==0.0.3
python-cptv==0.0.5
5 changes: 5 additions & 0 deletions src/autobuild-cron
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#run the first of every month
SHELL=/bin/bash
BASH_ENV=~/.bashrc_conda

* * 1 * * cp ( cd /home/cp/cacophony/classifier-pipeline/src && ./autobuild.sh /data2/cptv-files) 2>&1 | logger --tag classifier-auto-build
18 changes: 11 additions & 7 deletions src/autobuild.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#!/bin/sh

#!/bin/bash
set -e
set -x
conda init bash
conda activate tf
config="classifier-thermal.yaml"
month_ago=$(python3 rebuildDate.py -c $config)
echo "Saving into $1"
month_ago=$(python3 rebuildDate.py $1)
echo $month_ago
python3 ../../cptv-download/cptv-download.py -l 0 -i 'poor tracking' -i 'untagged' -i 'part' -i 'untagged-by-humans' -i 'unknown' -i 'unidentified' -m 'human-tagged' --start-date "$month_ago" "../clips$month_ago" [email protected] userpassword
echo "Downloading into ../clips$month_ago"
python3 load.py -target "../clips$month_ago" -c $config
python3 build.py -c $config
python3 ../../cptv-download/cptv-download.py -l 0 -i 'poor tracking' -i 'untagged' -i 'part' -i 'untagged-by-humans' -i 'unknown' -i 'unidentified' -m 'human-tagged' --start-date "$month_ago" "$1" [email protected] userpassword
echo "Downloading into $1"
python3 build.py -c $config --ext ".cptv" $1
dt=$(date '+%d%m%Y-%H%M%S');
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/cp/miniconda3/envs/tf/lib/
python3 train.py -c $config $dt
70 changes: 62 additions & 8 deletions src/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from ml_tools.tfwriter import create_tf_records
from ml_tools.irwriter import save_data as save_ir_data
from ml_tools.thermalwriter import save_data as save_thermal_data


from ml_tools.tools import CustomJSONEncoder
import attrs
import numpy as np

from pathlib import Path
Expand Down Expand Up @@ -57,7 +57,7 @@ def parse_args():
)
parser.add_argument("--split-file", help="Json file defining a split")
parser.add_argument(
"--ext", default=".hdf5", help="Extension of files to load .mp4,.cptv,.hdf5"
"--ext", default=".cptv", help="Extension of files to load .mp4,.cptv,.hdf5"
)

parser.add_argument("-c", "--config-file", help="Path to config file to use")
Expand Down Expand Up @@ -571,7 +571,7 @@ def add_samples(
dataset.add_samples(samples)


def validate_datasets(datasets, test_bins, date):
def validate_datasets(datasets, test_bins, after_date):
# check that clips are only in one dataset
# that only test set has clips after date
# that test set is the only dataset with test_clips
Expand All @@ -580,7 +580,7 @@ def validate_datasets(datasets, test_bins, date):
# for track in dataset.tracks:
# assert track.start_time < date

for i, dataset in enumerate(datasets):
for i, dataset in enumerate(datasets[:2]):
dont_check = set(
[
sample.bin_id
Expand Down Expand Up @@ -608,6 +608,15 @@ def validate_datasets(datasets, test_bins, date):
if sample.label in split_by_clip
]
)
if other.name == "test" and after_date is not None:
dont_check_other = set(
[
sample.bin_id
for sample in other.samples_by_id.values()
if sample.rec_time > after_date
]
)
dont_check = dont_check | dont_check_other
other_bins = set([sample.bin_id for sample in other.samples_by_id.values()])
other_bins = other_bins - dont_check
other_clips = set(
Expand Down Expand Up @@ -717,6 +726,42 @@ def dump_split_ids(datasets, out_file="datasplit.json"):
return


def rough_balance(datasets):
dev_threshold = 2000
logging.info("Roughly Balancing")
print_counts(*datasets)

for dataset in datasets:
lbl_counts = {}
counts = []
for label in dataset.labels:
label_count = len(dataset.samples_by_label.get(label, []))
lbl_counts[label] = label_count
counts.append(label_count)
counts.sort()
std_dev = np.std(counts)
logging.info("Counts are %s std dev %s", counts, std_dev)
if std_dev < dev_threshold or len(counts) <= 1:
logging.info("Not balancing")
continue
if len(counts) <= 2:
cap_at = counts[-2]
elif len(counts) < 7:
cap_at = counts[-2]
else:
cap_at = counts[-2]
logging.info("Capping dataset %s at %s", dataset.name, cap_at)
for lbl, count in lbl_counts.items():
if count <= cap_at:
continue
samples_to_remove = count - cap_at
by_labels = dataset.samples_by_label[lbl]
np.random.shuffle(by_labels)
for i in range(samples_to_remove):
dataset.remove_sample(by_labels[i])
print_counts(*datasets)


def main():
init_logging()
args = parse_args()
Expand Down Expand Up @@ -782,6 +827,8 @@ def main():
print("Splitting data set into train / validation")

datasets = split_randomly(master_dataset, config, args.date, test_clips)

rough_balance(datasets)
validate_datasets(datasets, test_clips, args.date)
dump_split_ids(datasets, record_dir / "datasplit.json")

Expand Down Expand Up @@ -849,15 +896,20 @@ def main():
{
"segment_frame_spacing": master_dataset.segment_spacing * 9,
"segment_width": master_dataset.segment_length,
"segment_type": master_dataset.segment_type,
"segment_types": master_dataset.segment_types,
"segment_min_avg_mass": master_dataset.segment_min_avg_mass,
"max_segments": master_dataset.max_segments,
"dont_filter_segment": True,
"skip_ffc": True,
"tag_precedence": config.load.tag_precedence,
"tag_precedence": config.build.tag_precedence,
"min_mass": master_dataset.min_frame_mass,
"thermal_diff_norm": config.build.thermal_diff_norm,
"filter_by_lq": master_dataset.filter_by_lq,
"max_frames": master_dataset.max_frames,
}
)
# dont filter the test set,
extra_args["filter_by_fp"] = dataset.name != "test"
create_tf_records(
dataset,
dir,
Expand All @@ -879,10 +931,12 @@ def main():
"type": config.train.type,
"counts": dataset_counts,
"by_label": False,
"config": attrs.asdict(config),
"segment_types": master_dataset.segment_types,
}

with open(meta_filename, "w") as f:
json.dump(meta_data, f, indent=4)
json.dump(meta_data, f, indent=4, cls=CustomJSONEncoder)


if __name__ == "__main__":
Expand Down
8 changes: 1 addition & 7 deletions src/classify/clipclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,8 @@
from track.clip import Clip
from track.cliptrackextractor import ClipTrackExtractor, is_affected_by_ffc
from ml_tools import tools
from ml_tools.kerasmodel import KerasModel
from track.irtrackextractor import IRTrackExtractor
from ml_tools.previewer import Previewer
from track.track import Track

from cptv import CPTVReader
from datetime import datetime
from ml_tools.interpreter import get_interpreter


Expand Down Expand Up @@ -134,7 +129,7 @@ def process_file(self, filename, cache=None, reuse_frames=None):
clip = Clip(track_extractor.config, filename)
clip.load_metadata(
meta_data,
self.config.load.tag_precedence,
self.config.build.tag_precedence,
)
track_extractor.parse_clip(clip)

Expand Down Expand Up @@ -250,7 +245,6 @@ def save_metadata(
prediction = predictions.prediction_for(track.get_id())
if prediction is None:
continue

prediction_meta = prediction.get_metadata()
prediction_meta["model_id"] = model_id
prediction_info.append(prediction_meta)
Expand Down
19 changes: 13 additions & 6 deletions src/classify/trackprediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def clarity(self):
best = np.argsort(self.prediction)
return self.prediction[best[-1]] - self.prediction[best[-2]]

def __str__(self):
return f"{self.frames} conf: {np.round(100*self.prediction)}"


class TrackPrediction:
"""
Expand Down Expand Up @@ -107,18 +110,23 @@ def __init__(self, track_id, labels, keep_all=True, start_frame=None):
self.masses = []

def classified_clip(
self, predictions, smoothed_predictions, prediction_frames, top_score=None
self,
predictions,
smoothed_predictions,
prediction_frames,
masses,
top_score=None,
):
self.num_frames_classified = len(predictions)
for prediction, smoothed_prediction, frames in zip(
predictions, smoothed_predictions, prediction_frames
for prediction, smoothed_prediction, frames, mass in zip(
predictions, smoothed_predictions, prediction_frames, masses
):
prediction = Prediction(
prediction,
smoothed_prediction,
frames,
np.amax(frames),
None,
mass,
)
self.predictions.append(prediction)

Expand Down Expand Up @@ -162,11 +170,10 @@ def classified_frames(self, frame_numbers, predictions, mass):
self.class_best_score += smoothed_prediction

def classified_frame(self, frame_number, predictions, mass):
self.prediction_frames.append([frame_number])
self.last_frame_classified = frame_number
self.num_frames_classified += 1
self.masses.append(mass)
smoothed_prediction = prediction * prediction * mass
smoothed_prediction = predictions**2 * mass

prediction = Prediction(
predictions,
Expand Down
54 changes: 53 additions & 1 deletion src/config/buildconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
from os import path
from .defaultconfig import DefaultConfig
from ml_tools.rectangle import Rectangle


@attr.s
Expand All @@ -34,6 +35,45 @@ class BuildConfig(DefaultConfig):
min_frame_mass = attr.ib()
filter_by_lq = attr.ib()
max_segments = attr.ib()
thermal_diff_norm = attr.ib()
tag_precedence = attr.ib()
excluded_tags = attr.ib()
country = attr.ib()
use_segments = attr.ib()
max_frames = attr.ib()

EXCLUDED_TAGS = ["poor tracking", "part", "untagged", "unidentified"]
NO_MIN_FRAMES = ["stoat", "mustelid", "weasel", "ferret"]
# country bounding boxs
COUNTRY_LOCATIONS = {
"AU": Rectangle.from_ltrb(
113.338953078, -10.6681857235, 153.569469029, -43.6345972634
),
"NZ": Rectangle.from_ltrb(
166.509144322, -34.4506617165, 178.517093541, -46.641235447
),
}

DEFAULT_GROUPS = {
0: [
"bird",
"false-positive",
"hedgehog",
"possum",
"rodent",
"mustelid",
"cat",
"kiwi",
"dog",
"leporidae",
"human",
"insect",
"pest",
],
1: ["unidentified", "other"],
2: ["part", "bad track"],
3: ["default"],
}

@classmethod
def load(cls, build):
Expand All @@ -46,6 +86,12 @@ def load(cls, build):
min_frame_mass=build["min_frame_mass"],
filter_by_lq=build["filter_by_lq"],
max_segments=build["max_segments"],
thermal_diff_norm=build["thermal_diff_norm"],
tag_precedence=build["tag_precedence"],
excluded_tags=build["excluded_tags"],
country=build["country"],
use_segments=build["use_segments"],
max_frames=build["max_frames"],
)

@classmethod
Expand All @@ -58,7 +104,13 @@ def get_defaults(cls):
segment_min_avg_mass=10,
min_frame_mass=10,
filter_by_lq=False,
max_segments=5,
max_segments=3,
thermal_diff_norm=False,
tag_precedence=BuildConfig.DEFAULT_GROUPS,
excluded_tags=BuildConfig.EXCLUDED_TAGS,
country=None,
use_segments=True,
max_frames=75,
)

def validate(self):
Expand Down
Loading

0 comments on commit ab8cd07

Please sign in to comment.