From 60db648e23c28d9e22798b616747adac09fb49c2 Mon Sep 17 00:00:00 2001 From: NISH1001 Date: Tue, 27 Feb 2024 14:31:57 -0600 Subject: [PATCH 1/3] Remove operator and value in payload if they are None in AdvancedParams Sample: ```python { 'app': 'vanilla-search', 'query': { 'name': 'query', 'text': 'himawari', 'isFirstpage': False, 'strictRefine': False, 'removeDuplicates': False, 'action': 'search', 'page': 1, 'pageSize': 10, 'advanced': { 'collection': '/user_needs_database/snwg-assessments-2020/', # 'value': None, 'operator': None, } } } ``` --- pynequa/api/api.py | 41 +++++++++++++++++++++-------------------- pynequa/models.py | 13 ++++++++----- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/pynequa/api/api.py b/pynequa/api/api.py index 30d0bf0..da5a115 100644 --- a/pynequa/api/api.py +++ b/pynequa/api/api.py @@ -5,13 +5,13 @@ class API: - ''' - API Class handles all HTTP Requests + """ + API Class handles all HTTP Requests - Attributes: - base_url(string): REST API base URL for Sinequa instance - access_token(string): token for Sinequa authentication - ''' + Attributes: + base_url(string): REST API base URL for Sinequa instance + access_token(string): token for Sinequa authentication + """ def __init__(self, access_token: str, base_url: str) -> None: if not access_token or not base_url: @@ -21,9 +21,7 @@ def __init__(self, access_token: str, base_url: str) -> None: self.base_url = base_url def _get_headers(self) -> Dict: - headers = { - "Authorization": f"Bearer {self.access_token}" - } + headers = {"Authorization": f"Bearer {self.access_token}"} return headers def _get_url(self, endpoint) -> str: @@ -31,19 +29,22 @@ def _get_url(self, endpoint) -> str: def get(self, endpoint) -> Dict: """ - This method handles GET method. + This method handles GET method. """ - session = requests.Session() - resp = session.get(self._get_url(endpoint=endpoint), - headers=self._get_headers()) - session.close - return resp.json() + with requests.Session() as session: + resp = session.get( + self._get_url(endpoint=endpoint), headers=self._get_headers() + ) + return resp.json() def post(self, endpoint, payload) -> Dict: """ - This method handles POST method. + This method handles POST method. """ - session = requests.Session() - resp = session.post(self._get_url(endpoint=endpoint), - headers=self._get_headers(), json=payload) - return resp.json() + with requests.Session() as session: + resp = session.post( + self._get_url(endpoint=endpoint), + headers=self._get_headers(), + json=payload, + ) + return resp.json() diff --git a/pynequa/models.py b/pynequa/models.py index 553cf13..36e0865 100644 --- a/pynequa/models.py +++ b/pynequa/models.py @@ -31,6 +31,7 @@ class TreeParams(AbstractParams): Possible values: '=', '!=', '<', '<=', '>=', '>', 'between', 'not between'. value (str): The filter value (required). """ + box: str = "" column: str = "" op: str = "" @@ -85,7 +86,7 @@ def generate_payload(self, **kwargs) -> Dict: class AdvancedParams(AbstractParams): col_name: str = "" col_value: str = None - value: str or int = None + value: str or int = None operator: str = None debug: bool = False @@ -96,10 +97,13 @@ def generate_payload(self, **kwargs) -> Dict: """ payload = { self.col_name: self.col_value, - "value": self.value, - "operator": self.operator } + if self.value: + payload["value"] = self.value + if self.operator: + payload["operator"] = self.operator + if self.debug: logger.debug(payload) @@ -111,8 +115,7 @@ class QueryParams(AbstractParams): name: str = "" # required action: Optional[str] = None search_text: str = "" # required - select_params: Optional[List[SelectParams] - ] = field(default_factory=lambda: []) + select_params: Optional[List[SelectParams]] = field(default_factory=lambda: []) additional_select_clause: Optional[str] = None additional_where_clause: Optional[str] = None open_params: Optional[List[OpenParams]] = field(default_factory=lambda: []) From a020f53b65cee0ba482c34714f0b50887198f409 Mon Sep 17 00:00:00 2001 From: NISH1001 Date: Tue, 27 Feb 2024 14:46:22 -0600 Subject: [PATCH 2/3] Bugfix payload generation for AdvancedParams Now col_name and values are checked to see if empty. --- pynequa/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pynequa/models.py b/pynequa/models.py index 36e0865..9b5ef0b 100644 --- a/pynequa/models.py +++ b/pynequa/models.py @@ -95,10 +95,10 @@ def generate_payload(self, **kwargs) -> Dict: This method generates payload for AdvancedParams. """ - payload = { - self.col_name: self.col_value, - } - + payload = {} + # To prevent payloads with empty values + if self.col_name and self.col_value: + payload[self.col_name] = self.col_value if self.value: payload["value"] = self.value if self.operator: From 74105f8e7bb85f1fdac9b2acc9760a3dbbaf7bfe Mon Sep 17 00:00:00 2001 From: anisbhsl Date: Tue, 27 Feb 2024 16:03:17 -0600 Subject: [PATCH 3/3] added support for advanced params according to documentation --- README.md | 2 +- pynequa/models.py | 64 +++++++++++++++++++++++++++++++++++++------- tests/test_models.py | 59 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 112 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 12ed30b..5f6aedd 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ A python library to handle communication with Sinequa REST API. ``` ## Example Usage -``` +```python import pynequa from pynequa.models import QueryParams diff --git a/pynequa/models.py b/pynequa/models.py index 9b5ef0b..5065021 100644 --- a/pynequa/models.py +++ b/pynequa/models.py @@ -84,8 +84,36 @@ def generate_payload(self, **kwargs) -> Dict: @dataclass class AdvancedParams(AbstractParams): - col_name: str = "" - col_value: str = None + """ + AdvancedParams represents the elemental advanced params. + Remember following things: + + 1. col_name is required. + 2. col_value has to be either "str" or "List[str]". + 3. if col_value is not present, the value could be a dict of + "value" and "operator". + + + Example: + "advanced": { + "docformat": [ + "ppt", + "pdf" + ], + "modified": [ + { + "value": "2019-01-01", + "operator": ">=" + }, + { + "value": "2019-12-31", + "operator": "<=" + } + ] + } + """ + col_name: str + col_value: str or List[str] = None value: str or int = None operator: str = None debug: bool = False @@ -99,10 +127,11 @@ def generate_payload(self, **kwargs) -> Dict: # To prevent payloads with empty values if self.col_name and self.col_value: payload[self.col_name] = self.col_value - if self.value: - payload["value"] = self.value - if self.operator: - payload["operator"] = self.operator + if self.value and self.operator: + payload[self.col_name] = { + "value": self.value, + "operator": self.operator + } if self.debug: logger.debug(payload) @@ -115,7 +144,8 @@ class QueryParams(AbstractParams): name: str = "" # required action: Optional[str] = None search_text: str = "" # required - select_params: Optional[List[SelectParams]] = field(default_factory=lambda: []) + select_params: Optional[List[SelectParams] + ] = field(default_factory=lambda: []) additional_select_clause: Optional[str] = None additional_where_clause: Optional[str] = None open_params: Optional[List[OpenParams]] = field(default_factory=lambda: []) @@ -138,7 +168,8 @@ class QueryParams(AbstractParams): aggregations: Optional[List[str]] = field(default_factory=lambda: []) order_by: Optional[str] = None group_by: Optional[str] = None - advanced: Optional[AdvancedParams] = None + advanced: Optional[List[AdvancedParams]] = field( + default_factory=lambda: []) debug: bool = False def _prepare_query_args(self, query_name: str) -> Dict: @@ -225,8 +256,21 @@ def _prepare_query_args(self, query_name: str) -> Dict: if self.group_by is not None: params["groupBy"] = self.group_by - if self.advanced is not None: - params["advanced"] = self.advanced.generate_payload() + if len(self.advanced) > 0: + advanced_param_payload = {} + for advanced_param in self.advanced: + column_name = advanced_param.col_name + payload_value = advanced_param.generate_payload()[column_name] + if column_name in advanced_param_payload: + advanced_param_payload[column_name].append( + payload_value + ) + elif isinstance(payload_value, dict): + advanced_param_payload[column_name] = [payload_value] + else: + advanced_param_payload[column_name] = payload_value + + params["advanced"] = advanced_param_payload return params diff --git a/tests/test_models.py b/tests/test_models.py index d3d166f..1f1ee41 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,7 @@ -from pynequa.models import QueryParams +from pynequa.models import QueryParams, AdvancedParams import unittest import logging +import json class TestQueryParams(unittest.TestCase): @@ -12,10 +13,12 @@ def test_query_params_payload(self): """ qp = QueryParams( name="query", - search_text="What was Landsat-9 launched?" + search_text="What was Landsat-9 launched?", + page_size=20, ) payload = qp.generate_payload() + print(payload) logging.debug(payload) keys_which_must_be_in_payload = [ @@ -30,6 +33,58 @@ def test_query_params_payload(self): if key not in payload: self.assertEqual(key, "test", f"{key} is mising in payload") + def test_query_params_with_advanced_params(self): + """ + Test if advanced params are correctly + generated in query param payload or not. + """ + + ap1 = AdvancedParams( + col_name="collection", + col_value="accounting" + ) + + ap2 = AdvancedParams( + col_name="docformat", + col_value=["pdf", "docx"] + ) + + ap3 = AdvancedParams( + col_name="modified", + value="2019-01-01", + operator=">=" + ) + + ap4 = AdvancedParams( + col_name="modified", + value="2019-12-31", + operator="<=" + ) + + qp = QueryParams( + name="query", + search_text="What was Landsat-9 launched?", + advanced=[ + ap1, + ap2, + ap3, + ap4 + ] + ) + + payload = qp.generate_payload() + + expected_payload = { + "collection": "accounting", + "docformat": ["pdf", "docx"], + "modified": [ + {"value": "2019-01-01", "operator": ">="}, + {"value": "2019-12-31", "operator": "<="} + ] + } + + assert payload["advanced"] == expected_payload + if __name__ == '__main__': unittest.main()