From ca7da6943ada9c429b3a9a60bc09dc7ae0db629c Mon Sep 17 00:00:00 2001 From: Anish Bhusal Date: Wed, 18 Oct 2023 13:12:17 -0500 Subject: [PATCH] Feature: Update datamodels (#21) * updated datamodels to use dataclass * resolved comments --- pynequa/core.py | 14 ++-- pynequa/models.py | 161 ++++++++++++++++++++++++++++++------------- requirements.txt | 1 + tests/test_models.py | 35 ++++++++++ 4 files changed, 158 insertions(+), 53 deletions(-) create mode 100644 tests/test_models.py diff --git a/pynequa/core.py b/pynequa/core.py index de330e1..09b56c3 100644 --- a/pynequa/core.py +++ b/pynequa/core.py @@ -111,7 +111,7 @@ def search_query(self, query_params: QueryParams) -> Dict: payload = { "app": self.app_name, - "query": query_params._prepare_query_args( + "query": query_params.generate_payload( query_name=self.query_name) } @@ -160,7 +160,7 @@ def search_profile(self, profile_name: str, query_params: QueryParams, payload = { "profile": profile_name, "responsetype": response_type, - "query": query_params._prepare_query_args(query_name=self.query_name), + "query": query_params.generate_payload(query_name=self.query_name), } return self.post(endpoint=endpoint, payload=payload) @@ -209,7 +209,7 @@ def search_preview(self, query_params: QueryParams, action: str = "get", "action": action, "id": id, "origin": origin, - "query": query_params._prepare_query_args( + "query": query_params.generate_payload( query_name=self.query_name) } @@ -281,7 +281,7 @@ def search_similardocuments(self, source_doc_id: str, payload = { "app": self.app_name, "sourceDocumentId": source_doc_id, - "query": query_params._prepare_query_args( + "query": query_params.generate_payload( query_name=self.query_name) } @@ -300,7 +300,7 @@ def search_query_links(self, web_sevice: str, query_params: QueryParams) -> Dict endpoint = "search.querylinks" payload = { "webService": web_sevice, - "query": query_params._prepare_query_args( + "query": query_params.generate_payload( query_name=self.query_name) } return self.post(endpoint=endpoint, payload=payload) @@ -355,9 +355,9 @@ def search_profile_subtree(self, profile: str, query_params: QueryParams, endpoint = "search.profile.subtree" payload = { "profile": profile, - "query": query_params._prepare_query_args( + "query": query_params.generate_payload( query_name=self.query_name), - "tree": tree_params._generate_tree_params_payload() + "tree": tree_params.generate_payload() } return self.post(endpoint=endpoint, payload=payload) diff --git a/pynequa/models.py b/pynequa/models.py index db5d746..39d855f 100644 --- a/pynequa/models.py +++ b/pynequa/models.py @@ -1,13 +1,46 @@ from typing import Dict, List, Optional - - -class TreeParams: +from abc import abstractmethod, ABC +from dataclasses import dataclass, field +from loguru import logger + + +class AbstractParams(ABC): + """ + Abstract base class for all Sinequa models. + """ + + @abstractmethod + def generate_payload(self, **kwargs) -> Dict: + """ + This is abstract method for AbstractParams. + Every child class should implement this method. + """ + raise NotImplementedError() + + +@dataclass +class TreeParams(AbstractParams): + """ + Represents the parameters for configuring a tree parameters. + + Attributes: + box (str): The name of the relevant tree navigation box (required). + column (str): The name of the index column associated with the + navigation box (required). + op (str, optional): The relational operator. Default is 'eq'. + Possible values: '=', '!=', '<', '<=', '>=', '>', 'between', 'not between'. + value (str): The filter value (required). + """ box: str = "" column: str = "" op: str = "" value: str = "" - def _generate_tree_params_payload(self) -> Dict: + def generate_payload(self, **kwargs) -> Dict: + """ + This method generates payload for + TreeParams. + """ return { "box": self.box, "column": self.column, @@ -16,73 +49,94 @@ def _generate_tree_params_payload(self) -> Dict: } -class SelectParams: +@dataclass +class SelectParams(AbstractParams): expression: str = "" facet: str = "" - def _generate_select_params_payload(self) -> Dict: + def generate_payload(self, **kwargs) -> Dict: + """ + This method generates payload for + SelectParams. + """ return { "expression": self.expression, "facet": self.facet, } -class OpenParams: +@dataclass +class OpenParams(AbstractParams): expression: str = "" facet: str = "" - def _generate_open_params_payload(self) -> Dict: + def generate_payload(self, **kwargs) -> Dict: + """ + This method generates payload for + OpenParams. + """ return { "expression": self.expression, "facet": self.facet, } -class AdvancedParams: +@dataclass +class AdvancedParams(AbstractParams): col_name: str = "" - col_value: str = "" - value: str = "" - operator: str = "" - - def _generate_advanced_params_payload(self) -> Dict: - return { + col_value: str = None + value: str or int = None + operator: str = None + debug: bool = False + + def generate_payload(self, **kwargs) -> Dict: + """ + This method generates payload for + AdvancedParams. + """ + payload = { self.col_name: self.col_value, "value": self.value, "operator": self.operator } + if self.debug: + logger.debug(payload) + + return payload -class QueryParams: + +@dataclass +class QueryParams(AbstractParams): name: str = "" # required - action: str = None + action: Optional[str] = None search_text: str = "" # required - select_params: List[SelectParams] = [] - additional_select_clause: str = None - additional_where_clause: str = None - open_params: List[OpenParams] = [] - page: int = 0 - page_size: int = 0 - tab: str = None - scope: str = None - basket: str = None - is_first_page: bool = False - strict_refine: bool = False - global_relevance: int = None - question_language: str = None - question_default_language: str = None - spelling_correction_mode: str = None - spelling_correction_filter: str = None - document_weight: str = None - text_part_weights: str = None - relevance_transforms: str = None - remove_duplicates: bool = False - aggregations: List[str] = [] - order_by: str = None - group_by: str = None + select_params: Optional[List[SelectParams] + ] = field(default_factory=lambda: []) + additional_select_clause: Optional[str] = None + additional_where_clause: Optional[str] = None + open_params: Optional[List[OpenParams]] = field(default_factory=lambda: []) + page: Optional[int] = 1 + page_size: Optional[int] = 10 + tab: Optional[str] = None + scope: Optional[str] = None + basket: Optional[str] = None + is_first_page: Optional[bool] = False + strict_refine: Optional[bool] = False + global_relevance: Optional[int] = None + question_language: Optional[str] = None + question_default_language: Optional[str] = None + spelling_correction_mode: Optional[str] = None + spelling_correction_filter: Optional[str] = None + document_weight: Optional[str] = None + text_part_weights: Optional[str] = None + relevance_transforms: Optional[str] = None + remove_duplicates: Optional[bool] = False + aggregations: Optional[List[str]] = field(default_factory=lambda: []) + order_by: Optional[str] = None + group_by: Optional[str] = None advanced: Optional[AdvancedParams] = None - - def __init__(self) -> None: - pass + debug: bool = False def _prepare_query_args(self, query_name: str) -> Dict: params = { @@ -101,7 +155,7 @@ def _prepare_query_args(self, query_name: str) -> Dict: if len(self.select_params) > 0: select_params = [] for item in self.select_params: - select_params.append(item._generate_select_params_payload()) + select_params.append(item.generate_payload()) params["select"] = self.select_params if self.additional_select_clause is not None: @@ -113,7 +167,7 @@ def _prepare_query_args(self, query_name: str) -> Dict: if len(self.open_params) > 0: open_params = [] for item in self.open_params: - open_params.append(item._generate_open_params_payload()) + open_params.append(item.generate_payload()) params["open"] = self.open_params if self.page is not None: @@ -169,6 +223,21 @@ def _prepare_query_args(self, query_name: str) -> Dict: params["groupBy"] = self.group_by if self.advanced is not None: - params["advanced"] = self.advanced._generate_advanced_params_payload() + params["advanced"] = self.advanced.generate_payload() return params + + def generate_payload(self, **kwargs) -> Dict: + """ + This method generates payload for + QueryParams. + + Args: + query_name(str): Name of query service to query for + """ + query_name = kwargs.get("query_name") + payload = self._prepare_query_args(query_name) + if self.debug: + logger.debug(payload) + + return payload diff --git a/requirements.txt b/requirements.txt index 320e5db..30b66b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ Requests==2.31.0 setuptools==61.2.0 sphinx-click==4.4.0 +loguru==0.7.2 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..d3d166f --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,35 @@ +from pynequa.models import QueryParams +import unittest +import logging + + +class TestQueryParams(unittest.TestCase): + + def test_query_params_payload(self): + """ + Test if query params payload is correctly + generated or not. + """ + qp = QueryParams( + name="query", + search_text="What was Landsat-9 launched?" + ) + + payload = qp.generate_payload() + logging.debug(payload) + + keys_which_must_be_in_payload = [ + "name", + "text", + "isFirstpage", + "strictRefine", + "removeDuplicates" + ] + + for key in keys_which_must_be_in_payload: + if key not in payload: + self.assertEqual(key, "test", f"{key} is mising in payload") + + +if __name__ == '__main__': + unittest.main()