-
Notifications
You must be signed in to change notification settings - Fork 0
/
tfl_manager.py
74 lines (61 loc) · 2.82 KB
/
tfl_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from dataclasses import dataclass
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
from tensorflow.keras.models import load_model
from candidates import Candidates
from find_lights import find_tfl_lights
from create_data_set import crop_image
from SFM import calc_TFL_dist, get_foe_rotate
from visualation import visual
class FrameContainer(object):
def __init__(self, img_path, traffic_lights):
self.img = plt.imread(img_path)
self.traffic_light = traffic_lights
self.traffic_lights_3d_location = []
self.EM = []
self.corresponding_ind = []
self.valid = []
#TODO merge the frame_conteiner with the candidates
@dataclass
class TflManager:
def __init__(self, pp, focal, egomotion):
self.principle_point = pp
self.focal = focal
self.em = egomotion
self.my_model = load_model("model.h5")
self.prev_candidates = Candidates("", [], [])
def on_frame(self, frame_path, index):
lights_candidates, tfl_candidates = self.find_tfl(frame_path)
if index == 0:
distances , rot_pts ,foe =0,0,0
else:
distances, foe, rot_pts = self.calc_distance(tfl_candidates, index)
visual(lights_candidates, tfl_candidates, distances, rot_pts, foe)
self.prev_candidates = tfl_candidates
def find_tfl(self, frame_path):
can, aux = self.find_lights(frame_path)
lights_candidates = Candidates(frame_path, can, aux)
can, aux = self.recognize_tfl(lights_candidates)
tfl_candidates = Candidates(frame_path, can, aux)
return lights_candidates, tfl_candidates
def find_lights(self, frame) -> (list, list):
image = np.array(Image.open(frame))
return find_tfl_lights(image)
def recognize_tfl(self, candidate: Candidates) -> (list, list):
croped_images = [crop_image(candidate.frame_path, point[0], point[1]) for point in candidate.points]
predictions = self.my_model.predict(np.array(croped_images))
tfl_array = []
auxiliary = []
for index, predict in enumerate(predictions[:, 1]):
if predict > 0.5:
tfl_array.append(candidate.points[index])
auxiliary.append(candidate.auxiliary[index])
return tfl_array, auxiliary
def calc_distance(self, cur_frame: Candidates, index: int) -> float:
prev_container = FrameContainer(self.prev_candidates.frame_path, np.array(self.prev_candidates.points))
curr_container = FrameContainer(cur_frame.frame_path, np.array(cur_frame.points))
curr_container.EM = self.em[index - 1]
z = calc_TFL_dist(prev_container, curr_container, self.focal, self.principle_point)
foe, rot_pts = get_foe_rotate(prev_container, curr_container, self.focal, self.principle_point)
return z, foe, rot_pts