Skip to content

Commit

Permalink
serverless: fix template saves, loads and add new types (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
viseshrp authored Oct 8, 2024
1 parent 639c09e commit 8fe88ea
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 83 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
uses: ansys/actions/code-style@v8
with:
python-version: ${{ env.MAIN_PYTHON_VERSION }}
show-diff-on-failure: false

docs-style:
name: Documentation style check
Expand Down
17 changes: 1 addition & 16 deletions src/ansys/dynamicreporting/core/serverless/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
# serverless
from .adr import ADR
from .item import (
HTML,
Animation,
Dataset,
File,
Image,
Item,
Movie,
Plot,
Scene,
Session,
String,
Table,
Text,
Tree,
)
from .item import HTML, Animation, Dataset, File, Image, Item, Scene, Session, String, Table, Tree
from .template import (
BasicLayout,
BoxLayout,
Expand Down
12 changes: 9 additions & 3 deletions src/ansys/dynamicreporting/core/serverless/adr.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def get_report(self, **kwargs) -> Template:
self._logger.error(f"{e}")
raise e

def get_reports(self, fields: list = None, flat: bool = False) -> ObjectSet:
def get_reports(self, fields: list = None, flat: bool = False) -> Union[ObjectSet, list]:
# return list of reports by default.
# if fields are mentioned, return value list
try:
Expand All @@ -286,8 +286,14 @@ def get_reports(self, fields: list = None, flat: bool = False) -> ObjectSet:

return out

def get_list_reports(self, *fields) -> ObjectSet:
return self.get_reports(*fields)
def get_list_reports(self, r_type: Optional[str] = "name") -> Union[ObjectSet, list]:
supported_types = ["name", "report"]
if r_type not in supported_types:
raise ADRException(f"r_type must be one of {supported_types}")
if r_type == "name":
return self.get_reports([r_type], flat=True)
else:
return self.get_reports()

def render_report(self, context: dict = None, query: str = None, **kwargs: Any) -> str:
try:
Expand Down
53 changes: 34 additions & 19 deletions src/ansys/dynamicreporting/core/serverless/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def _get_all_field_names(cls):
property_fields.append(name)
return tuple(property_fields) + cls._get_field_names()

@property
def saved(self):
return self._saved

@classmethod
def from_db(cls, orm_instance, parent=None):
cls_fields = dict(cls._get_field_names(with_types=True, include_private=True))
Expand All @@ -242,7 +246,7 @@ def from_db(cls, orm_instance, parent=None):
type_ = cls._cls_registry[field_type]
else:
type_ = field_type
if issubclass(type_, cls):
if issubclass(cls, type_): # same hierarchy means there is a parent-child relation
value = parent
else:
value = type_.from_db(value)
Expand Down Expand Up @@ -280,28 +284,32 @@ def from_db(cls, orm_instance, parent=None):
obj._saved = True
return obj

@property
def saved(self):
return self._saved

@handle_field_errors
def save(self, **kwargs):
cls_fields = self._get_all_field_names()
model_fields = self._get_orm_field_names(self._orm_instance)
for field_ in cls_fields:
if field_ in model_fields:
value = getattr(self, field_, None)
if value is not None:
if isinstance(value, list):
obj_list = []
for obj in value:
obj_list.append(obj._orm_instance)
if obj_list:
getattr(self._orm_instance, field_).add(*obj_list)
else:
if isinstance(value, BaseModel):
value = value._orm_instance.__class__.objects.get(guid=value.guid)
setattr(self._orm_instance, field_, value)
if field_ not in model_fields:
continue
value = getattr(self, field_, None)
if value is None:
continue
if isinstance(value, list):
obj_list = []
for obj in value:
obj_list.append(obj._orm_instance)
getattr(self._orm_instance, field_).add(*obj_list)
else:
if isinstance(value, BaseModel): # relations
try:
value = value._orm_instance.__class__.objects.using(
kwargs.get("using", "default")
).get(guid=value.guid)
except ObjectDoesNotExist:
raise value.__class__.DoesNotExist
# for all others
setattr(self._orm_instance, field_, value)

self._orm_instance.save(**kwargs)
self._saved = True

Expand Down Expand Up @@ -337,6 +345,13 @@ def filter(cls, **kwargs):
qs = cls._orm_model_cls.objects.filter(**kwargs)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)

@classmethod
@handle_field_errors
def bulk_create(cls, **kwargs):
objs = cls._orm_model_cls.objects.bulk_create(**kwargs)
qs = cls._orm_model_cls.objects.filter(pk__in=[obj.pk for obj in objs])
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)

@classmethod
@handle_field_errors
def find(cls, query="", reverse=False, sort_tag="date"):
Expand Down Expand Up @@ -422,7 +437,7 @@ def values_list(self, *fields, flat=False):
ret = []
for obj in self._obj_set:
ret.append(tuple(getattr(obj, f, None) for f in fields))
return chain.from_iterable(ret) if flat else ret
return list(chain.from_iterable(ret)) if flat else ret


class Validator(ABC):
Expand Down
37 changes: 21 additions & 16 deletions src/ansys/dynamicreporting/core/serverless/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Session(BaseModel):
date: datetime = field(compare=False, kw_only=True, default_factory=timezone.now)
hostname: str = field(compare=False, kw_only=True, default=str(platform.node))
platform: str = field(compare=False, kw_only=True, default=str(report_utils.enve_arch))
application: str = field(compare=False, kw_only=True, default="Python API")
application: str = field(compare=False, kw_only=True, default="Serverless ADR Python API")
version: str = field(compare=False, kw_only=True, default="1.0")
_orm_model: str = "data.models.Session"

Expand Down Expand Up @@ -180,7 +180,7 @@ class SimplePayloadMixin:
def from_db(cls, orm_instance, **kwargs):
from data.extremely_ugly_hacks import safe_unpickle

obj = super().from_db(orm_instance)
obj = super().from_db(orm_instance, **kwargs)
obj.content = safe_unpickle(obj._orm_instance.payloaddata)
return obj

Expand All @@ -194,7 +194,7 @@ class FilePayloadMixin:

@classmethod
def from_db(cls, orm_instance, **kwargs):
obj = super().from_db(orm_instance)
obj = super().from_db(orm_instance, **kwargs)
obj.content = obj._orm_instance.payloadfile.path
return obj

Expand All @@ -217,6 +217,23 @@ class Item(BaseModel):
dataset: Dataset = field(compare=False, kw_only=True, default=None)
type: str = "none"
_orm_model: str = "data.models.Item"
# Class-level registry of subclasses keyed by type
_type_registry = {}

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Automatically register the subclass based on its type attribute
Item._type_registry[cls.type] = cls

@classmethod
def from_db(cls, orm_instance, **kwargs):
# Create a new instance of the correct subclass
if cls is Item:
# Get the class based on the type attribute
item_cls = cls._type_registry[orm_instance.type]
return item_cls.from_db(orm_instance, **kwargs)

return super().from_db(orm_instance, **kwargs)

def save(self, **kwargs):
if self.session is None or self.dataset is None:
Expand Down Expand Up @@ -280,10 +297,6 @@ class String(SimplePayloadMixin, Item):
type: str = "string"


class Text(String):
pass


class HTML(String):
content: HTMLContent = HTMLContent()
type: str = "html"
Expand All @@ -298,7 +311,7 @@ class Table(Item):
def from_db(cls, orm_instance, **kwargs):
from data.extremely_ugly_hacks import safe_unpickle

obj = super().from_db(orm_instance)
obj = super().from_db(orm_instance, **kwargs)
payload = safe_unpickle(obj._orm_instance.payloaddata)
obj.content = payload.pop("array", None)
for prop in cls._properties:
Expand All @@ -318,10 +331,6 @@ def save(self, **kwargs):
super().save(**kwargs)


class Plot(Table):
pass


class Tree(SimplePayloadMixin, Item):
content: TreeContent = TreeContent()
type: str = "tree"
Expand All @@ -347,10 +356,6 @@ class Animation(FilePayloadMixin, Item):
type: str = "anim"


class Movie(Animation):
pass


class Scene(FilePayloadMixin, Item):
content: SceneContent = SceneContent()
type: str = "scene"
Expand Down
88 changes: 59 additions & 29 deletions src/ansys/dynamicreporting/core/serverless/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,55 @@ class Template(BaseModel):
parent: "Template" = field(compare=False, kw_only=True, default=None)
children: list["Template"] = field(compare=False, kw_only=True, default_factory=list)
_children_order: str = field(
compare=False, init=False, default=None
compare=False, init=False, default=""
) # computed from self.children
_master: bool = field(compare=False, init=False, default=None) # computed from self.parent
report_type: str = ""
_properties: tuple = tuple()
_properties: tuple = tuple() # todo: add properties of each type ref: report_objects
_orm_model: str = "reports.models.Template"
# Class-level registry of subclasses keyed by type
_type_registry = {}

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Automatically register the subclass based on its type attribute
Template._type_registry[cls.report_type] = cls

@classmethod
def from_db(cls, orm_instance, **kwargs):
# Create a new instance of the correct subclass
if cls is Template:
# Get the class based on the type attribute
templ_cls = cls._type_registry[orm_instance.report_type]
obj = templ_cls.from_db(orm_instance, **kwargs)
else:
obj = super().from_db(orm_instance, **kwargs)
# add relevant props from property dict.
props = obj.get_property()
for prop in cls._properties:
if prop in props:
setattr(obj, prop, props[prop])
return obj

def save(self, **kwargs):
if self.parent is not None and not self.parent._saved:
raise Template.NotSaved(
extra_detail="Failed to save template because its parent is not saved"
)
for child in self.children:
if not child._saved:
raise Template.NotSaved(
extra_detail="Failed to save template because its children are not saved"
)
# set properties
prop_dict = {}
for prop in self._properties:
value = getattr(self, prop, None)
if value is not None:
prop_dict[prop] = value
if prop_dict:
self.add_property(prop_dict)
super().save(**kwargs)

@property
def type(self):
Expand All @@ -37,7 +80,7 @@ def type(self, value):

@property
def children_order(self):
return self._children_order
return ",".join([str(child.guid) for child in self.children])

@property
def master(self):
Expand Down Expand Up @@ -107,35 +150,10 @@ def add_property(self, new_props: dict):
params["properties"] = curr_props | new_props
self.params = json.dumps(params)

def save(self, **kwargs):
if self.parent is not None and not self.parent._saved:
raise Template.NotSaved(
extra_detail="Failed to save template because its parent is not saved"
)
for child in self.children:
if not child._saved:
raise Template.NotSaved(
extra_detail="Failed to save template because its children are not saved"
)
# set properties
prop_dict = {}
for prop in self._properties:
value = getattr(self, prop, None)
if value is not None:
prop_dict[prop] = value
if prop_dict:
self.add_property(prop_dict)
super().save(**kwargs)

@classmethod
def get(cls, **kwargs):
new_kwargs = {"report_type": cls.report_type, **kwargs} if cls.report_type else kwargs
obj = super().get(**new_kwargs)
props = obj.get_property()
for prop in cls._properties:
if prop in props:
setattr(obj, prop, props[prop])
return obj
return super().get(**new_kwargs)

@classmethod
def filter(cls, **kwargs):
Expand Down Expand Up @@ -278,3 +296,15 @@ class TreeMergeGenerator(Generator):

class SQLQueryGenerator(Generator):
report_type: str = "Generator:sqlqueries"


class ItemsComparisonGenerator(Generator):
report_type: str = "Generator:itemscomparison"


class StatisticalGenerator(Generator):
report_type: str = "Generator:statistical"


class IteratorGenerator(Generator):
report_type: str = "Generator:iterator"

0 comments on commit 8fe88ea

Please sign in to comment.