From 8fe88ea5c040aefb4b7da879248e257507d74b5c Mon Sep 17 00:00:00 2001 From: Visesh Rajendraprasad Date: Tue, 8 Oct 2024 21:20:58 +0530 Subject: [PATCH] serverless: fix template saves, loads and add new types (#137) --- .github/workflows/ci_cd.yml | 1 + .../core/serverless/__init__.py | 17 +--- .../dynamicreporting/core/serverless/adr.py | 12 ++- .../dynamicreporting/core/serverless/base.py | 53 +++++++---- .../dynamicreporting/core/serverless/item.py | 37 ++++---- .../core/serverless/template.py | 88 +++++++++++++------ 6 files changed, 125 insertions(+), 83 deletions(-) diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index 44a85a13..d21e9d01 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -34,6 +34,7 @@ jobs: uses: ansys/actions/code-style@v8 with: python-version: ${{ env.MAIN_PYTHON_VERSION }} + show-diff-on-failure: false docs-style: name: Documentation style check diff --git a/src/ansys/dynamicreporting/core/serverless/__init__.py b/src/ansys/dynamicreporting/core/serverless/__init__.py index aa2a3ccc..a4f7b98c 100644 --- a/src/ansys/dynamicreporting/core/serverless/__init__.py +++ b/src/ansys/dynamicreporting/core/serverless/__init__.py @@ -1,21 +1,6 @@ # serverless from .adr import ADR -from .item import ( - HTML, - Animation, - Dataset, - File, - Image, - Item, - Movie, - Plot, - Scene, - Session, - String, - Table, - Text, - Tree, -) +from .item import HTML, Animation, Dataset, File, Image, Item, Scene, Session, String, Table, Tree from .template import ( BasicLayout, BoxLayout, diff --git a/src/ansys/dynamicreporting/core/serverless/adr.py b/src/ansys/dynamicreporting/core/serverless/adr.py index 569d29cc..ae9d384f 100644 --- a/src/ansys/dynamicreporting/core/serverless/adr.py +++ b/src/ansys/dynamicreporting/core/serverless/adr.py @@ -273,7 +273,7 @@ def get_report(self, **kwargs) -> Template: self._logger.error(f"{e}") raise e - def get_reports(self, fields: list = None, flat: bool = False) -> ObjectSet: + def get_reports(self, fields: list = None, flat: bool = False) -> Union[ObjectSet, list]: # return list of reports by default. # if fields are mentioned, return value list try: @@ -286,8 +286,14 @@ def get_reports(self, fields: list = None, flat: bool = False) -> ObjectSet: return out - def get_list_reports(self, *fields) -> ObjectSet: - return self.get_reports(*fields) + def get_list_reports(self, r_type: Optional[str] = "name") -> Union[ObjectSet, list]: + supported_types = ["name", "report"] + if r_type not in supported_types: + raise ADRException(f"r_type must be one of {supported_types}") + if r_type == "name": + return self.get_reports([r_type], flat=True) + else: + return self.get_reports() def render_report(self, context: dict = None, query: str = None, **kwargs: Any) -> str: try: diff --git a/src/ansys/dynamicreporting/core/serverless/base.py b/src/ansys/dynamicreporting/core/serverless/base.py index e5072fcd..0cc2b7bb 100644 --- a/src/ansys/dynamicreporting/core/serverless/base.py +++ b/src/ansys/dynamicreporting/core/serverless/base.py @@ -218,6 +218,10 @@ def _get_all_field_names(cls): property_fields.append(name) return tuple(property_fields) + cls._get_field_names() + @property + def saved(self): + return self._saved + @classmethod def from_db(cls, orm_instance, parent=None): cls_fields = dict(cls._get_field_names(with_types=True, include_private=True)) @@ -242,7 +246,7 @@ def from_db(cls, orm_instance, parent=None): type_ = cls._cls_registry[field_type] else: type_ = field_type - if issubclass(type_, cls): + if issubclass(cls, type_): # same hierarchy means there is a parent-child relation value = parent else: value = type_.from_db(value) @@ -280,28 +284,32 @@ def from_db(cls, orm_instance, parent=None): obj._saved = True return obj - @property - def saved(self): - return self._saved - @handle_field_errors def save(self, **kwargs): cls_fields = self._get_all_field_names() model_fields = self._get_orm_field_names(self._orm_instance) for field_ in cls_fields: - if field_ in model_fields: - value = getattr(self, field_, None) - if value is not None: - if isinstance(value, list): - obj_list = [] - for obj in value: - obj_list.append(obj._orm_instance) - if obj_list: - getattr(self._orm_instance, field_).add(*obj_list) - else: - if isinstance(value, BaseModel): - value = value._orm_instance.__class__.objects.get(guid=value.guid) - setattr(self._orm_instance, field_, value) + if field_ not in model_fields: + continue + value = getattr(self, field_, None) + if value is None: + continue + if isinstance(value, list): + obj_list = [] + for obj in value: + obj_list.append(obj._orm_instance) + getattr(self._orm_instance, field_).add(*obj_list) + else: + if isinstance(value, BaseModel): # relations + try: + value = value._orm_instance.__class__.objects.using( + kwargs.get("using", "default") + ).get(guid=value.guid) + except ObjectDoesNotExist: + raise value.__class__.DoesNotExist + # for all others + setattr(self._orm_instance, field_, value) + self._orm_instance.save(**kwargs) self._saved = True @@ -337,6 +345,13 @@ def filter(cls, **kwargs): qs = cls._orm_model_cls.objects.filter(**kwargs) return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs) + @classmethod + @handle_field_errors + def bulk_create(cls, **kwargs): + objs = cls._orm_model_cls.objects.bulk_create(**kwargs) + qs = cls._orm_model_cls.objects.filter(pk__in=[obj.pk for obj in objs]) + return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs) + @classmethod @handle_field_errors def find(cls, query="", reverse=False, sort_tag="date"): @@ -422,7 +437,7 @@ def values_list(self, *fields, flat=False): ret = [] for obj in self._obj_set: ret.append(tuple(getattr(obj, f, None) for f in fields)) - return chain.from_iterable(ret) if flat else ret + return list(chain.from_iterable(ret)) if flat else ret class Validator(ABC): diff --git a/src/ansys/dynamicreporting/core/serverless/item.py b/src/ansys/dynamicreporting/core/serverless/item.py index 8dbf7305..7dbe669c 100644 --- a/src/ansys/dynamicreporting/core/serverless/item.py +++ b/src/ansys/dynamicreporting/core/serverless/item.py @@ -25,7 +25,7 @@ class Session(BaseModel): date: datetime = field(compare=False, kw_only=True, default_factory=timezone.now) hostname: str = field(compare=False, kw_only=True, default=str(platform.node)) platform: str = field(compare=False, kw_only=True, default=str(report_utils.enve_arch)) - application: str = field(compare=False, kw_only=True, default="Python API") + application: str = field(compare=False, kw_only=True, default="Serverless ADR Python API") version: str = field(compare=False, kw_only=True, default="1.0") _orm_model: str = "data.models.Session" @@ -180,7 +180,7 @@ class SimplePayloadMixin: def from_db(cls, orm_instance, **kwargs): from data.extremely_ugly_hacks import safe_unpickle - obj = super().from_db(orm_instance) + obj = super().from_db(orm_instance, **kwargs) obj.content = safe_unpickle(obj._orm_instance.payloaddata) return obj @@ -194,7 +194,7 @@ class FilePayloadMixin: @classmethod def from_db(cls, orm_instance, **kwargs): - obj = super().from_db(orm_instance) + obj = super().from_db(orm_instance, **kwargs) obj.content = obj._orm_instance.payloadfile.path return obj @@ -217,6 +217,23 @@ class Item(BaseModel): dataset: Dataset = field(compare=False, kw_only=True, default=None) type: str = "none" _orm_model: str = "data.models.Item" + # Class-level registry of subclasses keyed by type + _type_registry = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Automatically register the subclass based on its type attribute + Item._type_registry[cls.type] = cls + + @classmethod + def from_db(cls, orm_instance, **kwargs): + # Create a new instance of the correct subclass + if cls is Item: + # Get the class based on the type attribute + item_cls = cls._type_registry[orm_instance.type] + return item_cls.from_db(orm_instance, **kwargs) + + return super().from_db(orm_instance, **kwargs) def save(self, **kwargs): if self.session is None or self.dataset is None: @@ -280,10 +297,6 @@ class String(SimplePayloadMixin, Item): type: str = "string" -class Text(String): - pass - - class HTML(String): content: HTMLContent = HTMLContent() type: str = "html" @@ -298,7 +311,7 @@ class Table(Item): def from_db(cls, orm_instance, **kwargs): from data.extremely_ugly_hacks import safe_unpickle - obj = super().from_db(orm_instance) + obj = super().from_db(orm_instance, **kwargs) payload = safe_unpickle(obj._orm_instance.payloaddata) obj.content = payload.pop("array", None) for prop in cls._properties: @@ -318,10 +331,6 @@ def save(self, **kwargs): super().save(**kwargs) -class Plot(Table): - pass - - class Tree(SimplePayloadMixin, Item): content: TreeContent = TreeContent() type: str = "tree" @@ -347,10 +356,6 @@ class Animation(FilePayloadMixin, Item): type: str = "anim" -class Movie(Animation): - pass - - class Scene(FilePayloadMixin, Item): content: SceneContent = SceneContent() type: str = "scene" diff --git a/src/ansys/dynamicreporting/core/serverless/template.py b/src/ansys/dynamicreporting/core/serverless/template.py index c4cfd1a3..df1f1312 100644 --- a/src/ansys/dynamicreporting/core/serverless/template.py +++ b/src/ansys/dynamicreporting/core/serverless/template.py @@ -18,12 +18,55 @@ class Template(BaseModel): parent: "Template" = field(compare=False, kw_only=True, default=None) children: list["Template"] = field(compare=False, kw_only=True, default_factory=list) _children_order: str = field( - compare=False, init=False, default=None + compare=False, init=False, default="" ) # computed from self.children _master: bool = field(compare=False, init=False, default=None) # computed from self.parent report_type: str = "" - _properties: tuple = tuple() + _properties: tuple = tuple() # todo: add properties of each type ref: report_objects _orm_model: str = "reports.models.Template" + # Class-level registry of subclasses keyed by type + _type_registry = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Automatically register the subclass based on its type attribute + Template._type_registry[cls.report_type] = cls + + @classmethod + def from_db(cls, orm_instance, **kwargs): + # Create a new instance of the correct subclass + if cls is Template: + # Get the class based on the type attribute + templ_cls = cls._type_registry[orm_instance.report_type] + obj = templ_cls.from_db(orm_instance, **kwargs) + else: + obj = super().from_db(orm_instance, **kwargs) + # add relevant props from property dict. + props = obj.get_property() + for prop in cls._properties: + if prop in props: + setattr(obj, prop, props[prop]) + return obj + + def save(self, **kwargs): + if self.parent is not None and not self.parent._saved: + raise Template.NotSaved( + extra_detail="Failed to save template because its parent is not saved" + ) + for child in self.children: + if not child._saved: + raise Template.NotSaved( + extra_detail="Failed to save template because its children are not saved" + ) + # set properties + prop_dict = {} + for prop in self._properties: + value = getattr(self, prop, None) + if value is not None: + prop_dict[prop] = value + if prop_dict: + self.add_property(prop_dict) + super().save(**kwargs) @property def type(self): @@ -37,7 +80,7 @@ def type(self, value): @property def children_order(self): - return self._children_order + return ",".join([str(child.guid) for child in self.children]) @property def master(self): @@ -107,35 +150,10 @@ def add_property(self, new_props: dict): params["properties"] = curr_props | new_props self.params = json.dumps(params) - def save(self, **kwargs): - if self.parent is not None and not self.parent._saved: - raise Template.NotSaved( - extra_detail="Failed to save template because its parent is not saved" - ) - for child in self.children: - if not child._saved: - raise Template.NotSaved( - extra_detail="Failed to save template because its children are not saved" - ) - # set properties - prop_dict = {} - for prop in self._properties: - value = getattr(self, prop, None) - if value is not None: - prop_dict[prop] = value - if prop_dict: - self.add_property(prop_dict) - super().save(**kwargs) - @classmethod def get(cls, **kwargs): new_kwargs = {"report_type": cls.report_type, **kwargs} if cls.report_type else kwargs - obj = super().get(**new_kwargs) - props = obj.get_property() - for prop in cls._properties: - if prop in props: - setattr(obj, prop, props[prop]) - return obj + return super().get(**new_kwargs) @classmethod def filter(cls, **kwargs): @@ -278,3 +296,15 @@ class TreeMergeGenerator(Generator): class SQLQueryGenerator(Generator): report_type: str = "Generator:sqlqueries" + + +class ItemsComparisonGenerator(Generator): + report_type: str = "Generator:itemscomparison" + + +class StatisticalGenerator(Generator): + report_type: str = "Generator:statistical" + + +class IteratorGenerator(Generator): + report_type: str = "Generator:iterator"