Skip to content

Commit

Permalink
Linting, formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
sei-aderr committed Dec 10, 2024
1 parent 955f020 commit bb90883
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 75 deletions.
8 changes: 2 additions & 6 deletions mlte/qa_category/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def to_model(self) -> QACategoryModel:
)

@classmethod
def from_model(
cls, model: QACategoryModel
) -> QACategory:
def from_model(cls, model: QACategoryModel) -> QACategory:
"""
Load a QACategory instance from a model.
Expand All @@ -85,9 +83,7 @@ def from_model(
except Exception:
raise RuntimeError(f"Module {module_path} not found")
try:
class_: Type[QACategory] = getattr(
qa_category_module, classname
)
class_: Type[QACategory] = getattr(qa_category_module, classname)
except AttributeError:
raise RuntimeError(
f"QACategory {model.name} in module {module_path} not found"
Expand Down
46 changes: 15 additions & 31 deletions mlte/spec/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def to_model(self) -> ArtifactModel:
header=self.build_artifact_header(),
body=SpecModel(
qa_categories=[
self._to_qa_category_model(
qa_category
)
self._to_qa_category_model(qa_category)
for qa_category, _ in self.qa_categories.items()
],
),
Expand All @@ -83,23 +81,17 @@ def from_model(cls, model: ArtifactModel) -> Spec:
body = typing.cast(SpecModel, model.body)
return Spec(
identifier=model.header.identifier,
qa_categories=Spec.to_qa_category_dict(
body.qa_categories
),
qa_categories=Spec.to_qa_category_dict(body.qa_categories),
)

def _to_qa_category_model(
self, qa_category: QACategory
) -> QACategoryModel:
def _to_qa_category_model(self, qa_category: QACategory) -> QACategoryModel:
"""
Generate a qa category model. This just uses QACategory.to_model, but adds the list of conditions.
:param qa_category: The qa category of interest
:return: The qa category model
"""
qa_category_model: QACategoryModel = (
qa_category.to_model()
)
qa_category_model: QACategoryModel = qa_category.to_model()
qa_category_model.conditions = {
measurement_id: condition.to_model()
for measurement_id, condition in self.qa_categories[
Expand All @@ -115,9 +107,7 @@ def to_qa_category_dict(
) -> Dict[QACategory, Dict[str, Condition]]:
"""Converts a list of qa category models, into a dict of properties and conditions."""
return {
QACategory.from_model(
qa_category_model
): {
QACategory.from_model(qa_category_model): {
measurement_id: Condition.from_model(condition_model)
for measurement_id, condition_model in qa_category_model.conditions.items()
}
Expand All @@ -133,43 +123,37 @@ def get_default_id() -> str:
# Quality Attribute Category Manipulation
# -------------------------------------------------------------------------

def get_qa_category(
self, qa_category_id: str
) -> QACategory:
def get_qa_category(self, qa_category_id: str) -> QACategory:
"""
Returns a particular qa category with the given id.
:param qa_category: The qa category itself, or its identifier
:return: The qa category object.
"""
properties = [
prop
for prop in self.qa_categories
if prop.name == qa_category_id
qa_categories = [
category
for category in self.qa_categories
if category.name == qa_category_id
]
if len(properties) == 0:
if len(qa_categories) == 0:
raise RuntimeError(
f"QA category {qa_category_id} was not found in list."
)
if len(properties) > 1:
if len(qa_categories) > 1:
raise RuntimeError(
f"Multiple properties with same id were found: {qa_category_id}"
)
return properties[0]
return qa_categories[0]

def has_qa_category(
self, qa_category: Union[QACategory, str]
) -> bool:
def has_qa_category(self, qa_category: Union[QACategory, str]) -> bool:
"""
Determine if the spec contains a particular qa category.
:param qa_category: The qa category itself, or its identifier
:return: `True` if the spec has the qa category, `False` otherwise
"""
target_name = (
qa_category
if isinstance(qa_category, str)
else qa_category.name
qa_category if isinstance(qa_category, str) else qa_category.name
)
return any(
qa_category.name == target_name
Expand Down
10 changes: 2 additions & 8 deletions mlte/store/artifact/underlying/rdbs/factory_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from sqlalchemy.orm import Session

from mlte.evidence.metadata import EvidenceMetadata, Identifier
from mlte.spec.model import (
ConditionModel,
QACategoryModel,
SpecModel,
)
from mlte.spec.model import ConditionModel, QACategoryModel, SpecModel
from mlte.store.artifact.underlying.rdbs.metadata import DBArtifactHeader
from mlte.store.artifact.underlying.rdbs.metadata_spec import (
DBCondition,
Expand Down Expand Up @@ -45,9 +41,7 @@ def create_spec_db_from_model(
module=qa_category.module,
spec=spec_obj,
)
spec_obj.qa_categories.append(
qa_category_obj
)
spec_obj.qa_categories.append(qa_category_obj)

for (
measurement_id,
Expand Down
4 changes: 3 additions & 1 deletion mlte/store/artifact/underlying/rdbs/metadata_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ class DBCondition(DBBase):
value_class: Mapped[str]
qa_category_id: Mapped[int] = mapped_column(ForeignKey("qa_category.id"))

qa_category: Mapped[DBQACategory] = relationship(back_populates="conditions")
qa_category: Mapped[DBQACategory] = relationship(
back_populates="conditions"
)

def __repr__(self) -> str:
return f"Condition(id={self.id!r}, name={self.name!r}, arguments={self.arguments!r}, value_class={self.value_class!r}, qa_category={self.qa_category!r})"
Expand Down
4 changes: 1 addition & 3 deletions mlte/validation/validated_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def from_model(cls, model: ArtifactModel) -> ValidatedSpec:
identifier=model.header.identifier,
spec=Spec(
body.spec_identifier,
Spec.to_qa_category_dict(
body.spec.qa_categories
)
Spec.to_qa_category_dict(body.spec.qa_categories)
if body.spec is not None
else {},
),
Expand Down
6 changes: 1 addition & 5 deletions test/fixture/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@
QuantitiveAnalysisDescriptor,
ReportModel,
)
from mlte.spec.model import (
ConditionModel,
QACategoryModel,
SpecModel,
)
from mlte.spec.model import ConditionModel, QACategoryModel, SpecModel
from mlte.validation.model import ResultModel, ValidatedSpecModel
from mlte.value.model import IntegerValueModel, ValueModel
from mlte.value.types.integer import Integer
Expand Down
33 changes: 13 additions & 20 deletions test/qa_category/test_quality_attribute_categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,14 @@
"""

from mlte.qa_category.base import QACategory
from mlte.qa_category.costs.predicting_compute_cost import (
PredictingComputeCost,
)
from mlte.qa_category.costs.predicting_memory_cost import (
PredictingMemoryCost,
)
from mlte.qa_category.costs.predicting_compute_cost import PredictingComputeCost
from mlte.qa_category.costs.predicting_memory_cost import PredictingMemoryCost
from mlte.qa_category.costs.storage_cost import StorageCost
from mlte.qa_category.costs.training_compute_cost import (
TrainingComputeCost,
)
from mlte.qa_category.costs.training_memory_cost import (
TrainingMemoryCost,
)
from mlte.qa_category.costs.training_compute_cost import TrainingComputeCost
from mlte.qa_category.costs.training_memory_cost import TrainingMemoryCost
from mlte.qa_category.fairness.fairness import Fairness
from mlte.qa_category.functionality.task_efficacy import (
TaskEfficacy,
)
from mlte.qa_category.interpretability.interpretability import (
Interpretability,
)
from mlte.qa_category.functionality.task_efficacy import TaskEfficacy
from mlte.qa_category.interpretability.interpretability import Interpretability
from mlte.qa_category.robustness.robustness import Robustness


Expand Down Expand Up @@ -67,7 +55,10 @@ def test_training_memory_cost():
def test_task_efficacy():
p = TaskEfficacy("test")
assert_qa_category(
p, "TaskEfficacy", "test", "mlte.qa_category.functionality.task_efficacy"
p,
"TaskEfficacy",
"test",
"mlte.qa_category.functionality.task_efficacy",
)


Expand All @@ -93,7 +84,9 @@ def test_predicting_memory_cost():

def test_fairness():
p = Fairness("test")
assert_qa_category(p, "Fairness", "test", "mlte.qa_category.fairness.fairness")
assert_qa_category(
p, "Fairness", "test", "mlte.qa_category.fairness.fairness"
)


def test_robustness():
Expand Down
2 changes: 1 addition & 1 deletion test/spec/extended_qa_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class ExtendedQACategory(QACategory):
"""
The ExtendedQACategory qa category is a
The ExtendedQACategory qa category is a
qa category not defined in the default qa category package.
"""

Expand Down

0 comments on commit bb90883

Please sign in to comment.