Skip to content

Commit

Permalink
Remove test fields from queries
Browse files Browse the repository at this point in the history
  • Loading branch information
ziv17 committed Aug 15, 2024
1 parent 3787efa commit f9b77c9
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, request_params: RequestParams):
def generate_items(self) -> None:
res1 = get_accidents_stats(
table_obj=InvolvedMarkerView,
filters=get_injured_filters(self.request_params),
filters=get_injured_filters(self.request_params.location_info),
group_by=("accident_year", "injury_severity"),
count="injury_severity",
start_time=self.request_params.start_time,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import copy
from typing import Dict
from anyway.request_params import RequestParams, LocationInfo
from anyway.backend_constants import InjurySeverity
Expand All @@ -10,6 +9,7 @@
get_accidents_stats,
join_strings,
get_location_text,
get_injured_filters,
)
from anyway.backend_constants import BE_CONST
from flask_babel import _
Expand Down Expand Up @@ -40,7 +40,7 @@ def get_injured_count_by_severity(
start_time: datetime.date,
end_time: datetime.date,
):
filters = copy.copy(location_info)
filters = get_injured_filters(location_info)
filters["injury_severity"] = [
InjurySeverity.KILLED.value,
InjurySeverity.SEVERE_INJURED.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def generate_items(self) -> None:

@staticmethod
def count_accidents_by_driver_type(request_params: RequestParams):
filters = get_injured_filters(request_params)
filters = get_injured_filters(request_params.location_info)
filters["involved_type"] = [
consts.InvolvedType.DRIVER.value,
consts.InvolvedType.INJURED_DRIVER.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def __init__(self, request_params: RequestParams):
self.rank = 18
self.information = "Injured and killed pedestrians by severity and year"

def validate_parameters(self, yishuv_name, street1_hebrew):
def validate_parameters(self, yishuv_symbol, street1):
# TODO: validate each parameter and display message accordingly
return (
yishuv_name is not None
and street1_hebrew is not None
yishuv_symbol is not None
and street1 is not None
and self.request_params.years_ago is not None
)

Expand All @@ -53,10 +53,10 @@ def convert_to_dict(query_results):

def generate_items(self) -> None:
try:
yishuv_name = self.request_params.location_info.get("yishuv_name")
street1_hebrew = self.request_params.location_info.get("street1_hebrew")
yishuv_symbol = self.request_params.location_info.get("yishuv_symbol")
street1 = self.request_params.location_info.get("street1")

# if not self.validate_parameters(yishuv_name, street1_hebrew):
# if not self.validate_parameters(yishuv_symbol, street1_hebrew):
# # TODO: this will fail since there is no news_flash_obj in request_params
# logging.exception(f"Could not validate parameters yishuv_name + street1_hebrew in widget : {self.name}")
# return None
Expand All @@ -74,7 +74,7 @@ def generate_items(self) -> None:
func.count().label("count"),
)
.filter(loc_ex)
.filter(InvolvedMarkerView.accident_yishuv_name == yishuv_name)
.filter(InvolvedMarkerView.accident_yishuv_symbol == yishuv_symbol)
.filter(
InvolvedMarkerView.injury_severity.in_(
[
Expand All @@ -87,8 +87,8 @@ def generate_items(self) -> None:
.filter(InvolvedMarkerView.injured_type == InjuredType.PEDESTRIAN.value)
.filter(
or_(
InvolvedMarkerView.street1_hebrew == street1_hebrew,
InvolvedMarkerView.street2_hebrew == street1_hebrew,
InvolvedMarkerView.street1 == street1,
InvolvedMarkerView.street2 == street1,
)
)
.filter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def __init__(self, request_params: RequestParams):

def generate_items(self) -> None:
self.items = SevereFatalCountByVehicleByYearWidget.separate_data(
self.request_params.location_info["yishuv_name"],
self.request_params.location_info["yishuv_symbol"],
self.request_params.start_time,
self.request_params.end_time,
self.request_params.resolution,
)

@staticmethod
def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]:
def separate_data(yishuv_symbol, start_time, end_time, resolution) -> Dict[str, Any]:
output = {
"e_bikes": get_accidents_stats(
table_obj=InvolvedMarkerView,
Expand All @@ -39,7 +39,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]:
InjurySeverity.SEVERE_INJURED.value,
],
"involve_vehicle_type": VehicleType.ELECTRIC_BIKE.value,
"accident_yishuv_name": yishuv,
"accident_yishuv_symbol": yishuv_symbol,
},
group_by="accident_year",
count="accident_year",
Expand All @@ -55,7 +55,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]:
InjurySeverity.SEVERE_INJURED.value,
],
"involve_vehicle_type": VehicleType.BIKE.value,
"accident_yishuv_name": yishuv,
"accident_yishuv_symbol": yishuv_symbol,
},
group_by="accident_year",
count="accident_year",
Expand All @@ -71,7 +71,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]:
InjurySeverity.SEVERE_INJURED.value,
],
"involve_vehicle_type": VehicleType.ELECTRIC_SCOOTER.value,
"accident_yishuv_name": yishuv,
"accident_yishuv_symbol": yishuv_symbol,
},
group_by="accident_year",
count="accident_year",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def __init__(self, request_params: RequestParams):

def generate_items(self) -> None:
self.items = SmallMotorSevereFatalCountByYearWidget.get_motor_stats(
self.request_params.location_info["yishuv_name"],
self.request_params.location_info["yishuv_symbol"],
self.request_params.start_time,
self.request_params.end_time,
self.request_params.resolution,
)

@staticmethod
def get_motor_stats(location_info, start_time, end_time, resolution):
def get_motor_stats(yishuv_symbol, start_time, end_time, resolution):
count_by_year = get_accidents_stats(
table_obj=InvolvedMarkerView,
filters={
Expand All @@ -37,7 +37,7 @@ def get_motor_stats(location_info, start_time, end_time, resolution):
InjurySeverity.SEVERE_INJURED.value,
],
"involve_vehicle_type": VehicleCategory.BICYCLE_AND_SMALL_MOTOR.get_codes(),
"accident_yishuv_name": location_info,
"accident_yishuv_symbol": yishuv_symbol,
},
group_by="accident_year",
count="accident_year",
Expand Down
28 changes: 14 additions & 14 deletions anyway/widgets/urban_widgets/urban_crosswalk_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from anyway.widgets.widget_utils import get_accidents_stats


# TODO: pretty sure there are errors in this widget, for example, is_included returns self.items
class UrbanCrosswalkWidget(UrbanWidget):
name: str = "urban_accidents_by_cross_location"
files = [__file__]
Expand All @@ -21,15 +20,14 @@ def __init__(self, request_params: RequestParams):

def generate_items(self) -> None:
self.items = UrbanCrosswalkWidget.get_crosswalk(
self.request_params.location_info["yishuv_name"],
self.request_params.location_info["street1_hebrew"],
self.request_params.location_info,
self.request_params.start_time,
self.request_params.end_time,
self.request_params.resolution,
)

@staticmethod
def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str, Any]:
def get_crosswalk(location_info: dict, start_time, end_time, resolution) -> Dict[str, Any]:
cross_output = {
"with_crosswalk": get_accidents_stats(
table_obj=InvolvedMarkerView,
Expand All @@ -39,11 +37,11 @@ def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str,
InjurySeverity.SEVERE_INJURED.value,
],
"cross_location": CrossCategory.CROSSWALK.get_codes(),
"accident_yishuv_name": yishuv,
"street1_hebrew": street,
"accident_yishuv_symbol": location_info["yishuv_symbol"],
"street1": location_info["street"],
},
group_by="street1_hebrew",
count="street1_hebrew",
group_by="street1",
count="street1",
start_time=start_time,
end_time=end_time,
resolution=resolution,
Expand All @@ -56,20 +54,22 @@ def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str,
InjurySeverity.SEVERE_INJURED.value,
],
"cross_location": CrossCategory.NONE.get_codes(),
"accident_yishuv_name": yishuv,
"street1_hebrew": street,
"accident_yishuv_symbol": location_info["yishuv_symbol"],
"street1": location_info["street"],
},
group_by="street1_hebrew",
count="street1_hebrew",
group_by="street1",
count="street1",
start_time=start_time,
end_time=end_time,
resolution=resolution,
),
}
if not cross_output["with_crosswalk"]:
cross_output["with_crosswalk"] = [{"street1_hebrew": street, "count": 0}]
cross_output["with_crosswalk"] = [{"street1_hebrew": location_info["_hebrew"],
"count": 0}]
if not cross_output["without_crosswalk"]:
cross_output["without_crosswalk"] = [{"street1_hebrew": street, "count": 0}]
cross_output["without_crosswalk"] = [{"street1_hebrew": location_info["street_hebrew"],
"count": 0}]
return cross_output

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions anyway/widgets/urban_widgets/urban_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def __init__(self, request_params: RequestParams):
def is_urban(request_params: RequestParams) -> bool:
return (
request_params is not None
and "yishuv_name" in request_params.location_info
and "street1_hebrew" in request_params.location_info
and "yishuv_symbol" in request_params.location_info
and "street1" in request_params.location_info
)

@staticmethod
Expand Down
15 changes: 8 additions & 7 deletions anyway/widgets/widget_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def get_query(table_obj, filters, start_time, end_time):

def remove_loc_text_fields_from_filter(filters: dict) -> dict:
def remove_first_if_both_exist(d: dict, first: str, second: str) -> dict:
if first in d and second in d:
d.pop(first)
if first in d and second in d:
d.pop(first)

res = copy.copy(filters)
remove_first_if_both_exist(res, "road_segment_name", "road_segment_id")
Expand Down Expand Up @@ -221,9 +221,9 @@ def gen_entity_labels(entity: Type[LabeledCode]) -> dict:
# return filters


def get_injured_filters(request_params: RequestParams):
new_filters = copy.copy(request_params.location_info)
for curr_filter, curr_value in request_params.location_info.items():
def get_injured_filters(location_info: dict):
new_filters = copy.copy(location_info)
for curr_filter, curr_value in location_info.items():
if curr_filter in ["region_hebrew", "district_hebrew", "yishuv_name", "yishuv_symbol"]:
new_filter_name = "accident_" + curr_filter
new_filters[new_filter_name] = curr_value
Expand Down Expand Up @@ -305,9 +305,10 @@ def get_involved_counts(
.order_by(table.accident_year)
)
filters = add_resolution_location_accuracy_filter(location_info, table)
if "yishuv_symbol" in location_info:
filters = remove_loc_text_fields_from_filter(filters)
if "yishuv_symbol" in filters:
filters["accident_yishuv_symbol"] = filters["yishuv_symbol"]
filters = remove_loc_text_fields_from_filter(filters)
filters.pop("yishuv_symbol")
ex = get_expression_for_fields(filters, table)
query = query.filter(ex).group_by(table.accident_year)

Expand Down
46 changes: 23 additions & 23 deletions tests/test_infographics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,29 +214,29 @@ def test_get_expression_for_non_road_segment_fields(self):
str(actual), "4")

def test_remove_names_from_filters(self):
self.assertEqual({},remove_loc_text_fields_from_filter({}), "1")
expected = {"a": 1, "yishuv_symbol": 2}
test = {"a": 1, "yishuv_symbol": 2, "yishuv_name": "yishuv"}
actual = remove_loc_text_fields_from_filter(test)
self.assertEqual(expected, actual, "2")
expected = {"a": 1, "yishuv_symbol": 2,
"street1": 3, "street2": 4,
"road_segment_id": 17}
test = {"a": 1, "yishuv_symbol": 2, "yishuv_name": "yishuv",
"street1": 3, "street1_hebrew": "Hebrew",
"street2": 4, "street2_hebrew": "Hebrew2",
"road_segment_name": "seg name", "road_segment_id": 17}
actual = remove_loc_text_fields_from_filter(test)
self.assertEqual(expected, actual, "3")
expected = {"a": 1, "accident_yishuv_symbol": 2,
"street1_hebrew": "Hebrew", "street2": 4,
"road_segment_id": 17}
test = {"a": 1, "accident_yishuv_symbol": 2, "accident_yishuv_name": "yishuv",
"street1_hebrew": "Hebrew",
"street2": 4, "street2_hebrew": "Hebrew2",
"road_segment_name": "seg name", "road_segment_id": 17}
actual = remove_loc_text_fields_from_filter(test)
self.assertEqual(expected, actual, "3")
self.assertEqual({}, remove_loc_text_fields_from_filter({}), "1")
expected = {"a": 1, "yishuv_symbol": 2}
test = {"a": 1, "yishuv_symbol": 2, "yishuv_name": "yishuv"}
actual = remove_loc_text_fields_from_filter(test)
self.assertEqual(expected, actual, "2")
expected = {"a": 1, "yishuv_symbol": 2,
"street1": 3, "street2": 4,
"road_segment_id": 17}
test = {"a": 1, "yishuv_symbol": 2, "yishuv_name": "yishuv",
"street1": 3, "street1_hebrew": "Hebrew",
"street2": 4, "street2_hebrew": "Hebrew2",
"road_segment_name": "seg name", "road_segment_id": 17}
actual = remove_loc_text_fields_from_filter(test)
self.assertEqual(expected, actual, "3")
expected = {"a": 1, "accident_yishuv_symbol": 2,
"street1_hebrew": "Hebrew", "street2": 4,
"road_segment_id": 17}
test = {"a": 1, "accident_yishuv_symbol": 2, "accident_yishuv_name": "yishuv",
"street1_hebrew": "Hebrew",
"street2": 4, "street2_hebrew": "Hebrew2",
"road_segment_name": "seg name", "road_segment_id": 17}
actual = remove_loc_text_fields_from_filter(test)
self.assertEqual(expected, actual, "3")


if __name__ == '__main__':
Expand Down

0 comments on commit f9b77c9

Please sign in to comment.