Skip to content

Commit

Permalink
apply style to notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Aug 2, 2024
1 parent 38ca85c commit a48b04f
Show file tree
Hide file tree
Showing 19 changed files with 1,059 additions and 962 deletions.
21 changes: 8 additions & 13 deletions notebooks/hfdemo/patch_tsmixer_blog.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,27 @@
"source": [
"# Standard\n",
"import os\n",
"import random\n",
"\n",
"# supress some warnings\n",
"import warnings\n",
"\n",
"import pandas as pd\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
" PatchTSMixerConfig,\n",
" PatchTSMixerForPrediction,\n",
" set_seed,\n",
" Trainer,\n",
" TrainingArguments,\n",
" set_seed,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
"from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n",
"from tsfm_public.toolkit.util import select_by_index\n",
"\n",
"# supress some warnings\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\", module=\"torch\")"
]
Expand Down Expand Up @@ -916,9 +915,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer_4/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer_4/electricity/model/pretrain/\")\n",
"print(\"Done\")"
]
},
Expand Down Expand Up @@ -1323,9 +1320,7 @@
],
"source": [
"# Reload the model\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer_4/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer_4/electricity/model/pretrain/\")\n",
"finetune_forecast_trainer = Trainer(\n",
" model=finetune_forecast_model,\n",
" args=finetune_forecast_args,\n",
Expand Down
12 changes: 5 additions & 7 deletions notebooks/hfdemo/patch_tsmixer_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
],
"source": [
"# Standard\n",
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -43,9 +46,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down Expand Up @@ -321,9 +321,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"inference_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"ibm-granite/granite-timeseries-patchtsmixer\"\n",
")\n",
"inference_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"ibm-granite/granite-timeseries-patchtsmixer\")\n",
"print(\"Done\")"
]
},
Expand Down
15 changes: 6 additions & 9 deletions notebooks/hfdemo/patch_tsmixer_transfer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -57,9 +61,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down Expand Up @@ -923,9 +924,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n",
"print(\"Done\")"
]
},
Expand Down Expand Up @@ -1415,9 +1414,7 @@
],
"source": [
"# Reload the model\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n",
"finetune_forecast_trainer = Trainer(\n",
" model=finetune_forecast_model,\n",
" args=finetune_forecast_args,\n",
Expand Down
12 changes: 5 additions & 7 deletions notebooks/hfdemo/patch_tst_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
"outputs": [],
"source": [
"# Standard\n",
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -30,9 +33,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down Expand Up @@ -309,9 +309,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"inference_forecast_model = PatchTSTForPrediction.from_pretrained(\n",
" \"ibm-granite/granite-timeseries-patchtst\"\n",
")\n",
"inference_forecast_model = PatchTSTForPrediction.from_pretrained(\"ibm-granite/granite-timeseries-patchtst\")\n",
"print(\"Done\")"
]
},
Expand Down
7 changes: 4 additions & 3 deletions notebooks/hfdemo/patch_tst_transfer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -43,9 +47,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down
48 changes: 19 additions & 29 deletions notebooks/hfdemo/tinytimemixer/ttm_benchmarking_1024_96.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,24 @@
"# Standard\n",
"import math\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Third Party\n",
"from torch.optim import AdamW\n",
"from torch.optim.lr_scheduler import OneCycleLR\n",
"from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.callbacks import TrackingCallback\n",
"# Local\n",
"from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction\n",
"from tsfm_public.models.tinytimemixer.utils import (\n",
" count_parameters,\n",
" get_data,\n",
" plot_preds,\n",
")\n",
"\n",
"# Local\n",
"from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction\n"
"# First Party\n",
"from tsfm_public.toolkit.callbacks import TrackingCallback"
]
},
{
Expand Down Expand Up @@ -84,13 +84,13 @@
"\n",
"# Make sure all the datasets in the following `list_datasets` are\n",
"# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
"# Refer to the get_data() function \n",
"# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py \n",
"# Refer to the get_data() function\n",
"# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py\n",
"# to see how it is used.\n",
"DATA_ROOT_PATH = \"datasets/\"\n",
"\n",
"# This is where results will be saved\n",
"OUT_DIR = \"ttm_results_benchmark_1024_96_tmp\"\n"
"OUT_DIR = \"ttm_results_benchmark_1024_96_tmp\""
]
},
{
Expand All @@ -115,7 +115,7 @@
" \"weather\",\n",
" \"electricity\",\n",
" \"traffic\",\n",
"]\n"
"]"
]
},
{
Expand All @@ -138,9 +138,7 @@
"elif context_length == 1024:\n",
" hf_model_branch = \"1024_96_v1\"\n",
"else:\n",
" raise ValueError(\n",
" \"Current supported context lengths are 512 and 1024. Stay tuned for more TTMs!\"\n",
" )\n"
" raise ValueError(\"Current supported context lengths are 512 and 1024. Stay tuned for more TTMs!\")"
]
},
{
Expand Down Expand Up @@ -195,9 +193,7 @@
" ##### Use the pretrained model in zero-shot forecasting #####\n",
" #############################################################\n",
" # Load model\n",
" zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(\n",
" hf_model_path, revision=hf_model_branch\n",
" )\n",
" zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(hf_model_path, revision=hf_model_branch)\n",
"\n",
" # zeroshot_trainer\n",
" zeroshot_trainer = Trainer(\n",
Expand Down Expand Up @@ -242,7 +238,7 @@
" context_length,\n",
" forecast_length,\n",
" fewshot_fraction=fewshot_percent / 100,\n",
" data_root_path=DATA_ROOT_PATH\n",
" data_root_path=DATA_ROOT_PATH,\n",
" )\n",
"\n",
" # change head dropout to 0.7 for ett datasets\n",
Expand Down Expand Up @@ -343,20 +339,14 @@
"\n",
" # write results\n",
" all_results[f\"fs{fewshot_percent}_mse\"].append(fewshot_output[\"eval_loss\"])\n",
" all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(\n",
" tracking_callback.mean_epoch_time\n",
" )\n",
" all_results[f\"fs{fewshot_percent}_total_train_time\"].append(\n",
" tracking_callback.total_train_time\n",
" )\n",
" all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(\n",
" tracking_callback.best_eval_metric\n",
" )\n",
" all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(tracking_callback.mean_epoch_time)\n",
" all_results[f\"fs{fewshot_percent}_total_train_time\"].append(tracking_callback.total_train_time)\n",
" all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(tracking_callback.best_eval_metric)\n",
"\n",
" df_out = pd.DataFrame(all_results).round(3)\n",
" print(df_out[[\"dataset\", \"zs_mse\", \"fs5_mse\", \"fs10_mse\"]])\n",
" df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n",
" df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n"
" df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")"
]
},
{
Expand Down Expand Up @@ -544,7 +534,7 @@
}
],
"source": [
"df_out\n"
"df_out"
]
},
{
Expand Down
Loading

0 comments on commit a48b04f

Please sign in to comment.