Skip to content

Commit

Permalink
continue refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Aug 2, 2024
1 parent a48b04f commit c71957f
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 44 deletions.
28 changes: 8 additions & 20 deletions notebooks/hfdemo/tinytimemixer/ttm_benchmarking_1024_96.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,15 @@
}
],
"source": [
"# Standard\n",
"import math\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Third Party\n",
"from torch.optim import AdamW\n",
"from torch.optim.lr_scheduler import OneCycleLR\n",
"from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
"\n",
"# Local\n",
"from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction\n",
"from tsfm_public.models.tinytimemixer.utils import (\n",
" count_parameters,\n",
" get_data,\n",
" plot_preds,\n",
")\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.callbacks import TrackingCallback"
"from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset, plot_predictions"
]
},
{
Expand Down Expand Up @@ -187,7 +175,7 @@
" BATCH_SIZE = 64\n",
"\n",
" # Data prep: Get dataset\n",
" _, _, dset_test = get_data(DATASET, context_length, forecast_length, data_root_path=DATA_ROOT_PATH)\n",
" _, _, dset_test = load_dataset(DATASET, context_length, forecast_length, dataset_root_path=DATA_ROOT_PATH)\n",
"\n",
" #############################################################\n",
" ##### Use the pretrained model in zero-shot forecasting #####\n",
Expand All @@ -213,8 +201,8 @@
" all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n",
"\n",
" # Plot\n",
" plot_preds(\n",
" zeroshot_trainer,\n",
" plot_predictions(\n",
" zeroshot_trainer.models,\n",
" dset_test,\n",
" SUBDIR,\n",
" num_plots=10,\n",
Expand All @@ -233,12 +221,12 @@
" for fewshot_percent in [5, 10]:\n",
" print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
" # Data prep: Get dataset\n",
" dset_train, dset_val, dset_test = get_data(\n",
" dset_train, dset_val, dset_test = load_dataset(\n",
" DATASET,\n",
" context_length,\n",
" forecast_length,\n",
" fewshot_fraction=fewshot_percent / 100,\n",
" data_root_path=DATA_ROOT_PATH,\n",
" dataset_root_path=DATA_ROOT_PATH,\n",
" )\n",
"\n",
" # change head dropout to 0.7 for ett datasets\n",
Expand Down Expand Up @@ -327,8 +315,8 @@
" print(\"+\" * 60)\n",
"\n",
" # Plot\n",
" plot_preds(\n",
" finetune_forecast_trainer,\n",
" plot_predictions(\n",
" finetune_forecast_trainer.model,\n",
" dset_test,\n",
" SUBDIR,\n",
" num_plots=10,\n",
Expand Down
Loading

0 comments on commit c71957f

Please sign in to comment.