Skip to content

Commit

Permalink
docstring updates
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Sep 9, 2024
1 parent c0716a2 commit 74fdcec
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions tsfm_public/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,26 +224,38 @@ def plot_predictions(
User should pass either:
- input_df and predictions_df: context will be extracted from input_df, and predictions will be extracted from predictions_df. Predictions_df is expected to have rows containing lists of predictions.
- input_df and exploded_predictions_df: context will be extracted from input_df, and predictions from exploded_predictions_df will be plotted
- input_df and predictions_df: context will be extracted from input_df, and predictions will be extracted from
predictions_df. Predictions_df is expected to have rows containing lists of predictions.
- input_df and exploded_predictions_df: context will be extracted from input_df, and predictions from
exploded_predictions_df will be plotted
- dset and model: model will be used to produce predictions from records selected from dset
If exploded_predictions_df is passed, indices and num_plots are ignored, the assumption is that there are only one
set of predictions passed for plotting.
Args:
input_df (Optional[pd.DataFrame], optional): The input dataframe from which the predictions are generated, containing timestamp and target columns. Defaults to None.
predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, where each row contains starting timestamp and a list of predictions for each target column. Defaults to None.
exploded_predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, containing timestamp and predicted target columns. Defaults to None.
dset (Optional[Dataset], optional): Torch dataset containing the context data to use as input for the model. Defaults to None.
input_df (Optional[pd.DataFrame], optional): The input dataframe from which the predictions are generated,
containing timestamp and target columns. Defaults to None.
predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, where each row contains starting
timestamp and a list of predictions for each target column. Defaults to None.
exploded_predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, containing timestamp
and predicted target columns. Defaults to None.
dset (Optional[Dataset], optional): Torch dataset containing the context data to use as input for the model.
Defaults to None.
model (Optional[PreTrainedModel], optional): The pre-trained time series model. Defaults to None.
freq (Optional[str], optional): Frequency of the time series data, using Pandas string abbreviations (https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). Defaults to None.
freq (Optional[str], optional): Frequency of the time series data, using Pandas string abbreviations
(https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). Defaults to None.
timestamp_column (Optional[str], optional): Name of timestamp column in the dataframe. Defaults to None.
id_columns (Optional[List[str]], optional): (For future use) List of id columns in the dataframe. Defaults to None.
plot_context (Optional[int], optional): Integer representing the number of time points of historical data to plot. Defaults to None.
id_columns (Optional[List[str]], optional): (For future use) List of id columns in the dataframe. Defaults to
None.
plot_context (Optional[int], optional): Integer representing the number of time points of historical data to
plot. Defaults to None.
plot_dir (str, optional): Directory where plots are saved. Defaults to None.
num_plots (int, optional): Number of subplots to plot in the figure. Defaults to 10.
plot_prefix (str, optional): Prefix to put on the plot file names. Defaults to "valid".
channel (Union[int, str], optional): Channel, i.e., target column or its index, to plot. Defaults to None.
indices (List[int], optional): List of indices to plot. If None, random examples will be chosen. Defaults to None.
indices (List[int], optional): List of indices to plot. If None, random examples will be chosen. Defaults to
None.
"""
if indices is not None:
num_plots = len(indices)
Expand All @@ -263,6 +275,7 @@ def plot_predictions(
plot_context = len(input_df)
using_pipeline = True
plot_test_data = False
indices = [-1] # indices not used in exploded case
elif input_df is not None and predictions_df is not None:
# 2) input_df and predictions plus column information is provided

Expand Down

0 comments on commit 74fdcec

Please sign in to comment.