From 2ac5b9325ed3b54950c6c61fd5838ac6e55a9fe1 Mon Sep 17 00:00:00 2001 From: Gift Sinthong Date: Mon, 13 Nov 2023 10:06:32 -0800 Subject: [PATCH] [time series] Add PatchTST (#25927) * Initial commit of PatchTST model classes Co-authored-by: Phanwadee Sinthong Co-authored-by: Nam Nguyen Co-authored-by: Vijay Ekambaram Co-authored-by: Ngoc Diep Do <55230119+diepi@users.noreply.github.com> Co-authored-by: Wesley Gifford <79663411+wgifford@users.noreply.github.com> * Add PatchTSTForPretraining * update to include classification Co-authored-by: Phanwadee Sinthong Co-authored-by: Nam Nguyen Co-authored-by: Vijay Ekambaram Co-authored-by: Ngoc Diep Do <55230119+diepi@users.noreply.github.com> Co-authored-by: Wesley Gifford <79663411+wgifford@users.noreply.github.com> * clean up auto files * Add PatchTSTForPrediction * Fix relative import * Replace original PatchTSTEncoder with ChannelAttentionPatchTSTEncoder * temporary adding absolute path + add PatchTSTForForecasting class * Update base PatchTSTModel + Unittest * Update ForecastHead to use the config class * edit cv_random_masking, add mask to model output * Update configuration_patchtst.py * add masked_loss to the pretraining * add PatchEmbeddings * Update configuration_patchtst.py * edit loss which considers mask in the pretraining * remove patch_last option * Add commits from internal repo * Update ForecastHead * Add model weight initilization + unittest * Update PatchTST unittest to use local import * PatchTST integration tests for pretraining and prediction * Added PatchTSTForRegression + update unittest to include label generation * Revert unrelated model test file * Combine similar output classes * update PredictionHead * Update configuration_patchtst.py * Add Revin * small edit to PatchTSTModelOutputWithNoAttention * Update modeling_patchtst.py * Updating integration test for forecasting * Fix unittest after class structure changed * docstring updates * change input_size to num_input_channels * more formatting * Remove some unused params * Add a comment for pretrained models * add channel_attention option add channel_attention option and remove unused positional encoders. * Update PatchTST models to use HF's MultiHeadAttention module * Update paper + github urls * Fix hidden_state return value * Update integration test to use PatchTSTForForecasting * Adding dataclass decorator for model output classes * Run fixup script * Rename model repos for integration test * edit argument explanation * change individual option to shared_projection * style * Rename integration test + import cleanup * Fix outpu_hidden_states return value * removed unused mode * added std, mean and nops scaler * add initial distributional loss for predition * fix typo in docs * add generate function * formatting * add num_parallel_samples * Fix a typo * copy weighted_average function, edit PredictionHead * edit PredictionHead * add distribution head to forecasting * formatting * Add generate function for forecasting * Add generate function to prediction task * formatting * use argsort * add past_observed_mask ordering * fix arguments * docs * add back test_model_outputs_equivalence test * formatting * cleanup * formatting * use ACT2CLS * formatting * fix add_start_docstrings decorator * add distribution head and generate function to regression task add distribution head and generate function to regression task. Also made add PatchTSTForForecastingOutput, PatchTSTForRegressionOutput. * add distribution head and generate function to regression task add distribution head and generate function to regression task. Also made add PatchTSTForForecastingOutput, PatchTSTForRegressionOutput. * fix typos * add forecast_masking * fixed tests * use set_seed * fix doc test * formatting * Update docs/source/en/model_doc/patchtst.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * better var names * rename PatchTSTTranspose * fix argument names and docs string * remove compute_num_patches and unused class * remove assert * renamed to PatchTSTMasking * use num_labels for classification * use num_labels * use default num_labels from super class * move model_type after docstring * renamed PatchTSTForMaskPretraining * bs -> batch_size * more review fixes * use hidden_state * rename encoder layer and block class * remove commented seed_number * edit docstring * Add docstring * formatting * use past_observed_mask * doc suggestion * make fix-copies * use Args: * add docstring * add docstring * change some variable names and add PatchTST before some class names * formatting * fix argument types * fix tests * change x variable to patch_input * format * formatting * fix-copies * Update tests/models/patchtst/test_modeling_patchtst.py Co-authored-by: Patrick von Platen * move loss to forward * Update src/transformers/models/patchtst/modeling_patchtst.py Co-authored-by: Patrick von Platen * Update src/transformers/models/patchtst/modeling_patchtst.py Co-authored-by: Patrick von Platen * Update src/transformers/models/patchtst/modeling_patchtst.py Co-authored-by: Patrick von Platen * Update src/transformers/models/patchtst/modeling_patchtst.py Co-authored-by: Patrick von Platen * Update src/transformers/models/patchtst/modeling_patchtst.py Co-authored-by: Patrick von Platen * formatting * fix a bug when pre_norm is set to True * output_hidden_states is set to False as default * set pre_norm=True as default * format docstring * format * output_hidden_states is None by default * add missing docs * better var names * docstring: remove default to False in output_hidden_states * change labels name to target_values in regression task * format * fix tests * change to forecast_mask_ratios and random_mask_ratio * change mask names * change future_values to target_values param in the prediction class * remove nn.Sequential and make PatchTSTBatchNorm class * black * fix argument name for prediction * add output_attentions option * add output_attentions to PatchTSTEncoder * formatting * Add attention output option to all classes * Remove PatchTSTEncoderBlock * create PatchTSTEmbedding class * use config in PatchTSTPatchify * Use config in PatchTSTMasking class * add channel_attn_weights * Add PatchTSTScaler class * add output_attentions arg to test function * format * Update doc with image patchtst.md * fix-copies * rename Forecast <-> Prediction * change name of a few parameters to match with PatchTSMixer. * Remove *ForForecasting class to match with other time series models. * make style * Remove PatchTSTForForecasting in the test * remove PatchTSTForForecastingOutput class * change test_forecast_head to test_prediction_head * style * fix docs * fix tests * change num_labels to num_targets * Remove PatchTSTTranspose * remove arguments in PatchTSTMeanScaler * remove arguments in PatchTSTStdScaler * add config as an argument to all the scaler classes * reformat * Add norm_eps for batchnorm and layernorm * reformat. * reformat * edit docstring * update docstring * change variable name pooling to pooling_type * fix output_hidden_states as tuple * fix bug when calling PatchTSTBatchNorm * change stride to patch_stride * create PatchTSTPositionalEncoding class and restructure the PatchTSTEncoder * formatting * initialize scalers with configs * edit output_hidden_states * style * fix forecast_mask_patches doc string --------- Co-authored-by: Gift Sinthong Co-authored-by: Nam Nguyen Co-authored-by: Vijay Ekambaram Co-authored-by: Ngoc Diep Do <55230119+diepi@users.noreply.github.com> Co-authored-by: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Co-authored-by: Wesley M. Gifford Co-authored-by: nnguyen Co-authored-by: Ngoc Diep Do Co-authored-by: Kashif Rasul Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Patrick von Platen --- README.md | 1 + README_es.md | 1 + README_hd.md | 1 + README_ja.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/patchtst.md | 73 + src/transformers/__init__.py | 26 + src/transformers/models/__init__.py | 1 + src/transformers/models/auto/__init__.py | 4 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 21 + .../models/autoformer/modeling_autoformer.py | 118 +- .../models/informer/modeling_informer.py | 118 +- src/transformers/models/patchtst/__init__.py | 66 + .../models/patchtst/configuration_patchtst.py | 274 +++ .../models/patchtst/modeling_patchtst.py | 1913 +++++++++++++++++ .../modeling_time_series_transformer.py | 112 +- src/transformers/utils/dummy_pt_objects.py | 51 + tests/models/patchtst/__init__.py | 0 .../models/patchtst/test_modeling_patchtst.py | 353 +++ utils/check_repo.py | 2 + 25 files changed, 2974 insertions(+), 171 deletions(-) create mode 100644 docs/source/en/model_doc/patchtst.md create mode 100644 src/transformers/models/patchtst/__init__.py create mode 100644 src/transformers/models/patchtst/configuration_patchtst.py create mode 100755 src/transformers/models/patchtst/modeling_patchtst.py create mode 100644 tests/models/patchtst/__init__.py create mode 100644 tests/models/patchtst/test_modeling_patchtst.py diff --git a/README.md b/README.md index 12724e60a1881d..5096adcaef144e 100644 --- a/README.md +++ b/README.md @@ -439,6 +439,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby. +1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/abs/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, and Peter J. Liu. 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. diff --git a/README_es.md b/README_es.md index 5cdbc27ec7918d..0a3db02aedd74a 100644 --- a/README_es.md +++ b/README_es.md @@ -414,6 +414,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby. +1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, and Peter J. Liu. 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. diff --git a/README_hd.md b/README_hd.md index 01937532f967c4..36bd25cda1922e 100644 --- a/README_hd.md +++ b/README_hd.md @@ -388,6 +388,7 @@ conda install -c huggingface transformers 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI से) साथ में कागज [विज़न ट्रांसफॉर्मर्स के साथ सिंपल ओपन-वोकैबुलरी ऑब्जेक्ट डिटेक्शन](https:/ /arxiv.org/abs/2205.06230) मैथियास मिंडरर, एलेक्सी ग्रिट्सेंको, ऑस्टिन स्टोन, मैक्सिम न्यूमैन, डिर्क वीसेनबोर्न, एलेक्सी डोसोवित्स्की, अरविंद महेंद्रन, अनुराग अर्नब, मुस्तफा देहघानी, ज़ुओरन शेन, जिओ वांग, ज़ियाओहुआ झाई, थॉमस किफ़, और नील हॉल्सबी द्वारा पोस्ट किया गया। 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI से) Matthias Minderer, Alexey Gritsenko, Neil Houlsby. द्वाराअनुसंधान पत्र [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) के साथ जारी किया गया +1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM से) Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. द्वाराअनुसंधान पत्र [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) के साथ जारी किया गया 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google की ओर से) साथ में दिया गया पेपर [लंबे इनपुट सारांश के लिए ट्रांसफ़ॉर्मरों को बेहतर तरीके से एक्सटेंड करना](https://arxiv .org/abs/2208.04347) जेसन फांग, याओ झाओ, पीटर जे लियू द्वारा। 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (दीपमाइंड से) साथ में पेपर [पर्सीवर आईओ: संरचित इनपुट और आउटपुट के लिए एक सामान्य वास्तुकला] (https://arxiv.org/abs/2107.14795) एंड्रयू जेगल, सेबेस्टियन बोरग्यूड, जीन-बैप्टिस्ट अलायराक, कार्ल डोर्श, कैटलिन इओनेस्कु, डेविड द्वारा डिंग, स्कंद कोप्पुला, डैनियल ज़ोरान, एंड्रयू ब्रॉक, इवान शेलहैमर, ओलिवियर हेनाफ, मैथ्यू एम। बोट्विनिक, एंड्रयू ज़िसरमैन, ओरिओल विनियल्स, जोआओ कैरेरा द्वारा पोस्ट किया गया। diff --git a/README_ja.md b/README_ja.md index 5935da396bf165..06a6fdd5e7dc4c 100644 --- a/README_ja.md +++ b/README_ja.md @@ -448,6 +448,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (Meta AI から) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al から公開された研究論文: [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI から) Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby から公開された研究論文: [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI から) Matthias Minderer, Alexey Gritsenko, Neil Houlsby. から公開された研究論文 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) +1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM から) Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. から公開された研究論文 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (Google から) Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu から公開された研究論文: [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google から) Jason Phang, Yao Zhao, and Peter J. Liu から公開された研究論文: [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (Deepmind から) Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira から公開された研究論文: [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) diff --git a/README_ko.md b/README_ko.md index e0c38472cc4606..db06296a72962d 100644 --- a/README_ko.md +++ b/README_ko.md @@ -363,6 +363,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (Meta AI 에서) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 의 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 논문과 함께 발표했습니다. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI 에서) Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 의 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 논문과 함께 발표했습니다. 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI 에서 제공)은 Matthias Minderer, Alexey Gritsenko, Neil Houlsby.의 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683)논문과 함께 발표했습니다. +1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM 에서 제공)은 Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.의 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf)논문과 함께 발표했습니다. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (Google 에서) Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 의 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 논문과 함께 발표했습니다. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google 에서) Jason Phang, Yao Zhao, Peter J. Liu 의 [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) 논문과 함께 발표했습니다. 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (Deepmind 에서) Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira 의 [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 3d84374d5561d5..5dd9f9b35a14de 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -387,6 +387,7 @@ conda install -c huggingface transformers 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (来自 Meta AI) 伴随论文 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 由 Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 发布。 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (来自 Google AI) 伴随论文 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 由 Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 发布。 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (来自 Google AI) 伴随论文 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) 由 Matthias Minderer, Alexey Gritsenko, Neil Houlsby 发布。 +1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (来自 IBM) 伴随论文 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) 由 Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam 发布。 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (来自 Google) 伴随论文 [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) 由 Jason Phang, Yao Zhao, Peter J. Liu 发布。 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (来自 Deepmind) 伴随论文 [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) 由 Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index c095423cce15dd..f155fafe91f1d0 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -399,6 +399,7 @@ conda install -c huggingface transformers 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby. +1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, Peter J. Liu. 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4e0ce88c10af31..612f21ab38d341 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -747,6 +747,8 @@ title: Autoformer - local: model_doc/informer title: Informer + - local: model_doc/patchtst + title: PatchTST - local: model_doc/time_series_transformer title: Time Series Transformer title: Time series models diff --git a/docs/source/en/index.md b/docs/source/en/index.md index ae01569e970ce3..d962338becf884 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -213,6 +213,7 @@ Flax), PyTorch, and/or TensorFlow. | [OPT](model_doc/opt) | ✅ | ✅ | ✅ | | [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ | | [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ | +| [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ | | [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ | | [PEGASUS-X](model_doc/pegasus_x) | ✅ | ❌ | ❌ | | [Perceiver](model_doc/perceiver) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/patchtst.md b/docs/source/en/model_doc/patchtst.md new file mode 100644 index 00000000000000..c18abeb20e64ef --- /dev/null +++ b/docs/source/en/model_doc/patchtst.md @@ -0,0 +1,73 @@ + + +# PatchTST + +## Overview + +The PatchTST model was proposed in [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/abs/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. + +The abstract from the paper is the following: + +*We propose an efficient design of Transformer-based models for multivariate time series forecasting and self-supervised representation learning. It is based on two key components: (i) segmentation of time series into subseries-level patches which are served as input tokens to Transformer; (ii) channel-independence where each channel contains a single univariate time series that shares the same embedding and Transformer weights across all the series. Patching design naturally has three-fold benefit: local semantic information is retained in the embedding; computation and memory usage of the attention maps are quadratically reduced given the same look-back window; and the model can attend longer history. Our channel-independent patch time series Transformer (PatchTST) can improve the long-term forecasting accuracy significantly when compared with that of SOTA Transformer-based models. We also apply our model to self-supervised pre-training tasks and attain excellent fine-tuning performance, which outperforms supervised training on large datasets. Transferring of masked pre-trained representation on one dataset to others also produces SOTA forecasting accuracy.* + +Tips: + +The model can also be used for time series classification and time series regression. See the respective [`PatchTSTForClassification`] and [`PatchTSTForRegression`] classes. + +At a high level the model vectorizes time series into patches of a given size and encodes them via a Transformer which then outputs the prediction length forecasts: + +![model](https://github.com/namctin/transformers/assets/8100/150af169-29de-419a-8d98-eb78251c21fa) + + +This model was contributed by [namctin](https://huggingface.co/namctin), [gsinthong](https://huggingface.co/gsinthong), [diepi](https://huggingface.co/diepi), [vijaye12](https://huggingface.co/vijaye12), [wmgifford](https://huggingface.co/wmgifford), and [kashif](https://huggingface.co/kashif). + +The original code can be found [here](https://github.com/yuqinie98/PatchTST). + + +## PatchTSTConfig + +[[autodoc]] PatchTSTConfig + + +## PatchTSTModel + +[[autodoc]] PatchTSTModel + - forward + + +## PatchTSTForPrediction + +[[autodoc]] PatchTSTForPrediction + - forward + + +## PatchTSTForClassification + +[[autodoc]] PatchTSTForClassification + - forward + + +## PatchTSTForPretraining + +[[autodoc]] PatchTSTForPretraining + - forward + + +## PatchTSTForRegression + +[[autodoc]] PatchTSTForRegression + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cf89602b6597a3..9cbb988c53475c 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -493,6 +493,7 @@ "OwlViTTextConfig", "OwlViTVisionConfig", ], + "models.patchtst": ["PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP", "PatchTSTConfig"], "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig"], "models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"], @@ -1167,6 +1168,8 @@ "MODEL_FOR_TEXT_ENCODING_MAPPING", "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING", "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING", + "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING", + "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", @@ -2485,6 +2488,17 @@ "OwlViTVisionModel", ] ) + _import_structure["models.patchtst"].extend( + [ + "PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST", + "PatchTSTForClassification", + "PatchTSTForPrediction", + "PatchTSTForPretraining", + "PatchTSTForRegression", + "PatchTSTModel", + "PatchTSTPreTrainedModel", + ] + ) _import_structure["models.pegasus"].extend( ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"] ) @@ -4697,6 +4711,7 @@ OwlViTTextConfig, OwlViTVisionConfig, ) + from .models.patchtst import PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP, PatchTSTConfig from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer from .models.pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer @@ -5303,6 +5318,8 @@ MODEL_FOR_TEXT_ENCODING_MAPPING, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING, + MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING, + MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, @@ -6387,6 +6404,15 @@ OwlViTTextModel, OwlViTVisionModel, ) + from .models.patchtst import ( + PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST, + PatchTSTForClassification, + PatchTSTForPrediction, + PatchTSTForPretraining, + PatchTSTForRegression, + PatchTSTModel, + PatchTSTPreTrainedModel, + ) from .models.pegasus import ( PegasusForCausalLM, PegasusForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 6132512688e6b1..968704c0bf8640 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -158,6 +158,7 @@ opt, owlv2, owlvit, + patchtst, pegasus, pegasus_x, perceiver, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index dc01c93406b791..153f7f10def694 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -77,6 +77,8 @@ "MODEL_WITH_LM_HEAD_MAPPING", "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", + "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING", + "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING", "AutoModel", "AutoBackbone", "AutoModelForAudioClassification", @@ -250,6 +252,8 @@ MODEL_FOR_TEXT_ENCODING_MAPPING, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING, + MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING, + MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c1c2387373b8bd..900f1da799d971 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -164,6 +164,7 @@ ("opt", "OPTConfig"), ("owlv2", "Owlv2Config"), ("owlvit", "OwlViTConfig"), + ("patchtst", "PatchTSTConfig"), ("pegasus", "PegasusConfig"), ("pegasus_x", "PegasusXConfig"), ("perceiver", "PerceiverConfig"), @@ -376,6 +377,7 @@ ("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("owlv2", "OWLV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("owlvit", "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("patchtst", "PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("pegasus_x", "PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -607,6 +609,7 @@ ("opt", "OPT"), ("owlv2", "OWLv2"), ("owlvit", "OWL-ViT"), + ("patchtst", "PatchTST"), ("pegasus", "Pegasus"), ("pegasus_x", "PEGASUS-X"), ("perceiver", "Perceiver"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ffcae9a234942c..437aed60143c9b 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -157,6 +157,7 @@ ("opt", "OPTModel"), ("owlv2", "Owlv2Model"), ("owlvit", "OwlViTModel"), + ("patchtst", "PatchTSTModel"), ("pegasus", "PegasusModel"), ("pegasus_x", "PegasusXModel"), ("perceiver", "PerceiverModel"), @@ -1130,6 +1131,18 @@ ] ) +MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("patchtst", "PatchTSTForClassification"), + ] +) + +MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict( + [ + ("patchtst", "PatchTSTForRegression"), + ] +) + MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict( [ ("swin2sr", "Swin2SRForImageSuperResolution"), @@ -1221,6 +1234,14 @@ MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) +MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES +) + +MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES +) + MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 92e9df2c7e5b1b..8f26274b44bcdb 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -208,71 +208,70 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: ) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeries->Autoformer +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->Autoformer,TimeSeries->Autoformer class AutoformerStdScaler(nn.Module): """ - Standardize features by calculating the mean and scaling along some given dimension `dim`, and then normalizes it - by subtracting from the mean and dividing by the standard deviation. - - Args: - dim (`int`): - Dimension along which to calculate the mean and standard deviation. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. - minimum_scale (`float`, *optional*, defaults to 1e-5): - Default scale that is used for elements that are constantly zero along dimension `dim`. + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. """ - def __init__(self, dim: int, keepdim: bool = False, minimum_scale: float = 1e-5): + def __init__(self, config: AutoformerConfig): super().__init__() - if not dim > 0: - raise ValueError("Cannot compute scale along dim = 0 (batch dimension), please provide dim > 0") - self.dim = dim - self.keepdim = keepdim - self.minimum_scale = minimum_scale + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 - @torch.no_grad() - def forward(self, data: torch.Tensor, weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - denominator = weights.sum(self.dim, keepdim=self.keepdim) + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) denominator = denominator.clamp_min(1.0) - loc = (data * weights).sum(self.dim, keepdim=self.keepdim) / denominator + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator - variance = (((data - loc) * weights) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator scale = torch.sqrt(variance + self.minimum_scale) return (data - loc) / scale, loc, scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeries->Autoformer +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->Autoformer,TimeSeries->Autoformer class AutoformerMeanScaler(nn.Module): """ - Computes a scaling factor as the weighted average absolute value along dimension `dim`, and scales the data + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data accordingly. - - Args: - dim (`int`): - Dimension along which to compute the scale. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. - default_scale (`float`, *optional*, defaults to `None`): - Default scale that is used for elements that are constantly zero. If `None`, we use the scale of the batch. - minimum_scale (`float`, *optional*, defaults to 1e-10): - Default minimum possible scale that is used for any item. """ - def __init__( - self, dim: int = -1, keepdim: bool = True, default_scale: Optional[float] = None, minimum_scale: float = 1e-10 - ): + def __init__(self, config: AutoformerConfig): super().__init__() - self.dim = dim - self.keepdim = keepdim - self.minimum_scale = minimum_scale - self.default_scale = default_scale + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None - @torch.no_grad() def forward( self, data: torch.Tensor, observed_indicator: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # shape: (N, [C], T=1) + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) num_observed = observed_indicator.sum(self.dim, keepdim=True) @@ -300,26 +299,29 @@ def forward( return scaled_data, torch.zeros_like(scale), scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeries->Autoformer +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->Autoformer,TimeSeries->Autoformer class AutoformerNOPScaler(nn.Module): """ - Assigns a scaling factor equal to 1 along dimension `dim`, and therefore applies no scaling to the input data. - - Args: - dim (`int`): - Dimension along which to compute the scale. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. """ - def __init__(self, dim: int, keepdim: bool = False): + def __init__(self, config: AutoformerConfig): super().__init__() - self.dim = dim - self.keepdim = keepdim + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True def forward( - self, data: torch.Tensor, observed_indicator: torch.Tensor + self, data: torch.Tensor, observed_indicator: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) return data, loc, scale @@ -1433,11 +1435,11 @@ def __init__(self, config: AutoformerConfig): super().__init__(config) if config.scaling == "mean" or config.scaling is True: - self.scaler = AutoformerMeanScaler(dim=1, keepdim=True) + self.scaler = AutoformerMeanScaler(config) elif config.scaling == "std": - self.scaler = AutoformerStdScaler(dim=1, keepdim=True) + self.scaler = AutoformerStdScaler(config) else: - self.scaler = AutoformerNOPScaler(dim=1, keepdim=True) + self.scaler = AutoformerNOPScaler(config) if config.num_static_categorical_features > 0: self.embedder = AutoformerFeatureEmbedder( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index c0a5a205950285..205c8ba22f743e 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -81,71 +81,70 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: ) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeries->Informer +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerStdScaler(nn.Module): """ - Standardize features by calculating the mean and scaling along some given dimension `dim`, and then normalizes it - by subtracting from the mean and dividing by the standard deviation. - - Args: - dim (`int`): - Dimension along which to calculate the mean and standard deviation. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. - minimum_scale (`float`, *optional*, defaults to 1e-5): - Default scale that is used for elements that are constantly zero along dimension `dim`. + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. """ - def __init__(self, dim: int, keepdim: bool = False, minimum_scale: float = 1e-5): + def __init__(self, config: InformerConfig): super().__init__() - if not dim > 0: - raise ValueError("Cannot compute scale along dim = 0 (batch dimension), please provide dim > 0") - self.dim = dim - self.keepdim = keepdim - self.minimum_scale = minimum_scale + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 - @torch.no_grad() - def forward(self, data: torch.Tensor, weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - denominator = weights.sum(self.dim, keepdim=self.keepdim) + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) denominator = denominator.clamp_min(1.0) - loc = (data * weights).sum(self.dim, keepdim=self.keepdim) / denominator + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator - variance = (((data - loc) * weights) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator scale = torch.sqrt(variance + self.minimum_scale) return (data - loc) / scale, loc, scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeries->Informer +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerMeanScaler(nn.Module): """ - Computes a scaling factor as the weighted average absolute value along dimension `dim`, and scales the data + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data accordingly. - - Args: - dim (`int`): - Dimension along which to compute the scale. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. - default_scale (`float`, *optional*, defaults to `None`): - Default scale that is used for elements that are constantly zero. If `None`, we use the scale of the batch. - minimum_scale (`float`, *optional*, defaults to 1e-10): - Default minimum possible scale that is used for any item. """ - def __init__( - self, dim: int = -1, keepdim: bool = True, default_scale: Optional[float] = None, minimum_scale: float = 1e-10 - ): + def __init__(self, config: InformerConfig): super().__init__() - self.dim = dim - self.keepdim = keepdim - self.minimum_scale = minimum_scale - self.default_scale = default_scale + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None - @torch.no_grad() def forward( self, data: torch.Tensor, observed_indicator: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # shape: (N, [C], T=1) + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) num_observed = observed_indicator.sum(self.dim, keepdim=True) @@ -173,26 +172,29 @@ def forward( return scaled_data, torch.zeros_like(scale), scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeries->Informer +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerNOPScaler(nn.Module): """ - Assigns a scaling factor equal to 1 along dimension `dim`, and therefore applies no scaling to the input data. - - Args: - dim (`int`): - Dimension along which to compute the scale. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. """ - def __init__(self, dim: int, keepdim: bool = False): + def __init__(self, config: InformerConfig): super().__init__() - self.dim = dim - self.keepdim = keepdim + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True def forward( - self, data: torch.Tensor, observed_indicator: torch.Tensor + self, data: torch.Tensor, observed_indicator: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) return data, loc, scale @@ -1446,11 +1448,11 @@ def __init__(self, config: InformerConfig): super().__init__(config) if config.scaling == "mean" or config.scaling is True: - self.scaler = InformerMeanScaler(dim=1, keepdim=True) + self.scaler = InformerMeanScaler(config) elif config.scaling == "std": - self.scaler = InformerStdScaler(dim=1, keepdim=True) + self.scaler = InformerStdScaler(config) else: - self.scaler = InformerNOPScaler(dim=1, keepdim=True) + self.scaler = InformerNOPScaler(config) if config.num_static_categorical_features > 0: self.embedder = InformerFeatureEmbedder( diff --git a/src/transformers/models/patchtst/__init__.py b/src/transformers/models/patchtst/__init__.py new file mode 100644 index 00000000000000..8c7db64c198406 --- /dev/null +++ b/src/transformers/models/patchtst/__init__.py @@ -0,0 +1,66 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +# rely on isort to merge the imports +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_patchtst": [ + "PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP", + "PatchTSTConfig", + ], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_patchtst"] = [ + "PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST", + "PatchTSTModel", + "PatchTSTPreTrainedModel", + "PatchTSTForPrediction", + "PatchTSTForPretraining", + "PatchTSTForRegression", + "PatchTSTForClassification", + ] + + +if TYPE_CHECKING: + from .configuration_patchtst import PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP, PatchTSTConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_patchtst import ( + PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST, + PatchTSTForClassification, + PatchTSTForPrediction, + PatchTSTForPretraining, + PatchTSTForRegression, + PatchTSTModel, + PatchTSTPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/patchtst/configuration_patchtst.py b/src/transformers/models/patchtst/configuration_patchtst.py new file mode 100644 index 00000000000000..4ced00c3604600 --- /dev/null +++ b/src/transformers/models/patchtst/configuration_patchtst.py @@ -0,0 +1,274 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PatchTST model configuration""" + +from typing import List, Optional, Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ibm/patchtst-base": "https://huggingface.co/ibm/patchtst-base/resolve/main/config.json", + # See all PatchTST models at https://huggingface.co/ibm/models?filter=patchtst +} + + +class PatchTSTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`PatchTSTModel`]. It is used to instantiate an + PatchTST model according to the specified arguments, defining the model architecture. + [ibm/patchtst](https://huggingface.co/ibm/patchtst) architecture. + + Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_input_channels (`int`, *optional*, defaults to 1): + The size of the target variable which by default is 1 for univariate targets. Would be > 1 in case of + multivariate targets. + context_length (`int`, *optional*, defaults to 32): + The context length for the encoder. + distribution_output (`str`, *optional*, defaults to `"student_t"`): + The distribution emission head for the model when loss is "nll". Could be either "student_t", "normal" or + "negative_binomial". + loss (`str`, *optional*, defaults to `"mse"`): + The loss function for the model corresponding to the `distribution_output` head. For parametric + distributions it is the negative log likelihood ("nll") and for point estimates it is the mean squared + error "mse". + patch_length (`int`, *optional*, defaults to 1): + Define the patch length of the patchification process. + patch_stride (`int`, *optional*, defaults to 1): + define the stride of the patchification process. + encoder_layers (`int`, *optional*, defaults to 3): + Number of encoder layers. + d_model (`int`, *optional*, defaults to 64): + Dimensionality of the transformer layers. + encoder_attention_heads (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + shared_embedding (`bool`, *optional*, defaults to `True`): + Sharing the input embedding across all channels. + channel_attention (`bool`, *optional*, defaults to `False`): + Activate channel attention block in the Transformer to allow channels to attend each other. + encoder_ffn_dim (`int`, *optional*, defaults to 256): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + norm (`str` , *optional*, defaults to `"BatchNorm"`): + Normalization at each Transformer layer. Can be `"BatchNorm"` or `"LayerNorm"`. + norm_eps (`float`, *optional*, defaults to 1e-05): + A value added to the denominator for numerical stability of normalization. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention probabilities. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the encoder, and decoder. + positional_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability in the positional embedding layer. + dropout_path (`float`, *optional*, defaults to 0.0): + The dropout path in the residual block. + ff_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability used between the two layers of the feed-forward networks. + bias (`bool`, *optional*, defaults to `True`): + Consider bias in the feed-forward networks. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (string) in the encoder.`"gelu"` and `"relu"` are supported. + pre_norm (`bool`, *optional*, defaults to `True`): + Normalization is applied before self-attention if pre_norm is set to `True`. Otherwise, normalization is + applied after residual block. + positional_encoding_type (`str`, *optional*, defaults to `"sincos"`): + Positional encodings. `"zeros"`, `"normal"`, `"uniform"' and `"sincos"` are supported. + learn_pe (`bool`, *optional*, defaults to `False`): + Whether the positional encoding is updated during training. + use_cls_token (`bool`, *optional*, defaults to `False`): + Whether cls token is used. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal weight initialization distribution. + shared_projection (`bool`, *optional*, defaults to `True`): + Sharing the projection layer across different channels in the forecast head. + seed_number (`Optional`, *optional*): + Seed number used for random masking. If unset, no seed is set. + scaling (`Union`, *optional*, defaults to `"mean"`): + Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the + scaler is set to "mean". + mask_input (`bool`, *optional*, defaults to `False`): + Apply masking during the pretraining. + mask_type (`str`, *optional*, defaults to `"random"`): + Masking type. Only `"random"` and `"forecast"` are currently supported. + random_mask_ratio (`float`, *optional*, defaults to 0.5): + Masking ratio is applied to mask the input data during random pretraining. + forecast_mask_patches (`List`, *optional*, defaults to `[2, 3]`): + List of patch lengths to mask in the end of the data. + forecast_mask_ratios (`List`, *optional*, defaults to `[1, 1]`): + List of weights to use for each patch length. For Ex. if patch_lengths is [5,4] and mix_ratio is [1,1], + then equal weights to both patch lengths. Defaults to None. + channel_consistent_masking (`bool`, *optional*, defaults to `False`): + If channel consistent masking is True, all the channels will have the same masking. + unmasked_channel_indices (`list`, *optional*): + Channels that are not masked during pretraining. + mask_value (`int`, *optional*, defaults to 0): + Define the value of entries to be masked when pretraining. + pooling_type (`str`, *optional*, defaults to `"mean"`): + Pooling of the embedding. `"mean"`, `"max"` and `None` are supported. + head_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for head. + prediction_length (`int`, *optional*, defaults to 24): + The prediction length for the encoder. In other words, the prediction horizon of the model. + num_targets (`int`, *optional*, defaults to 1): + Number of targets for regression and classificastion tasks. For classification, it is the number of + classes. + output_range (`list`, *optional*): + Output range for regression task. The range of output values can be set to enforce the model to produce + values within a range. + num_parallel_samples (`int`, *optional*, defaults to 100): + The number of samples is generated in parallel for probablistic prediction. + + + ```python + >>> from transformers import PatchTSTConfig, PatchTSTModel + + >>> # Initializing an PatchTST configuration with 12 time steps for prediction + >>> configuration = PatchTSTConfig(prediction_length=12) + + >>> # Randomly initializing a model (with random weights) from the configuration + >>> model = PatchTSTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "patchtst" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + "num_hidden_layers": "encoder_layers", + } + + def __init__( + self, + # time series specific configuration + num_input_channels: int = 1, + context_length: int = 32, + distribution_output: str = "student_t", + loss: str = "mse", + # PatchTST arguments + patch_length: int = 1, + patch_stride: int = 1, + # Transformer architecture configuration + encoder_layers: int = 3, + d_model: int = 64, + encoder_attention_heads: int = 4, + shared_embedding: bool = True, + channel_attention: bool = False, + encoder_ffn_dim: int = 256, + norm: str = "BatchNorm", + norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + dropout: float = 0.0, + positional_dropout: float = 0.0, + dropout_path: float = 0.0, + ff_dropout: float = 0.0, + bias: bool = True, + activation_function: str = "gelu", + pre_norm: bool = True, + positional_encoding_type: str = "sincos", + learn_pe: bool = False, + use_cls_token: bool = False, + init_std: float = 0.02, + shared_projection: bool = True, + seed_number: Optional[int] = None, + scaling: Optional[Union[str, bool]] = "mean", + # mask pretraining + mask_input: Optional[bool] = None, + mask_type: str = "random", + random_mask_ratio: float = 0.5, + forecast_mask_patches: List[int] = [2, 3], + forecast_mask_ratios: List[int] = [1, 1], + channel_consistent_masking: bool = False, + unmasked_channel_indices: Optional[List[int]] = None, + mask_value=0, + # head + pooling_type: str = "mean", + head_dropout: float = 0.0, + prediction_length: int = 24, + num_targets: int = 1, + output_range: List = None, + # distribution head + num_parallel_samples: int = 100, + **kwargs, + ): + # time series specific configuration + self.context_length = context_length + self.num_input_channels = num_input_channels # n_vars + self.loss = loss + self.distribution_output = distribution_output + self.num_parallel_samples = num_parallel_samples + + # Transformer architecture configuration + self.d_model = d_model + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.dropout = dropout + self.attention_dropout = attention_dropout + self.shared_embedding = shared_embedding + self.channel_attention = channel_attention + self.norm = norm + self.norm_eps = norm_eps + self.positional_dropout = positional_dropout + self.dropout_path = dropout_path + self.ff_dropout = ff_dropout + self.bias = bias + self.activation_function = activation_function + self.pre_norm = pre_norm + self.positional_encoding_type = positional_encoding_type + self.learn_pe = learn_pe + self.use_cls_token = use_cls_token + self.init_std = init_std + self.scaling = scaling + + # PatchTST parameters + self.patch_length = patch_length + self.patch_stride = patch_stride + self.num_patches = self._num_patches() + + # Mask pretraining + self.seed_number = seed_number + self.mask_input = mask_input + self.mask_type = mask_type + self.random_mask_ratio = random_mask_ratio # for random masking + self.forecast_mask_patches = forecast_mask_patches # for forecast masking + self.forecast_mask_ratios = forecast_mask_ratios + self.channel_consistent_masking = channel_consistent_masking + self.unmasked_channel_indices = unmasked_channel_indices + self.mask_value = mask_value + + # general head params + self.pooling_type = pooling_type + self.head_dropout = head_dropout + + # For prediction head + self.shared_projection = shared_projection + self.prediction_length = prediction_length + + # For prediction and regression head + self.num_parallel_samples = num_parallel_samples + + # Regression + self.num_targets = num_targets + self.output_range = output_range + + super().__init__(**kwargs) + + def _num_patches(self): + return (max(self.context_length, self.patch_length) - self.patch_length) // self.patch_stride + 1 diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py new file mode 100755 index 00000000000000..30522a048f024d --- /dev/null +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -0,0 +1,1913 @@ +# coding=utf-8 +# Copyright 2023 IBM & Hugging Face. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch PatchTST model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2CLS +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...trainer_utils import set_seed +from ...utils import ModelOutput, add_start_docstrings, logging +from .configuration_patchtst import PatchTSTConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "PatchTSTConfig" + +PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ibm/patchtst-etth1-pretrain", + # See all PatchTST models at https://huggingface.co/models?filter=patchtst +] + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTST +class PatchTSTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PatchTSTConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class PatchTSTBatchNorm(nn.Module): + """ + Parameters: + Compute batch normalization + d_model (`int`): model dimension + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps) + + def forward(self, inputs: torch.Tensor): + """ + Parameters: + inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`): + input for Batch norm calculation + Returns: + `torch.Tensor` of shape `(batch_size, sequence_length, d_model)` + """ + output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length) + output = self.batchnorm(output) + return output.transpose(1, 2) + + +def positional_encoding(positional_encoding_type, learned, q_len, d_model): + # Positional encoding + if positional_encoding_type is None: + # positional_encoding_type = None and learned = False can be used to measure impact of positional encoding + position_enc = torch.empty((q_len, d_model)) + nn.init.uniform_(position_enc, -0.02, 0.02) + learned = False + elif positional_encoding_type == "zeros": + position_enc = torch.empty((q_len, d_model)) + nn.init.uniform_(position_enc, -0.02, 0.02) + elif positional_encoding_type == "normal": + position_enc = torch.zeros((q_len, 1)) + nn.init.normal_(position_enc, mean=0.0, std=0.1) + elif positional_encoding_type == "uniform": + position_enc = torch.zeros((q_len, 1)) + nn.init.uniform_(position_enc, a=0.0, b=0.1) + elif positional_encoding_type == "sincos": + position_enc = torch.zeros(q_len, d_model) + position = torch.arange(0, q_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + position_enc[:, 0::2] = torch.sin(position * div_term) + position_enc[:, 1::2] = torch.cos(position * div_term) + position_enc = position_enc - position_enc.mean() + position_enc = position_enc / (position_enc.std() * 10) + else: + raise ValueError( + f"{positional_encoding_type} is not a valid positional encoder. Available types are 'normal', 'zeros', 'zero', uniform', 'sincos', None." + ) + return nn.Parameter(position_enc, requires_grad=learned) + + +def random_masking( + inputs: torch.Tensor, + mask_ratio: float, + unmasked_channel_indices: list = None, + channel_consistent_masking: bool = False, + mask_value: int = 0, + seed_number: Optional[int] = None, +): + """random_masking: Mask the input considering the control variables. + + Args: + inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`): + The input tensor to mask. + mask_ratio (`float`): + Mask ratio. + unmasked_channel_indices (list, *optional*): + indices of unmasked channels. These channels will not be masked. + channel_consistent_masking (bool, *optional* defaults to False): + When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary + across channels. + mask_value (int, *optional*, defaults to 0): + Value to use for masking. + seed_number (int, *optional*): + Value to set for the random seed. + + Returns: + `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x + n] + """ + if seed_number: + set_seed(seed_number) + + batch_size, num_channels, sequence_length, num_features = inputs.shape + device = inputs.device + + len_keep = int(sequence_length * (1 - mask_ratio)) + + if channel_consistent_masking: + noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L + noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time + else: + # noise in [0, 1], bs x num_channels x L + noise = torch.rand(batch_size, num_channels, sequence_length, device=device) + + # mask: [bs x num_channels x num_patch] + mask = torch.ones(batch_size, num_channels, sequence_length, device=device) + mask[:, :, :len_keep] = 0 + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L] + + mask = torch.gather(mask, dim=-1, index=ids_restore) + mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length] + if unmasked_channel_indices is not None: + mask[:, unmasked_channel_indices, :, :] = 0 + + inputs_mask = inputs.masked_fill(mask.bool(), mask_value) + return inputs_mask, mask[..., 0] + + +def forecast_masking( + inputs: torch.Tensor, + forecast_mask_patches: list, + forecast_mask_ratios: list = None, + unmasked_channel_indices: list = None, + mask_value: int = 0, + seed_number: Optional[int] = None, +): + """Forecast masking that masks the last K patches where K is from the forecast_mask_patches list. + For every batch, distribute the patch lengths based on forecast_mask_ratios and ignore masks for column indices + mentioned in unmasked_channel_indices. + + Parameters: + inputs (`torch.Tensor`): + Input of shape `(bs, num_channels, num_patch, patch_len)` or `(bs, tsg1, tag2, num_channels, num_patch, + patch_len)` + forecast_mask_patches (`list`): + List of patch lengths to mask at the end of the data e.g. [2, 4]. + forecast_mask_ratios (`list`, *optional*): + List of weights to use for each patch length. For example if forecast_mask_patches is [5,4] and + forecast_mask_ratios is [1,1], then equal weights to both patch lengths. + unmasked_channel_indices (`list`, *optional*): + Control Variable channel indices. These channels will not be masked. + mask_value (`int`, *optional*, defaults to 0): + Value to use for masking. + seed_number (`int`, *optional*): + Value to set for the random seed. + + Returns: + `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs, + num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)` + """ + if seed_number: + set_seed(seed_number) + + if forecast_mask_ratios is None: + forecast_mask_ratios = [1 for _ in forecast_mask_patches] + + batch_size, num_channels, sequence_length, num_features = inputs.shape + mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device) + + t_list = [] + total_length = 0 + total_ratio = sum(forecast_mask_ratios) + + for patch_length, ratio in zip(forecast_mask_patches, forecast_mask_ratios): + if patch_length <= 0 or patch_length >= sequence_length: + raise Exception("masked_patch_len should be greater than 0 and less than total patches.") + temp_len = int(batch_size * ratio / total_ratio) + t_list.append([patch_length, ratio, temp_len]) + total_length += temp_len + + t_list = sorted(t_list, key=lambda x: x[2]) + + if total_length < batch_size: + t_list[0][2] = t_list[0][2] + (batch_size - total_length) + elif total_length > batch_size: + t_list[-1][2] = t_list[-1][2] + (total_length - batch_size) + + batch1 = 0 + for patch_len, _, temp_len in t_list: + batch2 = batch1 + temp_len + mask[batch1:batch2, :, -patch_len:] = 1 + batch1 = batch2 + + perm = torch.randperm(mask.shape[0]) + mask = mask[perm] + + mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len] + if unmasked_channel_indices is not None: + mask[:, unmasked_channel_indices, :, :] = 0 + + inputs_mask = inputs.masked_fill(mask.bool(), mask_value) + return inputs_mask, mask[..., 0] + + +class PatchTSTPatchify(nn.Module): + """ + A class to patchify the time series sequence into different patches + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + + self.sequence_length = config.context_length + self.patch_length = config.patch_length + self.patch_stride = config.patch_stride + + if self.sequence_length <= self.patch_length: + raise ValueError( + f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})" + ) + + # get the number of patches + num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1 + new_sequence_length = self.patch_length + self.patch_stride * (num_patches - 1) + self.sequence_start = self.sequence_length - new_sequence_length + + def forward(self, past_values: torch.Tensor): + """ + Parameters: + past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*): + Input to be patchified + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` + """ + sequence_length = past_values.shape[-2] + if sequence_length != self.sequence_length: + raise ValueError( + f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})." + ) + # output: [bs x new_sequence_length x num_channels] + output = past_values[:, self.sequence_start :, :] + # output: [bs x num_patches x num_input_channels x patch_length] + output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride) + # output: [bs x num_input_channels x num_patches x patch_length] + output = output.transpose(-2, -3).contiguous() + return output + + +class PatchTSTMasking(nn.Module): + """ + Class to perform random or forecast masking. + + Parameters: + config (`PatchTSTConfig`): model config + + Returns: + x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) + Masked patched input + mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) + Bool tensor indicating True on masked points + + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.random_mask_ratio = config.random_mask_ratio + self.channel_consistent_masking = config.channel_consistent_masking + self.mask_type = config.mask_type + self.forecast_mask_patches = config.forecast_mask_patches + self.forecast_mask_ratios = config.forecast_mask_ratios + self.unmasked_channel_indices = config.unmasked_channel_indices + self.mask_value = config.mask_value + if self.unmasked_channel_indices is not None: + self.unmasked_channel_indices.sort() + self.seed_number = config.seed_number + + def forward(self, patch_input: torch.Tensor): + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Patch input + + Return: + masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) + Masked patched input + mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) + Bool tensor indicating True on masked points + + """ + + if self.mask_type == "random": + masked_input, mask = random_masking( + inputs=patch_input, + mask_ratio=self.random_mask_ratio, + unmasked_channel_indices=self.unmasked_channel_indices, + channel_consistent_masking=self.channel_consistent_masking, + mask_value=self.mask_value, + seed_number=self.seed_number, + ) + elif self.mask_type == "forecast": + masked_input, mask = forecast_masking( + inputs=patch_input, + forecast_mask_patches=self.forecast_mask_patches, + forecast_mask_ratios=self.forecast_mask_ratios, + unmasked_channel_indices=self.unmasked_channel_indices, + mask_value=self.mask_value, + seed_number=self.seed_number, + ) + else: + raise Exception("Invalid mask type") + + mask = mask.bool() # mask: [bs x num_input_channels x num_patch] + + return masked_input, mask + + +class PatchTSTEncoderLayer(nn.Module): + """ + PatchTST encoder layer + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + + self.channel_attention = config.channel_attention + + # Multi-Head attention + self.self_attn = PatchTSTAttention( + embed_dim=config.d_model, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + + # Add & Norm of the sublayer 1 + self.dropout_path1 = nn.Dropout(config.dropout_path) if config.dropout_path > 0 else nn.Identity() + if "batch" in config.norm.lower(): + self.norm_sublayer1 = PatchTSTBatchNorm(config) + else: + self.norm_sublayer1 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + + # Add & Norm of the sublayer 2 + if self.channel_attention: + self.dropout_path2 = nn.Dropout(config.dropout_path) if config.dropout_path > 0 else nn.Identity() + if "batch" in config.norm.lower(): + self.norm_sublayer2 = PatchTSTBatchNorm(config) + else: + self.norm_sublayer2 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + + # Position-wise Feed-Forward + self.ff = nn.Sequential( + nn.Linear(config.d_model, config.encoder_ffn_dim, bias=config.bias), + ACT2CLS[config.activation_function](), + nn.Dropout(config.ff_dropout) if config.ff_dropout > 0 else nn.Identity(), + nn.Linear(config.encoder_ffn_dim, config.d_model, bias=config.bias), + ) + + # Add & Norm of sublayer 3 + self.dropout_path3 = nn.Dropout(config.dropout_path) if config.dropout_path > 0 else nn.Identity() + if "batch" in config.norm.lower(): + self.norm_sublayer3 = PatchTSTBatchNorm(config) + else: + self.norm_sublayer3 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + + self.pre_norm = config.pre_norm + + def forward(self, hidden_state: torch.Tensor, output_attentions: Optional[bool] = None): + """ + Parameters: + hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*): + Past values of the time series + Return: + `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)` + + """ + batch_size, num_input_channels, sequence_length, d_model = hidden_state.shape + + # First sublayer: attention across time + # hidden_states: [(bs*num_channels) x sequence_length x d_model] + hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model) + + if self.pre_norm: + ## Norm and Multi-Head attention and Add residual connection + attn_output, attn_weights, _ = self.self_attn( + hidden_states=self.norm_sublayer1(hidden_state), output_attentions=output_attentions + ) + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path1(attn_output) + else: + ## Multi-Head attention and Add residual connection and Norm - Standard Transformer from BERT + attn_output, attn_weights, _ = self.self_attn( + hidden_states=hidden_state, output_attentions=output_attentions + ) + # hidden_states: [(bs*num_channels) x sequence_length x d_model] + hidden_state = self.norm_sublayer1(hidden_state + self.dropout_path1(attn_output)) + + # hidden_state: [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model) + + # second sublayer: attention across variable at any given time + if self.channel_attention: + # hidden_state: [bs x sequence_length x num_channels x d_model] + hidden_state = hidden_state.transpose(2, 1).contiguous() + # hidden_state: [(bs*sequence_length) x num_channels x d_model] + hidden_state = hidden_state.view(batch_size * sequence_length, num_input_channels, d_model) + if self.pre_norm: + ## Norm and Multi-Head attention and Add residual connection + attn_output, channel_attn_weights, _ = self.self_attn( + hidden_states=self.norm_sublayer2(hidden_state), output_attentions=output_attentions + ) + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path2(attn_output) + else: + ## Multi-Head attention and Add residual connection and Norm + attn_output, channel_attn_weights, _ = self.self_attn( + hidden_states=hidden_state, output_attentions=output_attentions + ) + # hidden_states: [(bs*sequence_length) x num_channels x d_model] + hidden_state = self.norm_sublayer2(hidden_state + self.dropout_path2(attn_output)) + + # Reshape hidden state + # hidden_state: [bs x sequence_length x num_channels x d_model] + hidden_state = hidden_state.reshape(batch_size, sequence_length, num_input_channels, d_model) + # hidden_state: [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.transpose(1, 2).contiguous() + + # Third sublayer: mixing across hidden + # hidden_state: [(batch_size*num_channels) x sequence_length x d_model] + hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model) + if self.pre_norm: + ## Norm and Position-wise Feed-Forward and Add residual connection + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path3(self.ff(self.norm_sublayer3(hidden_state))) + else: + ## Position-wise Feed-Forward and Add residual connection and Norm + # Add: residual connection with residual dropout + hidden_state = self.norm_sublayer3(hidden_state + self.dropout_path3(self.ff(hidden_state))) + + # [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model) + + outputs = (hidden_state,) + if output_attentions: + outputs += (attn_weights, channel_attn_weights) if self.channel_attention else (attn_weights,) + + return outputs + + +class PatchTSTPreTrainedModel(PreTrainedModel): + config_class = PatchTSTConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + """Initialize weights""" + if self.config.use_cls_token: + nn.init.normal_(self.config.cls_token, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (PatchTSTEncoder)): + module.gradient_checkpointing = value + + +class PatchTSTEmbedding(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + # Input encoding: projection of feature vectors onto a d-dim vector space + if not config.shared_embedding: + self.input_embedding = nn.ModuleList() + for _ in range(config.num_input_channels): + self.input_embedding.append(nn.Linear(config.patch_length, config.d_model)) + else: + self.input_embedding = nn.Linear(config.patch_length, config.d_model) + + def forward(self, patch_input: torch.Tensor): + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Patch input for embedding + return: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)` + """ + # Input encoding + num_input_channels = patch_input.shape[1] + if isinstance(self.input_embedding, nn.ModuleList): + embeddings = [self.input_embedding[i](patch_input[:, i, :, :]) for i in range(num_input_channels)] + embeddings = torch.stack(embeddings, dim=1) + else: + embeddings = self.input_embedding(patch_input) # x: [bs x num_channels x num_patches x d_model] + return embeddings + + +class PatchTSTPositionalEncoding(nn.Module): + """ + Class for positional encoding + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.use_cls_token = config.use_cls_token + if config.use_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, config.d_model)) + num_patches = config.num_patches + 1 + else: + num_patches = config.num_patches + # postional encoding + self.position_enc = positional_encoding( + config.positional_encoding_type, config.learn_pe, num_patches, config.d_model + ) + # Positional dropout + self.positional_dropout = ( + nn.Dropout(config.positional_dropout) if config.positional_dropout > 0 else nn.Identity() + ) + + def forward(self, patch_input: torch.Tensor): + if self.use_cls_token: + # patch_input: [bs x num_channels x num_patches x d_model] + patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :]) + # append cls token where cls_token: [1 x 1 x 1 x d_model] + cls_token = self.cls_token + self.position_enc[:1, :] + # get the same copy of cls_token for all the samples in batch + cls_tokens = cls_token.expand(patch_input.shape[0], -1, -1) + # hidden_state: [bs x num_channels x (num_patches+1) x d_model] + hidden_state = torch.cat((cls_tokens, patch_input), dim=1) + else: + # hidden_state: [bs x num_channels x num_patches x d_model] + hidden_state = self.positional_dropout(patch_input + self.position_enc) + return hidden_state + + +class PatchTSTEncoder(PatchTSTPreTrainedModel): + """ + PatchTST Encoder + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + self.num_input_channels = config.num_input_channels + self.num_patches = config.num_patches + self.patch_length = config.patch_length + self.d_model = config.d_model + self.shared_embedding = config.shared_embedding + self.use_cls_token = config.use_cls_token + self.gradient_checkpointing = False + + # Input embedding: projection of feature vectors onto a d-dim vector space + self.embedder = PatchTSTEmbedding(config) + # Positional encoding + self.positional_encoder = PatchTSTPositionalEncoding(config) + # Encoder + self.layers = nn.ModuleList([PatchTSTEncoderLayer(config) for i in range(config.encoder_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + patch_input: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + ) -> BaseModelOutput: + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Past values of the time series + output_hidden_states (bool, optional): Indicates if hidden states should be outputted. + output_attentions (bool, optional): Indicates if attentions should be outputted. + + return: + `BaseModelOutput` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Input embedding + patch_input = self.embedder(patch_input) + # Positional encoding + hidden_state = self.positional_encoder(patch_input) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_state,) + + layer_outputs = encoder_layer(hidden_state=hidden_state, output_attentions=output_attentions) + # get hidden state. hidden_state shape is [bs x num_channels x num_patches x d_model] + # or [bs x num_channels x (num_patches+1) x d_model] if use cls_token + hidden_state = layer_outputs[0] + # append attention matrix at each layer + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + # return past_values, hidden_states + return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=encoder_states, attentions=all_attentions) + + +PATCHTST_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PatchTSTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PATCHTST_INPUTS_DOCSTRING = r""" + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`). + + For multivariate time series, the `num_input_channels` > 1 dimension is required and corresponds to the + number of variates in the time series per time step. + + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, num_input_channels)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + For multivariate time series, the `num_input_channels` > 1 dimension is required and corresponds to the + number of variates in the time series per time step. + + output_hidden_states (`bool`, *optional*, default to False): + Whether or not to return the hidden states of all layers. +""" + + +@dataclass +class PatchTSTModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Parameters: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`): + patched input to the Transformer + mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`,*optional*) + Bool masked tensor indicating which patches are masked + loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*) + mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length + scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*) + std of the input data (batch_size, sequence_length, num_channels) over the sequence_length + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + patch_input: torch.FloatTensor = None + mask: torch.FloatTensor = None + loc: torch.FloatTensor = None + scale: torch.FloatTensor = None + + +@dataclass +class PatchTSTForPretrainingOutput(ModelOutput): + """ + Output type of [`PatchTSTForPretraining`]. + + Parameters: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + prediction_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction outputs of the time series modeling heads. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PatchTSTForRegressionOutput(ModelOutput): + """ + Output type of [`PatchTSTForRegression`]. + + Parameters: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + forecast_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction outputs of the time series modeling heads. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + forecast_outputs: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PatchTSTForPredictionOutput(ModelOutput): + """ + Output type of [`PatchTSTForPrediction`]. + + Parameters: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + prediction_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, -1)`): + Prediction outputs of the time series modeling heads. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_outputs: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + loc: torch.FloatTensor = None + scale: torch.FloatTensor = None + + +@dataclass +class PatchTSTForClassificationOutput(ModelOutput): + """ + Output type of [`PatchTSTForClassification`]. + + Parameters: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SamplePatchTSTPredictionOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Parameters: + sequences `(batch_size, num_samples, prediction_length, num_targets)`): + Sampled values from the chosen distribution. + """ + + sequences: torch.FloatTensor = None + + +@dataclass +class SamplePatchTSTRegressionOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Parameters: + sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, num_targets)` + Sampled values from the chosen distribution. + """ + + sequences: torch.FloatTensor = None + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTStdScaler(nn.Module): + """ + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) + denominator = denominator.clamp_min(1.0) + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator + + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + scale = torch.sqrt(variance + self.minimum_scale) + return (data - loc) / scale, loc, scale + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTMeanScaler(nn.Module): + """ + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data + accordingly. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) + num_observed = observed_indicator.sum(self.dim, keepdim=True) + + scale = ts_sum / torch.clamp(num_observed, min=1) + + # If `default_scale` is provided, we use it, otherwise we use the scale + # of the batch. + if self.default_scale is None: + batch_sum = ts_sum.sum(dim=0) + batch_observations = torch.clamp(num_observed.sum(0), min=1) + default_scale = torch.squeeze(batch_sum / batch_observations) + else: + default_scale = self.default_scale * torch.ones_like(scale) + + # apply default scale where there are no observations + scale = torch.where(num_observed > 0, scale, default_scale) + + # ensure the scale is at least `self.minimum_scale` + scale = torch.clamp(scale, min=self.minimum_scale) + scaled_data = data / scale + + if not self.keepdim: + scale = scale.squeeze(dim=self.dim) + + return scaled_data, torch.zeros_like(scale), scale + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTNOPScaler(nn.Module): + """ + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + return data, loc, scale + + +class PatchTSTScaler(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + if config.scaling == "mean" or config.scaling is True: + self.scaler = PatchTSTMeanScaler(config) + elif config.scaling == "std": + self.scaler = PatchTSTStdScaler(config) + else: + self.scaler = PatchTSTNOPScaler(config) + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, um_input_channels)`) + """ + data, loc, scale = self.scaler(data, observed_indicator) + return data, loc, scale + + +@add_start_docstrings( + "The bare PatchTST Model outputting raw hidden-states without any specific head.", + PATCHTST_START_DOCSTRING, +) +class PatchTSTModel(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + self.scaler = PatchTSTScaler(config) + self.patchifier = PatchTSTPatchify(config) + self.mask_input = config.mask_input + + if self.mask_input: + self.masking = PatchTSTMasking(config) + else: + self.masking = nn.Identity() + self.encoder = PatchTSTEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PatchTSTModelOutput]: + """ + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False) + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) + + # x: tensor [bs x sequence_length x num_input_channels] + scaled_past_values, loc, scale = self.scaler(past_values, past_observed_mask) + + # patched_values: [bs x num_input_channels x num_patches x patch_length] for pretrain + patched_values = self.patchifier(scaled_past_values) + if self.mask_input: + masked_values, mask = self.masking(patched_values) + else: + masked_values, mask = self.masking(patched_values), None + + encoder_output = self.encoder( + patch_input=masked_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + + if not return_dict: + outputs = (encoder_output.last_hidden_state, encoder_output.hidden_states, encoder_output.attentions) + outputs = outputs + (patched_values, mask, loc, scale) + return tuple(v for v in outputs if v is not None) + + return PatchTSTModelOutput( + last_hidden_state=encoder_output.last_hidden_state, + hidden_states=encoder_output.hidden_states, + attentions=encoder_output.attentions, + patch_input=patched_values, + mask=mask, + loc=loc, + scale=scale, + ) + + +class PatchTSTMaskPretrainHead(nn.Module): + """ + Pretraining head for mask modelling + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dropout = nn.Dropout(config.dropout) + self.linear = nn.Linear(config.d_model, config.patch_length) + self.use_cls_token = config.use_cls_token + + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` + or `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True + + """ + embedding = self.linear(self.dropout(embedding)) # [bs x num_channels x num_patches x patch_length] + if self.use_cls_token: + embedding = embedding[:, :, 1:, :] # remove the first cls token + return embedding + + +class PatchTSTForPretraining(PatchTSTPreTrainedModel): + """ + Mask pretrain model: PatchTST model + pretrain head + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + config.mask_input = True + self.model = PatchTSTModel(config=config) + self.head = PatchTSTMaskPretrainHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PatchTSTForPretrainingOutput]: + """ + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # past_values: [bs x num_channels x num_patches x d_model] or + # [bs x num_channels x (num_patches+1) x d_model] if use cls_token + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + + # model_output[0]: [bs x num_channels x num_patches x patch_length] or + # [bs x num_channels x (num_patches+1) x patch_length] if use cls_token + x_hat = self.head(model_output[0]) + + # calculate masked_loss + loss = nn.MSELoss(reduction="none") + loss_val = loss(x_hat, model_output.patch_input) + masked_loss = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10) + + encoder_states = model_output.hidden_states + if not return_dict: + outputs = (masked_loss, x_hat, model_output.hidden_states, model_output.attentions) + return tuple(v for v in outputs if v is not None) + return PatchTSTForPretrainingOutput( + loss=masked_loss, prediction_output=x_hat, hidden_states=encoder_states, attentions=model_output.attentions + ) + + +class PatchTSTForClassification(PatchTSTPreTrainedModel): + """ + PatchTST model for classification. The model contains PatchTST model + classification head + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + self.model = PatchTSTModel(config) + self.head = PatchTSTClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + target_values: torch.Tensor = None, + past_observed_mask: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTForClassificationOutput]: + """ + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + target_values (`torch.Tensor`, *optional*): labels associates with the `past_values` + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForClassificationOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + y_hat = self.head(model_output[0]) + + loss_val = None + if target_values is not None: + loss = nn.CrossEntropyLoss() + loss_val = loss(y_hat, target_values) + + if not return_dict: + outputs = (loss_val, y_hat, model_output.hidden_states, model_output.attentions) + return tuple(v for v in outputs if v is not None) + return PatchTSTForClassificationOutput( + loss=loss_val, + prediction_logits=y_hat, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + ) + + +class PatchTSTClassificationHead(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + self.flatten = nn.Flatten(start_dim=1) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + self.linear = nn.Linear(config.num_input_channels * config.d_model, config.num_targets) + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` + or `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, num_targets)` + + """ + if self.use_cls_token: + # use the first output token, pooled_embedding: bs x num_channels x d_model + pooled_embedding = embedding[:, :, 0, :] + elif self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2) + else: + raise Exception(f"pooling operator {self.pooling_type} is not implemented yet") + # pooled_embedding: bs x num_channels * d_model + pooled_embedding = self.flatten(pooled_embedding) + # output: bs x n_classes + output = self.linear(self.dropout(pooled_embedding)) + return output + + +class PatchTSTPredictionHead(nn.Module): + def __init__(self, config: PatchTSTConfig, distribution_output=None): + super().__init__() + + self.shared_projection = config.shared_projection + self.num_input_channels = config.num_input_channels + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + head_dim = config.d_model if self.pooling_type else config.d_model * config.num_patches + + if not self.shared_projection: + # if each channel has its own head + self.projections = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.num_input_channels): + self.flattens.append(nn.Flatten(start_dim=2)) + if distribution_output is None: + # use linear head + self.projections.append(nn.Linear(head_dim, config.prediction_length)) + else: + # use distribution head + self.projections.append(distribution_output.get_parameter_projection(head_dim)) + self.dropouts.append(nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()) + else: + # all the channels share the same head + self.flatten = nn.Flatten(start_dim=2) + if distribution_output is None: + # use linear head + self.projection = nn.Linear(head_dim, config.prediction_length) + else: + # use distribution head + self.projection = distribution_output.get_parameter_projection(head_dim) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` + or `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, forecast_len, num_channels)` + + """ + if self.use_cls_token: + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding[:, :, 0, :] + else: + if self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2) + else: + # pooled_embedding: [bs x num_channels x num_patches x d_model] + pooled_embedding = embedding + + if not self.shared_projection: + output = [] + for i in range(self.num_input_channels): + # pooled_embedding: [bs x (d_model * num_patches)] or [bs x d_model)] + pooled_embedding = self.flattens[i](pooled_embedding[:, i, :]) + pooled_embedding = self.dropouts[i](pooled_embedding) + # pooled_embedding: [bs x forecast_len] + # or tuple ([bs x forecast_len], [bs x forecast_len]) if using distribution head + pooled_embedding = self.projections[i](pooled_embedding) + output.append(pooled_embedding) + # output: [bs x num_channels x forecast_len] + output = torch.stack(output, dim=1) + else: + # pooled_embedding: [bs x num_channels x (d_model * num_patches)] or [bs x num_channels x d_model)] + pooled_embedding = self.flatten(pooled_embedding) + pooled_embedding = self.dropout(pooled_embedding) + # output: [bs x num_channels x forecast_len] or + # tuple ([bs x num_channels x forecast_len], [bs x num_channels x forecast_len]) if using distribution head + output = self.projection(pooled_embedding) + + if isinstance(output, tuple): + # output: ([bs x forecast_len x num_channels], [bs x forecast_len x num_channels]) + output = tuple(z.transpose(2, 1) for z in output) + else: + output = output.transpose(2, 1) # [bs x forecast_len x num_channels] + return output + + +class PatchTSTForPrediction(PatchTSTPreTrainedModel): + """ + PatchTST for forecasting. The model contains PatchTST model + Forecasting head + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + self.model = PatchTSTModel(config) + + if config.loss == "mse": + self.distribution_output = None + else: + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.prediction_length) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.prediction_length) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.head = PatchTSTPredictionHead(config, self.distribution_output) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PatchTSTForPredictionOutput]: + """ + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*): + future target values associated with the `past_values` + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # get model output + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + # get output head + y_hat = self.head(model_output.last_hidden_state) + + loss_val = None + + if future_values is not None: + if self.distribution_output: + distribution = self.distribution_output.distribution( + y_hat, loc=model_output.loc, scale=model_output.scale + ) + loss_val = nll(distribution, future_values) + # take average of the loss + loss_val = weighted_average(loss_val) + # for testing + # loss_val = nn.MSELoss(reduction='none')(distribution.mean, future_values) + # loss_val = weighted_average(loss_val) + else: + y_hat = y_hat * model_output.scale + model_output.loc + loss = nn.MSELoss(reduction="mean") + loss_val = loss(y_hat, future_values) + + loc = model_output.loc + scale = model_output.scale + + if not return_dict: + outputs = (loss_val, y_hat, model_output.hidden_states, model_output.attentions, loc, scale) + return tuple(v for v in outputs if v is not None) + return PatchTSTForPredictionOutput( + loss=loss_val, + prediction_outputs=y_hat, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + loc=loc, + scale=scale, + ) + + def generate( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + ) -> SamplePatchTSTPredictionOutput: + """ + Generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series that serves as context in order to predict the future. + + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Return: + [`SamplePatchTSTPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, + number of samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, + num_input_channels)` for multivariate predictions. + """ + # get number of samples + num_parallel_samples = self.config.num_parallel_samples + + # get model output + outputs = self( + past_values=past_values, + future_values=None, + past_observed_mask=past_observed_mask, + output_hidden_states=False, + ) + + # get distribution + distribution = self.distribution_output.distribution( + outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale + ) + # get samples: list of [bs x forecast_len x num_channels] + samples = [distribution.sample() for _ in range(num_parallel_samples)] + # stack tensors + samples = torch.stack(samples, dim=1) # [bs x num_samples x forecast_len x num_channels] + return SamplePatchTSTPredictionOutput(sequences=samples) + + +class PatchTSTRegressionHead(nn.Module): + """ + Regression head + """ + + def __init__(self, config: PatchTSTConfig, distribution_output=None): + super().__init__() + self.y_range = config.output_range + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + self.distribution_output = distribution_output + + head_dim = config.num_input_channels * config.d_model + + self.flatten = nn.Flatten(start_dim=1) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + + if distribution_output is None: + self.projection = nn.Linear(head_dim, config.num_targets) + else: + self.projection = distribution_output.get_parameter_projection(head_dim) + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` + or `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, output_dim)` + + """ + if self.use_cls_token: + # use the first output token, pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding[:, :, 0, :] + elif self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2) + else: + raise Exception(f"pooling operator {self.pooling_type} is not implemented yet") + # flatten the input + # pooled_embedding: bs x (num_channels * d_model) + pooled_embedding = self.dropout(self.flatten(pooled_embedding)) + # projection + # output: bs x output_dim or a tuple of this shape for distribution head + output = self.projection(pooled_embedding) + # + if (self.distribution_output is None) & (self.y_range is not None): # linear head + output = torch.sigmoid(output) * (self.y_range[1] - self.y_range[0]) + self.y_range[0] + return output + + +class PatchTSTForRegression(PatchTSTPreTrainedModel): + # PatchTST model + Regression head + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + self.model = PatchTSTModel(config) + + self.model = PatchTSTModel(config) + if config.loss == "mse": + self.distribution_output = None + else: + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.prediction_length * config.num_targets) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.prediction_length * config.num_targets) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length * config.num_targets) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.head = PatchTSTRegressionHead(config, self.distribution_output) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + target_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTForRegressionOutput]: + """ + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + target_values (`torch.Tensor` of shape `(bs, num_input_channels)`): + target values associates with the `past_values` + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForRegressionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + # get output head. y_hat is of shape [bs x num_targets] or tuple of this shape + y_hat = self.head(model_output.last_hidden_state) + + loss_val = None + if target_values is not None: + if self.distribution_output: + distribution = self.distribution_output.distribution(y_hat) + loss_val = nll(distribution, target_values) + # take average of the loss + loss_val = weighted_average(loss_val) + else: + loss = nn.MSELoss(reduction="mean") + loss_val = loss(y_hat, target_values) + + if not return_dict: + outputs = (loss_val, y_hat, model_output.hidden_states, model_output.attentions) + return tuple(v for v in outputs if v is not None) + return PatchTSTForRegressionOutput( + loss=loss_val, + forecast_outputs=y_hat, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + ) + + def generate( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + ) -> SamplePatchTSTRegressionOutput: + """ + Generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series that serves as context in order to predict the future. + + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Return: + [`SamplePatchTSTRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, + number of samples, num_targets)`. + """ + # get number of samples + num_parallel_samples = self.config.num_parallel_samples + + # get model output + outputs = self( + past_values=past_values, + target_values=None, + past_observed_mask=past_observed_mask, + output_hidden_states=False, + ) + + # get distribution + distribution = self.distribution_output.distribution(outputs.forecast_outputs) + # get samples: list of [bs x num_targets] + samples = [distribution.sample() for _ in range(num_parallel_samples)] + # stack tensors + samples = torch.stack(samples, dim=1) # [bs x num_samples x num_targets] + return SamplePatchTSTRegressionOutput(sequences=samples) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 904c02b4f04308..2c875dd56e1b08 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -83,67 +83,66 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: class TimeSeriesStdScaler(nn.Module): """ - Standardize features by calculating the mean and scaling along some given dimension `dim`, and then normalizes it - by subtracting from the mean and dividing by the standard deviation. - - Args: - dim (`int`): - Dimension along which to calculate the mean and standard deviation. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. - minimum_scale (`float`, *optional*, defaults to 1e-5): - Default scale that is used for elements that are constantly zero along dimension `dim`. + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. """ - def __init__(self, dim: int, keepdim: bool = False, minimum_scale: float = 1e-5): + def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() - if not dim > 0: - raise ValueError("Cannot compute scale along dim = 0 (batch dimension), please provide dim > 0") - self.dim = dim - self.keepdim = keepdim - self.minimum_scale = minimum_scale + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 - @torch.no_grad() - def forward(self, data: torch.Tensor, weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - denominator = weights.sum(self.dim, keepdim=self.keepdim) + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) denominator = denominator.clamp_min(1.0) - loc = (data * weights).sum(self.dim, keepdim=self.keepdim) / denominator + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator - variance = (((data - loc) * weights) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator scale = torch.sqrt(variance + self.minimum_scale) return (data - loc) / scale, loc, scale class TimeSeriesMeanScaler(nn.Module): """ - Computes a scaling factor as the weighted average absolute value along dimension `dim`, and scales the data + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data accordingly. - - Args: - dim (`int`): - Dimension along which to compute the scale. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. - default_scale (`float`, *optional*, defaults to `None`): - Default scale that is used for elements that are constantly zero. If `None`, we use the scale of the batch. - minimum_scale (`float`, *optional*, defaults to 1e-10): - Default minimum possible scale that is used for any item. """ - def __init__( - self, dim: int = -1, keepdim: bool = True, default_scale: Optional[float] = None, minimum_scale: float = 1e-10 - ): + def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() - self.dim = dim - self.keepdim = keepdim - self.minimum_scale = minimum_scale - self.default_scale = default_scale + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None - @torch.no_grad() def forward( self, data: torch.Tensor, observed_indicator: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # shape: (N, [C], T=1) + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) num_observed = observed_indicator.sum(self.dim, keepdim=True) @@ -173,23 +172,26 @@ def forward( class TimeSeriesNOPScaler(nn.Module): """ - Assigns a scaling factor equal to 1 along dimension `dim`, and therefore applies no scaling to the input data. - - Args: - dim (`int`): - Dimension along which to compute the scale. - keepdim (`bool`, *optional*, defaults to `False`): - Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it. + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. """ - def __init__(self, dim: int, keepdim: bool = False): + def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() - self.dim = dim - self.keepdim = keepdim + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True def forward( - self, data: torch.Tensor, observed_indicator: torch.Tensor + self, data: torch.Tensor, observed_indicator: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) return data, loc, scale @@ -1180,11 +1182,11 @@ def __init__(self, config: TimeSeriesTransformerConfig): super().__init__(config) if config.scaling == "mean" or config.scaling is True: - self.scaler = TimeSeriesMeanScaler(dim=1, keepdim=True) + self.scaler = TimeSeriesMeanScaler(config) elif config.scaling == "std": - self.scaler = TimeSeriesStdScaler(dim=1, keepdim=True) + self.scaler = TimeSeriesStdScaler(config) else: - self.scaler = TimeSeriesNOPScaler(dim=1, keepdim=True) + self.scaler = TimeSeriesNOPScaler(config) if config.num_static_categorical_features > 0: self.embedder = TimeSeriesFeatureEmbedder( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c6b20c7e36746a..07bcf3867fb175 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -627,6 +627,12 @@ def __init__(self, *args, **kwargs): MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = None +MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = None + + +MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = None + + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None @@ -6019,6 +6025,51 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class PatchTSTForClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PatchTSTForPrediction(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PatchTSTForPretraining(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PatchTSTForRegression(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PatchTSTModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PatchTSTPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PegasusForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/patchtst/__init__.py b/tests/models/patchtst/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py new file mode 100644 index 00000000000000..8d6f2202ee81ce --- /dev/null +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -0,0 +1,353 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch PatchTST model. """ + +import inspect +import random +import tempfile +import unittest + +from huggingface_hub import hf_hub_download + +from transformers import is_torch_available +from transformers.models.auto import get_values +from transformers.testing_utils import is_flaky, require_torch, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +TOLERANCE = 1e-4 + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING, + MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING, + PatchTSTConfig, + PatchTSTForClassification, + PatchTSTForPrediction, + PatchTSTForPretraining, + PatchTSTForRegression, + PatchTSTModel, + ) + + +@require_torch +class PatchTSTModelTester: + def __init__( + self, + parent, + batch_size=13, + prediction_length=7, + context_length=14, + patch_length=5, + patch_stride=5, + num_input_channels=1, + num_time_features=1, + is_training=True, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + lags_sequence=[1, 2, 3, 4, 5], + distil=False, + seed_number=42, + num_targets=2, + num_output_channels=2, + ): + self.parent = parent + self.batch_size = batch_size + self.prediction_length = prediction_length + self.context_length = context_length + self.patch_length = patch_length + self.patch_stride = patch_stride + self.num_input_channels = num_input_channels + self.num_time_features = num_time_features + self.lags_sequence = lags_sequence + self.is_training = is_training + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + + self.seed_number = seed_number + self.num_targets = num_targets + self.num_output_channels = num_output_channels + self.distil = distil + self.num_patches = (max(self.context_length, self.patch_length) - self.patch_length) // self.patch_stride + 1 + + def get_config(self): + return PatchTSTConfig( + prediction_length=self.prediction_length, + patch_length=self.patch_length, + patch_stride=self.patch_stride, + num_input_channels=self.num_input_channels, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + context_length=self.context_length, + activation_function=self.hidden_act, + seed_number=self.seed_number, + num_targets=self.num_targets, + num_output_channels=self.num_output_channels, + ) + + def prepare_patchtst_inputs_dict(self, config): + _past_length = config.context_length + # bs, num_input_channels, num_patch, patch_len + + # [bs x seq_len x num_input_channels] + past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels]) + + future_values = floats_tensor([self.batch_size, config.prediction_length, self.num_input_channels]) + + inputs_dict = { + "past_values": past_values, + "future_values": future_values, + } + return inputs_dict + + def prepare_config_and_inputs(self): + config = self.get_config() + inputs_dict = self.prepare_patchtst_inputs_dict(config) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + +@require_torch +class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + PatchTSTModel, + PatchTSTForPrediction, + PatchTSTForPretraining, + PatchTSTForClassification, + PatchTSTForRegression, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = ( + (PatchTSTForPrediction, PatchTSTForRegression, PatchTSTForPretraining) if is_torch_available() else () + ) + pipeline_model_mapping = {"feature-extraction": PatchTSTModel} if is_torch_available() else {} + test_pruning = False + test_head_masking = False + test_missing_keys = False + test_torchscript = False + test_inputs_embeds = False + test_model_common_attributes = False + + test_resize_embeddings = True + test_resize_position_embeddings = False + test_mismatched_shapes = True + test_model_parallel = False + has_attentions = False + + def setUp(self): + self.model_tester = PatchTSTModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=PatchTSTConfig, + has_text_modality=False, + prediction_length=self.model_tester.prediction_length, + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + # if PatchTSTForPretraining + if model_class == PatchTSTForPretraining: + inputs_dict.pop("future_values") + # else if classification model: + elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING): + rng = random.Random(self.model_tester.seed_number) + labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng) + inputs_dict["target_values"] = labels + inputs_dict.pop("future_values") + elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING): + rng = random.Random(self.model_tester.seed_number) + target_values = floats_tensor( + [self.model_tester.batch_size, self.model_tester.num_output_channels], rng=rng + ) + inputs_dict["target_values"] = target_values + inputs_dict.pop("future_values") + return inputs_dict + + def test_save_load_strict(self): + config, _ = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) + self.assertEqual(info["missing_keys"], []) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + num_patch = self.model_tester.num_patches + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [num_patch, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + print("model_class: ", model_class) + + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + @unittest.skip(reason="we have no tokens embeddings") + def test_resize_tokens_embeddings(self): + pass + + def test_model_main_input_name(self): + model_signature = inspect.signature(getattr(PatchTSTModel, "forward")) + # The main input is the name of the argument after `self` + observed_main_input_name = list(model_signature.parameters.keys())[1] + self.assertEqual(PatchTSTModel.main_input_name, observed_main_input_name) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = [ + "past_values", + "past_observed_mask", + "future_values", + ] + if model_class == PatchTSTForPretraining: + expected_arg_names.remove("future_values") + elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values( + MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING + ): + expected_arg_names.remove("future_values") + expected_arg_names.remove("past_observed_mask") + expected_arg_names.append("target_values") if model_class in get_values( + MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING + ) else expected_arg_names.append("target_values") + expected_arg_names.append("past_observed_mask") + expected_arg_names.extend( + [ + "output_hidden_states", + "output_attentions", + "return_dict", + ] + ) + + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + @is_flaky() + def test_retain_grad_hidden_states_attentions(self): + super().test_retain_grad_hidden_states_attentions() + + +# Note: Publishing of this dataset is under internal review. The dataset is not yet downloadable. +def prepare_batch(repo_id="ibm/etth1-forecast-test", file="train-batch.pt"): + file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset") + batch = torch.load(file, map_location=torch_device) + return batch + + +# Note: Publishing of pretrained weights is under internal review. Pretrained model is not yet downloadable. +@require_torch +@slow +class PatchTSTModelIntegrationTests(unittest.TestCase): + # Publishing of pretrained weights are under internal review. Pretrained model is not yet downloadable. + def test_pretrain_head(self): + model = PatchTSTForPretraining.from_pretrained("ibm/patchtst-etth1-pretrain").to(torch_device) + batch = prepare_batch() + + torch.manual_seed(0) + with torch.no_grad(): + output = model(past_values=batch["past_values"].to(torch_device)).prediction_output + num_patch = ( + max(model.config.context_length, model.config.patch_length) - model.config.patch_length + ) // model.config.patch_stride + 1 + expected_shape = torch.Size([64, model.config.num_input_channels, num_patch, model.config.patch_length]) + self.assertEqual(output.shape, expected_shape) + + expected_slice = torch.tensor( + [[[-0.5409]], [[0.3093]], [[-0.3759]], [[0.5068]], [[-0.8387]], [[0.0937]], [[0.2809]]], + device=torch_device, + ) + self.assertTrue(torch.allclose(output[0, :7, :1, :1], expected_slice, atol=TOLERANCE)) + + # Publishing of pretrained weights are under internal review. Pretrained model is not yet downloadable. + def test_prediction_head(self): + model = PatchTSTForPrediction.from_pretrained("ibm/patchtst-etth1-forecast").to(torch_device) + + batch = prepare_batch(file="test-batch.pt") + + torch.manual_seed(0) + with torch.no_grad(): + output = model( + past_values=batch["past_values"].to(torch_device), + future_values=batch["future_values"].to(torch_device), + ).prediction_outputs + expected_shape = torch.Size([64, model.config.prediction_length, model.config.num_input_channels]) + self.assertEqual(output.shape, expected_shape) + + expected_slice = torch.tensor( + [[0.3228, 0.4320, 0.4591, 0.4066, -0.3461, 0.3094, -0.8426]], + device=torch_device, + ) + self.assertTrue(torch.allclose(output[0, :1, :7], expected_slice, atol=TOLERANCE)) diff --git a/utils/check_repo.py b/utils/check_repo.py index d740eefed01936..390f4ca5cab5cd 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -185,6 +185,8 @@ "TimeSeriesTransformerForPrediction", "InformerForPrediction", "AutoformerForPrediction", + "PatchTSTForPretraining", + "PatchTSTForPrediction", "JukeboxVQVAE", "JukeboxPrior", "SamModel",