Skip to content

Commit

Permalink
add z-score normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
fialhocoelho committed Jun 23, 2024
1 parent dcd10b3 commit 6b903d3
Showing 1 changed file with 98 additions and 173 deletions.
271 changes: 98 additions & 173 deletions notebooks/timegpt_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -57,7 +57,7 @@
"\n",
"set_random_seeds(params.data_params['default_seed'])\n",
"\n",
"id_experiment = 'timegpt_forecast_target_features'"
"id_experiment = 'timegpt_znorm_forecast_target_features'"
]
},
{
Expand All @@ -69,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -84,8 +84,8 @@
"try:\n",
" nixtla_client = NixtlaClient(\n",
" api_key = load_api_key(\"../config/nixtla_api.key\"),\n",
" max_retries=30,\n",
" retry_interval=30,\n",
" max_retries=params.model_params['attempts_after_failure'],\n",
" retry_interval=params.model_params['retry_interval'],\n",
" )\n",
" params.logger.info(f' TimeGPT model load with successfull o/')\n",
"except Exception as err:\n",
Expand Down Expand Up @@ -152,6 +152,8 @@
" target_features['train_filepath'])\n",
" df_test_target = pd.read_parquet(\n",
" target_features['test_filepath'])\n",
" #print(f'Train filepath: {target_features['train_filepath']}')\n",
" #print(f'Test filepath: {target_features['test_filepath']}')\n",
"\n",
" try:\n",
" # Process the training dataframe with specified parameters\n",
Expand Down Expand Up @@ -192,7 +194,8 @@
" X_test_index, y_test_index = generate_indices(\n",
" df_test_processed_target, context_len, forecast_len,\n",
" shift, mode)\n",
" #params.logger.debug(' X_test_index, y_test_index are created.')\n",
"\n",
" len_X_test_index = len(X_test_index)\n",
"\n",
" # Initialize DataFrames for predictions and index of agreement (IOA) values\n",
" df_y_hat = pd.DataFrame()\n",
Expand All @@ -205,7 +208,6 @@
" df_y_hat.index, params.data_params['datetime_col']\n",
" ])\n",
"\n",
" #params.logger.debug(' start loop from df features')\n",
" # Iterate over each target feature for prediction\n",
" for target_feature in target_features['list_features']:\n",
" # Add training data to improve the size of the inference data\n",
Expand All @@ -214,13 +216,10 @@
"\n",
" test_signal = df_test_processed_target.loc[:, \n",
" target_feature]\n",
" \n",
" len_X_test_index = len(X_test_index)\n",
"\n",
" df_target_col_name = f'{ocean_variable}_{target_feature}'\n",
" params.logger.debug(\n",
" f\" Target feature: {target_feature} | {df_target_col_name}\")\n",
" \n",
"\n",
" df_train_composed[df_target_col_name] = train_signal.values\n",
" df_test_composed[df_target_col_name] = test_signal.values\n",
Expand All @@ -244,171 +243,21 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 179/179 [33:55<00:00, 11.37s/it]\n"
" 18%|█▊ | 33/179 [04:41<20:46, 8.54s/it]INFO:utils.nexdata: \n",
"target feature: current_praticagem_cross_shore_current\n",
"INFO:nixtla.nixtla_client:Validating inputs...\n",
"INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
"INFO:nixtla.nixtla_client:Using the following exogenous variables: sofs_praticagem_cross_shore_current, sofs_praticagem_ssh, astronomical_tide_astronomical_tide\n",
"INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>datetime</th>\n",
" <th>current_praticagem_cross_shore_current</th>\n",
" <th>waves_palmas_hs</th>\n",
" <th>waves_palmas_tp</th>\n",
" <th>waves_palmas_ws</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>168</th>\n",
" <td>2022-01-08 00:00:00</td>\n",
" <td>-0.496241</td>\n",
" <td>1.205139</td>\n",
" <td>10.518562</td>\n",
" <td>-0.015064</td>\n",
" </tr>\n",
" <tr>\n",
" <th>169</th>\n",
" <td>2022-01-08 01:00:00</td>\n",
" <td>-0.475252</td>\n",
" <td>1.117056</td>\n",
" <td>10.122596</td>\n",
" <td>0.002176</td>\n",
" </tr>\n",
" <tr>\n",
" <th>170</th>\n",
" <td>2022-01-08 02:00:00</td>\n",
" <td>-0.221233</td>\n",
" <td>1.495980</td>\n",
" <td>10.170714</td>\n",
" <td>0.200205</td>\n",
" </tr>\n",
" <tr>\n",
" <th>171</th>\n",
" <td>2022-01-08 03:00:00</td>\n",
" <td>-0.304862</td>\n",
" <td>1.322694</td>\n",
" <td>9.801672</td>\n",
" <td>0.191701</td>\n",
" </tr>\n",
" <tr>\n",
" <th>172</th>\n",
" <td>2022-01-08 04:00:00</td>\n",
" <td>-0.258456</td>\n",
" <td>1.231174</td>\n",
" <td>10.200695</td>\n",
" <td>0.186209</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8755</th>\n",
" <td>2022-12-31 19:00:00</td>\n",
" <td>0.117516</td>\n",
" <td>0.660472</td>\n",
" <td>9.674246</td>\n",
" <td>0.078386</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8756</th>\n",
" <td>2022-12-31 20:00:00</td>\n",
" <td>0.144268</td>\n",
" <td>0.662330</td>\n",
" <td>9.598036</td>\n",
" <td>0.114077</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8757</th>\n",
" <td>2022-12-31 21:00:00</td>\n",
" <td>-0.021904</td>\n",
" <td>0.564333</td>\n",
" <td>9.813744</td>\n",
" <td>0.033776</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8758</th>\n",
" <td>2022-12-31 22:00:00</td>\n",
" <td>0.150519</td>\n",
" <td>0.658883</td>\n",
" <td>9.335948</td>\n",
" <td>0.162465</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8759</th>\n",
" <td>2022-12-31 23:00:00</td>\n",
" <td>0.177804</td>\n",
" <td>0.735480</td>\n",
" <td>9.691661</td>\n",
" <td>0.144297</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>8592 rows × 5 columns</p>\n",
"</div>"
],
"text/plain": [
" datetime current_praticagem_cross_shore_current \\\n",
"168 2022-01-08 00:00:00 -0.496241 \n",
"169 2022-01-08 01:00:00 -0.475252 \n",
"170 2022-01-08 02:00:00 -0.221233 \n",
"171 2022-01-08 03:00:00 -0.304862 \n",
"172 2022-01-08 04:00:00 -0.258456 \n",
"... ... ... \n",
"8755 2022-12-31 19:00:00 0.117516 \n",
"8756 2022-12-31 20:00:00 0.144268 \n",
"8757 2022-12-31 21:00:00 -0.021904 \n",
"8758 2022-12-31 22:00:00 0.150519 \n",
"8759 2022-12-31 23:00:00 0.177804 \n",
"\n",
" waves_palmas_hs waves_palmas_tp waves_palmas_ws \n",
"168 1.205139 10.518562 -0.015064 \n",
"169 1.117056 10.122596 0.002176 \n",
"170 1.495980 10.170714 0.200205 \n",
"171 1.322694 9.801672 0.191701 \n",
"172 1.231174 10.200695 0.186209 \n",
"... ... ... ... \n",
"8755 0.660472 9.674246 0.078386 \n",
"8756 0.662330 9.598036 0.114077 \n",
"8757 0.564333 9.813744 0.033776 \n",
"8758 0.658883 9.335948 0.162465 \n",
"8759 0.735480 9.691661 0.144297 \n",
"\n",
"[8592 rows x 5 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fcst_df_full = pd.DataFrame()\n",
"\n",
"fcst_df_full.index = pd.RangeIndex(start=context_len,\n",
" stop=df_test_composed.shape[0],\n",
" step=1)\n",
" stop=df_test_composed.shape[0],\n",
" step=1)\n",
"fcst_df_full[params.data_params['datetime_col']] = df_y_hat[\n",
" params.data_params['datetime_col']]\n",
"\n",
Expand All @@ -423,11 +272,23 @@
" y_test_df = df_test_composed.loc[y_test_index[idx], :]\n",
"\n",
" # Concatenate training and test signals\n",
" combined_df = pd.concat([df_train_composed,\n",
" X_test_df], axis=0).reset_index(drop=True)\n",
" if params.model_params['normalize']:\n",
" aux_combined_df = pd.concat([df_train_composed,\n",
" X_test_df], axis=0).reset_index(\n",
" drop=True)\n",
" combined_df, _,_ = normalize_z_score_df(aux_combined_df)\n",
" df_cct = pd.concat([df_train_composed,y_test_df], axis=0)\n",
" _, means, stds = normalize_z_score_df(df_cct.reset_index(drop=True))\n",
" y_test_df_norm , _, _ = normalize_z_score_df(y_test_df, means, stds)\n",
" else:\n",
" combined_df = pd.concat([df_train_composed,\n",
" X_test_df], axis=0).reset_index(drop=True) \n",
" \n",
" attempts = params.model_params['attempts_after_failure']\n",
"\n",
" exog_df = y_test_df_norm[exog_list] if params.model_params['normalize']\\\n",
" else y_test_df[exog_list] # Somente features exógenas\n",
"\n",
" for attempt in range(attempts):\n",
" try:\n",
" fcst_df = nixtla_client.forecast(\n",
Expand All @@ -436,7 +297,7 @@
" freq=params.data_params['target_freq'],\n",
" time_col=params.data_params['datetime_col'],\n",
" target_col=tgt_feature,\n",
" X_df=y_test_df[exog_list], # Somente features exógenas\n",
" X_df = exog_df,\n",
" model=params.model_params['timegpt_model'],\n",
" #finetune_steps=params.model_params[\n",
" # 'timegpt_finetune_steps'],\n",
Expand All @@ -454,7 +315,12 @@
" sys.exit(1)\n",
" \n",
" dt_list.extend(fcst_df.datetime.values)\n",
" fcst_list.extend(fcst_df.TimeGPT.values)\n",
" if params.model_params['normalize']:\n",
" fcst_list.extend(denormalize_z_score(fcst_df.TimeGPT.values,\n",
" means[tgt_feature],\n",
" stds[tgt_feature]))\n",
" else:\n",
" fcst_list.extend(fcst_df.TimeGPT.values)\n",
" clear_output(wait=True)\n",
" #TODO: Pegar tbm o valor medido para entrar no df\n",
"\n",
Expand All @@ -466,13 +332,72 @@
" f\"{id_experiment}_\"\n",
" f\"{params.timestamp}.pkl\")\n",
"\n",
"params.logger.info(f'Output file: {filename}')\n",
"params.logger.info(f' Output file: {filename}')\n",
"\n",
"fcst_df_full.to_parquet(filename)\n",
"\n",
"display(fcst_df_full)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for tgt_feature in ['waves_palmas_hs']:\n",
" ioa_list_norm = []\n",
" ioa_list_non_norm = []\n",
" for idx in range(len(y_test_index)):\n",
" ioa_norm = calculate_ioa(df_test_composed.loc[y_test_index[idx],tgt_feature].values,\n",
" fcst_df_full.loc[y_test_index[idx],tgt_feature].values)\n",
" ioa_list_norm.append(ioa_norm)\n",
"\n",
" ioa_non_norm = calculate_ioa(df_test_composed.loc[y_test_index[idx], tgt_feature].values,\n",
" fcst_df_without_normalized.loc[y_test_index[idx],tgt_feature].values)\n",
" ioa_list_non_norm.append(ioa_non_norm)\n",
"\n",
" plt.plot(np.cumsum(ioa_list_norm), label='norm')\n",
" plt.plot(np.cumsum(ioa_list_non_norm), label='non_norm')\n",
" #plt.plot(df_test_composed.loc[y_test_index[idx],'datetime'],\n",
" # df_test_composed.loc[y_test_index[idx],tgt_feature],\n",
" # label='Measured')\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
" print(f'{tgt_feature} norm mean: {np.std(ioa_list_norm)}')\n",
" print(f'{tgt_feature} non-norm mean: {np.std(ioa_list_non_norm)}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for tgt_feature in ['waves_palmas_hs']:\n",
" for idx in range(150, 160):\n",
" print(f'ioa norm: {calculate_ioa(df_test_composed.loc[y_test_index[idx],\n",
" tgt_feature].values,\n",
" fcst_df_full.loc[y_test_index[idx],\n",
" tgt_feature].values)}')\n",
" print(f'ioa non_norm: {calculate_ioa(df_test_composed.loc[y_test_index[idx],\n",
" tgt_feature].values,\n",
" fcst_df_without_normalized.loc[y_test_index[idx],\n",
" tgt_feature].values)}')\n",
" plt.plot(fcst_df_full.loc[y_test_index[idx],'datetime'],\n",
" fcst_df_full.loc[y_test_index[idx],tgt_feature],\n",
" label=tgt_feature)\n",
" plt.plot(fcst_df_without_normalized.loc[y_test_index[idx],'datetime'],\n",
" fcst_df_without_normalized.loc[y_test_index[idx],tgt_feature],\n",
" label='non_norm')\n",
" plt.plot(df_test_composed.loc[y_test_index[idx],'datetime'],\n",
" df_test_composed.loc[y_test_index[idx],tgt_feature],\n",
" label='Measured')\n",
" plt.legend()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 6b903d3

Please sign in to comment.