From c71957f6c46207b2cf44aafd46654e48d7cc3607 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:03:20 -0400 Subject: [PATCH] continue refactor --- .../ttm_benchmarking_1024_96.ipynb | 28 ++-- .../ttm_benchmarking_512_96.ipynb | 121 +++++++++++++++--- tsfm_public/toolkit/data_handling.py | 4 +- 3 files changed, 109 insertions(+), 44 deletions(-) diff --git a/notebooks/hfdemo/tinytimemixer/ttm_benchmarking_1024_96.ipynb b/notebooks/hfdemo/tinytimemixer/ttm_benchmarking_1024_96.ipynb index 69106742..ae2c1d71 100644 --- a/notebooks/hfdemo/tinytimemixer/ttm_benchmarking_1024_96.ipynb +++ b/notebooks/hfdemo/tinytimemixer/ttm_benchmarking_1024_96.ipynb @@ -32,27 +32,15 @@ } ], "source": [ - "# Standard\n", "import math\n", "\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", - "\n", - "# Third Party\n", "from torch.optim import AdamW\n", "from torch.optim.lr_scheduler import OneCycleLR\n", "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n", "\n", - "# Local\n", - "from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction\n", - "from tsfm_public.models.tinytimemixer.utils import (\n", - " count_parameters,\n", - " get_data,\n", - " plot_preds,\n", - ")\n", - "\n", - "# First Party\n", - "from tsfm_public.toolkit.callbacks import TrackingCallback" + "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset, plot_predictions" ] }, { @@ -187,7 +175,7 @@ " BATCH_SIZE = 64\n", "\n", " # Data prep: Get dataset\n", - " _, _, dset_test = get_data(DATASET, context_length, forecast_length, data_root_path=DATA_ROOT_PATH)\n", + " _, _, dset_test = load_dataset(DATASET, context_length, forecast_length, dataset_root_path=DATA_ROOT_PATH)\n", "\n", " #############################################################\n", " ##### Use the pretrained model in zero-shot forecasting #####\n", @@ -213,8 +201,8 @@ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n", "\n", " # Plot\n", - " plot_preds(\n", - " zeroshot_trainer,\n", + " plot_predictions(\n", + " zeroshot_trainer.models,\n", " dset_test,\n", " SUBDIR,\n", " num_plots=10,\n", @@ -233,12 +221,12 @@ " for fewshot_percent in [5, 10]:\n", " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n", " # Data prep: Get dataset\n", - " dset_train, dset_val, dset_test = get_data(\n", + " dset_train, dset_val, dset_test = load_dataset(\n", " DATASET,\n", " context_length,\n", " forecast_length,\n", " fewshot_fraction=fewshot_percent / 100,\n", - " data_root_path=DATA_ROOT_PATH,\n", + " dataset_root_path=DATA_ROOT_PATH,\n", " )\n", "\n", " # change head dropout to 0.7 for ett datasets\n", @@ -327,8 +315,8 @@ " print(\"+\" * 60)\n", "\n", " # Plot\n", - " plot_preds(\n", - " finetune_forecast_trainer,\n", + " plot_predictions(\n", + " finetune_forecast_trainer.model,\n", " dset_test,\n", " SUBDIR,\n", " num_plots=10,\n", diff --git a/notebooks/hfdemo/tinytimemixer/ttm_benchmarking_512_96.ipynb b/notebooks/hfdemo/tinytimemixer/ttm_benchmarking_512_96.ipynb index 0ab91550..f077330a 100644 --- a/notebooks/hfdemo/tinytimemixer/ttm_benchmarking_512_96.ipynb +++ b/notebooks/hfdemo/tinytimemixer/ttm_benchmarking_512_96.ipynb @@ -22,15 +22,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-08 13:35:32.541840: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" - ] - } - ], + "outputs": [], "source": [ "import math\n", "\n", @@ -40,7 +32,7 @@ "from torch.optim.lr_scheduler import OneCycleLR\n", "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n", "\n", - "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, get_data, plot_preds" + "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset, plot_predictions" ] }, { @@ -139,9 +131,96 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "====================================================================================================\n", + "Running zero-shot/few-shot for TTM-512 on dataset = etth1, forecast_len = 96\n", + "Model will be loaded from ibm/TTM/main\n", + "etth1 512 96\n", + "Data lengths: train = 8033, val = 2785, test = 2785\n", + "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d82abfe284a749f79144aa614219fefd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/44 [00:00 160\u001b[0m \u001b[43mfinetune_forecast_trainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;66;03m# Evaluation\u001b[39;00m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[1;32m 164\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m+\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m20\u001b[39m,\n\u001b[1;32m 165\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTest MSE after few-shot \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfewshot_percent\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m% fine-tuning\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 166\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m+\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m20\u001b[39m,\n\u001b[1;32m 167\u001b[0m )\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/site-packages/transformers/trainer.py:1885\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1883\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1884\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1885\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1889\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1890\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/site-packages/transformers/trainer.py:2178\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2175\u001b[0m rng_to_sync \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 2177\u001b[0m step \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 2178\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, inputs \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(epoch_iterator):\n\u001b[1;32m 2179\u001b[0m total_batched_samples \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 2181\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39minclude_num_input_tokens_seen:\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/site-packages/accelerate/data_loader.py:464\u001b[0m, in \u001b[0;36mDataLoaderShard.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 463\u001b[0m current_batch \u001b[38;5;241m=\u001b[39m send_to_device(current_batch, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice, non_blocking\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_non_blocking)\n\u001b[0;32m--> 464\u001b[0m next_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdataloader_iter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 465\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mskip_batches:\n\u001b[1;32m 466\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m current_batch\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/site-packages/torch/utils/data/dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 629\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 635\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1318\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1315\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1316\u001b[0m \u001b[38;5;66;03m# no valid `self._rcvd_idx` is found (i.e., didn't break)\u001b[39;00m\n\u001b[1;32m 1317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_persistent_workers:\n\u001b[0;32m-> 1318\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_shutdown_workers\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1319\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n\u001b[1;32m 1321\u001b[0m \u001b[38;5;66;03m# Now `self._rcvd_idx` is the batch index we want to fetch\u001b[39;00m\n\u001b[1;32m 1322\u001b[0m \n\u001b[1;32m 1323\u001b[0m \u001b[38;5;66;03m# Check if the next sample has already been generated\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1443\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._shutdown_workers\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1438\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mark_worker_as_unavailable(worker_id, shutdown\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 1439\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_workers:\n\u001b[1;32m 1440\u001b[0m \u001b[38;5;66;03m# We should be able to join here, but in case anything went\u001b[39;00m\n\u001b[1;32m 1441\u001b[0m \u001b[38;5;66;03m# wrong, we set a timeout and if the workers fail to join,\u001b[39;00m\n\u001b[1;32m 1442\u001b[0m \u001b[38;5;66;03m# they are killed in the `finally` block.\u001b[39;00m\n\u001b[0;32m-> 1443\u001b[0m \u001b[43mw\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_utils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMP_STATUS_CHECK_INTERVAL\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1444\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m q \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_index_queues:\n\u001b[1;32m 1445\u001b[0m q\u001b[38;5;241m.\u001b[39mcancel_join_thread()\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/multiprocessing/process.py:149\u001b[0m, in \u001b[0;36mBaseProcess.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_parent_pid \u001b[38;5;241m==\u001b[39m os\u001b[38;5;241m.\u001b[39mgetpid(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcan only join a child process\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_popen \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcan only join a started process\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m--> 149\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_popen\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m res \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 151\u001b[0m _children\u001b[38;5;241m.\u001b[39mdiscard(\u001b[38;5;28mself\u001b[39m)\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/multiprocessing/popen_fork.py:40\u001b[0m, in \u001b[0;36mPopen.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmultiprocessing\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconnection\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m wait\n\u001b[0;32m---> 40\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msentinel\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# This shouldn't block if wait() returned successfully.\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/multiprocessing/connection.py:931\u001b[0m, in \u001b[0;36mwait\u001b[0;34m(object_list, timeout)\u001b[0m\n\u001b[1;32m 928\u001b[0m deadline \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic() \u001b[38;5;241m+\u001b[39m timeout\n\u001b[1;32m 930\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 931\u001b[0m ready \u001b[38;5;241m=\u001b[39m \u001b[43mselector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 932\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ready:\n\u001b[1;32m 933\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [key\u001b[38;5;241m.\u001b[39mfileobj \u001b[38;5;28;01mfor\u001b[39;00m (key, events) \u001b[38;5;129;01min\u001b[39;00m ready]\n", + "File \u001b[0;32m~/miniconda3/envs/tsfm_public/lib/python3.10/selectors.py:416\u001b[0m, in \u001b[0;36m_PollLikeSelector.select\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 414\u001b[0m ready \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 416\u001b[0m fd_event_list \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_selector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpoll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 417\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mInterruptedError\u001b[39;00m:\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ready\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "all_results = {\n", " \"dataset\": [],\n", @@ -175,7 +254,7 @@ " BATCH_SIZE = 64\n", "\n", " # Data prep: Get dataset\n", - " _, _, dset_test = get_data(DATASET, context_length, forecast_length, data_root_path=DATA_ROOT_PATH)\n", + " _, _, dset_test = load_dataset(DATASET, context_length, forecast_length, dataset_root_path=DATA_ROOT_PATH)\n", "\n", " #############################################################\n", " ##### Use the pretrained model in zero-shot forecasting #####\n", @@ -201,8 +280,8 @@ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n", "\n", " # Plot\n", - " plot_preds(\n", - " zeroshot_trainer,\n", + " plot_predictions(\n", + " zeroshot_trainer.model,\n", " dset_test,\n", " SUBDIR,\n", " num_plots=10,\n", @@ -221,12 +300,12 @@ " for fewshot_percent in [5, 10]:\n", " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n", " # Data prep: Get dataset\n", - " dset_train, dset_val, dset_test = get_data(\n", + " dset_train, dset_val, dset_test = load_dataset(\n", " DATASET,\n", " context_length,\n", " forecast_length,\n", " fewshot_fraction=fewshot_percent / 100,\n", - " data_root_path=DATA_ROOT_PATH,\n", + " dataset_root_path=DATA_ROOT_PATH,\n", " )\n", "\n", " # change head dropout to 0.7 for ett datasets\n", @@ -315,8 +394,8 @@ " print(\"+\" * 60)\n", "\n", " # Plot\n", - " plot_preds(\n", - " finetune_forecast_trainer,\n", + " plot_predictions(\n", + " finetune_forecast_trainer.model,\n", " dset_test,\n", " SUBDIR,\n", " num_plots=10,\n", @@ -349,7 +428,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -544,7 +623,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/tsfm_public/toolkit/data_handling.py b/tsfm_public/toolkit/data_handling.py index a991932d..7bd984d2 100644 --- a/tsfm_public/toolkit/data_handling.py +++ b/tsfm_public/toolkit/data_handling.py @@ -9,9 +9,7 @@ import pandas as pd import yaml -from tsfm_public import get_datasets - -from .time_series_preprocessor import TimeSeriesPreprocessor +from .time_series_preprocessor import TimeSeriesPreprocessor, get_datasets def load_dataset(