Skip to content

Commit

Permalink
Add logging suport (debug anf info level using -vv)
Browse files Browse the repository at this point in the history
  • Loading branch information
fialhocoelho committed Jun 11, 2024
1 parent f666016 commit c2eb6f4
Show file tree
Hide file tree
Showing 9 changed files with 2,399 additions and 704 deletions.
53 changes: 37 additions & 16 deletions notebooks/chronos_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,14 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:utils.nexdata:Loading config file ../config/config.yaml\n",
"INFO:utils.nexdata:waves_palmas train path: ../data/raw/santos_dataset/train/waves_palmas.parquet\n",
"INFO:utils.nexdata:waves_palmas test path: ../data/raw/santos_dataset/test/waves_palmas.parquet\n",
"INFO:utils.nexdata:current_praticagem train path: ../data/raw/santos_dataset/train/current_praticagem.parquet\n",
"INFO:utils.nexdata:current_praticagem test path: ../data/raw/santos_dataset/test/current_praticagem.parquet\n",
"INFO:utils.nexdata:Random seed: 42\n",
"INFO:utils.nexdata:Default device: cuda\n",
"INFO:utils.nexdata:Defining paths...\n"
"INFO:utils.nexdata: Defining paths...\n"
]
}
],
Expand All @@ -37,11 +30,27 @@
"from chronos import ChronosPipeline\n",
"import numpy as np\n",
"\n",
"import logging\n",
"\n",
"sys.path.append('../src/')\n",
"from utils.nexdata import *\n",
"from utils.nexutil import *\n",
"\n",
"params = NexData(nexus_folder='../')\n",
"# Simular argumentos da linha de comando\n",
"sys.argv = ['chronos.py', '-v']\n",
"\n",
"# Configure the root logger\n",
"# Parse arguments\n",
"args = parse_args()\n",
"log_level = get_log_level(args.verbose)\n",
"log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'\n",
"\n",
"logging.basicConfig(level=log_level, format=log_fmt)\n",
"\n",
"\n",
"params = NexData(nexus_folder='../',\n",
" log_level = log_level)\n",
"\n",
"set_random_seeds(params.data_params['default_seed'])"
]
},
Expand All @@ -60,15 +69,20 @@
"source": [
"# Load the ChronosPipeline model from the pretrained\n",
"# 'amazon/chronos-t5-large' model\n",
"chronos_pipeline = ChronosPipeline.from_pretrained(\n",
" 'amazon/chronos-t5-large',\n",
" device_map='cuda',\n",
" torch_dtype=torch.bfloat16,\n",
")\n",
"try:\n",
" chronos_pipeline = ChronosPipeline.from_pretrained(\n",
" 'amazon/chronos-t5-large',\n",
" device_map='cuda',\n",
" torch_dtype=torch.bfloat16,\n",
" )\n",
" params.logger.info(f'Chronos model load with successfull o/')\n",
"except Exception as err:\n",
" params.logger.critical(f'Chronos model cannot be loaded. {err}')\n",
" raise\n",
"\n",
"# Iterate over each ocean variable defined in the parameters\n",
"for ocean_variable in params.features.keys():\n",
" print(f'Ocean variable: {ocean_variable}')\n",
" params.logger.debug(f'Ocean variable: {ocean_variable}')\n",
"\n",
" # Retrieve target features and experiment IDs\n",
" target_features = params.features[ocean_variable]\n",
Expand All @@ -78,8 +92,11 @@
" # Load train and test data for the target feature\n",
" df_train_target = pd.read_parquet(\n",
" target_features['train_filepath'])\n",
" params.logger.debug(f' df_train_target:\\n{df_train_target.head}')\n",
" df_test_target = pd.read_parquet(\n",
" target_features['test_filepath'])\n",
" params.logger.debug(f' df_test_target:\\n{df_test_target.head}')\n",
"\n",
"\n",
" # Process the training dataframe with specified parameters\n",
" df_train_processed_target = process_dataframe(\n",
Expand All @@ -90,6 +107,7 @@
" params.data_params['interp_method'],\n",
" params.data_params['datetime_col'],\n",
" params.data_params['round_freq'])\n",
" params.logger.debug(' df_train_target are processed.')\n",
"\n",
" # Process the test dataframe with specified parameters\n",
" df_test_processed_target = process_dataframe(\n",
Expand All @@ -100,6 +118,7 @@
" params.data_params['interp_method'],\n",
" params.data_params['datetime_col'],\n",
" params.data_params['round_freq'])\n",
" params.logger.debug(' df_test_target are processed.')\n",
"\n",
" # Define the context and forecast window lengths and shift\n",
" context_len = params.model_params['context_window_len']\n",
Expand All @@ -111,6 +130,7 @@
" X_test_index, y_test_index = generate_indices(\n",
" df_test_processed_target, context_len, forecast_len,\n",
" shift, mode)\n",
" params.logger.debug(' X_test_index, y_test_index are created.')\n",
"\n",
" # Initialize DataFrames for predictions and index of agreement (IOA) values\n",
" df_y_hat = pd.DataFrame()\n",
Expand All @@ -123,6 +143,7 @@
" df_y_hat.index, params.data_params['datetime_col']\n",
" ])\n",
"\n",
" params.logger.debug(' start loop from df features')\n",
" # Iterate over each target feature for prediction\n",
" for target_feature in target_features['list_features']:\n",
" y_hat = []\n",
Expand Down
Loading

0 comments on commit c2eb6f4

Please sign in to comment.