Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tfp.sts.forcast only applies to future data. Is it possible to "predict" training data? #1846

Open
HuangKY opened this issue Oct 8, 2024 · 1 comment

Comments

@HuangKY
Copy link

HuangKY commented Oct 8, 2024

I am using the Structural Time Series of TensorFlow Probability.
I separated my data(1 year) into training data (10 months) and test data (2 months).

After building the model and forecasting, I wondered whether it is possible to "predict" the training data so that I can compare the performance of the model in the training data interval.
Since the arguments of "tfp.sts.forcast" can only specify "num_steps_forecast" which does not include that option, I would ask here.

@jeffpollock9
Copy link
Contributor

@HuangKY if you run forward_filter the resulting structure has (predicted/forecasted) observation means and covariances, is that what you are after? For example:

import tensorflow as tf
import tensorflow_probability as tfp

print(tf.__version__)
# 2.18.0

print(tfp.__version__)
# 0.25.0

sts = tfp.sts
tfd = tfp.distributions

num_timesteps = 10
param_vals = [1.0, 2.0]
observations = tf.ones([num_timesteps, 1])

local_level = sts.LocalLevel()

model = sts.Sum(components=[local_level])

ssm = model.make_state_space_model(num_timesteps=num_timesteps, param_vals=param_vals)

results = ssm.forward_filter(observations)

tf.concat(
    [
        observations,
        results.observation_means,
        results.observation_covs[..., 0],
    ],
    axis=-1,
)
# <tf.Tensor: shape=(10, 3), dtype=float32, numpy=
# array([[1.        , 0.        , 1.9999999 ],
#        [1.        , 0.5       , 5.5       ],
#        [1.        , 0.9090909 , 5.818182  ],
#        [1.        , 0.984375  , 5.828125  ],
#        [1.        , 0.99731904, 5.8284183 ],
#        [1.        , 0.99954003, 5.8284264 ],
#        [1.        , 0.9999211 , 5.828428  ],
#        [1.        , 0.99998647, 5.828428  ],
#        [1.        , 0.9999977 , 5.828428  ],
#        [1.        , 0.9999996 , 5.828428  ]], dtype=float32)>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants