Skip to content

Commit

Permalink
feat: add data struct for inference schema (#494)
Browse files Browse the repository at this point in the history
<!--  Thanks for sending a pull request!  Here are some tips for you:

1. Run unit tests and ensure that they are passing
2. If your change introduces any API changes, make sure to update the
e2e tests
3. Make sure documentation is updated for your PR!

-->

**What this PR does / why we need it**:
This PR provides the data types required to support ML observability
within Merlin. The new data types are expected to be used for
#488 , flyte workflows, and
also in the future when we store an optional inference schema within
Merlin.

dataclasses_json dependency is introduced here as it is required if we
want to pass dataclass as an input to Flyte.

**Which issue(s) this PR fixes**:
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->

Fixes #

**Does this PR introduce a user-facing change?**:
<!--
If no, just write "NONE" in the release-note block below.
If yes, a release note is required. Enter your extended release note in
the block below.
If the PR requires additional action from users switching to the new
release, include the string "action required".

For more information about release notes, see kubernetes' guide here:
http://git.k8s.io/community/contributors/guide/release-notes.md
-->

```release-note
NONE
```

**Checklist**

- [ ] Added unit test, integration, and/or e2e tests
- [ ] Tested locally
- [ ] Updated documentation
- [ ] Update Swagger spec if the PR introduce API changes
- [ ] Regenerated Golang and Python client if the PR introduce API
changes
  • Loading branch information
khorshuheng authored Nov 22, 2023
1 parent 9cc5432 commit 80cf911
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
Empty file.
111 changes: 111 additions & 0 deletions python/sdk/merlin/observability/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from dataclasses import dataclass
from enum import unique, Enum
from typing import Dict, Optional, List

from dataclasses_json import dataclass_json


@unique
class ValueType(Enum):
FLOAT64 = 1
INT64 = 2
BOOLEAN = 3
STRING = 4


@dataclass_json
@dataclass
class RegressionOutput:
prediction_score_column: str

@property
def column_types(self) -> Dict[str, ValueType]:
return {self.prediction_score_column: ValueType.FLOAT64}


@dataclass_json
@dataclass
class BinaryClassificationOutput:
prediction_label_column: str
prediction_score_column: Optional[str] = None

@property
def column_types(self) -> Dict[str, ValueType]:
column_types_mapping = {self.prediction_label_column: ValueType.STRING}
if self.prediction_score_column is not None:
column_types_mapping[self.prediction_score_column] = ValueType.FLOAT64
return column_types_mapping


@dataclass_json
@dataclass
class MulticlassClassificationOutput:
prediction_label_columns: List[str]
prediction_score_columns: Optional[List[str]] = None

@property
def column_types(self) -> Dict[str, ValueType]:
column_types_mapping = {
label_column: ValueType.STRING
for label_column in self.prediction_label_columns
}
if self.prediction_score_columns is not None:
for column_name in self.prediction_score_columns:
column_types_mapping[column_name] = ValueType.FLOAT64
return column_types_mapping


@dataclass_json
@dataclass
class RankingOutput:
rank_column: str
prediction_group_id_column: str

@property
def column_types(self) -> Dict[str, ValueType]:
return {
self.rank_column: ValueType.INT64,
self.prediction_group_id_column: ValueType.STRING,
}


@unique
class InferenceType(Enum):
BINARY_CLASSIFICATION = 1
MULTICLASS_CLASSIFICATION = 2
REGRESSION = 3
RANKING = 4


@dataclass_json
@dataclass
class InferenceSchema:
feature_types: Dict[str, ValueType]
type: InferenceType
binary_classification: Optional[BinaryClassificationOutput] = None
multiclass_classification: Optional[MulticlassClassificationOutput] = None
regression: Optional[RegressionOutput] = None
ranking: Optional[RankingOutput] = None
prediction_id_column: Optional[str] = "prediction_id"
tag_columns: Optional[List[str]] = None

@property
def feature_columns(self) -> List[str]:
return list(self.feature_types.keys())

@property
def prediction_column_types(self) -> Dict[str, ValueType]:
if self.type == InferenceType.BINARY_CLASSIFICATION:
assert self.binary_classification is not None
return self.binary_classification.column_types
elif self.type == InferenceType.MULTICLASS_CLASSIFICATION:
assert self.multiclass_classification is not None
return self.multiclass_classification.column_types
elif self.type == InferenceType.REGRESSION:
assert self.regression is not None
return self.regression.column_types
elif self.type == InferenceType.RANKING:
assert self.ranking is not None
return self.ranking.column_types
else:
raise ValueError(f"Unknown prediction type: {self.type}")
1 change: 1 addition & 0 deletions python/sdk/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"Click>=7.0,<8.1.4",
"cloudpickle==2.0.0", # used by mlflow
"cookiecutter>=1.7.2",
"dataclasses-json>=0.5.2", # allow Flyte version 1.2.0 or above to import Merlin SDK
"docker>=4.2.1",
"google-cloud-storage>=1.19.0",
"protobuf>=3.12.0,<5.0.0", # Determined by the mlflow dependency
Expand Down

0 comments on commit 80cf911

Please sign in to comment.