diff --git a/.github/workflows/build-2.0.yml b/.github/workflows/build-2.0.yml index 3e2ff7bf3..40bccc503 100644 --- a/.github/workflows/build-2.0.yml +++ b/.github/workflows/build-2.0.yml @@ -34,4 +34,4 @@ jobs: file: docker/Dockerfile platforms: linux/amd64 push: true - tags: ghcr.io/${{ github.repository }}/viadot:2.0-latest + tags: ghcr.io/${{ github.repository }}/viadot:2.0-latest \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index de4872e23..ac0e5ac46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.4.15] - 2023-05-11 +### Added +- Added `BusinessCore` source class +- Added `BusinessCoreToParquet` task class +- Added `verify` parameter to `handle_api_response()`. +- Added `to_parquet()` in `base.py` +- Added new source class `SAPRFCV2` in `sap_rfc.py` with new approximation. +- Added new parameter `rfc_replacement` to `sap_rfc_to_adls.py` to replace +an extra separator character within a string column to avoid conflicts. +- Added `rfc_unique_id` in `SAPRFCV2` to merge chunks on this column. +- Added `close_connection()` to `SAPRFC` and `SAPRFCV2` + +### Fixed +- Removed `try-except` sentence and added a new logic to remove extra separators in `sap_rfc.py` +source file, to vaoid a mismatch in columns lenght between iterative connections to SAP tables. +- When `SAP` tables are updated during `sap_rfc.py` scrip running, if there are chunks, the +columns in the next chunk are unrelated rows. +- Fixed `sap_rfc.py` source file to not breakdown by both, +and extra separator in a row and adding new rows in SAP table between iterations. + + ## [0.4.14] - 2023-04-13 ### Added - Added `anonymize_df` task function to `task_utils.py` to anonymize data in the dataframe in selected columns. @@ -56,7 +77,6 @@ This parameter enables user to decide whether or not filter should be validated. ### Changed - Changed data extraction logic for `Outlook` data. - ## [0.4.10] - 2022-11-16 ### Added - Added `credentials_loader` function in utils diff --git a/tests/integration/flows/test_eurostat_to_adls.py b/tests/integration/flows/test_eurostat_to_adls.py new file mode 100644 index 000000000..d74023953 --- /dev/null +++ b/tests/integration/flows/test_eurostat_to_adls.py @@ -0,0 +1,29 @@ +from unittest import mock +import pytest +import pandas as pd +import os + +from viadot.flows import EurostatToADLS + +DATA = {"geo": ["PL", "DE", "NL"], "indicator": [35, 55, 77]} +ADLS_FILE_NAME = "test_eurostat.parquet" +ADLS_DIR_PATH = "raw/tests/" + + +@mock.patch( + "viadot.tasks.EurostatToDF.run", + return_value=pd.DataFrame(data=DATA), +) +@pytest.mark.run +def test_eurostat_to_adls_run_flow(mocked_class): + flow = EurostatToADLS( + "test_eurostat_to_adls_flow_run", + dataset_code="ILC_DI04", + overwrite_adls=True, + adls_dir_path=ADLS_DIR_PATH, + adls_file_name=ADLS_FILE_NAME, + ) + result = flow.run() + assert result.is_successful() + os.remove("test_eurostat_to_adls_flow_run.parquet") + os.remove("test_eurostat_to_adls_flow_run.json") diff --git a/tests/integration/tasks/test_eurostat.py b/tests/integration/tasks/test_eurostat.py new file mode 100644 index 000000000..430f8dfe3 --- /dev/null +++ b/tests/integration/tasks/test_eurostat.py @@ -0,0 +1,223 @@ +import pytest +import pandas as pd +import logging + +from viadot.tasks import eurostat + + +def test_and_validate_dataset_code_without_params(caplog): + """This function is designed to test the accuracy of the data retrieval feature in a program. + Specifically, it tests to ensure that the program returns a non-empty DataFrame when a correct + dataset code is provided without any parameters. The function is intended to be used in software + development to verify that the program is correctly retrieving data from the appropriate dataset. + """ + task = eurostat.EurostatToDF(dataset_code="ILC_DI04").run() + assert isinstance(task, pd.DataFrame) + assert not task.empty + assert caplog.text == "" + + +def test_wrong_dataset_code_logger(caplog): + """This function is designed to test the accuracy of the error logging feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log errors + when provided with only incorrect dataset code. + The function is intended to be used in software development to identify correct type errors + and messages in the program's handling of codes. + """ + task = eurostat.EurostatToDF(dataset_code="ILC_DI04E") + + with pytest.raises(ValueError, match="DataFrame is empty!"): + with caplog.at_level(logging.ERROR): + task.run() + assert ( + f"Failed to fetch data for ILC_DI04E, please check correctness of dataset code!" + in caplog.text + ) + + +def test_wrong_parameters_codes_logger(caplog): + """This function is designed to test the accuracy of the error logging feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log errors + when provided with a correct dataset_code and correct parameters are provided, but both parameters codes are incorrect. + The function is intended to be used in software development to identify correct type errors + and messages in the program's handling of codes. + """ + task = eurostat.EurostatToDF( + dataset_code="ILC_DI04", + params={"hhtyp": "total1", "indic_il": "non_existing_code"}, + ) + + with pytest.raises(ValueError, match="DataFrame is empty!"): + with caplog.at_level(logging.ERROR): + task.run() + assert ( + f"Parameters codes: 'total1 | non_existing_code' are not available. Please check your spelling!" + in caplog.text + ) + assert ( + f"You can find everything via link: https://ec.europa.eu/eurostat/databrowser/view/ILC_DI04/default/table?lang=en" + in caplog.text + ) + + +def test_parameter_codes_as_list_logger(caplog): + """This function is designed to test the accuracy of the error logging feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log errors + when provided with a correct dataset code, correct parameters, but incorrect parameters codes structure + (as a list with strings, instead of single string). + The function is intended to be used in software development to identify correct type errors + and messages in the program's handling of codes. + """ + + task = eurostat.EurostatToDF( + dataset_code="ILC_DI04", + params={"hhtyp": ["totale", "nottotale"], "indic_il": "med_e"}, + ) + with pytest.raises(ValueError, match="Wrong structure of params!"): + with caplog.at_level(logging.ERROR): + task.run() + assert ( + "You can provide only one code per one parameter as 'str' in params!\n" + in caplog.text + ) + assert ( + "CORRECT: params = {'unit': 'EUR'} | INCORRECT: params = {'unit': ['EUR', 'USD', 'PLN']}" + in caplog.text + ) + + +def test_wrong_parameters(caplog): + """This function is designed to test the accuracy of the error logging feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log errors + when provided with a correct dataset_code, but incorrect parameters keys. + The function is intended to be used in software development to identify correct type errors + and messages in the program's handling of codes. + """ + + task = eurostat.EurostatToDF( + dataset_code="ILC_DI04", params={"hhhtyp": "total", "indic_ilx": "med_e"} + ) + with pytest.raises(ValueError, match="DataFrame is empty!"): + with caplog.at_level(logging.ERROR): + task.run() + assert ( + f"Parameters: 'hhhtyp | indic_ilx' are not in dataset. Please check your spelling!\n" + in caplog.text + ) + assert ( + f"Possible parameters: freq | hhtyp | indic_il | unit | geo | time" + in caplog.text + ) + + +def test_params_as_list(): + """This function is designed to test the accuracy of the error logging feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log error + when provided with a correct dataset_code, but incorrect params structure (as list instead of dict). + The function is intended to be used in software development to identify correct type errors + and messages in the program's handling of codes. + """ + with pytest.raises(TypeError, match="Params should be a dictionary."): + eurostat.EurostatToDF(dataset_code="ILC_DI04", params=["total", "med_e"]).run() + + +def test_correct_params_and_dataset_code(caplog): + """This function is designed to test the accuracy of the data retrieval feature in a program. + Specifically, it tests to ensure that the program returns a non-empty DataFrame when a correct + dataset code is provided with correct params. The function is intended to be used in software + development to verify that the program is correctly retrieving data from the appropriate dataset. + """ + + task = eurostat.EurostatToDF( + dataset_code="ILC_DI04", params={"hhtyp": "total", "indic_il": "med_e"} + ).run() + + assert isinstance(task, pd.DataFrame) + assert not task.empty + assert caplog.text == "" + + +def task_correct_requested_columns(caplog): + """This function is designed to test the accuracy of the data retrieval feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log error + when provided with a correct dataset_code, correct params and correct requested_columns. + The function is intended to be used in software development to verify that the program is correctly + retrieving data from the appropriate dataset. + """ + task = eurostat.EurostatToDF( + dataset_code="ILC_DI04", + params={"hhtyp": "total", "indic_il": "med_e"}, + requested_columns=["updated", "geo", "indicator"], + ) + task.run() + + assert isinstance(task, pd.DataFrame) + assert not task.empty + assert caplog.text == "" + assert list(task.columns) == task.needed_columns + + +def test_wrong_needed_columns_names(caplog): + """This function is designed to test the accuracy of the error logging feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log error + when provided with a correct dataset_code, correct parameters, but incorrect names of requested columns. + The function is intended to be used in software development to identify correct type errors + and messages in the program's handling of codes. + """ + task = eurostat.EurostatToDF( + dataset_code="ILC_DI04", + params={"hhtyp": "total", "indic_il": "med_e"}, + requested_columns=["updated1", "geo1", "indicator1"], + ) + with pytest.raises(ValueError, match="Provided columns are not available!"): + with caplog.at_level(logging.ERROR): + task.run() + assert ( + f"Name of the columns: 'updated1 | geo1 | indicator1' are not in DataFrame. Please check spelling!\n" + in caplog.text + ) + assert f"Available columns: geo | time | indicator | label | updated" in caplog.text + + +def test_wrong_params_and_wrong_requested_columns_names(caplog): + """This function is designed to test the accuracy of the error logging feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log error + when provided with a correct dataset_code, incorrect parameters and incorrect names of requested columns. + Test should log errors only related with wrong params - we are trying to check if program will stop after + params validation. The function is intended to be used in software development to identify correct type errors + and messages in the program's handling of codes. + """ + task = eurostat.EurostatToDF( + dataset_code="ILC_DI04", + params={"hhhtyp": "total", "indic_ilx": "med_e"}, + requested_columns=["updated1", "geo1", "indicator1"], + ) + with pytest.raises(ValueError, match="DataFrame is empty!"): + with caplog.at_level(logging.ERROR): + task.run() + assert ( + f"Parameters: 'hhhtyp | indic_ilx' are not in dataset. Please check your spelling!\n" + in caplog.text + ) + assert ( + f"Possible parameters: freq | hhtyp | indic_il | unit | geo | time" + in caplog.text + ) + + +def test_requested_columns_not_in_list(): + """This function is designed to test the accuracy of the error logging feature in a program. + Specifically, it tests to ensure that the program is able to correctly identify and log error + when provided with a correct dataset_code, correct params but incorrect requested_columns structure + (as single string instead of list with strings). + The function is intended to be used in software development to identify correct type errors + and messages in the program's handling of codes. + """ + with pytest.raises( + TypeError, match="Requested columns should be provided as list of strings." + ): + eurostat.EurostatToDF( + dataset_code="ILC_DI04", + params={"hhtyp": "total", "indic_il": "med_e"}, + requested_columns="updated", + ).run() diff --git a/tests/integration/test_business_core.py b/tests/integration/test_business_core.py new file mode 100644 index 000000000..d700f67db --- /dev/null +++ b/tests/integration/test_business_core.py @@ -0,0 +1,49 @@ +import pytest +from unittest.mock import patch, Mock +import pandas as pd +from viadot.sources import BusinessCore + + +@pytest.fixture(scope="module") +def business_core(): + return BusinessCore( + url="https://api.businesscore.ae/api/GetCustomerData", + filters_dict={ + "BucketCount": 10, + "BucketNo": 1, + "FromDate": None, + "ToDate": None, + }, + credentials={"username": "test", "password": "test123"}, + ) + + +@patch("viadot.sources.business_core.handle_api_response") +def test_generate_token(mock_api_response, business_core): + mock_api_response.return_value = Mock(text='{"access_token": "12345"}') + token = business_core.generate_token() + assert token == "12345" + + +def test_clean_filters_dict(business_core): + filters = business_core.clean_filters_dict() + assert filters == { + "BucketCount": 10, + "BucketNo": 1, + "FromDate": "&", + "ToDate": "&", + } + + +def test_to_df(business_core): + with patch.object( + business_core, + "get_data", + return_value={"MasterDataList": [{"id": 1, "name": "John Doe"}]}, + ): + df = business_core.to_df() + assert isinstance(df, pd.DataFrame) + assert len(df.columns) == 2 + assert len(df) == 1 + assert df["id"].tolist() == [1] + assert df["name"].tolist() == ["John Doe"] diff --git a/tests/integration/test_sap_rfc.py b/tests/integration/test_sap_rfc.py index 2f960066e..20078d312 100644 --- a/tests/integration/test_sap_rfc.py +++ b/tests/integration/test_sap_rfc.py @@ -1,8 +1,9 @@ from collections import OrderedDict -from viadot.sources import SAPRFC +from viadot.sources import SAPRFC, SAPRFCV2 sap = SAPRFC() +sap2 = SAPRFCV2() sql1 = "SELECT a AS a_renamed, b FROM table1 WHERE table1.c = 1" sql2 = "SELECT a FROM fake_schema.fake_table WHERE a=1 AND b=2 OR c LIKE 'a%' AND d IN (1, 2) LIMIT 5 OFFSET 3" @@ -103,3 +104,86 @@ def test___build_pandas_filter_query(): sap._build_pandas_filter_query(sap.client_side_filters) == "thirdlongcolname == 01234" ), sap._build_pandas_filter_query(sap.client_side_filters) + + +def test__get_table_name_v2(): + assert sap2._get_table_name(sql1) == "table1" + assert sap2._get_table_name(sql2) == "fake_schema.fake_table", sap2._get_table_name( + sql2 + ) + assert sap2._get_table_name(sql7) == "b" + + +def test__get_columns_v2(): + assert sap2._get_columns(sql1) == ["a", "b"] + assert sap2._get_columns(sql1, aliased=True) == [ + "a_renamed", + "b", + ], sap2._get_columns(sql1, aliased=True) + assert sap2._get_columns(sql2) == ["a"] + assert sap2._get_columns(sql7) == ["a", "b"] + + +def test__get_where_condition_v2(): + assert sap2._get_where_condition(sql1) == "table1.c = 1", sap2._get_where_condition( + sql1 + ) + assert ( + sap2._get_where_condition(sql2) == "a=1 AND b=2 OR c LIKE 'a%' AND d IN (1, 2)" + ), sap2._get_where_condition(sql2) + assert ( + sap2._get_where_condition(sql3) + == "testORword=1 AND testANDword=2 AND testLIMITword=3 AND testOFFSETword=4" + ), sap2._get_where_condition(sql3) + assert ( + sap2._get_where_condition(sql4) + == "testLIMIT = 1 AND testOFFSET = 2 AND LIMITtest=3 AND OFFSETtest=4" + ), sap2._get_where_condition(sql4) + assert ( + sap2._get_where_condition(sql7) + == "c = 1 AND d = 2 AND longcolname = 12345 AND otherlongcolname = 6789" + ), sap2._get_where_condition(sql7) + + +def test__get_limit_v2(): + assert sap2._get_limit(sql1) is None + assert sap2._get_limit(sql2) == 5 + assert sap2._get_limit(sql7) == 5 + + +def test__get_offset_v2(): + assert sap2._get_offset(sql1) is None + assert sap2._get_offset(sql2) == 3 + assert sap2._get_offset(sql7) == 10 + + +def test_client_side_filters_simple_v2(): + _ = sap2._get_where_condition(sql5) + assert sap2.client_side_filters == OrderedDict( + {"AND": "longword123=5"} + ), sap2.client_side_filters + + +def test_client_side_filters_with_limit_offset_v2(): + _ = sap2._get_where_condition(sql6) + assert sap2.client_side_filters == OrderedDict( + {"AND": "otherlongcolname=5"} + ), sap2.client_side_filters + + _ = sap2._get_where_condition(sql7) + assert sap2.client_side_filters == OrderedDict( + {"AND": "thirdlongcolname = 01234"} + ), sap2.client_side_filters + + +def test___build_pandas_filter_query_v2(): + _ = sap2._get_where_condition(sql6) + assert ( + sap2._build_pandas_filter_query(sap2.client_side_filters) + == "otherlongcolname == 5" + ), sap2._build_pandas_filter_query(sap2.client_side_filters) + _ = sap2._get_where_condition(sql7) + assert ( + sap2._build_pandas_filter_query(sap2.client_side_filters) + == "thirdlongcolname == 01234" + ), sap2._build_pandas_filter_query(sap2.client_side_filters) diff --git a/tests/test_viadot.py b/tests/test_viadot.py index 0f0a56609..bcf154c18 100644 --- a/tests/test_viadot.py +++ b/tests/test_viadot.py @@ -2,4 +2,4 @@ def test_version(): - assert __version__ == "0.4.14" + assert __version__ == "0.4.15" diff --git a/viadot/__init__.py b/viadot/__init__.py index f658d0a64..5a4bb1d41 100644 --- a/viadot/__init__.py +++ b/viadot/__init__.py @@ -1 +1 @@ -__version__ = "0.4.14" +__version__ = "0.4.15" diff --git a/viadot/flows/__init__.py b/viadot/flows/__init__.py index 203a2447c..5b3ef3535 100644 --- a/viadot/flows/__init__.py +++ b/viadot/flows/__init__.py @@ -37,5 +37,8 @@ from .sftp_operations import SftpToAzureSQL, SftpToADLS from .mindful_to_adls import MindfulToADLS from .mediatool_to_adls import MediatoolToADLS + +from .eurostat_to_adls import EurostatToADLS from .hubspot_to_adls import HubspotToADLS -from .customer_gauge_to_adls import CustomerGaugeToADLS \ No newline at end of file +from .customer_gauge_to_adls import CustomerGaugeToADLS + diff --git a/viadot/flows/eurostat_to_adls.py b/viadot/flows/eurostat_to_adls.py new file mode 100644 index 000000000..45e4056a3 --- /dev/null +++ b/viadot/flows/eurostat_to_adls.py @@ -0,0 +1,171 @@ +import os +from pathlib import Path +from typing import Any, Dict, List + +import pendulum +from prefect import Flow +from prefect.backend import set_key_value +from prefect.utilities import logging + +from ..task_utils import ( + add_ingestion_metadata_task, + df_get_data_types_task, + df_map_mixed_dtypes_for_parquet, + df_to_csv, + df_to_parquet, + dtypes_to_json_task, + update_dtypes_dict, +) + +from ..tasks import AzureDataLakeUpload, EurostatToDF + +file_to_adls_task = AzureDataLakeUpload() +json_to_adls_task = AzureDataLakeUpload() + + +class EurostatToADLS(Flow): + """Flow for downloading data from the Eurostat platform via HTTPS REST API (no credentials required) + to a CSV or Parquet file. Then upload it to Azure Data Lake. + """ + + def __init__( + self, + name: str, + dataset_code: str, + params: dict = None, + base_url: str = "https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/", + requested_columns: list = None, + output_file_extension: str = ".parquet", + adls_dir_path: str = None, + local_file_path: str = None, + adls_file_name: str = None, + adls_sp_credentials_secret: str = None, + overwrite_adls: bool = False, + if_exists: str = "replace", + *args: List[Any], + **kwargs: Dict[str, Any], + ): + """ + Args: + name (str): The name of the flow. + dataset_code(str): The code of eurostat dataset that has to be upload. + params (Dict[str], optional): + A dictionary with optional URL parameters. The key represents the parameter id, while the value is the code + for a specific parameter, for example 'params = {'unit': 'EUR'}' where "unit" is the parameter that you would like to set + and "EUR" is the code of the specific parameter. You can add more than one parameter, but only one code per parameter! + So you CAN NOT provide list of codes as in example 'params = {'unit': ['EUR', 'USD', 'PLN']}' + This parameter is REQUIRED in most cases to pull a specific dataset from the API. + Both parameter and code has to provided as a string! + Defaults to None. + base_url (str): The base URL used to access the Eurostat API. This parameter specifies the root URL for all requests made to the API. + It should not be modified unless the API changes its URL scheme. + Defaults to "https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/" + requested_columns (List[str], optional): List of columns that are needed from DataFrame - works as filter. + The data are downloaded from Eurostat is the same structure every time. The filter is applied after the data is fetched. + output_file_extension (str, optional): Output file extension - to allow selection of .csv for data + which is not easy to handle with parquet. Defaults to ".parquet". + adls_dir_path (str, optional): Azure Data Lake destination folder/catalog path. Defaults to None. + local_file_path (str, optional): Local destination path. Defaults to None. + adls_file_name (str, optional): Name of file in ADLS. Defaults to None. + adls_sp_credentials_secret (str, optional): The name of the Azure Key Vault secret containing a dictionary with + ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET) for the Azure Data Lake. + Defaults to None. + overwrite_adls (bool, optional): Whether to overwrite files in the lake. Defaults to False. + if_exists (str, optional): What to do if the file exists. Defaults to "replace". + """ + + # EurostatToDF + self.dataset_code = dataset_code + self.params = params + self.base_url = base_url + self.requested_columns = requested_columns + + # AzureDataLakeUpload + self.overwrite = overwrite_adls + self.adls_sp_credentials_secret = adls_sp_credentials_secret + self.if_exists = if_exists + self.output_file_extension = output_file_extension + self.now = str(pendulum.now("utc")) + + self.local_file_path = ( + local_file_path or self.slugify(name) + self.output_file_extension + ) + self.local_json_path = self.slugify(name) + ".json" + self.adls_dir_path = adls_dir_path + + if adls_file_name is not None: + self.adls_file_path = os.path.join(adls_dir_path, adls_file_name) + self.adls_schema_file_dir_file = os.path.join( + adls_dir_path, "schema", Path(adls_file_name).stem + ".json" + ) + else: + self.adls_file_path = os.path.join( + adls_dir_path, self.now + self.output_file_extension + ) + self.adls_schema_file_dir_file = os.path.join( + adls_dir_path, "schema", self.now + ".json" + ) + + super().__init__(name=name, *args, **kwargs) + + self.gen_flow() + + @staticmethod + def slugify(name): + return name.replace(" ", "_").lower() + + def gen_flow(self) -> Flow: + df = EurostatToDF( + dataset_code=self.dataset_code, + params=self.params, + base_url=self.base_url, + requested_columns=self.requested_columns, + ) + + df = df.bind(flow=self) + + df_with_metadata = add_ingestion_metadata_task.bind(df, flow=self) + dtypes_dict = df_get_data_types_task.bind(df_with_metadata, flow=self) + + if self.output_file_extension == ".parquet": + df_to_be_loaded = df_map_mixed_dtypes_for_parquet( + df_with_metadata, dtypes_dict, flow=self + ) + df_to_file = df_to_parquet.bind( + df=df_to_be_loaded, + path=self.local_file_path, + if_exists=self.if_exists, + flow=self, + ) + else: + df_to_file = df_to_csv.bind( + df=df_with_metadata, + path=self.local_file_path, + if_exists=self.if_exists, + flow=self, + ) + + file_to_adls_task.bind( + from_path=self.local_file_path, + to_path=self.adls_file_path, + overwrite=self.overwrite, + sp_credentials_secret=self.adls_sp_credentials_secret, + flow=self, + ) + + dtypes_updated = update_dtypes_dict(dtypes_dict, flow=self) + dtypes_to_json_task.bind( + dtypes_dict=dtypes_updated, local_json_path=self.local_json_path, flow=self + ) + + json_to_adls_task.bind( + from_path=self.local_json_path, + to_path=self.adls_schema_file_dir_file, + overwrite=self.overwrite, + sp_credentials_secret=self.adls_sp_credentials_secret, + flow=self, + ) + + file_to_adls_task.set_upstream(df_to_file, flow=self) + json_to_adls_task.set_upstream(dtypes_to_json_task, flow=self) + set_key_value(key=self.adls_dir_path, value=self.adls_file_path) diff --git a/viadot/flows/sap_rfc_to_adls.py b/viadot/flows/sap_rfc_to_adls.py index d23ffc428..0d56efeec 100644 --- a/viadot/flows/sap_rfc_to_adls.py +++ b/viadot/flows/sap_rfc_to_adls.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, Literal -import pandas as pd from prefect import Flow, task, unmapped from viadot.task_utils import concat_dfs, df_to_csv, df_to_parquet, set_new_kv @@ -13,8 +12,10 @@ def __init__( name: str, query: str = None, rfc_sep: str = None, + rfc_replacement: str = "-", func: str = "RFC_READ_TABLE", rfc_total_col_width_character_limit: int = 400, + rfc_unique_id: List[str] = None, sap_credentials: dict = None, output_file_extension: str = ".parquet", local_file_path: str = None, @@ -27,6 +28,7 @@ def __init__( update_kv: bool = False, filter_column: str = None, timeout: int = 3600, + alternative_version: bool = False, *args: List[any], **kwargs: Dict[str, Any], ): @@ -48,10 +50,21 @@ def __init__( name (str): The name of the flow. query (str): Query to be executed with pyRFC. Defaults to None. rfc_sep(str, optional): Which separator to use when querying SAP. If not provided, multiple options are automatically tried. + rfc_replacement (str, optional): In case of sep is on a columns, set up a new character to replace + inside the string to avoid flow breakdowns. Defaults to "-". func (str, optional): SAP RFC function to use. Defaults to "RFC_READ_TABLE". rfc_total_col_width_character_limit (int, optional): Number of characters by which query will be split in chunks in case of too many columns - for RFC function. According to SAP documentation, the limit is 512 characters. However, we observed SAP raising an exception - even on a slightly lower number of characters, so we add a safety margin. Defaults to 400. + for RFC function. According to SAP documentation, the limit is 512 characters. However, we observed SAP raising an exception + even on a slightly lower number of characters, so we add a safety margin. Defaults to 400. + rfc_unique_id (List[str], optional): Reference columns to merge chunks Data Frames. These columns must to be unique. If no columns are provided + in this parameter, all data frame columns will by concatenated. Defaults to None. + Example: + -------- + SAPRFCToADLS( + ... + rfc_unique_id=["VBELN", "LPRIO"], + ... + ) sap_credentials (dict, optional): The credentials to use to authenticate with SAP. By default, they're taken from the local viadot config. output_file_extension (str, optional): Output file extension - to allow selection of .csv for data which is not easy to handle with parquet. Defaults to ".parquet". local_file_path (str, optional): Local destination path. Defaults to None. @@ -66,11 +79,14 @@ def __init__( filter_column (str, optional): Name of the field based on which key value will be updated. Defaults to None. timeout(int, optional): The amount of time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600. + alternative_version (bool, optional): Enable the use version 2 in source. Defaults to False. """ self.query = query self.rfc_sep = rfc_sep + self.rfc_replacement = rfc_replacement self.func = func self.rfc_total_col_width_character_limit = rfc_total_col_width_character_limit + self.rfc_unique_id = rfc_unique_id self.sap_credentials = sap_credentials self.output_file_extension = output_file_extension self.local_file_path = local_file_path @@ -81,6 +97,7 @@ def __init__( self.adls_sp_credentials_secret = adls_sp_credentials_secret self.vault_name = vault_name self.timeout = timeout + self.alternative_version = alternative_version self.update_kv = update_kv self.filter_column = filter_column @@ -94,8 +111,11 @@ def gen_flow(self) -> Flow: df = download_sap_task( query=self.query, sep=self.rfc_sep, + replacement=self.rfc_replacement, func=self.func, rfc_total_col_width_character_limit=self.rfc_total_col_width_character_limit, + rfc_unique_id=self.rfc_unique_id, + alternative_version=self.alternative_version, credentials=self.sap_credentials, flow=self, ) diff --git a/viadot/sources/__init__.py b/viadot/sources/__init__.py index c9db9d79d..c4b857d47 100644 --- a/viadot/sources/__init__.py +++ b/viadot/sources/__init__.py @@ -12,7 +12,7 @@ from .mediatool import Mediatool try: - from .sap_rfc import SAPRFC + from .sap_rfc import SAPRFC, SAPRFCV2 except ImportError: pass @@ -25,4 +25,6 @@ # APIS from .uk_carbon_intensity import UKCarbonIntensity +from .eurostat import Eurostat from .hubspot import Hubspot +from .business_core import BusinessCore diff --git a/viadot/sources/base.py b/viadot/sources/base.py index 6e30db497..77e09b253 100644 --- a/viadot/sources/base.py +++ b/viadot/sources/base.py @@ -124,6 +124,45 @@ def to_excel( out_df.to_excel(path, index=False, encoding="utf8") return True + def to_parquet( + self, + path: str, + if_exists: Literal["append", "replace", "skip"] = "replace", + if_empty: Literal["warn", "fail", "skip"] = "warn", + **kwargs, + ) -> None: + """ + Write from source to a Parquet file. + + Args: + path (str): The destination path. + if_exists (Literal["append", "replace", "skip"], optional): What to do if the file exists. Defaults to "replace". + if_empty (Literal["warn", "fail", "skip"], optional): What to do if the source contains no data. Defaults to "warn". + + """ + try: + df = self.to_df(if_empty=if_empty) + except SKIP: + return False + if if_exists == "append" and os.path.isfile(path): + parquet_df = pd.read_parquet(path) + out_df = pd.concat([parquet_df, df]) + elif if_exists == "replace": + out_df = df + elif if_exists == "skip": + logger.info("Skipped.") + return + else: + out_df = df + + # create directories if they don't exist + + if not os.path.isfile(path): + directory = os.path.dirname(path) + os.makedirs(directory, exist_ok=True) + + out_df.to_parquet(path, index=False, **kwargs) + def _handle_if_empty(self, if_empty: str = None) -> NoReturn: """What to do if empty.""" if if_empty == "warn": diff --git a/viadot/sources/business_core.py b/viadot/sources/business_core.py new file mode 100644 index 000000000..a0deec80d --- /dev/null +++ b/viadot/sources/business_core.py @@ -0,0 +1,158 @@ +import pandas as pd +import json +from prefect.utilities import logging +from typing import Any, Dict + +from ..exceptions import CredentialError, APIError +from .base import Source +from ..config import local_config +from ..utils import handle_api_response + + +logger = logging.get_logger(__name__) + + +class BusinessCore(Source): + """ + Source for getting data from Bussines Core ERP API. + + """ + + def __init__( + self, + url: str = None, + filters_dict: Dict[str, Any] = { + "BucketCount": None, + "BucketNo": None, + "FromDate": None, + "ToDate": None, + }, + verify: bool = True, + credentials: Dict[str, Any] = None, + config_key: str = "BusinessCore", + *args, + **kwargs, + ): + """ + Creating an instance of BusinessCore source class. + + Args: + url (str, optional): Base url to a view in Business Core API. Defaults to None. + filters_dict (Dict[str, Any], optional): Filters in form of dictionary. Available filters: 'BucketCount', + 'BucketNo', 'FromDate', 'ToDate'. Defaults to {"BucketCount": None,"BucketNo": None,"FromDate": None, + "ToDate": None,}. + verify (bool, optional): Whether or not verify certificates while connecting to an API. Defaults to True. + credentials (Dict[str, Any], optional): Credentials stored in a dictionary. Required credentials: username, + password. Defaults to None. + config_key (str, optional): Credential key to dictionary where details are stored. Defaults to "BusinessCore". + + Raises: + CredentialError: When credentials are not found. + """ + DEFAULT_CREDENTIALS = local_config.get(config_key) + credentials = credentials or DEFAULT_CREDENTIALS + + required_credentials = ["username", "password"] + if any([cred_key not in credentials for cred_key in required_credentials]): + not_found = [c for c in required_credentials if c not in credentials] + raise CredentialError(f"Missing credential(s): '{not_found}'.") + + self.config_key = config_key + self.url = url + self.filters_dict = filters_dict + self.verify = verify + + super().__init__(*args, credentials=credentials, **kwargs) + + def generate_token(self) -> str: + """ + Function for generating Business Core API token based on username and password. + + Returns: + string: Business Core API token. + + """ + url = "https://api.businesscore.ae/api/user/Login" + + payload = f'grant_type=password&username={self.credentials.get("username")}&password={self.credentials.get("password")}&scope=' + headers = {"Content-Type": "application/x-www-form-urlencoded"} + response = handle_api_response( + url=url, headers=headers, method="GET", body=payload, verify=self.verify + ) + token = json.loads(response.text).get("access_token") + self.token = token + return token + + def clean_filters_dict(self) -> Dict: + """ + Function for replacing 'None' with '&' in a dictionary. Needed for payload in 'x-www-form-urlencoded' from. + + Returns: + Dict: Dictionary with filters prepared for further use. + """ + return { + key: ("&" if val is None else val) for key, val in self.filters_dict.items() + } + + def get_data(self) -> Dict: + """ + Function for obtaining data in dictionary format from Business Core API. + + Returns: + Dict: Dictionary with data downloaded from Business Core API. + """ + filters = self.clean_filters_dict() + + payload = ( + "BucketCount=" + + str(filters.get("BucketCount")) + + "BucketNo=" + + str(filters.get("BucketNo")) + + "FromDate=" + + str(filters.get("FromDate")) + + "ToDate" + + str(filters.get("ToDate")) + ) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Bearer " + self.generate_token(), + } + logger.info("Downloading the data...") + response = handle_api_response( + url=self.url, + headers=headers, + method="GET", + body=payload, + verify=self.verify, + ) + logger.info("Data was downloaded successfully.") + return json.loads(response.text) + + def to_df(self, if_empty: str = "skip") -> pd.DataFrame: + """ + Function for transforming data from dictionary to pd.DataFrame. + + Args: + if_empty (str, optional): What to do if output DataFrame is empty. Defaults to "skip". + + Returns: + pd.DataFrame: DataFrame with data downloaded from Business Core API view. + Raises: + APIError: When selected API view is not available. + """ + view = self.url.split("/")[-1] + if view not in ["GetCustomerData", "GetItemMaster", "GetPendingSalesOrderData"]: + raise APIError(f"View {view} currently not available.") + if view in ("GetCustomerData", "GetItemMaster"): + data = self.get_data().get("MasterDataList") + df = pd.DataFrame.from_dict(data) + logger.info( + f"Data was successfully transformed into DataFrame: {len(df.columns)} columns and {len(df)} rows." + ) + if df.empty: + self._handle_if_empty(if_empty) + return df + + if view == "GetPendingSalesOrderData": + # todo waiting for schema + raise APIError(f"View {view} currently not available.") diff --git a/viadot/sources/eurostat.py b/viadot/sources/eurostat.py new file mode 100644 index 000000000..1204f15ce --- /dev/null +++ b/viadot/sources/eurostat.py @@ -0,0 +1,240 @@ +import pandas as pd + +from .base import Source +from viadot.utils import handle_api_response, APIError + + +class Eurostat(Source): + """ + Class for creating instance of Eurostat connector to REST API by HTTPS response (no credentials required). + """ + + def __init__( + self, + dataset_code: str, + params: dict = None, + base_url: str = "https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/", + *args, + **kwargs, + ): + """ + It is using HTTPS REST request to pull the data. + No API registration or API key are required. + Data will pull based on parameters provided in dynamic part of the url. + + Example of url: https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/TEIBS020/?format=JSON&lang=EN&indic=BS-CSMCI-BAL + Static part: https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data + Dynamic part: /TEIBS020/?format=JSON&lang=EN&indic=BS-CSMCI-BAL + + Please note that for one dataset there are usually multiple data regarding different subjects. + In order to retrive data that you are interested in you have to provide parameters with codes into 'params'. + + Args: + dataset_code (str): The code of eurostat dataset that we would like to upload. + params (Dict[str], optional): + A dictionary with optional URL parameters. The key represents the parameter id, while the value is the code + for a specific parameter, for example: params = {'unit': 'EUR'} where "unit" is the parameter that you would like to set + and "EUR" is the code of the specific parameter. You can add more than one parameter, but only one code per parameter! + So you CAN NOT provide list of codes as in example 'params = {'unit': ['EUR', 'USD', 'PLN']}' + These parameters are REQUIRED in most cases to pull a specific dataset from the API. + Both parameter and code has to be provided as a string! Defaults to None. + base_url (str): The base URL used to access the Eurostat API. This parameter specifies the root URL for all requests made to the API. + It should not be modified unless the API changes its URL scheme. + Defaults to "https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/" + Raises: + TypeError: If self.params is different type than a dictionary. + """ + + self.dataset_code = dataset_code + self.params = params + if not isinstance(self.params, dict) and self.params is not None: + raise TypeError("Params should be a dictionary.") + self.base_url = f"{base_url}{self.dataset_code}?format=JSON&lang=EN" + + super().__init__(*args, **kwargs) + + def get_parameters_codes(self) -> dict: + """Function for getting available parameters with codes from dataset. + + Raises: + ValueError: If the response from the API is empty or invalid. + + Returns: + Dict: Key is parameter and value is a list of available codes for specific parameter. + """ + + try: + response = handle_api_response(self.base_url) + data = response.json() + except APIError: + self.logger.error( + f"Failed to fetch data for {self.dataset_code}, please check correctness of dataset code!" + ) + raise ValueError("DataFrame is empty!") + + # getting list of available parameters + available_params = data["id"] + + # dictionary from JSON with keys and reletad codes values + dimension = data["dimension"] + + # Assigning list of available codes to specific parameters + params_and_codes = {} + for key in available_params: + if key in dimension: + codes = list(dimension[key]["category"]["index"].keys()) + params_and_codes[key] = codes + return params_and_codes + + def make_params_validation(self): + """Function for validation of given parameters in comparison + to parameteres and their codes from JSON. + + Raises: + ValueError: If any of the self.params keys or values is not a string or + any of them is not available for specific dataset. + """ + + # In order to make params validation, first we need to get params_and_codes. + key_codes = self.get_parameters_codes() + + # Validation of type of values + for key, val in self.params.items(): + if not isinstance(key, str) or not isinstance(val, str): + self.logger.error( + "You can provide only one code per one parameter as 'str' in params!\n" + "CORRECT: params = {'unit': 'EUR'} | INCORRECT: params = {'unit': ['EUR', 'USD', 'PLN']}" + ) + raise ValueError("Wrong structure of params!") + + if key_codes is not None: + # Conversion keys and values on lowwer cases by using casefold + key_codes_after_conversion = { + k.casefold(): [v_elem.casefold() for v_elem in v] + for k, v in key_codes.items() + } + params_after_conversion = { + k.casefold(): v.casefold() for k, v in self.params.items() + } + + # comparing keys and values + non_available_keys = [ + key + for key in params_after_conversion.keys() + if key not in key_codes_after_conversion + ] + non_available_codes = [ + value + for key, value in params_after_conversion.items() + if key in key_codes_after_conversion.keys() + and value not in key_codes_after_conversion[key] + ] + + # error loggers + if non_available_keys: + self.logger.error( + f"Parameters: '{' | '.join(non_available_keys)}' are not in dataset. Please check your spelling!\n" + f"Possible parameters: {' | '.join(key_codes.keys())}" + ) + if non_available_codes: + self.logger.error( + f"Parameters codes: '{' | '.join(non_available_codes)}' are not available. Please check your spelling!\n" + f"You can find everything via link: https://ec.europa.eu/eurostat/databrowser/view/{self.dataset_code}/default/table?lang=en" + ) + raise ValueError("DataFrame is empty!") + + def eurostat_dictionary_to_df(self, *signals: list) -> pd.DataFrame: + """Function for creating DataFrame from JSON pulled from Eurostat. + + Returns: + pd.DataFrame: With 4 columns: index, geo, time, indicator. + """ + + class T_SIGNAL: + signal_keys_list: list + signal_index_list: list + signal_label_list: list + signal_name: str + + # Dataframe creation + columns0 = signals[0].copy() + columns0.append("indicator") + df = pd.DataFrame(columns=columns0) + indicator_list = [] + index_list = [] + signal_lists = [] + + # get the dictionary from the inputs + eurostat_dictionary = signals[-1] + + for signal in signals[0]: + signal_struct = T_SIGNAL() + signal_struct.signal_name = signal + signal_struct.signal_keys_list = list( + eurostat_dictionary["dimension"][signal]["category"]["index"].keys() + ) + signal_struct.signal_index_list = list( + eurostat_dictionary["dimension"][signal]["category"]["index"].values() + ) + signal_label_dict = eurostat_dictionary["dimension"][signal]["category"][ + "label" + ] + signal_struct.signal_label_list = [ + signal_label_dict[i] for i in signal_struct.signal_keys_list + ] + signal_lists.append(signal_struct) + + col_signal_temp = [] + row_signal_temp = [] + for row_index, row_label in zip( + signal_lists[0].signal_index_list, signal_lists[0].signal_label_list + ): # rows + for col_index, col_label in zip( + signal_lists[1].signal_index_list, signal_lists[1].signal_label_list + ): # cols + index = str( + col_index + row_index * len(signal_lists[1].signal_label_list) + ) + if index in eurostat_dictionary["value"].keys(): + index_list.append(index) + col_signal_temp.append(col_label) + row_signal_temp.append(row_label) + + indicator_list = [eurostat_dictionary["value"][i] for i in index_list] + df.indicator = indicator_list + df[signal_lists[1].signal_name] = col_signal_temp + df[signal_lists[0].signal_name] = row_signal_temp + + return df + + def get_data_frame_from_response(self) -> pd.DataFrame: + """Function responsible for getting response, creating DataFrame using method 'eurostat_dictionary_to_df' + with validation of provided parameters and their codes if needed. + + Raises: + APIError: If there is an error with the API request. + ValueError: If the resulting DataFrame is empty. + + Returns: + pd.DataFrame: Final DataFrame or raise prefect.logger.error, if issues occur. + """ + + try: + response = handle_api_response(self.base_url, params=self.params) + data = response.json() + data_frame = self.eurostat_dictionary_to_df(["geo", "time"], data) + + if data_frame.empty: + raise ValueError + except (APIError, ValueError): + self.make_params_validation() + + # merging data_frame with label and last updated date + label_col = pd.Series(str(data["label"]), index=data_frame.index, name="label") + last_updated__col = pd.Series( + str(data["updated"]), + index=data_frame.index, + name="updated", + ) + data_frame = pd.concat([data_frame, label_col, last_updated__col], axis=1) + return data_frame diff --git a/viadot/sources/sap_rfc.py b/viadot/sources/sap_rfc.py index f350a5b7d..3f0f7b31c 100644 --- a/viadot/sources/sap_rfc.py +++ b/viadot/sources/sap_rfc.py @@ -1,11 +1,12 @@ import re from prefect.utilities import logging from collections import OrderedDict -from typing import List, Literal +from typing import List, Literal, Union from typing import OrderedDict as OrderedDictType from typing import Tuple, Union import pandas as pd +import numpy as np try: import pyrfc @@ -84,6 +85,137 @@ def trim_where(where: str) -> Tuple[str, OrderedDictType[str, str]]: return where_trimmed, wheres_to_add +def detect_extra_rows( + row_index: int, data_raw: np.array, chunk: int, fields: List[str] +) -> Union[int, np.array, bool]: + """Check if, in between calls to the SAP table, the number of rows have increased. + If so, remove the last rows added, to fit the size of the previous columns. + + Args: + row_index (int): Number of rows set it down in he first iteration with the SAP table. + data_raw (np.array): Array with the data retrieve from SAP table. + chunk (int): The part number in which a number of SAP table columns have been split. + fields (List[str]): A list with the names of the columns in a chunk. + + Returns: + Union[int, np.array, bool]: A tuple with the parameters "row_index", "data_raw", a new + boolean variable "start" to indicate when the for loop has to be restarted, + and "chunk" variable. + """ + start = False + if row_index == 0: + row_index = data_raw.shape[0] + if row_index == 0: + logger.warning( + f"Empty output was generated for chunk {chunk} in columns {fields}." + ) + start = True + elif data_raw.shape[0] != row_index: + data_raw = data_raw[:row_index] + logger.warning( + f"New rows were generated during the execution of the script. The table is truncated to the number of rows for the first chunk." + ) + + return row_index, data_raw, start + + +def replace_separator_in_data( + data_raw: np.array, + no_sep_index: np.array, + record_key: str, + pos_sep_index: np.array, + sep: str, + replacement: str, +) -> np.array: + """Function to replace the extra separator in every row of the data_raw numpy array. + + Args: + data_raw (np.array): Array with the data retrieve from SAP table. + no_sep_index (np.array): Array with indexes where are extra separators characters in rows. + record_key (str): Key word to extract the data from the numpy array "data_raw". + pos_sep_index (np.array): Array with indexes where are placed real separators. + sep (str): Which separator to use when querying SAP. + replacement (str): In case of sep is on a columns, set up a new character to replace + inside the string to avoid flow breakdowns. + + Returns: + np.array: the same data_raw numpy array with the "replacement" separator instead. + """ + for no_sep in no_sep_index: + logger.warning( + "A separator character was found and replaced inside a string text that could produce future errors:" + ) + logger.warning("\n" + data_raw[no_sep][record_key]) + split_array = np.array([*data_raw[no_sep][record_key]]) + position = np.where(split_array == f"{sep}")[0] + index_sep_index = np.argwhere(np.in1d(position, pos_sep_index) == False) + index_sep_index = index_sep_index.reshape( + len(index_sep_index), + ) + split_array[position[index_sep_index]] = replacement + data_raw[no_sep][record_key] = "".join(split_array) + logger.warning("\n" + data_raw[no_sep][record_key]) + + return data_raw + + +def catch_extra_separators( + data_raw: np.array, record_key: str, sep: str, fields: List[str], replacement: str +) -> np.array: + """Function to replace extra separators in every row of the table. + + Args: + data_raw (np.array): Array with the data retrieve from SAP table. + record_key (str): Key word to extract the data from the numpy array "data_raw". + sep (str): Which separator to use when querying SAP. + fields (List[str]): A list with the names of the columns in a chunk. + replacement (str): In case of sep is on a columns, set up a new character to replace + inside the string to avoid flow breakdowns. + + Returns: + np.array: The argument "data_raw" with no extra delimiters. + """ + + # remove scape characters from data_raw ("\t") + for n, r in enumerate(data_raw): + if "\t" in r[record_key]: + data_raw[n][record_key] = r[record_key].replace("\t", " ") + + # first it is identified where the data has an extra separator in text columns. + sep_counts = np.array([], dtype=int) + for row in data_raw: + sep_counts = np.append(sep_counts, row[record_key].count(f"{sep}")) + + no_sep_index = np.argwhere(sep_counts != len(fields) - 1) + no_sep_index = no_sep_index.reshape( + len(no_sep_index), + ) + sep_index = np.argwhere(sep_counts == len(fields) - 1) + sep_index = sep_index.reshape( + len(sep_index), + ) + # indentifying "good" rows we obtain the index of separator positions. + pos_sep_index = np.array([], dtype=int) + for data in data_raw[sep_index]: + pos_sep_index = np.append( + pos_sep_index, + np.where(np.array([*data[record_key]]) == f"{sep}"), + ) + pos_sep_index = np.unique(pos_sep_index) + + # in rows with an extra separator, we replace them by another character: "-" by default + data_raw = replace_separator_in_data( + data_raw, + no_sep_index, + record_key, + pos_sep_index, + sep, + replacement, + ) + + return data_raw + + class SAPRFC(Source): """ A class for querying SAP with SQL using the RFC protocol. @@ -149,6 +281,11 @@ def check_connection(self) -> None: self.con.ping() self.logger.info("Connection has been validated successfully.") + def close_connection(self) -> None: + """Closing RFC connection.""" + self.con.close() + self.logger.info("Connection has been closed successfully.") + def get_function_parameters( self, function_name: str, @@ -507,5 +644,506 @@ def to_df(self): if col not in self.select_columns_aliased ] df.drop(cols_to_drop, axis=1, inplace=True) + self.close_connection() + return df + + +class SAPRFCV2(Source): + """ + A class for querying SAP with SQL using the RFC protocol. + + Note that only a very limited subset of SQL is supported: + - aliases + - where clauses combined using the AND operator + - limit & offset + + Unsupported: + - aggregations + - joins + - subqueries + - etc. + """ + + def __init__( + self, + sep: str = None, + replacement: str = "-", + func: str = "RFC_READ_TABLE", + rfc_total_col_width_character_limit: int = 400, + rfc_unique_id: List[str] = None, + *args, + **kwargs, + ): + """Create an instance of the SAPRFC class. + + Args: + sep (str, optional): Which separator to use when querying SAP. If not provided, + multiple options are automatically tried. + replacement (str, optional): In case of separator is on a columns, set up a new character to replace + inside the string to avoid flow breakdowns. Defaults to "-". + func (str, optional): SAP RFC function to use. Defaults to "RFC_READ_TABLE". + rfc_total_col_width_character_limit (int, optional): Number of characters by which query will be split in chunks + in case of too many columns for RFC function. According to SAP documentation, the limit is + 512 characters. However, we observed SAP raising an exception even on a slightly lower number + of characters, so we add a safety margin. Defaults to 400. + rfc_unique_id (List[str], optional): Reference columns to merge chunks Data Frames. These columns must to be unique. Defaults to None. + + Raises: + CredentialError: If provided credentials are incorrect. + """ + + self._con = None + DEFAULT_CREDENTIALS = local_config.get("SAP").get("DEV") + credentials = kwargs.pop("credentials", None) or DEFAULT_CREDENTIALS + if credentials is None: + raise CredentialError("Missing credentials.") + + super().__init__(*args, credentials=credentials, **kwargs) + + self.sep = sep + self.replacement = replacement + self.client_side_filters = None + self.func = func + self.rfc_total_col_width_character_limit = rfc_total_col_width_character_limit + # remove repeated reference columns + if rfc_unique_id is not None: + self.rfc_unique_id = list(set(rfc_unique_id)) + else: + self.rfc_unique_id = rfc_unique_id + + @property + def con(self) -> pyrfc.Connection: + if self._con is not None: + return self._con + con = pyrfc.Connection(**self.credentials) + self._con = con + return con + + def check_connection(self) -> None: + self.logger.info("Checking the connection...") + self.con.ping() + self.logger.info("Connection has been validated successfully.") + + def close_connection(self) -> None: + """Closing RFC connection.""" + self.con.close() + self.logger.info("Connection has been closed successfully.") + def get_function_parameters( + self, + function_name: str, + description: Union[None, Literal["short", "long"]] = "short", + *args, + ) -> Union[List[str], pd.DataFrame]: + """Get the description for a SAP RFC function. + + Args: + function_name (str): The name of the function to detail. + description (Union[None, Literal[, optional): Whether to display + a short or a long description. Defaults to "short". + + Raises: + ValueError: If the argument for description is incorrect. + + Returns: + Union[List[str], pd.DataFrame]: Either a list of the function's + parameter names (if 'description' is set to None), + or a short or long description. + """ + if description is not None: + if description not in ["short", "long"]: + raise ValueError( + "Incorrect value for 'description'. Correct values: (None, 'short', 'long'" + ) + + descr = self.con.get_function_description(function_name, *args) + param_names = [param["name"] for param in descr.parameters] + detailed_params = descr.parameters + filtered_detailed_params = [ + { + "name": param["name"], + "parameter_type": param["parameter_type"], + "default_value": param["default_value"], + "optional": param["optional"], + "parameter_text": param["parameter_text"], + } + for param in descr.parameters + ] + + if description is not None: + if description == "long": + params = detailed_params + else: + params = filtered_detailed_params + params = pd.DataFrame.from_records(params) + else: + params = param_names + + return params + + def _get_where_condition(self, sql: str) -> str: + """Retrieve the WHERE conditions from a SQL query. + + Args: + sql (str): The input SQL query. + + Raises: + ValueError: Raised if the WHERE clause is longer than + 75 characters (SAP's limitation) and the condition for the + extra clause(s) is OR. + + Returns: + str: The where clause trimmed to <= 75 characters. + """ + + where_match = re.search("\\sWHERE ", sql.upper()) + if not where_match: + return None + + limit_match = re.search("\\sLIMIT ", sql.upper()) + limit_pos = limit_match.span()[0] if limit_match else len(sql) + + where = sql[where_match.span()[1] : limit_pos] + where_sanitized = remove_whitespaces(where) + where_trimmed, client_side_filters = trim_where(where_sanitized) + if client_side_filters: + self.logger.warning( + "A WHERE clause longer than 75 character limit detected." + ) + if "OR" in [key.upper() for key in client_side_filters.keys()]: + raise ValueError( + "WHERE conditions after the 75 character limit can only be combined with the AND keyword." + ) + else: + filters_pretty = list(client_side_filters.items()) + self.logger.warning( + f"Trimmed conditions ({filters_pretty}) will be applied client-side." + ) + self.logger.warning(f"See the documentation for caveats.") + + self.client_side_filters = client_side_filters + return where_trimmed + + @staticmethod + def _get_table_name(sql: str) -> str: + parsed = Parser(sql) + if len(parsed.tables) > 1: + raise ValueError("Querying more than one table is not supported.") + return parsed.tables[0] + + def _build_pandas_filter_query( + self, client_side_filters: OrderedDictType[str, str] + ) -> str: + """Build a WHERE clause that will be applied client-side. + This is required if the WHERE clause passed to query() is + longer than 75 characters. + + Args: + client_side_filters (OrderedDictType[str, str]): The + client-side filters to apply. + + Returns: + str: the WHERE clause reformatted to fit the format + required by DataFrame.query(). + """ + for i, f in enumerate(client_side_filters.items()): + if i == 0: + # skip the first keyword; we assume it's "AND" + query = f[1] + else: + query += " " + f[0] + " " + f[1] + + filter_column_name = f[1].split()[0] + resolved_column_name = self._resolve_col_name(filter_column_name) + query = re.sub("\\s?=\\s?", " == ", query).replace( + filter_column_name, resolved_column_name + ) + return query + + def extract_values(self, sql: str) -> None: + """TODO: This should cover all values, not just columns""" + self.where = self._get_where_condition(sql) + self.select_columns = self._get_columns(sql, aliased=False) + self.select_columns_aliased = self._get_columns(sql, aliased=True) + + def _resolve_col_name(self, column: str) -> str: + """Get aliased column name if it exists, otherwise return column name.""" + return self.aliases_keyed_by_columns.get(column, column) + + def _get_columns(self, sql: str, aliased: bool = False) -> List[str]: + """Retrieve column names from a SQL query. + + Args: + sql (str): The SQL query to parse. + aliased (bool, optional): Whether to returned aliased + names. Defaults to False. + + Returns: + List[str]: A list of column names. + """ + parsed = Parser(sql) + columns = list(parsed.columns_dict["select"]) + if aliased: + aliases_keyed_by_alias = parsed.columns_aliases + aliases_keyed_by_columns = OrderedDict( + {val: key for key, val in aliases_keyed_by_alias.items()} + ) + + self.aliases_keyed_by_columns = aliases_keyed_by_columns + + columns = [ + aliases_keyed_by_columns[col] + if col in aliases_keyed_by_columns + else col + for col in columns + ] + + if self.client_side_filters: + # In case the WHERE clause is > 75 characters long, we execute the rest of the filters + # client-side. To do this, we need to pull all fields in the client-side WHERE conditions. + # Below code adds these columns to the list of SELECTed fields. + cols_to_add = [v.split()[0] for v in self.client_side_filters.values()] + if aliased: + cols_to_add = [aliases_keyed_by_columns[col] for col in cols_to_add] + columns.extend(cols_to_add) + columns = list(dict.fromkeys(columns)) # remove duplicates + + return columns + + @staticmethod + def _get_limit(sql: str) -> int: + """Get limit from the query""" + limit_match = re.search("\\sLIMIT ", sql.upper()) + if not limit_match: + return None + + return int(sql[limit_match.span()[1] :].split()[0]) + + @staticmethod + def _get_offset(sql: str) -> int: + """Get offset from the query""" + offset_match = re.search("\\sOFFSET ", sql.upper()) + if not offset_match: + return None + + return int(sql[offset_match.span()[1] :].split()[0]) + + def query(self, sql: str, sep: str = None) -> None: + """Parse an SQL query into pyRFC commands and save it into + an internal dictionary. + + Args: + sql (str): The SQL query to be ran. + sep (str, optional): The separator to be used + to split columns in the result blob. Defaults to self.sep. + + Raises: + ValueError: If the query is not a SELECT query. + """ + + if not sql.strip().upper().startswith("SELECT"): + raise ValueError("Only SELECT queries are supported.") + + sep = sep if sep is not None else self.sep + + self.sql = sql + + self.extract_values(sql) + + table_name = self._get_table_name(sql) + # this has to be called before checking client_side_filters + where = self.where + columns = self.select_columns + lists_of_columns = [] + cols = [] + col_length_total = 0 + if isinstance(self.rfc_unique_id[0], str): + character_limit = self.rfc_total_col_width_character_limit + for ref_column in self.rfc_unique_id: + col_length_reference_column = int( + self.call( + "DDIF_FIELDINFO_GET", + TABNAME=table_name, + FIELDNAME=ref_column, + )["DFIES_TAB"][0]["LENG"] + ) + if col_length_reference_column > int( + self.rfc_total_col_width_character_limit / 4 + ): + raise ValueError( + f"{ref_column} can't be used as unique column, too large." + ) + local_limit = ( + self.rfc_total_col_width_character_limit + - col_length_reference_column + ) + if local_limit < character_limit: + character_limit = local_limit + else: + character_limit = self.rfc_total_col_width_character_limit + + for col in columns: + info = self.call("DDIF_FIELDINFO_GET", TABNAME=table_name, FIELDNAME=col) + col_length = info["DFIES_TAB"][0]["LENG"] + col_length_total += int(col_length) + if col_length_total <= character_limit: + cols.append(col) + else: + if isinstance(self.rfc_unique_id[0], str) and all( + [rfc_col not in cols for rfc_col in self.rfc_unique_id] + ): + for rfc_col in self.rfc_unique_id: + if rfc_col not in cols: + cols.append(rfc_col) + lists_of_columns.append(cols) + cols = [col] + col_length_total = int(col_length) + else: + if isinstance(self.rfc_unique_id[0], str) and all( + [rfc_col not in cols for rfc_col in self.rfc_unique_id] + ): + for rfc_col in self.rfc_unique_id: + if rfc_col not in cols: + cols.append(rfc_col) + lists_of_columns.append(cols) + + columns = lists_of_columns + options = [{"TEXT": where}] if where else None + limit = self._get_limit(sql) + offset = self._get_offset(sql) + query_json = dict( + QUERY_TABLE=table_name, + FIELDS=columns, + OPTIONS=options, + ROWCOUNT=limit, + ROWSKIPS=offset, + DELIMITER=sep, + ) + # SAP doesn't understand None, so we filter out non-specified parameters + query_json_filtered = { + key: query_json[key] for key in query_json if query_json[key] is not None + } + self._query = query_json_filtered + + def call(self, func: str, *args, **kwargs): + """Call a SAP RFC function""" + return self.con.call(func, *args, **kwargs) + + def _get_alias(self, column: str) -> str: + return self.aliases_keyed_by_columns.get(column, column) + + def _get_client_side_filter_cols(self): + return [f[1].split()[0] for f in self.client_side_filters.items()] + + def to_df(self): + """ + Load the results of a query into a pandas DataFrame. + + Due to SAP limitations, if the length of the WHERE clause is longer than 75 + characters, we trim whe WHERE clause and perform the rest of the filtering + on the resulting DataFrame. Eg. if the WHERE clause contains 4 conditions + and has 80 characters, we only perform 3 filters in the query, and perform + the last filter on the DataFrame. If characters per row limit will be exceeded, + data will be downloaded in chunks. + + Source: https://success.jitterbit.com/display/DOC/Guide+to+Using+RFC_READ_TABLE+to+Query+SAP+Tables#GuidetoUsingRFC_READ_TABLEtoQuerySAPTables-create-the-operation + - WHERE clause: 75 character limit + - SELECT: 512 character row limit + + Returns: + pd.DataFrame: A DataFrame representing the result of the query provided in `PyRFC.query()`. + """ + params = self._query + columns = self.select_columns_aliased + sep = self._query.get("DELIMITER") + fields_lists = self._query.get("FIELDS") + if len(fields_lists) > 1: + logger.info(f"Data will be downloaded in {len(fields_lists)} chunks.") + func = self.func + if sep is None: + # automatically find a working separator + SEPARATORS = [ + "|", + "/t", + "#", + ";", + "@", + "%", + "^", + "`", + "~", + "{", + "}", + "$", + ] + else: + SEPARATORS = [sep] + + for sep in SEPARATORS: + logger.info(f"Checking if separator '{sep}' works.") + if isinstance(self.rfc_unique_id[0], str): + # columns only for the first chunk and we add the rest later to avoid name conflicts + df = pd.DataFrame(columns=fields_lists[0]) + else: + df = pd.DataFrame() + self._query["DELIMITER"] = sep + chunk = 1 + row_index = 0 + for fields in fields_lists: + logger.info(f"Downloading {chunk} data chunk...") + self._query["FIELDS"] = fields + try: + response = self.call(func, **params) + except ABAPApplicationError as e: + if e.key == "DATA_BUFFER_EXCEEDED": + raise DataBufferExceeded( + "Character limit per row exceeded. Please select fewer columns." + ) + else: + raise e + record_key = "WA" + data_raw = np.array(response["DATA"]) + + # if the reference columns are provided not necessary to remove any extra row. + if not isinstance(self.rfc_unique_id[0], str): + row_index, data_raw, start = detect_extra_rows( + row_index, data_raw, chunk, fields + ) + else: + start = False + + data_raw = catch_extra_separators( + data_raw, record_key, sep, fields, self.replacement + ) + + records = np.array([row[record_key].split(sep) for row in data_raw]) + + if ( + isinstance(self.rfc_unique_id[0], str) + and not list(df.columns) == fields + ): + df_tmp = pd.DataFrame(columns=fields) + df_tmp[fields] = records + df = pd.merge(df, df_tmp, on=self.rfc_unique_id, how="outer") + else: + if not start: + df[fields] = records + else: + df[fields] = np.nan + chunk += 1 + df.columns = columns + + if self.client_side_filters: + filter_query = self._build_pandas_filter_query(self.client_side_filters) + df.query(filter_query, inplace=True) + client_side_filter_cols_aliased = [ + self._get_alias(col) for col in self._get_client_side_filter_cols() + ] + cols_to_drop = [ + col + for col in client_side_filter_cols_aliased + if col not in self.select_columns_aliased + ] + df.drop(cols_to_drop, axis=1, inplace=True) + self.close_connection() return df diff --git a/viadot/tasks/__init__.py b/viadot/tasks/__init__.py index 814085bfd..1eb3f878c 100644 --- a/viadot/tasks/__init__.py +++ b/viadot/tasks/__init__.py @@ -44,9 +44,11 @@ from .sql_server import SQLServerCreateTable, SQLServerToDF, SQLServerQuery from .epicor import EpicorOrdersToDF +from .eurostat import EurostatToDF from .sftp import SftpToDF, SftpList from .mindful import MindfulToCSV from .hubspot import HubspotToDF from .mediatool import MediatoolToDF from .customer_gauge import CustomerGaugeToDF +from .business_core import BusinessCoreToParquet diff --git a/viadot/tasks/business_core.py b/viadot/tasks/business_core.py new file mode 100644 index 000000000..f4da67bff --- /dev/null +++ b/viadot/tasks/business_core.py @@ -0,0 +1,103 @@ +import xml.etree.ElementTree as ET +from typing import Any, Dict, List, Optional +from xml.etree.ElementTree import fromstring + +import pandas as pd +from prefect import Task +from prefect.utilities.tasks import defaults_from_attrs + +from ..sources import BusinessCore + + +class BusinessCoreToParquet(Task): + def __init__( + self, + path: str, + url: str, + filters_dict: Dict[str, Any] = { + "BucketCount": None, + "BucketNo": None, + "FromDate": None, + "ToDate": None, + }, + verify: bool = True, + credentials: Dict[str, Any] = None, + config_key: str = "BusinessCore", + if_empty: str = "skip", + timeout=3600, + *args, + **kwargs, + ): + + """Task for downloading data from Business Core API to a Parquet file. + + Args: + path (str, required): Path where to save a Parquet file. + url (str, required): Base url to a view in Business Core API. + filters_dict (Dict[str, Any], optional): Filters in form of dictionary. Available filters: 'BucketCount', + 'BucketNo', 'FromDate', 'ToDate'. Defaults to {"BucketCount": None,"BucketNo": None,"FromDate": None, + "ToDate": None,}. + verify (bool, optional): Whether or not verify certificates while connecting to an API. Defaults to True. + credentials (Dict[str, Any], optional): Credentials stored in a dictionary. Required credentials: username, + password. Defaults to None. + config_key (str, optional): Credential key to dictionary where details are stored. Defaults to "BusinessCore". + if_empty (str, optional): What to do if output DataFrame is empty. Defaults to "skip". + timeout(int, optional): The amount of time (in seconds) to wait while running this task before + a timeout occurs. Defaults to 3600. + """ + + self.url = url + self.path = path + self.credentials = credentials + self.config_key = config_key + self.filters_dict = filters_dict + self.verify = verify + self.if_empty = if_empty + + super().__init__( + name="business_core_to_parquet", + timeout=timeout, + *args, + **kwargs, + ) + + def __call__(self, *args, **kwargs): + """Load Business Core data to Parquet""" + return super().__call__(*args, **kwargs) + + @defaults_from_attrs( + "url", "path", "credentials", "config_key", "filters_dict", "verify", "if_empty" + ) + def run( + self, + path: str = None, + url: str = None, + credentials: Dict[str, Any] = None, + config_key: str = None, + filters_dict: str = None, + verify: bool = True, + if_empty: str = None, + ): + """Run method for BusinessCoreToParquet task. Saves data from Business Core API to Parquet file. + + Args: + path (str, optional): Path where to save a Parquet file. Defaults to None. + url (str, optional): Base url to a view in Business Core API. Defaults to None. + filters_dict (Dict[str, Any], optional): Filters in form of dictionary. Available filters: 'BucketCount', + 'BucketNo', 'FromDate', 'ToDate'. Defaults to None. + verify (bool, optional): Whether or not verify certificates while connecting to an API. Defaults to True. + credentials (Dict[str, Any], optional): Credentials stored in a dictionary. Required credentials: username, + password. Defaults to None. + config_key (str, optional): Credential key to dictionary where details are stored. Defaults to None. + if_empty (str, optional): What to do if output DataFrame is empty. Defaults to None. + + + """ + bc = BusinessCore( + url=url, + credentials=credentials, + config_key=config_key, + filters_dict=filters_dict, + verify=verify, + ) + return bc.to_parquet(path=path, if_empty=if_empty) diff --git a/viadot/tasks/eurostat.py b/viadot/tasks/eurostat.py new file mode 100644 index 000000000..2bfbf5bcc --- /dev/null +++ b/viadot/tasks/eurostat.py @@ -0,0 +1,88 @@ +from prefect import Task +from ..sources import Eurostat +import pandas as pd + + +class EurostatToDF(Task): + """Task for creating pandas data frame from Eurostat HTTPS REST API (no credentials required). + + Args: + dataset_code (str): The code of eurostat dataset that we would like to upload. + params (Dict[str], optional): + A dictionary with optional URL parameters. The key represents the parameter id, while the value is the code + for a specific parameter, for example 'params = {'unit': 'EUR'}' where "unit" is the parameter that you would like to set + and "EUR" is the code of the specific parameter. You can add more than one parameter, but only one code per parameter! + So you CAN NOT provide list of codes as in example 'params = {'unit': ['EUR', 'USD', 'PLN']}' + This parameter is REQUIRED in most cases to pull a specific dataset from the API. + Both parameter and code has to provided as a string! Defaults to None. + base_url (str): The base URL used to access the Eurostat API. This parameter specifies the root URL for all requests made to the API. + It should not be modified unless the API changes its URL scheme. + Defaults to "https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/" + requested_columns (List[str], optional): list of needed names of columns. Names should be given as str's into the list. + Defaults to None. + Raises: + TypeError: If self.requested_columns have different type than a list. + """ + + def __init__( + self, + dataset_code: str, + params: dict = None, + base_url: str = "https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/", + requested_columns: list = None, + *args, + **kwargs, + ): + self.dataset_code = dataset_code + self.params = params + self.base_url = base_url + self.requested_columns = requested_columns + if ( + not isinstance(self.requested_columns, list) + and self.requested_columns is not None + ): + raise TypeError("Requested columns should be provided as list of strings.") + + super().__init__(name="eurostat_to_df", *args, **kwargs) + + def run(self) -> pd.DataFrame: + """Run function for returning unchanged DataFrame, or modify DataFrame and returning if user need specific columns. + + Raises: + ValueError: If self.requested_columns contains columns names that do not exist in the DataFrame. + + Returns: + pd.DataFrame: Unchanged DataFrame or DataFrame with only choosen columns. + """ + + data_frame = Eurostat( + self.dataset_code, self.params + ).get_data_frame_from_response() + + if self.requested_columns is None: + return data_frame + else: + columns_list = data_frame.columns.tolist() + columns_list = [str(column).casefold() for column in columns_list] + needed_column_after_validation = [] + non_available_columns = [] + + for column in self.requested_columns: + # Checking if user column is in our dataframe column list + column = str(column).casefold() + + if column in columns_list: + needed_column_after_validation.append(column) + else: + non_available_columns.append(column) + + # Error logger + if non_available_columns: + self.logger.error( + f"Name of the columns: '{' | '.join(non_available_columns)}' are not in DataFrame. Please check spelling!\n" + f"Available columns: {' | '.join(columns_list)}" + ) + raise ValueError("Provided columns are not available!") + + new_df = data_frame.loc[:, needed_column_after_validation] + return new_df diff --git a/viadot/tasks/sap_rfc.py b/viadot/tasks/sap_rfc.py index b863db2fb..66831bdc8 100644 --- a/viadot/tasks/sap_rfc.py +++ b/viadot/tasks/sap_rfc.py @@ -1,3 +1,4 @@ +from typing import List from datetime import timedelta import pandas as pd @@ -5,7 +6,7 @@ from prefect.utilities.tasks import defaults_from_attrs try: - from ..sources import SAPRFC + from viadot.sources import SAPRFC, SAPRFCV2 except ImportError: raise @@ -15,6 +16,7 @@ def __init__( self, query: str = None, sep: str = None, + replacement: str = "-", func: str = None, rfc_total_col_width_character_limit: int = 400, credentials: dict = None, @@ -42,6 +44,8 @@ def __init__( query (str, optional): The query to be executed with pyRFC. sep (str, optional): The separator to use when reading query results. If not provided, multiple options are automatically tried. Defaults to None. + replacement (str, optional): In case of sep is on a columns, set up a new character to replace + inside the string to avoid flow breakdowns. Defaults to "-". func (str, optional): SAP RFC function to use. Defaults to None. rfc_total_col_width_character_limit (int, optional): Number of characters by which query will be split in chunks in case of too many columns for RFC function. According to SAP documentation, the limit is @@ -52,6 +56,7 @@ def __init__( """ self.query = query self.sep = sep + self.replacement = replacement self.credentials = credentials self.func = func self.rfc_total_col_width_character_limit = rfc_total_col_width_character_limit @@ -68,21 +73,21 @@ def __init__( @defaults_from_attrs( "query", "sep", + "replacement", "func", "rfc_total_col_width_character_limit", "credentials", - "max_retries", - "retry_delay", ) def run( self, query: str = None, sep: str = None, + replacement: str = "-", credentials: dict = None, func: str = None, rfc_total_col_width_character_limit: int = None, - max_retries: int = None, - retry_delay: timedelta = None, + rfc_unique_id: List[str] = None, + alternative_version: bool = False, ) -> pd.DataFrame: """Task run method. @@ -90,20 +95,50 @@ def run( query (str, optional): The query to be executed with pyRFC. sep (str, optional): The separator to use when reading query results. If not provided, multiple options are automatically tried. Defaults to None. + replacement (str, optional): In case of sep is on a columns, set up a new character to replace + inside the string to avoid flow breakdowns. Defaults to "-". func (str, optional): SAP RFC function to use. Defaults to None. rfc_total_col_width_character_limit (int, optional): Number of characters by which query will be split in chunks in case of too many columns for RFC function. According to SAP documentation, the limit is 512 characters. However, we observed SAP raising an exception even on a slightly lower number of characters, so we add a safety margin. Defaults to None. + rfc_unique_id (List[str], optional): Reference columns to merge chunks Data Frames. These columns must to be unique. If no columns are provided + in this parameter, all data frame columns will by concatenated. Defaults to None. + Example: + -------- + SAPRFCToADLS( + ... + rfc_unique_id=["VBELN", "LPRIO"], + ... + ) + alternative_version (bool, optional): Enable the use version 2 in source. Defaults to False. + + Returns: + pd.DataFrame: DataFrame with SAP data. """ if query is None: raise ValueError("Please provide the query.") - sap = SAPRFC( - sep=sep, - credentials=credentials, - func=func, - rfc_total_col_width_character_limit=rfc_total_col_width_character_limit, - ) + + if alternative_version is True: + if rfc_unique_id: + self.logger.warning( + "If the column/set of columns are not unique the table will be malformed." + ) + sap = SAPRFCV2( + sep=sep, + replacement=replacement, + credentials=credentials, + func=func, + rfc_total_col_width_character_limit=rfc_total_col_width_character_limit, + rfc_unique_id=rfc_unique_id, + ) + else: + sap = SAPRFC( + sep=sep, + credentials=credentials, + func=func, + rfc_total_col_width_character_limit=rfc_total_col_width_character_limit, + ) sap.query(query) self.logger.info(f"Downloading data from SAP to a DataFrame...") self.logger.debug(f"Running query: \n{query}.") diff --git a/viadot/utils.py b/viadot/utils.py index f92b36002..ed0abb17e 100644 --- a/viadot/utils.py +++ b/viadot/utils.py @@ -33,21 +33,24 @@ def handle_api_response( timeout: tuple = (3.05, 60 * 30), method: Literal["GET", "POST", "DELETE"] = "GET", body: str = None, + verify: bool = True, ) -> requests.models.Response: - """Handle and raise Python exceptions during request with retry strategy for specyfic status. + """Handle and raise Python exceptions during request with retry strategy for specific status. Args: - url (str): the URL which trying to connect. - auth (tuple, optional): authorization information. Defaults to None. - params (Dict[str, Any], optional): the request params also includes parameters such as the content type. Defaults to None. - headers: (Dict[str, Any], optional): the request headers. Defaults to None. - timeout (tuple, optional): the request times out. Defaults to (3.05, 60 * 30). + url (str): The URL which trying to connect. + auth (tuple, optional): Authorization information. Defaults to None. + params (Dict[str, Any], optional): The request params also includes parameters such as the content type. Defaults to None. + headers: (Dict[str, Any], optional): The request headers. Defaults to None. + timeout (tuple, optional): The request times out. Defaults to (3.05, 60 * 30). method (Literal ["GET", "POST","DELETE"], optional): REST API method to use. Defaults to "GET". body (str, optional): Data to send using POST method. Defaults to None. + verify (bool, optional): Whether to verify cerificates. Defaults to True. Raises: - ValueError: raises when 'method' parameter value hasn't been specified - ReadTimeout: stop waiting for a response after a given number of seconds with the timeout parameter. - HTTPError: exception that indicates when HTTP status codes returned values different than 200. - ConnectionError: exception that indicates when client is unable to connect to the server. + ValueError: Raises when 'method' parameter value hasn't been specified + ReadTimeout: Stop waiting for a response after a given number of seconds with the timeout parameter. + HTTPError: Exception that indicates when HTTP status codes returned values different than 200. + ConnectionError: Exception that indicates when client is unable to connect to the server. + ProtocolError: Raised when something unexpected happens mid-request/response. APIError: defined by user. Returns: requests.models.Response @@ -74,6 +77,7 @@ def handle_api_response( timeout=timeout, data=body, method=method, + verify=verify, ) as response: response.raise_for_status() except ReadTimeout as e: @@ -83,7 +87,7 @@ def handle_api_response( raise APIError(msg) except HTTPError as e: raise APIError( - f"The API call to {url} failed. " + f"The API call to {url} failed. {e} " "Perhaps your account credentials need to be refreshed?", ) from e except (ConnectionError, Timeout) as e: