Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Chronos in TSFM Inference Service #203

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
23f5dc0
chronos experiment
gganapavarapu Nov 18, 2024
53455b1
fixed chronos input dim, and example univariate
gganapavarapu Nov 20, 2024
f37279e
chronos service handler
gganapavarapu Nov 21, 2024
7d83780
Merge remote-tracking branch 'origin/service_abstraction' into servic…
gganapavarapu Nov 21, 2024
b9e5a42
enable workflows on two more branchs
ssiegel95 Nov 21, 2024
c433455
Merge branch 'cve' into services_chronos_support
ssiegel95 Nov 21, 2024
66695df
Merge remote-tracking branch 'origin/main' into services_chronos_support
ssiegel95 Nov 21, 2024
d677abd
remove unnecessary comment
gganapavarapu Nov 22, 2024
8129434
enabling chrono small for now.
gganapavarapu Nov 22, 2024
82af8d2
enable chronos tiny only, and placeholder for _calculate_data_point_c…
gganapavarapu Nov 22, 2024
42f33d4
aiohttp cve
ssiegel95 Nov 25, 2024
23c9be8
Merge branch 'main' into services_chronos_support
ssiegel95 Nov 25, 2024
14f40ba
Merge branch 'service_abstraction_chronos' of https://github.com/ggan…
ssiegel95 Nov 25, 2024
612c564
fix style issue
ssiegel95 Nov 25, 2024
a951a3e
Merge pull request #2 from ssiegel95/main
gganapavarapu Nov 25, 2024
7572e88
Merge branch 'service_abstraction_chronos' into services_chronos_support
gganapavarapu Nov 25, 2024
5eb0637
inference service workflow install chronos for tests
gganapavarapu Nov 25, 2024
e9fcf9d
update poetry.lock
ssiegel95 Nov 25, 2024
3c74ec8
Merge pull request #3 from ssiegel95/services_chronos_support
gganapavarapu Nov 25, 2024
b5207ca
Merge branch 'main' into service_abstraction_chronos
gganapavarapu Nov 25, 2024
34e8b70
fix unpack issue in tests
gganapavarapu Nov 25, 2024
46b38c4
Merge remote-tracking branch 'origin/service_abstraction_chronos' int…
gganapavarapu Nov 25, 2024
6b088ad
fix neg scenarios in ttm tests
gganapavarapu Nov 25, 2024
e691605
chronos in required dependencies
gganapavarapu Nov 25, 2024
59337e0
Merge branch 'new_model_integrations' into service_abstraction_chronos
gganapavarapu Nov 25, 2024
4507b12
chronos from test repo, clean up notebook
gganapavarapu Nov 25, 2024
7e2b591
unnecessary chronos option
gganapavarapu Nov 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 295 additions & 0 deletions notebooks/tutorial/chronos_with_tsfm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "523c80d7-878b-4ad0-9f9e-e7592a6aea33",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from chronos import ChronosPipeline\n",
"\n",
"\n",
"pipeline = ChronosPipeline.from_pretrained(\n",
" \"amazon/chronos-t5-tiny\",\n",
" # device_map=\"cuda\",\n",
" # torch_dtype=torch.bfloat16,\n",
")\n",
"\n",
"df = pd.read_csv(\n",
" \"https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv\"\n",
")\n",
"\n",
"# context must be either a 1D tensor, a list of 1D tensors,\n",
"# or a left-padded 2D tensor with batch as the first dimension\n",
"context = torch.tensor(df[\"#Passengers\"])\n",
"prediction_length = 12\n",
"forecast = pipeline.predict(context, prediction_length) # shape [num_series, num_samples, prediction_length]\n",
"# visualize the forecast\n",
"forecast_index = range(len(df), len(df) + prediction_length)\n",
"low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)\n",
"\n",
"plt.figure(figsize=(8, 4))\n",
"plt.plot(df[\"#Passengers\"], color=\"royalblue\", label=\"historical data\")\n",
"plt.plot(forecast_index, median, color=\"tomato\", label=\"median forecast\")\n",
"plt.fill_between(forecast_index, low, high, color=\"tomato\", alpha=0.3, label=\"80% prediction interval\")\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "188a262d-6f5d-4b99-9dfe-caf42d36e8de",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import torch\n",
"from chronos import ChronosPipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd74b7cf-62c5-42cc-8481-b67d3df926fb",
"metadata": {},
"outputs": [],
"source": [
"DATA_FILE_PATH = \"/energy_data/energy_dataset.csv\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd36be62-6d7a-47e9-b849-fd0542f7a81e",
"metadata": {},
"outputs": [],
"source": [
"timestamp_column = \"time\"\n",
"target_columns = [\"total load actual\"]\n",
"context_length = 512\n",
"prediction_length = 96 # new param\n",
"batch_size = 16 # new param"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4add6f9-a02d-47fb-97e8-ccb55a652a28",
"metadata": {},
"outputs": [],
"source": [
"# Read in the data from the downloaded file.\n",
"input_df = pd.read_csv(\n",
" DATA_FILE_PATH,\n",
" parse_dates=[timestamp_column], # Parse the timestamp values as dates.\n",
")\n",
"\n",
"# Fill NA/NaN values by propagating the last valid value.\n",
"input_df = input_df.ffill()\n",
"\n",
"# Only use the last `context_length` rows for prediction.\n",
"input_df = input_df.iloc[-context_length:,]\n",
"\n",
"# Show the last few rows of the dataset.\n",
"input_df.tail()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c03a01e-b6d7-450d-b5fb-58c4dc286cde",
"metadata": {},
"outputs": [],
"source": [
"input_df.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b9a4095-7924-4b3c-9715-4935d2bebabe",
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(len(target_columns), 1, figsize=(10, 2 * len(target_columns)), squeeze=False)\n",
"for ax, target_column in zip(axs, target_columns):\n",
" ax[0].plot(input_df[timestamp_column], input_df[target_column])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "488bc1be-1de7-4527-8ddd-7f6b162f5f6f",
"metadata": {},
"outputs": [],
"source": [
"zeroshot_model = ChronosPipeline.from_pretrained(\n",
" \"amazon/chronos-t5-tiny\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc7be3e3-e030-405c-a997-b4c098fe27bf",
"metadata": {},
"outputs": [],
"source": [
"context_cols = input_df.columns.tolist()\n",
"context_cols.remove(timestamp_column)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dcea75e4-b4ac-41d9-806c-7659da723a47",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"context = torch.tensor(input_df[context_cols].values).transpose(1, 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "530c69d2-c811-4352-a1eb-9e836229ea98",
"metadata": {},
"outputs": [],
"source": [
"context.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f10bd45-1c8d-4ec7-a510-8bd03d6f1e39",
"metadata": {},
"outputs": [],
"source": [
"prediction_length = 12\n",
"forecast = zeroshot_model.predict(\n",
" context, prediction_length, num_samples=20, temperature=None, top_k=None, top_p=None, limit_prediction_length=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53ae5816-454a-42b3-a122-25e87d80d8eb",
"metadata": {},
"outputs": [],
"source": [
"context.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a190ef4d-61fc-4282-b983-8a66c1f19c27",
"metadata": {},
"outputs": [],
"source": [
"forecast.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa8f0123-742b-4afe-adb7-72ba247cd7f6",
"metadata": {},
"outputs": [],
"source": [
"median_arr = []\n",
"for i in range(len(context_cols)):\n",
" median_arr.append(forecast[i].median())\n",
"median_arr"
]
},
{
"cell_type": "markdown",
"id": "0eac725f-0e1d-46e5-8336-5c757ff65458",
"metadata": {},
"source": [
"#### Univariate"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e99e652d-9fb1-473b-a74e-f08e06337951",
"metadata": {},
"outputs": [],
"source": [
"context = torch.tensor(input_df[target_columns].values).transpose(1, 0)\n",
"prediction_length = 24\n",
"zeroshot_model = ChronosPipeline.from_pretrained(\n",
" \"amazon/chronos-t5-tiny\",\n",
")\n",
"forecast = zeroshot_model.predict(context, prediction_length)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff030b8a-d84e-44ae-b3bf-8daf8699c640",
"metadata": {},
"outputs": [],
"source": [
"context.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0cef489f-ea2d-4168-a7f1-2be10e4a9738",
"metadata": {},
"outputs": [],
"source": [
"forecast.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e830effa-ee4f-4a0b-8a7c-9a71b70af80b",
"metadata": {},
"outputs": [],
"source": [
"low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)\n",
"median"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading