-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Prescriptors and project to handle new prescription #89
Changes from all commits
1f1001a
9db5c04
a184505
33059f2
7acd0d2
1e5c73b
8bc07d0
b917f17
1cb7620
d0596b3
f5ae448
babf912
926ffe0
7b53ef1
3bdde8c
0f58c24
04c26e6
2f34f9d
8a96046
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ ignore=prescriptors/esp | |
|
||
recursive=y | ||
|
||
fail-under=9.65 | ||
fail-under=9.7 | ||
|
||
jobs=0 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,44 +18,35 @@ | |
import dash_bootstrap_components as dbc | ||
|
||
from data import constants | ||
from data.eluc_data import ELUCEncoder | ||
import app.constants as app_constants | ||
from app import utils | ||
from prescriptors.nsga2.torch_prescriptor import TorchPrescriptor | ||
|
||
app = Dash(__name__, | ||
external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP], | ||
prevent_initial_callbacks="initial_duplicate") | ||
server = app.server | ||
|
||
df = pd.read_csv(app_constants.DATA_FILE_PATH, index_col=app_constants.INDEX_COLS) | ||
df.rename(columns={col + ".1": col for col in app_constants.INDEX_COLS}, inplace=True) | ||
COUNTRIES_DF = regionmask.defined_regions.natural_earth_v5_0_0.countries_110.to_dataframe() | ||
|
||
# Prescriptor list should be in order of least to most change | ||
# Load pareto df | ||
pareto_df = pd.read_csv(app_constants.PARETO_CSV_PATH) | ||
pareto_df = pareto_df.sort_values(by="change", ascending=True) | ||
pareto_df.sort_values(by="change", inplace=True) | ||
prescriptor_list = list(pareto_df["id"]) | ||
|
||
encoder = ELUCEncoder.from_json(app_constants.PRESCRIPTOR_PATH / "fields.json") | ||
# TODO: Stop hard-coding candidate params -> make cand config file? | ||
candidate_params = { | ||
"in_size": len(constants.CAO_MAPPING["context"]), | ||
"hidden_size": 16, | ||
"out_size": len(constants.RECO_COLS) | ||
} | ||
prescriptor = TorchPrescriptor(None, encoder, None, 1, candidate_params) | ||
# Load prescriptors | ||
prescriptor_manager = utils.load_prescriptors() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just use a single PrescriptorManager object now that can hold all the individual LandUsePrescriptors. This allows us to in the future add Heuristics too since they implement Prescriptor |
||
|
||
# Load predictors | ||
predictors = utils.load_predictors() | ||
|
||
# Cells | ||
min_lat = df.index.get_level_values("lat").min() | ||
max_lat = df.index.get_level_values("lat").max() | ||
min_lon = df.index.get_level_values("lon").min() | ||
max_lon = df.index.get_level_values("lon").max() | ||
min_time = df.index.get_level_values("time").min() | ||
max_time = df.index.get_level_values("time").max() | ||
min_lat = df["lat"].min() | ||
max_lat = df["lat"].max() | ||
min_lon = df["lon"].min() | ||
max_lon = df["lon"].max() | ||
min_time = df["time"].min() | ||
max_time = df["time"].max() | ||
|
||
lat_list = list(np.arange(min_lat, max_lat + app_constants.GRID_STEP, app_constants.GRID_STEP)) | ||
lon_list = list(np.arange(min_lon, max_lon + app_constants.GRID_STEP, app_constants.GRID_STEP)) | ||
|
@@ -478,9 +469,7 @@ def select_prescriptor(_, presc_idx, year, lat, lon): | |
presc_id = prescriptor_list[presc_idx] | ||
context = df.loc[year, lat, lon][constants.CAO_MAPPING["context"]] | ||
context_df = pd.DataFrame([context]) | ||
prescribed = prescriptor.prescribe_land_use(context_df, | ||
cand_id=presc_id, | ||
results_dir=app_constants.PRESCRIPTOR_PATH) | ||
prescribed = prescriptor_manager.prescribe(presc_id, context_df) | ||
# Prescribed gives it to us in diff format, we need to recompute recommendations | ||
for col in constants.RECO_COLS: | ||
prescribed[col] = context[col] + prescribed[f"{col}_diff"] | ||
|
@@ -544,7 +533,7 @@ def compute_land_change(sliders, year, lat, lon, locked): | |
warnings.append(html.P("WARNING: Negative values detected. Please lower the value of a locked slider.")) | ||
|
||
# Compute total change | ||
change = prescriptor.compute_percent_changed(context_actions_df) | ||
change = prescriptor_manager.compute_percent_changed(context_actions_df) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. compute_percent_changed is now part of PrescriptorManager |
||
|
||
return warnings, f"{change['change'].iloc[0] * 100:.2f}" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
id,parents,NSGA-II_rank,distance,ELUC,change | ||
1_1,"(None, None)",1,inf,-0.011886149,0.0040596884414631 | ||
49_0,"('45_27', '21_80')",1,0.20109607207578,-2.1743317,0.0533331221986483 | ||
61_92,"('26_38', '60_66')",1,0.0427627827542644,-4.3928847,0.0690307965236667 | ||
85_51,"('82_16', '53_65')",1,0.0506403998177242,-8.327706,0.1045782392704721 | ||
82_16,"('51_17', '76_44')",1,0.0606549652913774,-11.905426,0.1340041432727491 | ||
70_30,"('59_5', '51_15')",1,0.0230698188448444,-13.6480255,0.1685300249153988 | ||
67_70,"('39_79', '59_5')",1,0.032665984351135,-14.625315,0.1968721749125319 | ||
36_3,"('30_85', '34_74')",1,0.0577219544984151,-15.683007,0.2284119445886663 | ||
62_84,"('58_38', '56_3')",1,0.0310138542659899,-16.641254,0.2609490730066026 | ||
54_67,"('43_94', '43_94')",1,inf,-17.59739,0.2924915519940842 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,11 @@ | |
|
||
import app.constants as app_constants | ||
from data import constants | ||
|
||
from prescriptors.prescriptor_manager import PrescriptorManager | ||
from prescriptors.nsga2.land_use_prescriptor import LandUsePrescriptor | ||
|
||
from predictors.predictor import Predictor | ||
from predictors.neural_network.neural_net_predictor import NeuralNetPredictor | ||
from predictors.sklearn.sklearn_predictor import LinearRegressionPredictor, RandomForestPredictor | ||
|
||
|
@@ -258,9 +263,27 @@ def create_pareto(pareto_df: pd.DataFrame, presc_id: int) -> go.Figure: | |
" Average ELUC: %{y} tC/ha<extra></extra>") | ||
return fig | ||
|
||
def load_predictors() -> dict: | ||
def load_prescriptors() -> tuple[list[str], PrescriptorManager]: | ||
""" | ||
Loads in prescriptors from disk, downloads from HuggingFace first if needed. | ||
TODO: Currently hard-coded to load specific prescriptors from pareto path. | ||
:return: dict of prescriptor name -> prescriptor object. | ||
""" | ||
prescriptors = {} | ||
pareto_df = pd.read_csv(app_constants.PARETO_CSV_PATH) | ||
pareto_df = pareto_df.sort_values(by="change") | ||
for cand_id in pareto_df["id"]: | ||
cand_path = f"danyoung/eluc-{cand_id}" | ||
cand_local_dir = app_constants.PRESCRIPTOR_PATH / cand_path.replace("/", "--") | ||
prescriptors[cand_id] = LandUsePrescriptor.from_pretrained(cand_path, local_dir=cand_local_dir) | ||
|
||
prescriptor_manager = PrescriptorManager(prescriptors, None) | ||
|
||
return prescriptor_manager | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reads the hard-coded pareto and downloads the appropriate models from HuggingFace |
||
|
||
def load_predictors() -> dict[str, Predictor]: | ||
""" | ||
Loads in predictors from disk. | ||
Loads in predictors from disk, downloads from HuggingFace first if needed. | ||
TODO: Currently hard-coded to load specific predictors. We need to make this able to handle any amount! | ||
:return: dict of predictor name -> predictor object. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change our dockerfile to download the data from HuggingFace before we start the app so that we don't have to push the csv to the git repo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How fast (or slow) is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It takes about a minute. We could also upload the preprocessed app dataset to huggingface which would remove most of this time.