Skip to content

Commit

Permalink
Fix format issues
Browse files Browse the repository at this point in the history
  • Loading branch information
SamGos93 committed Oct 28, 2024
1 parent c5ff12b commit 0ee9809
Showing 1 changed file with 20 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
"n_test_periods = forecast_horizon\n",
"\n",
"from helper import get_timeseries\n",
"\n",
"X_train, y_train, X_test, y_test = get_timeseries(\n",
" train_len=n_train_periods,\n",
" test_len=n_test_periods,\n",
Expand Down Expand Up @@ -167,6 +168,7 @@
"source": [
"# Plot the example time series\n",
"import matplotlib.pyplot as plt\n",
"\n",
"whole_data = X_train.copy()\n",
"target_label = \"y\"\n",
"whole_data[target_label] = y_train\n",
Expand Down Expand Up @@ -204,8 +206,8 @@
"outputs": [],
"source": [
"# For vizualisation of the time series\n",
"df_train['data_type'] = 'Training' # Add a column to label training data\n",
"df_test['data_type'] = 'Testing' # Add a column to label testing data\n",
"df_train[\"data_type\"] = \"Training\" # Add a column to label training data\n",
"df_test[\"data_type\"] = \"Testing\" # Add a column to label testing data\n",
"\n",
"# Concatenate the training and testing DataFrames\n",
"df_plot = pd.concat([df_train, df_test])\n",
Expand All @@ -215,13 +217,19 @@
"ax = plt.gca() # Get current axis\n",
"\n",
"# Group by both 'data_type' and 'time_series_id'\n",
"for (data_type, time_series_id), df in df_plot.groupby(['data_type', 'time_series_id']):\n",
" df.plot(x='date', y=TARGET_COLUMN_NAME, label=f\"{data_type} - {time_series_id}\", ax=ax, legend=False)\n",
"for (data_type, time_series_id), df in df_plot.groupby([\"data_type\", \"time_series_id\"]):\n",
" df.plot(\n",
" x=\"date\",\n",
" y=TARGET_COLUMN_NAME,\n",
" label=f\"{data_type} - {time_series_id}\",\n",
" ax=ax,\n",
" legend=False,\n",
" )\n",
"\n",
"# Customize the plot\n",
"plt.xlabel('Date')\n",
"plt.ylabel('Value')\n",
"plt.title('Train and Test Data')\n",
"plt.xlabel(\"Date\")\n",
"plt.ylabel(\"Value\")\n",
"plt.title(\"Train and Test Data\")\n",
"\n",
"# Manually create the legend after plotting\n",
"plt.legend(title=\"Data Type and Time Series ID\")\n",
Expand All @@ -239,6 +247,7 @@
"import mltable\n",
"import os\n",
"\n",
"\n",
"def create_ml_table(data_frame, file_name, output_folder):\n",
" os.makedirs(output_folder, exist_ok=True)\n",
" data_path = os.path.join(output_folder, file_name)\n",
Expand Down Expand Up @@ -270,10 +279,10 @@
"\n",
"my_training_data_input.__dict__\n",
"\n",
"#Test data\n",
"# Test data\n",
"os.makedirs(\"data\", exist_ok=True)\n",
"create_ml_table(\n",
" X_test, #df_test,\n",
" X_test, # df_test,\n",
" \"X_test.parquet\",\n",
" \"./data/testing-mltable-folder\",\n",
")\n",
Expand Down Expand Up @@ -429,6 +438,7 @@
"outputs": [],
"source": [
"import mlflow\n",
"\n",
"MLFLOW_TRACKING_URI = ml_client.workspaces.get(\n",
" name=ml_client.workspace_name\n",
").mlflow_tracking_uri\n",
Expand Down Expand Up @@ -540,6 +550,7 @@
"source": [
"# Create local folder\n",
"import os\n",
"\n",
"local_dir = \"./artifact_downloads\"\n",
"if not os.path.exists(local_dir):\n",
" os.mkdir(local_dir)"
Expand Down

0 comments on commit 0ee9809

Please sign in to comment.