Skip to content

Commit

Permalink
include services in style, format
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Aug 2, 2024
1 parent b060ec3 commit de5584f
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Adapted from HF Transformers: https://github.com/huggingface/transformers/tree/main
.PHONY: quality style

check_dirs := tests tsfm_public tsfmhfdemos notebooks
check_dirs := tests tsfm_public tsfmhfdemos notebooks services


# this target runs checks on all files
Expand Down
10 changes: 1 addition & 9 deletions services/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Standard
from datetime import datetime
from io import BytesIO
from typing import Any, Dict
import json

# Third Party
import numpy as np
Expand All @@ -12,16 +9,11 @@

# First Party
from tsfm_public.toolkit.util import select_by_index
from tsfm_public import TinyTimeMixerForPrediction

from transformers import PatchTSTForPrediction, PatchTSMixerForPrediction


@pytest.fixture(scope="module")
def ts_data():
dataset_path = (
"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"
)
dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"

forecast_length = 96
context_length = 512
Expand Down
23 changes: 11 additions & 12 deletions services/inference/tsfmservices/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import logging
import os


TSFM_PYTHON_LOGGING_LEVEL = os.getenv("TSFM_PYTHON_LOGGING_LEVEL", "INFO")

LevelNamesMapping = dict(
INFO=logging.INFO,
WARN=logging.WARN,
WARNING=logging.WARNING,
ERROR=logging.ERROR,
CRITICAL=logging.CRITICAL,
DEBUG=logging.DEBUG,
FATAL=logging.FATAL,
)
LevelNamesMapping = {
"INFO": logging.INFO,
"WARN": logging.WARN,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
"DEBUG": logging.DEBUG,
"FATAL": logging.FATAL,
}

TSFM_PYTHON_LOGGING_LEVEL = (
logging.getLevelNamesMapping()[TSFM_PYTHON_LOGGING_LEVEL]
Expand Down Expand Up @@ -41,7 +42,5 @@

TSFM_CONFIG_FILE = os.getenv(
"TSFM_CONFIG_FILE",
os.path.realpath(
os.path.join(os.path.dirname(__file__), "config", "default_config.yml")
),
os.path.realpath(os.path.join(os.path.dirname(__file__), "config", "default_config.yml")),
)
21 changes: 6 additions & 15 deletions services/inference/tsfmservices/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Third Party
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel


LOGGER = logging.getLogger(__file__)


Expand All @@ -36,9 +37,7 @@ def register_config(model_type: str, model_config_name: str, module_path: str) -
mod = importlib.import_module(module_path)
conf_class = getattr(mod, model_config_name, None)
except ModuleNotFoundError as exc: # modulenot found, key error ?
raise RuntimeError(
f"Could not load {model_config_name} from {module_path}"
) from exc
raise RuntimeError(f"Could not load {model_config_name} from {module_path}") from exc

if conf_class is not None:
AutoConfig.register(model_type, conf_class)
Expand Down Expand Up @@ -73,19 +72,15 @@ def load_config(
conf = AutoConfig.from_pretrained(model_path)
except (KeyError, ValueError) as exc: # determine error raised by autoconfig
if model_type is None or model_config_name is None or module_path is None:
raise ValueError(
"model_type, model_config_name, and module_path should be specified."
) from exc
raise ValueError("model_type, model_config_name, and module_path should be specified.") from exc

register_config(model_type, model_config_name, module_path)
conf = AutoConfig.from_pretrained(model_path)

return conf


def _get_model_class(
config: PretrainedConfig, module_path: Optional[str] = None
) -> type:
def _get_model_class(config: PretrainedConfig, module_path: Optional[str] = None) -> type:
"""Helper to find model class based on config object
First the module_path will be checked if it can be loaded in the current environment. If not
Expand Down Expand Up @@ -120,9 +115,7 @@ def _get_model_class(
return model_class
except AttributeError as exc:
# catch specific error import error or attribute error
raise AttributeError(
"Could not load model class for architecture '{arch}'."
) from exc
raise AttributeError("Could not load model class for architecture '{arch}'.") from exc


def load_model(
Expand All @@ -147,9 +140,7 @@ def load_model(
"""

if module_path is not None and config is None:
raise ValueError(
"Config must be provided when loading from a custom module_path"
)
raise ValueError("Config must be provided when loading from a custom module_path")

if config is not None:
model_class = _get_model_class(config, module_path=module_path)
Expand Down
23 changes: 6 additions & 17 deletions services/inference/tsfmservices/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,25 @@

import datetime
import logging
from typing import Any, Dict, List
from pathlib import Path
from typing import Any, Dict, List

import pandas as pd
from fastapi import APIRouter, HTTPException

from tsfm_public import TimeSeriesForecastingPipeline, TimeSeriesPreprocessor
from tsfmservices import TSFM_ALLOW_LOAD_FROM_HF_HUB

from ..common.constants import API_VERSION
from ..common.util import load_config, load_model, register_config
from .payloads import ForecastingInferenceInput, PredictOutput

from tsfmservices import TSFM_ALLOW_LOAD_FROM_HF_HUB


LOGGER = logging.getLogger(__file__)


class InferenceRuntime:

def __init__(self, config: Dict[str, Any] = {}):

self.config = config
model_map = {}

Expand All @@ -35,14 +33,11 @@ def __init__(self, config: Dict[str, Any] = {}):
)
LOGGER.info(f"registered {custom_module['model_type']}")

model_map[custom_module["model_config_name"]] = custom_module[
"module_path"
]
model_map[custom_module["model_config_name"]] = custom_module["module_path"]

self.model_to_module_map = model_map

def add_routes(self, app):

self.router = APIRouter(prefix=f"/{API_VERSION}/inference", tags=["inference"])
self.router.add_api_route(
"/forecasting",
Expand All @@ -53,7 +48,6 @@ def add_routes(self, app):
app.include_router(self.router)

def load(self, model_path: str):

try:
preprocessor = TimeSeriesPreprocessor.from_pretrained(model_path)
LOGGER.info("Successfully loaded preprocessor")
Expand Down Expand Up @@ -84,10 +78,7 @@ def forecast(self, input: ForecastingInferenceInput):
LOGGER.exception(e)
raise HTTPException(status_code=500, detail=repr(e))

def _forecast_common(
self, input_payload: ForecastingInferenceInput
) -> PredictOutput:

def _forecast_common(self, input_payload: ForecastingInferenceInput) -> PredictOutput:
# we need some sort of model registry
# payload = input_payload.model_dump() # do we need?

Expand Down Expand Up @@ -151,9 +142,7 @@ def _forecast_common(
future_data = preprocessor.preprocess(future_data)
future_data.drop(columns=input_payload.target_columns)

forecasts = forecast_pipeline(
test_data, future_time_series=future_data, inverse_scale_outputs=True
)
forecasts = forecast_pipeline(test_data, future_time_series=future_data, inverse_scale_outputs=True)

return PredictOutput(
model_id=input_payload.model_id,
Expand Down
7 changes: 4 additions & 3 deletions services/inference/tsfmservices/inference/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""Primary entry point for inference services"""

import logging
from fastapi import FastAPI

import yaml
from fastapi import FastAPI

from tsfmservices import (
TSFM_CONFIG_FILE,
TSFM_PYTHON_LOGGING_FORMAT,
TSFM_PYTHON_LOGGING_LEVEL,
TSFM_CONFIG_FILE,
)
from tsfmservices.common.constants import API_VERSION
from tsfmservices.inference import InferenceRuntime


logging.basicConfig(
format=TSFM_PYTHON_LOGGING_FORMAT,
level=TSFM_PYTHON_LOGGING_LEVEL,
Expand Down Expand Up @@ -42,7 +44,6 @@ def root():


if __name__ == "__main__":

# Third Party
import uvicorn

Expand Down
17 changes: 6 additions & 11 deletions services/inference/tsfmservices/inference/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
# Third Party
from pydantic import BaseModel, ConfigDict, Field, model_validator


# WARNING: DO NOT IMPORT util here or else you'll get a circular dependency

EverythingPatternedString = Annotated[
str, Field(min_length=0, max_length=100, pattern=".*")
]
EverythingPatternedString = Annotated[str, Field(min_length=0, max_length=100, pattern=".*")]


class BaseMetadataInput(BaseModel):
Expand All @@ -31,11 +30,9 @@ class BaseMetadataInput(BaseModel):
min_length=0,
)
freq: Optional[str] = Field(
description="""A freqency indicator for the given timestamp_column.
See https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#period-aliases for a description of the allowed values.
If not provided, we will attempt to infer it from the data.""",
description="""A freqency indicator for the given timestamp_column. See https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#period-aliases for a description of the allowed values. If not provided, we will attempt to infer it from the data.""",
default=None,
pattern="\d+[B|D|W|M|Q|Y|h|min|s|ms|us|ns]|^\s*$",
pattern=r"\d+[B|D|W|M|Q|Y|h|min|s|ms|us|ns]|^\s*$",
min_length=0,
max_length=100,
example="1h",
Expand Down Expand Up @@ -94,7 +91,7 @@ class BaseInferenceInput(BaseModel):

model_id: str = Field(
description="A model identifier.",
pattern="^\S+$",
pattern=r"^\S+$",
min_length=1,
max_length=100,
example="ibm/tinytimemixer-monash-fl_96",
Expand All @@ -110,9 +107,7 @@ class ForecastingInferenceInput(BaseInferenceInput):
description="Data",
)

future_data: Optional[Dict[str, List[Any]]] = Field(
description="Future data", default=None
)
future_data: Optional[Dict[str, List[Any]]] = Field(description="Future data", default=None)

@model_validator(mode="before")
@classmethod
Expand Down

0 comments on commit de5584f

Please sign in to comment.