diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 7e35691d54b..0bd99a585e0 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -177,6 +177,7 @@ class OnnxConfig(ExportConfig, ABC): "text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), "text-classification": OrderedDict({"logits": {0: "batch_size"}}), "text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "time-series-forecasting": OrderedDict({"prediction_outputs": {0: "batch_size"}}), "token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "zero-shot-image-classification": OrderedDict( diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 1c838408807..a1d11728a4f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -59,6 +59,7 @@ NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedTextConfigWithGQA, + NormalizedTimeSeriesForecastingConfig, NormalizedVisionConfig, check_if_diffusers_greater, check_if_transformers_greater, @@ -2499,3 +2500,46 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + + + + +class TimesFMDummyInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("inputs",) + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + **kwargs, + ): + self.task = task + self.normalized_config = normalized_config + + self.batch_size = batch_size + self.context_len = normalized_config.context_len + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + return self.random_float_tensor( + shape=[self.batch_size, self.context_len], + min_value=-1, + max_value=1, + framework=framework, + dtype=float_dtype, + ) + + +class TimesFMOnnxConfig(OnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTimeSeriesForecastingConfig + MIN_TRANSFORMERS_VERSION = version.parse("4.47.0") + DUMMY_INPUT_GENERATOR_CLASSES = (TimesFMDummyInputGenerator,) + + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return {"inputs": {0: "batch_size", 1: "sequence_length"}} + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return super().outputs diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 4db4130302d..a9abf6556f5 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -320,6 +320,7 @@ class TasksManager: ("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"), # VisionEncoderDecoderModel is not registered in AutoModelForDocumentQuestionAnswering ("pt", "vision-encoder-decoder", "document-question-answering"): ("transformers", "VisionEncoderDecoderModel"), + ("pt", "timesfm", "time-series-forecasting"): ("transformers", "TimesFMModelForPrediction"), } _ENCODER_DECODER_TASKS = ( @@ -939,6 +940,10 @@ class TasksManager: "text-classification", onnx="Qwen2OnnxConfig", ), + "timesfm": supported_tasks_mapping( + "time-series-forecasting", + onnx="TimesFMOnnxConfig", + ), "llama": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 2aa90253d08..23ac9545ad9 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -91,5 +91,6 @@ NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedTextConfigWithGQA, + NormalizedTimeSeriesForecastingConfig, NormalizedVisionConfig, ) diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 9ceed24c2dd..3dc63e54713 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -96,6 +96,10 @@ class NormalizedSeq2SeqConfig(NormalizedTextConfig): DECODER_NUM_ATTENTION_HEADS = NormalizedTextConfig.NUM_ATTENTION_HEADS +class NormalizedTimeSeriesForecastingConfig(NormalizedConfig): + CONTEXT_LEN = "context_len" + + class NormalizedVisionConfig(NormalizedConfig): IMAGE_SIZE = "image_size" NUM_CHANNELS = "num_channels"